Spaces:
Paused
Paused
Fix TTS language issue by dynamically updating model in /translate-audio endpoint
Browse files
app.py
CHANGED
|
@@ -39,6 +39,7 @@ model_status = {
|
|
| 39 |
"tts": "not_loaded"
|
| 40 |
}
|
| 41 |
error_message = None
|
|
|
|
| 42 |
|
| 43 |
# Model instances
|
| 44 |
stt_processor = None
|
|
@@ -179,7 +180,7 @@ def load_models_task():
|
|
| 179 |
error_message = f"MT model loading failed: {str(e)}"
|
| 180 |
return
|
| 181 |
|
| 182 |
-
# Load TTS model (default to Tagalog, will be updated
|
| 183 |
logger.info("Starting to load TTS model...")
|
| 184 |
from transformers import VitsModel, AutoTokenizer
|
| 185 |
|
|
@@ -201,6 +202,7 @@ def load_models_task():
|
|
| 201 |
tts_model.to(device)
|
| 202 |
logger.info("Fallback TTS model loaded successfully")
|
| 203 |
model_status["tts"] = "loaded (fallback)"
|
|
|
|
| 204 |
except Exception as e2:
|
| 205 |
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
| 206 |
model_status["tts"] = "failed"
|
|
@@ -259,7 +261,7 @@ async def health_check():
|
|
| 259 |
|
| 260 |
@app.post("/update-languages")
|
| 261 |
async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 262 |
-
global stt_processor, stt_model, tts_model, tts_tokenizer
|
| 263 |
|
| 264 |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
|
| 265 |
raise HTTPException(status_code=400, detail="Invalid language selected")
|
|
@@ -314,6 +316,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
| 314 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 315 |
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 316 |
tts_model.to(device)
|
|
|
|
| 317 |
logger.info(f"TTS model updated to {target_code}")
|
| 318 |
model_status["tts"] = "loaded"
|
| 319 |
except Exception as e:
|
|
@@ -323,6 +326,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
| 323 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 324 |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 325 |
tts_model.to(device)
|
|
|
|
| 326 |
logger.info("Fallback TTS model loaded successfully")
|
| 327 |
model_status["tts"] = "loaded (fallback)"
|
| 328 |
except Exception as e2:
|
|
@@ -337,7 +341,7 @@ async def update_languages(source_lang: str = Form(...), target_lang: str = Form
|
|
| 337 |
@app.post("/translate-text")
|
| 338 |
async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 339 |
"""Endpoint to translate text and convert to speech"""
|
| 340 |
-
global mt_model, mt_tokenizer, tts_model, tts_tokenizer
|
| 341 |
|
| 342 |
if not text:
|
| 343 |
raise HTTPException(status_code=400, detail="No text provided")
|
|
@@ -373,6 +377,31 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
| 373 |
else:
|
| 374 |
logger.warning("MT model not loaded, skipping translation")
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
# Convert translated text to speech
|
| 377 |
output_audio_url = None
|
| 378 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
|
@@ -409,7 +438,7 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
|
|
| 409 |
@app.post("/translate-audio")
|
| 410 |
async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 411 |
"""Endpoint to transcribe, translate, and convert audio to speech"""
|
| 412 |
-
global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
|
| 413 |
|
| 414 |
if not audio:
|
| 415 |
raise HTTPException(status_code=400, detail="No audio file provided")
|
|
@@ -506,7 +535,32 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 506 |
else:
|
| 507 |
logger.warning("MT model not loaded, skipping translation")
|
| 508 |
|
| 509 |
-
# Step 5:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 511 |
try:
|
| 512 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|
|
|
|
| 39 |
"tts": "not_loaded"
|
| 40 |
}
|
| 41 |
error_message = None
|
| 42 |
+
current_tts_language = "tgl" # Track the current TTS language
|
| 43 |
|
| 44 |
# Model instances
|
| 45 |
stt_processor = None
|
|
|
|
| 180 |
error_message = f"MT model loading failed: {str(e)}"
|
| 181 |
return
|
| 182 |
|
| 183 |
+
# Load TTS model (default to Tagalog, will be updated dynamically)
|
| 184 |
logger.info("Starting to load TTS model...")
|
| 185 |
from transformers import VitsModel, AutoTokenizer
|
| 186 |
|
|
|
|
| 202 |
tts_model.to(device)
|
| 203 |
logger.info("Fallback TTS model loaded successfully")
|
| 204 |
model_status["tts"] = "loaded (fallback)"
|
| 205 |
+
current_tts_language = "eng"
|
| 206 |
except Exception as e2:
|
| 207 |
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
| 208 |
model_status["tts"] = "failed"
|
|
|
|
| 261 |
|
| 262 |
@app.post("/update-languages")
|
| 263 |
async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 264 |
+
global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
|
| 265 |
|
| 266 |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
|
| 267 |
raise HTTPException(status_code=400, detail="Invalid language selected")
|
|
|
|
| 316 |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 317 |
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 318 |
tts_model.to(device)
|
| 319 |
+
current_tts_language = target_code
|
| 320 |
logger.info(f"TTS model updated to {target_code}")
|
| 321 |
model_status["tts"] = "loaded"
|
| 322 |
except Exception as e:
|
|
|
|
| 326 |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 327 |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 328 |
tts_model.to(device)
|
| 329 |
+
current_tts_language = "eng"
|
| 330 |
logger.info("Fallback TTS model loaded successfully")
|
| 331 |
model_status["tts"] = "loaded (fallback)"
|
| 332 |
except Exception as e2:
|
|
|
|
| 341 |
@app.post("/translate-text")
|
| 342 |
async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 343 |
"""Endpoint to translate text and convert to speech"""
|
| 344 |
+
global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
|
| 345 |
|
| 346 |
if not text:
|
| 347 |
raise HTTPException(status_code=400, detail="No text provided")
|
|
|
|
| 377 |
else:
|
| 378 |
logger.warning("MT model not loaded, skipping translation")
|
| 379 |
|
| 380 |
+
# Update TTS model if the target language doesn't match the current TTS language
|
| 381 |
+
if current_tts_language != target_code:
|
| 382 |
+
try:
|
| 383 |
+
logger.info(f"Updating TTS model for {target_code}...")
|
| 384 |
+
from transformers import VitsModel, AutoTokenizer
|
| 385 |
+
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 386 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 387 |
+
tts_model.to(device)
|
| 388 |
+
current_tts_language = target_code
|
| 389 |
+
logger.info(f"TTS model updated to {target_code}")
|
| 390 |
+
model_status["tts"] = "loaded"
|
| 391 |
+
except Exception as e:
|
| 392 |
+
logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
|
| 393 |
+
try:
|
| 394 |
+
logger.info("Falling back to MMS-TTS English model...")
|
| 395 |
+
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 396 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 397 |
+
tts_model.to(device)
|
| 398 |
+
current_tts_language = "eng"
|
| 399 |
+
logger.info("Fallback TTS model loaded successfully")
|
| 400 |
+
model_status["tts"] = "loaded (fallback)"
|
| 401 |
+
except Exception as e2:
|
| 402 |
+
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
| 403 |
+
model_status["tts"] = "failed"
|
| 404 |
+
|
| 405 |
# Convert translated text to speech
|
| 406 |
output_audio_url = None
|
| 407 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
|
|
|
| 438 |
@app.post("/translate-audio")
|
| 439 |
async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
|
| 440 |
"""Endpoint to transcribe, translate, and convert audio to speech"""
|
| 441 |
+
global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
|
| 442 |
|
| 443 |
if not audio:
|
| 444 |
raise HTTPException(status_code=400, detail="No audio file provided")
|
|
|
|
| 535 |
else:
|
| 536 |
logger.warning("MT model not loaded, skipping translation")
|
| 537 |
|
| 538 |
+
# Step 5: Update TTS model if the target language doesn't match the current TTS language
|
| 539 |
+
if current_tts_language != target_code:
|
| 540 |
+
try:
|
| 541 |
+
logger.info(f"Updating TTS model for {target_code}...")
|
| 542 |
+
from transformers import VitsModel, AutoTokenizer
|
| 543 |
+
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 544 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
|
| 545 |
+
tts_model.to(device)
|
| 546 |
+
current_tts_language = target_code
|
| 547 |
+
logger.info(f"TTS model updated to {target_code}")
|
| 548 |
+
model_status["tts"] = "loaded"
|
| 549 |
+
except Exception as e:
|
| 550 |
+
logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
|
| 551 |
+
try:
|
| 552 |
+
logger.info("Falling back to MMS-TTS English model...")
|
| 553 |
+
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 554 |
+
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 555 |
+
tts_model.to(device)
|
| 556 |
+
current_tts_language = "eng"
|
| 557 |
+
logger.info("Fallback TTS model loaded successfully")
|
| 558 |
+
model_status["tts"] = "loaded (fallback)"
|
| 559 |
+
except Exception as e2:
|
| 560 |
+
logger.error(f"Failed to load fallback TTS model: {str(e2)}")
|
| 561 |
+
model_status["tts"] = "failed"
|
| 562 |
+
|
| 563 |
+
# Step 6: Convert translated text to speech (TTS)
|
| 564 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 565 |
try:
|
| 566 |
inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
|