MoHamdyy commited on
Commit
e532f67
·
1 Parent(s): 2877500

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +19 -1
app.py CHANGED
@@ -314,6 +314,12 @@ class TransformerTTS(nn.Module):
314
  mel_padded = SOS
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
 
 
 
 
 
 
317
  iters = range(max_length)
318
  for _ in iters:
319
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
@@ -322,6 +328,18 @@ class TransformerTTS(nn.Module):
322
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
323
  break
324
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
326
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
327
  mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
@@ -446,7 +464,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
446
  try:
447
  print("TTS: Synthesizing English speech...")
448
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
449
- generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, stop_token_threshold=0.3, with_tqdm=False)
450
 
451
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
452
  if generated_mel is not None and generated_mel.numel() > 0:
 
314
  mel_padded = SOS
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
+
318
+ # Silence detection parameters
319
+ silence_threshold = 0.1 # Consider frames below this as silence
320
+ consecutive_silence_limit = 20 # Stop after 20 consecutive silent frames
321
+ consecutive_silence_count = 0
322
+
323
  iters = range(max_length)
324
  for _ in iters:
325
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
 
328
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
329
  break
330
 
331
+ # Check for silence in the generated mel frame
332
+ current_frame = mel_postnet[:, -1:, :]
333
+ frame_energy = torch.mean(torch.abs(current_frame))
334
+
335
+ if frame_energy < silence_threshold:
336
+ consecutive_silence_count += 1
337
+ if consecutive_silence_count >= consecutive_silence_limit:
338
+ print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames")
339
+ break
340
+ else:
341
+ consecutive_silence_count = 0 # Reset silence counter
342
+
343
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
344
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
345
  mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
 
464
  try:
465
  print("TTS: Synthesizing English speech...")
466
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
467
+ generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-100, stop_token_threshold=0.1, with_tqdm=False)
468
 
469
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
470
  if generated_mel is not None and generated_mel.numel() > 0: