Spaces:
Sleeping
Sleeping
Updated the code to use the Whisper model if the source language is English or Tagalog; otherwise, it will use MMS. Additionally, the link to the synthesized speech has been updated to match the current space.
Browse files
app.py
CHANGED
|
@@ -420,7 +420,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
| 420 |
output_path, error = synthesize_speech(translated_text, target_code)
|
| 421 |
if output_path:
|
| 422 |
output_filename = os.path.basename(output_path)
|
| 423 |
-
output_audio_url = f"https://jerich-
|
| 424 |
logger.info("TTS conversion completed")
|
| 425 |
except Exception as e:
|
| 426 |
logger.error(f"Error during TTS conversion: {str(e)}")
|
|
@@ -448,7 +448,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 448 |
logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
|
| 449 |
request_id = str(uuid.uuid4())
|
| 450 |
|
| 451 |
-
# Check if STT
|
| 452 |
if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
|
| 453 |
logger.warning("STT model not loaded, returning placeholder response")
|
| 454 |
return {
|
|
@@ -499,23 +499,44 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 499 |
# Step 3: Transcribe the audio (STT)
|
| 500 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 501 |
logger.info(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
|
| 503 |
logger.info("Audio processed, generating transcription...")
|
| 504 |
|
| 505 |
with torch.no_grad():
|
| 506 |
-
if model_status["stt"] == "loaded_whisper":
|
| 507 |
-
# Whisper model
|
| 508 |
-
|
|
|
|
| 509 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 510 |
-
|
| 511 |
-
# MMS model
|
|
|
|
| 512 |
logits = stt_model(**inputs).logits
|
| 513 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 514 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
logger.info(f"Transcription completed: {transcription}")
|
| 516 |
|
| 517 |
# Step 4: Translate the transcribed text (MT)
|
| 518 |
-
source_code = LANGUAGE_MAPPING[source_lang]
|
| 519 |
target_code = LANGUAGE_MAPPING[target_lang]
|
| 520 |
|
| 521 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
|
@@ -549,7 +570,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 549 |
output_path, error = synthesize_speech(translated_text, target_code)
|
| 550 |
if output_path:
|
| 551 |
output_filename = os.path.basename(output_path)
|
| 552 |
-
output_audio_url = f"https://jerich-
|
| 553 |
logger.info("TTS conversion completed")
|
| 554 |
except Exception as e:
|
| 555 |
logger.error(f"Error during TTS conversion: {str(e)}")
|
|
@@ -603,7 +624,7 @@ async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
|
|
| 603 |
output_path, error = synthesize_speech(text, target_code)
|
| 604 |
if output_path:
|
| 605 |
output_filename = os.path.basename(output_path)
|
| 606 |
-
output_audio_url = f"https://jerich-
|
| 607 |
logger.info("TTS conversion completed")
|
| 608 |
else:
|
| 609 |
logger.error(f"TTS conversion failed: {error}")
|
|
|
|
| 420 |
output_path, error = synthesize_speech(translated_text, target_code)
|
| 421 |
if output_path:
|
| 422 |
output_filename = os.path.basename(output_path)
|
| 423 |
+
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
|
| 424 |
logger.info("TTS conversion completed")
|
| 425 |
except Exception as e:
|
| 426 |
logger.error(f"Error during TTS conversion: {str(e)}")
|
|
|
|
| 448 |
logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
|
| 449 |
request_id = str(uuid.uuid4())
|
| 450 |
|
| 451 |
+
# Check if STT models are loaded
|
| 452 |
if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
|
| 453 |
logger.warning("STT model not loaded, returning placeholder response")
|
| 454 |
return {
|
|
|
|
| 499 |
# Step 3: Transcribe the audio (STT)
|
| 500 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 501 |
logger.info(f"Using device: {device}")
|
| 502 |
+
|
| 503 |
+
# Determine which model to use based on source language
|
| 504 |
+
source_code = LANGUAGE_MAPPING[source_lang]
|
| 505 |
+
use_whisper = source_code in ["eng", "tgl"] # Use Whisper for English and Tagalog
|
| 506 |
+
use_mms = not use_whisper # Use MMS for other Philippine languages
|
| 507 |
+
|
| 508 |
+
logger.info(f"Source language: {source_lang} ({source_code}), Using Whisper: {use_whisper}, Using MMS: {use_mms}")
|
| 509 |
+
|
| 510 |
+
# Process with appropriate model
|
| 511 |
inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
|
| 512 |
logger.info("Audio processed, generating transcription...")
|
| 513 |
|
| 514 |
with torch.no_grad():
|
| 515 |
+
if use_whisper and model_status["stt"] == "loaded_whisper":
|
| 516 |
+
# Whisper model for English and Tagalog
|
| 517 |
+
logger.info(f"Using Whisper model for {source_lang}")
|
| 518 |
+
generated_ids = stt_model.generate(**inputs, language="en" if source_code == "eng" else "tl")
|
| 519 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 520 |
+
elif model_status["stt"] in ["loaded_mms", "loaded_mms_default"]:
|
| 521 |
+
# MMS model for other Philippine languages
|
| 522 |
+
logger.info(f"Using MMS model for {source_lang}")
|
| 523 |
logits = stt_model(**inputs).logits
|
| 524 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 525 |
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
| 526 |
+
else:
|
| 527 |
+
# Fallback to any available model
|
| 528 |
+
logger.info(f"Preferred model not available, using fallback model")
|
| 529 |
+
if model_status["stt"] == "loaded_whisper":
|
| 530 |
+
generated_ids = stt_model.generate(**inputs, language="en")
|
| 531 |
+
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 532 |
+
else:
|
| 533 |
+
logits = stt_model(**inputs).logits
|
| 534 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 535 |
+
transcription = stt_processor.batch_decode(predicted_ids)[0]
|
| 536 |
+
|
| 537 |
logger.info(f"Transcription completed: {transcription}")
|
| 538 |
|
| 539 |
# Step 4: Translate the transcribed text (MT)
|
|
|
|
| 540 |
target_code = LANGUAGE_MAPPING[target_lang]
|
| 541 |
|
| 542 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
|
|
|
| 570 |
output_path, error = synthesize_speech(translated_text, target_code)
|
| 571 |
if output_path:
|
| 572 |
output_filename = os.path.basename(output_path)
|
| 573 |
+
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
|
| 574 |
logger.info("TTS conversion completed")
|
| 575 |
except Exception as e:
|
| 576 |
logger.error(f"Error during TTS conversion: {str(e)}")
|
|
|
|
| 624 |
output_path, error = synthesize_speech(text, target_code)
|
| 625 |
if output_path:
|
| 626 |
output_filename = os.path.basename(output_path)
|
| 627 |
+
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
|
| 628 |
logger.info("TTS conversion completed")
|
| 629 |
else:
|
| 630 |
logger.error(f"TTS conversion failed: {error}")
|