MoHamdyy commited on
Commit
8fbabff
·
1 Parent(s): 2fbce4b

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -315,14 +315,14 @@ class TransformerTTS(nn.Module):
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
 
318
- # More aggressive stopping parameters
319
- silence_threshold = 0.2 # Increased from 0.1 to catch low-energy repetitions
320
- consecutive_silence_limit = 10 # Reduced from 20 to stop faster
321
  consecutive_silence_count = 0
322
 
323
- # Repetition detection parameters
324
- repetition_threshold = 0.95 # Cosine similarity threshold for detecting repeated frames
325
- repetition_limit = 8 # Stop after 8 similar consecutive frames
326
  repetition_count = 0
327
  previous_frames = []
328
 
@@ -497,7 +497,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
497
  try:
498
  print("TTS: Synthesizing English speech...")
499
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
500
- generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-100, stop_token_threshold=0.1, with_tqdm=False)
501
 
502
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
503
  if generated_mel is not None and generated_mel.numel() > 0:
 
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
 
318
+ # More balanced stopping parameters
319
+ silence_threshold = 0.05 # Much lower - only catch true silence
320
+ consecutive_silence_limit = 50 # Much higher - allow for natural pauses
321
  consecutive_silence_count = 0
322
 
323
+ # Less aggressive repetition detection parameters
324
+ repetition_threshold = 0.98 # Higher threshold - only catch very similar frames
325
+ repetition_limit = 20 # Allow more repetitive frames before stopping
326
  repetition_count = 0
327
  previous_frames = []
328
 
 
497
  try:
498
  print("TTS: Synthesizing English speech...")
499
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
500
+ generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, stop_token_threshold=0.2, with_tqdm=False)
501
 
502
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
503
  if generated_mel is not None and generated_mel.numel() > 0: