MoHamdyy commited on
Commit
2fbce4b
·
1 Parent(s): e532f67

Fix syntax error in TTS stage and complete pipeline

Browse files
Files changed (1) hide show
  1. app.py +41 -8
app.py CHANGED
@@ -315,30 +315,63 @@ class TransformerTTS(nn.Module):
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
 
318
- # Silence detection parameters
319
- silence_threshold = 0.1 # Consider frames below this as silence
320
- consecutive_silence_limit = 20 # Stop after 20 consecutive silent frames
321
  consecutive_silence_count = 0
322
 
 
 
 
 
 
 
323
  iters = range(max_length)
324
- for _ in iters:
325
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
326
 
327
- # Check stop token BEFORE adding to mel_padded
328
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
 
329
  break
330
 
331
- # Check for silence in the generated mel frame
332
  current_frame = mel_postnet[:, -1:, :]
333
  frame_energy = torch.mean(torch.abs(current_frame))
334
 
 
335
  if frame_energy < silence_threshold:
336
  consecutive_silence_count += 1
337
  if consecutive_silence_count >= consecutive_silence_limit:
338
- print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames")
339
  break
340
  else:
341
- consecutive_silence_count = 0 # Reset silence counter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
344
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
 
315
  mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
316
  stop_token_outputs = torch.FloatTensor([]).to(text.device)
317
 
318
+ # More aggressive stopping parameters
319
+ silence_threshold = 0.2 # Increased from 0.1 to catch low-energy repetitions
320
+ consecutive_silence_limit = 10 # Reduced from 20 to stop faster
321
  consecutive_silence_count = 0
322
 
323
+ # Repetition detection parameters
324
+ repetition_threshold = 0.95 # Cosine similarity threshold for detecting repeated frames
325
+ repetition_limit = 8 # Stop after 8 similar consecutive frames
326
+ repetition_count = 0
327
+ previous_frames = []
328
+
329
  iters = range(max_length)
330
+ for i, _ in enumerate(iters):
331
  mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
332
 
333
+ # Check stop token BEFORE adding to mel_padded (even more aggressive)
334
  if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
335
+ print(f"TTS: Stopping due to stop token at frame {i}")
336
  break
337
 
 
338
  current_frame = mel_postnet[:, -1:, :]
339
  frame_energy = torch.mean(torch.abs(current_frame))
340
 
341
+ # Check for silence with higher threshold
342
  if frame_energy < silence_threshold:
343
  consecutive_silence_count += 1
344
  if consecutive_silence_count >= consecutive_silence_limit:
345
+ print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames at frame {i}")
346
  break
347
  else:
348
+ consecutive_silence_count = 0
349
+
350
+ # NEW: Check for repetitive content (detecting loops)
351
+ if len(previous_frames) >= 3: # Start checking after a few frames
352
+ # Compare current frame with recent frames
353
+ current_flat = current_frame.flatten()
354
+ is_repetitive = False
355
+
356
+ for prev_frame in previous_frames[-3:]: # Check last 3 frames
357
+ prev_flat = prev_frame.flatten()
358
+ # Calculate cosine similarity
359
+ similarity = torch.cosine_similarity(current_flat, prev_flat, dim=0)
360
+ if similarity > repetition_threshold:
361
+ repetition_count += 1
362
+ is_repetitive = True
363
+ break
364
+
365
+ if is_repetitive and repetition_count >= repetition_limit:
366
+ print(f"TTS: Stopping due to repetitive content at frame {i}")
367
+ break
368
+ elif not is_repetitive:
369
+ repetition_count = 0 # Reset if not repetitive
370
+
371
+ # Keep track of recent frames for repetition detection
372
+ previous_frames.append(current_frame.clone())
373
+ if len(previous_frames) > 5: # Keep only last 5 frames
374
+ previous_frames.pop(0)
375
 
376
  mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
377
  stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)