Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -314,6 +314,12 @@ class TransformerTTS(nn.Module):
|
|
| 314 |
mel_padded = SOS
|
| 315 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 316 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
iters = range(max_length)
|
| 318 |
for _ in iters:
|
| 319 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
|
@@ -322,6 +328,18 @@ class TransformerTTS(nn.Module):
|
|
| 322 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 323 |
break
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 326 |
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
| 327 |
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
|
@@ -446,7 +464,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
|
|
| 446 |
try:
|
| 447 |
print("TTS: Synthesizing English speech...")
|
| 448 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 449 |
-
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-
|
| 450 |
|
| 451 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 452 |
if generated_mel is not None and generated_mel.numel() > 0:
|
|
|
|
| 314 |
mel_padded = SOS
|
| 315 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 316 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
| 317 |
+
|
| 318 |
+
# Silence detection parameters
|
| 319 |
+
silence_threshold = 0.1 # Consider frames below this as silence
|
| 320 |
+
consecutive_silence_limit = 20 # Stop after 20 consecutive silent frames
|
| 321 |
+
consecutive_silence_count = 0
|
| 322 |
+
|
| 323 |
iters = range(max_length)
|
| 324 |
for _ in iters:
|
| 325 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
|
|
|
| 328 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 329 |
break
|
| 330 |
|
| 331 |
+
# Check for silence in the generated mel frame
|
| 332 |
+
current_frame = mel_postnet[:, -1:, :]
|
| 333 |
+
frame_energy = torch.mean(torch.abs(current_frame))
|
| 334 |
+
|
| 335 |
+
if frame_energy < silence_threshold:
|
| 336 |
+
consecutive_silence_count += 1
|
| 337 |
+
if consecutive_silence_count >= consecutive_silence_limit:
|
| 338 |
+
print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames")
|
| 339 |
+
break
|
| 340 |
+
else:
|
| 341 |
+
consecutive_silence_count = 0 # Reset silence counter
|
| 342 |
+
|
| 343 |
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 344 |
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
| 345 |
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
|
|
|
| 464 |
try:
|
| 465 |
print("TTS: Synthesizing English speech...")
|
| 466 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 467 |
+
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-100, stop_token_threshold=0.1, with_tqdm=False)
|
| 468 |
|
| 469 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 470 |
if generated_mel is not None and generated_mel.numel() > 0:
|