MoHamdyy commited on
Commit
8785760
·
1 Parent(s): de6f9f5

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -323,7 +323,8 @@ class TransformerTTS(nn.Module):
323
  else:
324
  iters = range(max_length)
325
 
326
- for _ in iters:
 
327
  mel_postnet, mel_linear, stop_token = self(
328
  text,
329
  text_lengths,
@@ -331,6 +332,7 @@ class TransformerTTS(nn.Module):
331
  mel_lengths
332
  )
333
 
 
334
  mel_padded = torch.cat(
335
  [
336
  mel_padded,
@@ -338,13 +340,18 @@ class TransformerTTS(nn.Module):
338
  ],
339
  dim=1
340
  )
 
341
 
342
- if torch.sigmoid(stop_token[:, -1]) > gate_threshold:
 
 
 
343
  break
344
  else:
345
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
346
  mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
347
 
 
348
  return mel_postnet, stop_token_outputs
349
  # --- (End of your model definitions) ---
350
 
@@ -466,11 +473,19 @@ def full_speech_translation_pipeline(audio_input_path: str):
466
  generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, gate_threshold=1e-5, with_tqdm=False)
467
 
468
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
469
- if generated_mel is not None and generated_mel.numel() > 0:
470
  mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
471
- audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
472
- synthesized_audio_np = audio_tensor.cpu().numpy()
473
- print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
 
 
 
 
 
 
 
 
474
  except Exception as e:
475
  print(f"TTS Error: {e}")
476
 
 
323
  else:
324
  iters = range(max_length)
325
 
326
+ frames_generated = 0
327
+ for i in iters:
328
  mel_postnet, mel_linear, stop_token = self(
329
  text,
330
  text_lengths,
 
332
  mel_lengths
333
  )
334
 
335
+ # Add the new frame
336
  mel_padded = torch.cat(
337
  [
338
  mel_padded,
 
340
  ],
341
  dim=1
342
  )
343
+ frames_generated += 1
344
 
345
+ # Check stop condition but ensure minimum generation
346
+ stop_prob = torch.sigmoid(stop_token[:, -1])
347
+ if stop_prob > gate_threshold and frames_generated > 50: # Ensure at least 50 frames
348
+ print(f"TTS: Stopping at frame {frames_generated}, stop_prob: {stop_prob:.6f}")
349
  break
350
  else:
351
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
352
  mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
353
 
354
+ print(f"TTS: Generated {frames_generated} frames, final mel shape: {mel_postnet.shape}")
355
  return mel_postnet, stop_token_outputs
356
  # --- (End of your model definitions) ---
357
 
 
473
  generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, gate_threshold=1e-5, with_tqdm=False)
474
 
475
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
476
+ if generated_mel is not None and generated_mel.numel() > 128: # Ensure minimum size
477
  mel_for_vocoder = generated_mel.detach().squeeze(0).transpose(0, 1)
478
+ # Add safety check for mel dimensions
479
+ if mel_for_vocoder.numel() > 0 and mel_for_vocoder.shape[0] > 10:
480
+ audio_tensor = inverse_mel_spec_to_wav(mel_for_vocoder)
481
+ synthesized_audio_np = audio_tensor.cpu().numpy()
482
+ print(f"TTS: Synthesized audio shape: {synthesized_audio_np.shape}")
483
+ else:
484
+ print("TTS: Generated mel too small, using silence")
485
+ synthesized_audio_np = np.zeros(hp.sr, dtype=np.float32) # 1 second of silence
486
+ else:
487
+ print("TTS: Generated mel is empty or too small, using silence")
488
+ synthesized_audio_np = np.zeros(hp.sr, dtype=np.float32) # 1 second of silence
489
  except Exception as e:
490
  print(f"TTS Error: {e}")
491