MoHamdyy commited on
Commit
de6f9f5
·
1 Parent(s): 3c32f04

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +32 -66
app.py CHANGED
@@ -306,80 +306,46 @@ class TransformerTTS(nn.Module):
306
  return mel_postnet, mel_linear, stop_token
307
 
308
  @torch.no_grad()
309
- def inference(self, text, max_length=800, stop_token_threshold=0.5, with_tqdm=True):
310
- self.eval(); self.train(False)
 
311
  text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
312
  N = 1
313
  SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
 
314
  mel_padded = SOS
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
-
318
- # More balanced stopping parameters
319
- silence_threshold = 0.05 # Much lower - only catch true silence
320
- consecutive_silence_limit = 50 # Much higher - allow for natural pauses
321
- consecutive_silence_count = 0
322
-
323
- # Less aggressive repetition detection parameters
324
- repetition_threshold = 0.98 # Higher threshold - only catch very similar frames
325
- repetition_limit = 20 # Allow more repetitive frames before stopping
326
- repetition_count = 0
327
- previous_frames = []
328
-
329
- iters = range(max_length)
330
- for i, _ in enumerate(iters):
331
- mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
 
 
 
 
 
 
 
332
 
333
- # Check stop token BEFORE adding to mel_padded (even more aggressive)
334
- if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
335
- print(f"TTS: Stopping due to stop token at frame {i}")
336
  break
337
-
338
- current_frame = mel_postnet[:, -1:, :]
339
- frame_energy = torch.mean(torch.abs(current_frame))
340
-
341
- # Check for silence with higher threshold
342
- if frame_energy < silence_threshold:
343
- consecutive_silence_count += 1
344
- if consecutive_silence_count >= consecutive_silence_limit:
345
- print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames at frame {i}")
346
- break
347
  else:
348
- consecutive_silence_count = 0
349
-
350
- # NEW: Check for repetitive content (detecting loops)
351
- if len(previous_frames) >= 3: # Start checking after a few frames
352
- # Compare current frame with recent frames
353
- current_flat = current_frame.flatten()
354
- is_repetitive = False
355
-
356
- for prev_frame in previous_frames[-3:]: # Check last 3 frames
357
- prev_flat = prev_frame.flatten()
358
- # Calculate cosine similarity
359
- similarity = torch.cosine_similarity(current_flat, prev_flat, dim=0)
360
- if similarity > repetition_threshold:
361
- repetition_count += 1
362
- is_repetitive = True
363
- break
364
-
365
- if is_repetitive and repetition_count >= repetition_limit:
366
- print(f"TTS: Stopping due to repetitive content at frame {i}")
367
- break
368
- elif not is_repetitive:
369
- repetition_count = 0 # Reset if not repetitive
370
-
371
- # Keep track of recent frames for repetition detection
372
- previous_frames.append(current_frame.clone())
373
- if len(previous_frames) > 5: # Keep only last 5 frames
374
- previous_frames.pop(0)
375
-
376
- mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
377
- stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
378
- mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
379
-
380
- # Remove the initial SOS token and return only the generated mel
381
- generated_mel = mel_padded[:, 1:, :] # Remove first frame (SOS)
382
- return generated_mel, stop_token_outputs
383
  # --- (End of your model definitions) ---
384
 
385
  # --- Part 2: Model Loading ---
@@ -497,7 +463,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
497
  try:
498
  print("TTS: Synthesizing English speech...")
499
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
500
- generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, stop_token_threshold=0.5, with_tqdm=False)
501
 
502
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
503
  if generated_mel is not None and generated_mel.numel() > 0:
 
306
  return mel_postnet, mel_linear, stop_token
307
 
308
  @torch.no_grad()
309
+ def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
310
+ self.eval()
311
+ self.train(False)
312
  text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
313
  N = 1
314
  SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
315
+
316
  mel_padded = SOS
317
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
318
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
319
+
320
+ if with_tqdm:
321
+ from tqdm import tqdm
322
+ iters = tqdm(range(max_length))
323
+ else:
324
+ iters = range(max_length)
325
+
326
+ for _ in iters:
327
+ mel_postnet, mel_linear, stop_token = self(
328
+ text,
329
+ text_lengths,
330
+ mel_padded,
331
+ mel_lengths
332
+ )
333
+
334
+ mel_padded = torch.cat(
335
+ [
336
+ mel_padded,
337
+ mel_postnet[:, -1:, :]
338
+ ],
339
+ dim=1
340
+ )
341
 
342
+ if torch.sigmoid(stop_token[:, -1]) > gate_threshold:
 
 
343
  break
 
 
 
 
 
 
 
 
 
 
344
  else:
345
+ stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
346
+ mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
347
+
348
+ return mel_postnet, stop_token_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  # --- (End of your model definitions) ---
350
 
351
  # --- Part 2: Model Loading ---
 
463
  try:
464
  print("TTS: Synthesizing English speech...")
465
  sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
466
+ generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, gate_threshold=1e-5, with_tqdm=False)
467
 
468
  print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
469
  if generated_mel is not None and generated_mel.numel() > 0: