Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -315,30 +315,63 @@ class TransformerTTS(nn.Module):
|
|
| 315 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 316 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
| 317 |
|
| 318 |
-
#
|
| 319 |
-
silence_threshold = 0.
|
| 320 |
-
consecutive_silence_limit =
|
| 321 |
consecutive_silence_count = 0
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
iters = range(max_length)
|
| 324 |
-
for _ in iters:
|
| 325 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
| 326 |
|
| 327 |
-
# Check stop token BEFORE adding to mel_padded
|
| 328 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
|
|
|
| 329 |
break
|
| 330 |
|
| 331 |
-
# Check for silence in the generated mel frame
|
| 332 |
current_frame = mel_postnet[:, -1:, :]
|
| 333 |
frame_energy = torch.mean(torch.abs(current_frame))
|
| 334 |
|
|
|
|
| 335 |
if frame_energy < silence_threshold:
|
| 336 |
consecutive_silence_count += 1
|
| 337 |
if consecutive_silence_count >= consecutive_silence_limit:
|
| 338 |
-
print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames")
|
| 339 |
break
|
| 340 |
else:
|
| 341 |
-
consecutive_silence_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 344 |
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
|
|
|
| 315 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 316 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
| 317 |
|
| 318 |
+
# More aggressive stopping parameters
|
| 319 |
+
silence_threshold = 0.2 # Increased from 0.1 to catch low-energy repetitions
|
| 320 |
+
consecutive_silence_limit = 10 # Reduced from 20 to stop faster
|
| 321 |
consecutive_silence_count = 0
|
| 322 |
|
| 323 |
+
# Repetition detection parameters
|
| 324 |
+
repetition_threshold = 0.95 # Cosine similarity threshold for detecting repeated frames
|
| 325 |
+
repetition_limit = 8 # Stop after 8 similar consecutive frames
|
| 326 |
+
repetition_count = 0
|
| 327 |
+
previous_frames = []
|
| 328 |
+
|
| 329 |
iters = range(max_length)
|
| 330 |
+
for i, _ in enumerate(iters):
|
| 331 |
mel_postnet, mel_linear, stop_token = self(text, text_lengths, mel_padded, mel_lengths)
|
| 332 |
|
| 333 |
+
# Check stop token BEFORE adding to mel_padded (even more aggressive)
|
| 334 |
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 335 |
+
print(f"TTS: Stopping due to stop token at frame {i}")
|
| 336 |
break
|
| 337 |
|
|
|
|
| 338 |
current_frame = mel_postnet[:, -1:, :]
|
| 339 |
frame_energy = torch.mean(torch.abs(current_frame))
|
| 340 |
|
| 341 |
+
# Check for silence with higher threshold
|
| 342 |
if frame_energy < silence_threshold:
|
| 343 |
consecutive_silence_count += 1
|
| 344 |
if consecutive_silence_count >= consecutive_silence_limit:
|
| 345 |
+
print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames at frame {i}")
|
| 346 |
break
|
| 347 |
else:
|
| 348 |
+
consecutive_silence_count = 0
|
| 349 |
+
|
| 350 |
+
# NEW: Check for repetitive content (detecting loops)
|
| 351 |
+
if len(previous_frames) >= 3: # Start checking after a few frames
|
| 352 |
+
# Compare current frame with recent frames
|
| 353 |
+
current_flat = current_frame.flatten()
|
| 354 |
+
is_repetitive = False
|
| 355 |
+
|
| 356 |
+
for prev_frame in previous_frames[-3:]: # Check last 3 frames
|
| 357 |
+
prev_flat = prev_frame.flatten()
|
| 358 |
+
# Calculate cosine similarity
|
| 359 |
+
similarity = torch.cosine_similarity(current_flat, prev_flat, dim=0)
|
| 360 |
+
if similarity > repetition_threshold:
|
| 361 |
+
repetition_count += 1
|
| 362 |
+
is_repetitive = True
|
| 363 |
+
break
|
| 364 |
+
|
| 365 |
+
if is_repetitive and repetition_count >= repetition_limit:
|
| 366 |
+
print(f"TTS: Stopping due to repetitive content at frame {i}")
|
| 367 |
+
break
|
| 368 |
+
elif not is_repetitive:
|
| 369 |
+
repetition_count = 0 # Reset if not repetitive
|
| 370 |
+
|
| 371 |
+
# Keep track of recent frames for repetition detection
|
| 372 |
+
previous_frames.append(current_frame.clone())
|
| 373 |
+
if len(previous_frames) > 5: # Keep only last 5 frames
|
| 374 |
+
previous_frames.pop(0)
|
| 375 |
|
| 376 |
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 377 |
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|