MoHamdyy commited on
Commit
2877500
·
1 Parent(s): 91c76f8

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -317,13 +317,18 @@ class TransformerTTS(nn.Module):
317
  iters = range(max_length)
318
  for _ in iters:
319
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
320
- mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
 
321
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
322
  break
323
- else:
324
- stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
325
- mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
326
- return mel_postnet, stop_token_outputs
 
 
 
 
327
  # --- (End of your model definitions) ---
328
 
329
  # --- Part 2: Model Loading ---
@@ -441,7 +446,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
441
  try:
442
  print("TTS: Synthesizing English speech...")
443
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
444
- generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-20, stop_token_threshold=0.3, with_tqdm=False)
445
 
446
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
447
  if generated_mel is not None and generated_mel.numel() > 0:
 
317
  iters = range(max_length)
318
  for _ in iters:
319
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
320
+
321
+ # Check stop token BEFORE adding to mel_padded
322
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
323
  break
324
+
325
+ mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
326
+ stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
327
+ mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
328
+
329
+ # Remove the initial SOS token and return only the generated mel
330
+ generated_mel = mel_padded[:, 1:, :] # Remove first frame (SOS)
331
+ return generated_mel, stop_token_outputs
332
  # --- (End of your model definitions) ---
333
 
334
  # --- Part 2: Model Loading ---
 
446
  try:
447
  print("TTS: Synthesizing English speech...")
448
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
449
+ generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, stop_token_threshold=0.3, with_tqdm=False)
450
 
451
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
452
  if generated_mel is not None and generated_mel.numel() > 0: