Spaces:
Sleeping
Sleeping
Fix syntax error in TTS stage and complete pipeline
Browse files
app.py
CHANGED
|
@@ -306,80 +306,46 @@ class TransformerTTS(nn.Module):
|
|
| 306 |
return mel_postnet, mel_linear, stop_token
|
| 307 |
|
| 308 |
@torch.no_grad()
|
| 309 |
-
def inference(self, text, max_length=800,
|
| 310 |
-
self.eval()
|
|
|
|
| 311 |
text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
|
| 312 |
N = 1
|
| 313 |
SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
|
|
|
|
| 314 |
mel_padded = SOS
|
| 315 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 316 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
|
| 334 |
-
if torch.sigmoid(stop_token[:, -1]) > stop_token_threshold:
|
| 335 |
-
print(f"TTS: Stopping due to stop token at frame {i}")
|
| 336 |
break
|
| 337 |
-
|
| 338 |
-
current_frame = mel_postnet[:, -1:, :]
|
| 339 |
-
frame_energy = torch.mean(torch.abs(current_frame))
|
| 340 |
-
|
| 341 |
-
# Check for silence with higher threshold
|
| 342 |
-
if frame_energy < silence_threshold:
|
| 343 |
-
consecutive_silence_count += 1
|
| 344 |
-
if consecutive_silence_count >= consecutive_silence_limit:
|
| 345 |
-
print(f"TTS: Stopping due to {consecutive_silence_limit} consecutive silent frames at frame {i}")
|
| 346 |
-
break
|
| 347 |
else:
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
# Compare current frame with recent frames
|
| 353 |
-
current_flat = current_frame.flatten()
|
| 354 |
-
is_repetitive = False
|
| 355 |
-
|
| 356 |
-
for prev_frame in previous_frames[-3:]: # Check last 3 frames
|
| 357 |
-
prev_flat = prev_frame.flatten()
|
| 358 |
-
# Calculate cosine similarity
|
| 359 |
-
similarity = torch.cosine_similarity(current_flat, prev_flat, dim=0)
|
| 360 |
-
if similarity > repetition_threshold:
|
| 361 |
-
repetition_count += 1
|
| 362 |
-
is_repetitive = True
|
| 363 |
-
break
|
| 364 |
-
|
| 365 |
-
if is_repetitive and repetition_count >= repetition_limit:
|
| 366 |
-
print(f"TTS: Stopping due to repetitive content at frame {i}")
|
| 367 |
-
break
|
| 368 |
-
elif not is_repetitive:
|
| 369 |
-
repetition_count = 0 # Reset if not repetitive
|
| 370 |
-
|
| 371 |
-
# Keep track of recent frames for repetition detection
|
| 372 |
-
previous_frames.append(current_frame.clone())
|
| 373 |
-
if len(previous_frames) > 5: # Keep only last 5 frames
|
| 374 |
-
previous_frames.pop(0)
|
| 375 |
-
|
| 376 |
-
mel_padded = torch.cat([mel_padded, mel_postnet[:, -1:, :]], dim=1)
|
| 377 |
-
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
| 378 |
-
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
| 379 |
-
|
| 380 |
-
# Remove the initial SOS token and return only the generated mel
|
| 381 |
-
generated_mel = mel_padded[:, 1:, :] # Remove first frame (SOS)
|
| 382 |
-
return generated_mel, stop_token_outputs
|
| 383 |
# --- (End of your model definitions) ---
|
| 384 |
|
| 385 |
# --- Part 2: Model Loading ---
|
|
@@ -497,7 +463,7 @@ def full_speech_translation_pipeline(audio_input_path: str):
|
|
| 497 |
try:
|
| 498 |
print("TTS: Synthesizing English speech...")
|
| 499 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 500 |
-
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50,
|
| 501 |
|
| 502 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 503 |
if generated_mel is not None and generated_mel.numel() > 0:
|
|
|
|
| 306 |
return mel_postnet, mel_linear, stop_token
|
| 307 |
|
| 308 |
@torch.no_grad()
|
| 309 |
+
def inference(self, text, max_length=800, gate_threshold=1e-5, with_tqdm=True):
|
| 310 |
+
self.eval()
|
| 311 |
+
self.train(False)
|
| 312 |
text_lengths = torch.tensor(text.shape[1]).unsqueeze(0).to(DEVICE)
|
| 313 |
N = 1
|
| 314 |
SOS = torch.zeros((N, 1, hp.mel_freq), device=DEVICE)
|
| 315 |
+
|
| 316 |
mel_padded = SOS
|
| 317 |
mel_lengths = torch.tensor(1).unsqueeze(0).to(DEVICE)
|
| 318 |
stop_token_outputs = torch.FloatTensor([]).to(text.device)
|
| 319 |
+
|
| 320 |
+
if with_tqdm:
|
| 321 |
+
from tqdm import tqdm
|
| 322 |
+
iters = tqdm(range(max_length))
|
| 323 |
+
else:
|
| 324 |
+
iters = range(max_length)
|
| 325 |
+
|
| 326 |
+
for _ in iters:
|
| 327 |
+
mel_postnet, mel_linear, stop_token = self(
|
| 328 |
+
text,
|
| 329 |
+
text_lengths,
|
| 330 |
+
mel_padded,
|
| 331 |
+
mel_lengths
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
mel_padded = torch.cat(
|
| 335 |
+
[
|
| 336 |
+
mel_padded,
|
| 337 |
+
mel_postnet[:, -1:, :]
|
| 338 |
+
],
|
| 339 |
+
dim=1
|
| 340 |
+
)
|
| 341 |
|
| 342 |
+
if torch.sigmoid(stop_token[:, -1]) > gate_threshold:
|
|
|
|
|
|
|
| 343 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
else:
|
| 345 |
+
stop_token_outputs = torch.cat([stop_token_outputs, stop_token[:, -1:]], dim=1)
|
| 346 |
+
mel_lengths = torch.tensor(mel_padded.shape[1]).unsqueeze(0).to(DEVICE)
|
| 347 |
+
|
| 348 |
+
return mel_postnet, stop_token_outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
# --- (End of your model definitions) ---
|
| 350 |
|
| 351 |
# --- Part 2: Model Loading ---
|
|
|
|
| 463 |
try:
|
| 464 |
print("TTS: Synthesizing English speech...")
|
| 465 |
sequence = text_to_seq(english_translation).unsqueeze(0).to(DEVICE)
|
| 466 |
+
generated_mel, _ = TTS_MODEL.inference(sequence, max_length=hp.max_mel_time-50, gate_threshold=1e-5, with_tqdm=False)
|
| 467 |
|
| 468 |
print(f"TTS: Generated mel shape: {generated_mel.shape if generated_mel is not None else 'None'}")
|
| 469 |
if generated_mel is not None and generated_mel.numel() > 0:
|