Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -317,13 +317,18 @@ class TransformerTTS(nn.Module):
|
|
| 317 |
iters = range(max_length)
|
| 318 |
for _ in iters:
|
| 319 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
| 320 |
-
|
|
|
|
| 321 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 322 |
break
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
# --- (End of your model definitions) ---
|
| 328 |
|
| 329 |
# --- Part 2: Model Loading ---
|
|
@@ -441,7 +446,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
|
|
| 441 |
try:
|
| 442 |
print("TTS: Synthesizing English speech...")
|
| 443 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 444 |
-
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-
|
| 445 |
|
| 446 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 447 |
if generated_mel is not None and generated_mel.numel() > 0:
|
|
|
|
| 317 |
iters = range(max_length)
|
| 318 |
for _ in iters:
|
| 319 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
| 320 |
+
|
| 321 |
+
# Check stop token BEFORE adding to mel_padded
|
| 322 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 323 |
break
|
| 324 |
+
|
| 325 |
+
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 326 |
+
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
| 327 |
+
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
| 328 |
+
|
| 329 |
+
# Remove the initial SOS token and return only the generated mel
|
| 330 |
+
generated_mel = mel_padded[:, 1:, :] # Remove first frame (SOS)
|
| 331 |
+
return generated_mel, stop_token_outputs
|
| 332 |
# --- (End of your model definitions) ---
|
| 333 |
|
| 334 |
# --- Part 2: Model Loading ---
|
|
|
|
| 446 |
try:
|
| 447 |
print("TTS: Synthesizing English speech...")
|
| 448 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 449 |
+
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, stop_token_threshold=0.3, with_tqdm=False)
|
| 450 |
|
| 451 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 452 |
if generated_mel is not None and generated_mel.numel() > 0:
|