Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -116,7 +116,7 @@ def load_models_task():
|
|
| 116 |
logger.info("TTS model loaded successfully")
|
| 117 |
model_status["tts"] = "loaded"
|
| 118 |
except Exception as e:
|
| 119 |
-
logger.error(f"Failed to load TTS model: {str(e)}")
|
| 120 |
# Fallback to English TTS if the target language fails
|
| 121 |
try:
|
| 122 |
logger.info("Falling back to MMS-TTS English model...")
|
|
@@ -304,8 +304,12 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 304 |
temp_file.write(await audio.read())
|
| 305 |
temp_path = temp_file.name
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
try:
|
| 308 |
-
#
|
| 309 |
logger.info(f"Reading audio file: {temp_path}")
|
| 310 |
waveform, sample_rate = sf.read(temp_path)
|
| 311 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
|
@@ -313,7 +317,6 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 313 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
| 314 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
| 315 |
|
| 316 |
-
# Process the audio with Whisper (STT)
|
| 317 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 318 |
logger.info(f"Using device: {device}")
|
| 319 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
|
@@ -323,10 +326,9 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 323 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 324 |
logger.info(f"Transcription completed: {transcription}")
|
| 325 |
|
| 326 |
-
# Translate the transcribed text
|
| 327 |
source_code = LANGUAGE_MAPPING[source_lang]
|
| 328 |
target_code = LANGUAGE_MAPPING[target_lang]
|
| 329 |
-
translated_text = "Translation not available"
|
| 330 |
|
| 331 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
| 332 |
try:
|
|
@@ -348,11 +350,43 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
|
|
| 348 |
else:
|
| 349 |
logger.warning("MT model not loaded, skipping translation")
|
| 350 |
|
| 351 |
-
# Convert translated text to speech
|
| 352 |
-
output_audio = None
|
| 353 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 354 |
try:
|
| 355 |
inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
|
| 356 |
with torch.no_grad():
|
| 357 |
output = tts_model(**inputs)
|
| 358 |
-
speech = output.waveform.cpu().numpy().squeeze
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
logger.info("TTS model loaded successfully")
|
| 117 |
model_status["tts"] = "loaded"
|
| 118 |
except Exception as e:
|
| 119 |
+
logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
|
| 120 |
# Fallback to English TTS if the target language fails
|
| 121 |
try:
|
| 122 |
logger.info("Falling back to MMS-TTS English model...")
|
|
|
|
| 304 |
temp_file.write(await audio.read())
|
| 305 |
temp_path = temp_file.name
|
| 306 |
|
| 307 |
+
transcription = "Transcription not available"
|
| 308 |
+
translated_text = "Translation not available"
|
| 309 |
+
output_audio = None
|
| 310 |
+
|
| 311 |
try:
|
| 312 |
+
# Step 1: Transcribe the audio (STT)
|
| 313 |
logger.info(f"Reading audio file: {temp_path}")
|
| 314 |
waveform, sample_rate = sf.read(temp_path)
|
| 315 |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
|
|
|
|
| 317 |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
|
| 318 |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
|
| 319 |
|
|
|
|
| 320 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 321 |
logger.info(f"Using device: {device}")
|
| 322 |
inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
|
|
|
|
| 326 |
transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 327 |
logger.info(f"Transcription completed: {transcription}")
|
| 328 |
|
| 329 |
+
# Step 2: Translate the transcribed text (MT)
|
| 330 |
source_code = LANGUAGE_MAPPING[source_lang]
|
| 331 |
target_code = LANGUAGE_MAPPING[target_lang]
|
|
|
|
| 332 |
|
| 333 |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
|
| 334 |
try:
|
|
|
|
| 350 |
else:
|
| 351 |
logger.warning("MT model not loaded, skipping translation")
|
| 352 |
|
| 353 |
+
# Step 3: Convert translated text to speech (TTS)
|
|
|
|
| 354 |
if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
|
| 355 |
try:
|
| 356 |
inputs = tts_tokenizer(translated_text, return_tensors="pt", clean_up_tokenization_spaces=True).to(device)
|
| 357 |
with torch.no_grad():
|
| 358 |
output = tts_model(**inputs)
|
| 359 |
+
speech = output.waveform.cpu().numpy().squeeze()
|
| 360 |
+
speech = (speech * 32767).astype(np.int16)
|
| 361 |
+
output_audio = (tts_model.config.sampling_rate, speech.tolist())
|
| 362 |
+
logger.info("TTS conversion completed")
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.error(f"Error during TTS conversion: {str(e)}")
|
| 365 |
+
output_audio = None
|
| 366 |
+
|
| 367 |
+
return {
|
| 368 |
+
"request_id": request_id,
|
| 369 |
+
"status": "completed",
|
| 370 |
+
"message": "Transcription, translation, and TTS completed (or partially completed).",
|
| 371 |
+
"source_text": transcription,
|
| 372 |
+
"translated_text": translated_text,
|
| 373 |
+
"output_audio": output_audio
|
| 374 |
+
}
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"Error during processing: {str(e)}")
|
| 377 |
+
return {
|
| 378 |
+
"request_id": request_id,
|
| 379 |
+
"status": "failed",
|
| 380 |
+
"message": f"Processing failed: {str(e)}",
|
| 381 |
+
"source_text": transcription,
|
| 382 |
+
"translated_text": translated_text,
|
| 383 |
+
"output_audio": output_audio
|
| 384 |
+
}
|
| 385 |
+
finally:
|
| 386 |
+
logger.info(f"Cleaning up temporary file: {temp_path}")
|
| 387 |
+
os.unlink(temp_path)
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
import uvicorn
|
| 391 |
+
logger.info("Starting Uvicorn server...")
|
| 392 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|