Jerich commited on
Commit
1b4b3a1
·
verified ·
1 Parent(s): 0b22bab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -259
app.py CHANGED
@@ -12,10 +12,11 @@ import soundfile as sf
12
  import torchaudio
13
  import wave
14
  import time
 
15
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
16
  from fastapi.responses import JSONResponse
17
  from fastapi.staticfiles import StaticFiles
18
- from typing import Dict, Any, Optional, Tuple
19
  from datetime import datetime, timedelta
20
 
21
  # Configure logging
@@ -34,20 +35,12 @@ models_loaded = False
34
  loading_in_progress = False
35
  loading_thread = None
36
  model_status = {
37
- "stt": "not_loaded",
 
38
  "mt": "not_loaded",
39
- "tts": "not_loaded"
40
  }
41
  error_message = None
42
- current_tts_language = "tgl" # Track the current TTS language
43
-
44
- # Model instances
45
- stt_processor = None
46
- stt_model = None
47
- mt_model = None
48
- mt_tokenizer = None
49
- tts_model = None
50
- tts_tokenizer = None
51
 
52
  # Define the valid languages and mappings
53
  LANGUAGE_MAPPING = {
@@ -68,6 +61,31 @@ NLLB_LANGUAGE_CODES = {
68
  "pag": "pag_Latn"
69
  }
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # Function to save PCM data as a WAV file
72
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
73
  # Convert pcm_data to a NumPy array of 16-bit integers
@@ -105,6 +123,53 @@ def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0
105
  # For now, we assume if RMS is above threshold, there is speech
106
  return True
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Function to clean up old audio files
109
  def cleanup_old_audio_files():
110
  logger.info("Starting cleanup of old audio files...")
@@ -129,38 +194,46 @@ def schedule_cleanup():
129
  # Function to load models in background
130
  def load_models_task():
131
  global models_loaded, loading_in_progress, model_status, error_message
132
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
133
 
134
  try:
135
  loading_in_progress = True
 
136
 
137
- # Load STT model (MMS with fallback to Whisper)
138
- logger.info("Starting to load STT model...")
139
- from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
140
 
 
141
  try:
142
  logger.info("Loading MMS STT model...")
143
- model_status["stt"] = "loading"
144
- stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
145
- stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
146
- device = "cuda" if torch.cuda.is_available() else "cpu"
147
- stt_model.to(device)
 
148
  logger.info("MMS STT model loaded successfully")
149
- model_status["stt"] = "loaded_mms"
150
  except Exception as mms_error:
151
  logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
152
- logger.info("Falling back to Whisper STT model...")
153
- try:
154
- stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
155
- stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
156
- stt_model.to(device)
157
- logger.info("Whisper STT model loaded successfully as fallback")
158
- model_status["stt"] = "loaded_whisper"
159
- except Exception as whisper_error:
160
- logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
161
- model_status["stt"] = "failed"
162
- error_message = f"STT model loading failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
163
- return
 
 
 
 
 
 
164
 
165
  # Load MT model
166
  logger.info("Starting to load MT model...")
@@ -178,40 +251,62 @@ def load_models_task():
178
  logger.error(f"Failed to load MT model: {str(e)}")
179
  model_status["mt"] = "failed"
180
  error_message = f"MT model loading failed: {str(e)}"
181
- return
182
-
183
- # Load TTS model (default to Tagalog, will be updated dynamically)
184
- logger.info("Starting to load TTS model...")
185
  from transformers import VitsModel, AutoTokenizer
186
 
187
- try:
188
- logger.info("Loading MMS-TTS model for Tagalog...")
189
- model_status["tts"] = "loading"
190
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl")
191
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
192
- tts_model.to(device)
193
- logger.info("TTS model loaded successfully")
194
- model_status["tts"] = "loaded"
195
- except Exception as e:
196
- logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
197
- # Fallback to English TTS if the target language fails
198
  try:
199
- logger.info("Falling back to MMS-TTS English model...")
200
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
201
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
202
- tts_model.to(device)
203
- logger.info("Fallback TTS model loaded successfully")
204
- model_status["tts"] = "loaded (fallback)"
205
- current_tts_language = "eng"
206
- except Exception as e2:
207
- logger.error(f"Failed to load fallback TTS model: {str(e2)}")
208
- model_status["tts"] = "failed"
209
- error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
210
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- models_loaded = True
213
- logger.info("Model loading completed successfully")
 
 
 
 
214
 
 
 
 
 
 
 
 
215
  except Exception as e:
216
  error_message = str(e)
217
  logger.error(f"Error in model loading task: {str(e)}")
@@ -221,7 +316,7 @@ def load_models_task():
221
  # Start loading models in background
222
  def start_model_loading():
223
  global loading_thread, loading_in_progress
224
- if not loading_in_progress and not models_loaded:
225
  loading_in_progress = True
226
  loading_thread = threading.Thread(target=load_models_task)
227
  loading_thread.daemon = True
@@ -259,89 +354,61 @@ async def health_check():
259
  "error": error_message
260
  }
261
 
262
- @app.post("/update-languages")
263
- async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
264
- global stt_processor, stt_model, tts_model, tts_tokenizer, current_tts_language
265
-
266
- if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
267
  raise HTTPException(status_code=400, detail="Invalid language selected")
268
 
269
- source_code = LANGUAGE_MAPPING[source_lang]
270
- target_code = LANGUAGE_MAPPING[target_lang]
 
271
 
272
- # Update the STT model based on the source language (MMS or Whisper)
273
- try:
274
- logger.info("Updating STT model for source language...")
275
- from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
276
- device = "cuda" if torch.cuda.is_available() else "cpu"
277
-
278
- try:
279
- logger.info(f"Loading MMS STT model for {source_code}...")
280
- stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
281
- stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
282
- stt_model.to(device)
283
- # Set the target language for MMS
284
- if source_code in stt_processor.tokenizer.vocab.keys():
285
- stt_processor.tokenizer.set_target_lang(source_code)
286
- stt_model.load_adapter(source_code)
287
- logger.info(f"MMS STT model updated to {source_code}")
288
- model_status["stt"] = "loaded_mms"
289
- else:
290
- logger.warning(f"Language {source_code} not supported by MMS, using default")
291
- model_status["stt"] = "loaded_mms_default"
292
- except Exception as mms_error:
293
- logger.error(f"Failed to load MMS STT model for {source_code}: {str(mms_error)}")
294
- logger.info("Falling back to Whisper STT model...")
295
- try:
296
- stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
297
- stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
298
- stt_model.to(device)
299
- logger.info("Whisper STT model loaded successfully as fallback")
300
- model_status["stt"] = "loaded_whisper"
301
- except Exception as whisper_error:
302
- logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
303
- model_status["stt"] = "failed"
304
- error_message = f"STT model update failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
305
- return {"status": "failed", "error": error_message}
306
- except Exception as e:
307
- logger.error(f"Error updating STT model: {str(e)}")
308
- model_status["stt"] = "failed"
309
- error_message = f"STT model update failed: {str(e)}"
310
- return {"status": "failed", "error": error_message}
311
 
312
- # Update the TTS model based on the target language
313
- try:
314
- logger.info(f"Loading MMS-TTS model for {target_code}...")
315
- from transformers import VitsModel, AutoTokenizer
316
- tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
317
- tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
318
- tts_model.to(device)
319
- current_tts_language = target_code
320
- logger.info(f"TTS model updated to {target_code}")
321
- model_status["tts"] = "loaded"
322
- except Exception as e:
323
- logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
324
- try:
325
- logger.info("Falling back to MMS-TTS English model...")
326
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
327
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
328
- tts_model.to(device)
329
- current_tts_language = "eng"
330
- logger.info("Fallback TTS model loaded successfully")
331
- model_status["tts"] = "loaded (fallback)"
332
- except Exception as e2:
333
- logger.error(f"Failed to load fallback TTS model: {str(e2)}")
334
- model_status["tts"] = "failed"
335
- error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})"
336
- return {"status": "failed", "error": error_message}
337
 
338
- logger.info(f"Updating languages: {source_lang} {target_lang}")
339
- return {"status": f"Languages updated to {source_lang} → {target_lang}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  @app.post("/translate-text")
342
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
343
  """Endpoint to translate text and convert to speech"""
344
- global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
345
 
346
  if not text:
347
  raise HTTPException(status_code=400, detail="No text provided")
@@ -376,55 +443,23 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
376
  translated_text = f"Translation failed: {str(e)}"
377
  else:
378
  logger.warning("MT model not loaded, skipping translation")
379
-
380
- # Update TTS model if the target language doesn't match the current TTS language
381
- if current_tts_language != target_code:
382
- try:
383
- logger.info(f"Updating TTS model for {target_code}...")
384
- from transformers import VitsModel, AutoTokenizer
385
- tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
386
- tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
387
- tts_model.to(device)
388
- current_tts_language = target_code
389
- logger.info(f"TTS model updated to {target_code}")
390
- model_status["tts"] = "loaded"
391
- except Exception as e:
392
- logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
393
- try:
394
- logger.info("Falling back to MMS-TTS English model...")
395
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
396
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
397
- tts_model.to(device)
398
- current_tts_language = "eng"
399
- logger.info("Fallback TTS model loaded successfully")
400
- model_status["tts"] = "loaded (fallback)"
401
- except Exception as e2:
402
- logger.error(f"Failed to load fallback TTS model: {str(e2)}")
403
- model_status["tts"] = "failed"
404
-
405
  # Convert translated text to speech
 
 
406
  output_audio_url = None
407
- if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
408
- try:
409
- inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
410
- with torch.no_grad():
411
- output = tts_model(**inputs)
412
- speech = output.waveform.cpu().numpy().squeeze()
413
- speech = (speech * 32767).astype(np.int16)
414
- sample_rate = tts_model.config.sampling_rate
415
-
416
- # Save the audio as a WAV file
417
- output_filename = f"{request_id}.wav"
418
- output_path = os.path.join(AUDIO_DIR, output_filename)
419
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
420
- logger.info(f"Saved synthesized audio to {output_path}")
421
-
422
- # Generate a URL to the WAV file
423
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
424
- logger.info("TTS conversion completed")
425
- except Exception as e:
426
- logger.error(f"Error during TTS conversion: {str(e)}")
427
- output_audio_url = None
428
 
429
  return {
430
  "request_id": request_id,
@@ -432,13 +467,14 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
432
  "message": "Translation and TTS completed (or partially completed).",
433
  "source_text": text,
434
  "translated_text": translated_text,
435
- "output_audio": output_audio_url
 
436
  }
437
 
438
  @app.post("/translate-audio")
439
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
440
  """Endpoint to transcribe, translate, and convert audio to speech"""
441
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
442
 
443
  if not audio:
444
  raise HTTPException(status_code=400, detail="No audio file provided")
@@ -448,17 +484,38 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
448
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
449
  request_id = str(uuid.uuid4())
450
 
451
- # Check if STT model is loaded
452
- if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
453
- logger.warning("STT model not loaded, returning placeholder response")
454
- return {
455
- "request_id": request_id,
456
- "status": "processing",
457
- "message": "STT model not loaded yet. Please try again later.",
458
- "source_text": "Transcription not available",
459
- "translated_text": "Translation not available",
460
- "output_audio": None
461
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
  # Save the uploaded audio to a temporary file
464
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
@@ -468,6 +525,7 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
468
  transcription = "Transcription not available"
469
  translated_text = "Translation not available"
470
  output_audio_url = None
 
471
 
472
  try:
473
  # Step 1: Load and resample the audio using torchaudio
@@ -490,29 +548,49 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
490
  "message": "No speech detected in the audio.",
491
  "source_text": "No speech detected",
492
  "translated_text": "No translation available",
493
- "output_audio": None
 
494
  }
495
 
496
  # Step 3: Transcribe the audio (STT)
497
  device = "cuda" if torch.cuda.is_available() else "cpu"
498
- logger.info(f"Using device: {device}")
499
- inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
500
- logger.info("Audio processed, generating transcription...")
501
 
502
- with torch.no_grad():
503
- if model_status["stt"] == "loaded_whisper":
504
- # Whisper model
505
- generated_ids = stt_model.generate(**inputs, language="en")
506
- transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
507
- else:
508
- # MMS model
509
- logits = stt_model(**inputs).logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  predicted_ids = torch.argmax(logits, dim=-1)
511
- transcription = stt_processor.batch_decode(predicted_ids)[0]
 
512
  logger.info(f"Transcription completed: {transcription}")
513
 
514
  # Step 4: Translate the transcribed text (MT)
515
- source_code = LANGUAGE_MAPPING[source_lang]
516
  target_code = LANGUAGE_MAPPING[target_lang]
517
 
518
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
@@ -535,53 +613,21 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
535
  else:
536
  logger.warning("MT model not loaded, skipping translation")
537
 
538
- # Step 5: Update TTS model if the target language doesn't match the current TTS language
539
- if current_tts_language != target_code:
540
- try:
541
- logger.info(f"Updating TTS model for {target_code}...")
542
- from transformers import VitsModel, AutoTokenizer
543
- tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}")
544
- tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}")
545
- tts_model.to(device)
546
- current_tts_language = target_code
547
- logger.info(f"TTS model updated to {target_code}")
548
- model_status["tts"] = "loaded"
549
- except Exception as e:
550
- logger.error(f"Failed to load TTS model for {target_code}: {str(e)}")
551
- try:
552
- logger.info("Falling back to MMS-TTS English model...")
553
- tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
554
- tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
555
- tts_model.to(device)
556
- current_tts_language = "eng"
557
- logger.info("Fallback TTS model loaded successfully")
558
- model_status["tts"] = "loaded (fallback)"
559
- except Exception as e2:
560
- logger.error(f"Failed to load fallback TTS model: {str(e2)}")
561
- model_status["tts"] = "failed"
562
 
563
  # Step 6: Convert translated text to speech (TTS)
564
- if model_status["tts"].startswith("loaded") and tts_model is not None and tts_tokenizer is not None:
565
- try:
566
- inputs = tts_tokenizer(translated_text, return_tensors="pt").to(device)
567
- with torch.no_grad():
568
- output = tts_model(**inputs)
569
- speech = output.waveform.cpu().numpy().squeeze()
570
- speech = (speech * 32767).astype(np.int16)
571
- sample_rate = tts_model.config.sampling_rate
572
-
573
- # Save the audio as a WAV file
574
- output_filename = f"{request_id}.wav"
575
- output_path = os.path.join(AUDIO_DIR, output_filename)
576
- save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
577
- logger.info(f"Saved synthesized audio to {output_path}")
578
-
579
- # Generate a URL to the WAV file
580
- output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
581
- logger.info("TTS conversion completed")
582
- except Exception as e:
583
- logger.error(f"Error during TTS conversion: {str(e)}")
584
- output_audio_url = None
585
 
586
  return {
587
  "request_id": request_id,
@@ -589,7 +635,8 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
589
  "message": "Transcription, translation, and TTS completed (or partially completed).",
590
  "source_text": transcription,
591
  "translated_text": translated_text,
592
- "output_audio": output_audio_url
 
593
  }
594
  except Exception as e:
595
  logger.error(f"Error during processing: {str(e)}")
@@ -599,7 +646,8 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
599
  "message": f"Processing failed: {str(e)}",
600
  "source_text": transcription,
601
  "translated_text": translated_text,
602
- "output_audio": output_audio_url
 
603
  }
604
  finally:
605
  logger.info(f"Cleaning up temporary file: {temp_path}")
 
12
  import torchaudio
13
  import wave
14
  import time
15
+ import re
16
  from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
17
  from fastapi.responses import JSONResponse
18
  from fastapi.staticfiles import StaticFiles
19
+ from typing import Dict, Any, Optional, Tuple, List
20
  from datetime import datetime, timedelta
21
 
22
  # Configure logging
 
35
  loading_in_progress = False
36
  loading_thread = None
37
  model_status = {
38
+ "stt_mms": "not_loaded",
39
+ "stt_whisper_small": "not_loaded",
40
  "mt": "not_loaded",
41
+ "tts": {} # Will store status for each language
42
  }
43
  error_message = None
 
 
 
 
 
 
 
 
 
44
 
45
  # Define the valid languages and mappings
46
  LANGUAGE_MAPPING = {
 
61
  "pag": "pag_Latn"
62
  }
63
 
64
+ # Model dictionaries for different languages
65
+ stt_models = {
66
+ "mms": None,
67
+ "mms_processor": None,
68
+ "whisper_small": None,
69
+ "whisper_small_processor": None
70
+ }
71
+
72
+ mt_model = None
73
+ mt_tokenizer = None
74
+
75
+ tts_models = {} # Will store models for each language
76
+ tts_tokenizers = {} # Will store tokenizers for each language
77
+
78
+ # List of inappropriate words/phrases for content filtering
79
+ INAPPROPRIATE_WORDS = [
80
+ "fuck", "shit", "asshole", "bitch", "dick", "pussy", "cunt",
81
+ "whore", "slut", "bastard", "damn", "hell", "piss", "nigger",
82
+ "faggot", "retard", "crap", "porn", "sex", "penis", "vagina",
83
+ # Tagalog inappropriate words
84
+ "puta", "putangina", "gago", "bobo", "tanga", "tarantado",
85
+ "inutil", "ulol", "kantot", "jakol", "tite", "pekpek",
86
+ # Add more as needed
87
+ ]
88
+
89
  # Function to save PCM data as a WAV file
90
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
91
  # Convert pcm_data to a NumPy array of 16-bit integers
 
123
  # For now, we assume if RMS is above threshold, there is speech
124
  return True
125
 
126
+ # Function to check for inappropriate content
127
+ def check_inappropriate_content(text: str) -> bool:
128
+ """
129
+ Checks if the text contains inappropriate content.
130
+ Returns True if inappropriate content is detected, False otherwise.
131
+ """
132
+ # Convert text to lowercase for case-insensitive matching
133
+ text_lower = text.lower()
134
+
135
+ # Check if any inappropriate word is in the text
136
+ for word in INAPPROPRIATE_WORDS:
137
+ # Use word boundary regex to match whole words only
138
+ pattern = r'\b' + re.escape(word) + r'\b'
139
+ if re.search(pattern, text_lower):
140
+ logger.warning(f"Inappropriate content detected: '{word}'")
141
+ return True
142
+
143
+ return False
144
+
145
+ # Function to perform text-to-speech conversion
146
+ def text_to_speech(text: str, language_code: str) -> Tuple[Optional[np.ndarray], Optional[int], Optional[str]]:
147
+ """
148
+ Convert text to speech using the appropriate TTS model.
149
+ Returns the speech waveform, sample rate, and any error message.
150
+ """
151
+ if language_code not in tts_models or tts_models[language_code] is None:
152
+ error_msg = f"TTS model for {language_code} not loaded"
153
+ logger.error(error_msg)
154
+ return None, None, error_msg
155
+
156
+ try:
157
+ device = "cuda" if torch.cuda.is_available() else "cpu"
158
+ inputs = tts_tokenizers[language_code](text, return_tensors="pt").to(device)
159
+
160
+ with torch.no_grad():
161
+ output = tts_models[language_code](**inputs)
162
+
163
+ speech = output.waveform.cpu().numpy().squeeze()
164
+ speech = (speech * 32767).astype(np.int16)
165
+ sample_rate = tts_models[language_code].config.sampling_rate
166
+
167
+ return speech, sample_rate, None
168
+ except Exception as e:
169
+ error_msg = f"Error during TTS conversion: {str(e)}"
170
+ logger.error(error_msg)
171
+ return None, None, error_msg
172
+
173
  # Function to clean up old audio files
174
  def cleanup_old_audio_files():
175
  logger.info("Starting cleanup of old audio files...")
 
194
  # Function to load models in background
195
  def load_models_task():
196
  global models_loaded, loading_in_progress, model_status, error_message
197
+ global stt_models, mt_model, mt_tokenizer, tts_models, tts_tokenizers
198
 
199
  try:
200
  loading_in_progress = True
201
+ device = "cuda" if torch.cuda.is_available() else "cpu"
202
 
203
+ # Load STT models (both MMS and Whisper)
204
+ logger.info("Starting to load STT models...")
 
205
 
206
+ # Load MMS STT model
207
  try:
208
  logger.info("Loading MMS STT model...")
209
+ model_status["stt_mms"] = "loading"
210
+ from transformers import AutoProcessor, AutoModelForCTC
211
+
212
+ stt_models["mms_processor"] = AutoProcessor.from_pretrained("facebook/mms-1b-all")
213
+ stt_models["mms"] = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
214
+ stt_models["mms"].to(device)
215
  logger.info("MMS STT model loaded successfully")
216
+ model_status["stt_mms"] = "loaded"
217
  except Exception as mms_error:
218
  logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
219
+ model_status["stt_mms"] = "failed"
220
+ error_message = f"MMS STT model loading failed: {str(mms_error)}"
221
+
222
+ # Load Whisper Small STT model
223
+ try:
224
+ logger.info("Loading Whisper Small STT model...")
225
+ model_status["stt_whisper_small"] = "loading"
226
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
227
+
228
+ stt_models["whisper_small_processor"] = WhisperProcessor.from_pretrained("openai/whisper-small")
229
+ stt_models["whisper_small"] = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
230
+ stt_models["whisper_small"].to(device)
231
+ logger.info("Whisper Small STT model loaded successfully")
232
+ model_status["stt_whisper_small"] = "loaded"
233
+ except Exception as whisper_error:
234
+ logger.error(f"Failed to load Whisper Small STT model: {str(whisper_error)}")
235
+ model_status["stt_whisper_small"] = "failed"
236
+ error_message = f"Whisper Small STT model loading failed: {str(whisper_error)}"
237
 
238
  # Load MT model
239
  logger.info("Starting to load MT model...")
 
251
  logger.error(f"Failed to load MT model: {str(e)}")
252
  model_status["mt"] = "failed"
253
  error_message = f"MT model loading failed: {str(e)}"
254
+
255
+ # Load TTS models for all supported languages
256
+ logger.info("Starting to load TTS models for all languages...")
 
257
  from transformers import VitsModel, AutoTokenizer
258
 
259
+ for lang_name, lang_code in LANGUAGE_MAPPING.items():
 
 
 
 
 
 
 
 
 
 
260
  try:
261
+ logger.info(f"Loading MMS-TTS model for {lang_name} ({lang_code})...")
262
+ model_status["tts"][lang_code] = "loading"
263
+
264
+ # Load the model and tokenizer
265
+ tts_models[lang_code] = VitsModel.from_pretrained(f"facebook/mms-tts-{lang_code}")
266
+ tts_tokenizers[lang_code] = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{lang_code}")
267
+
268
+ # Move to GPU if available
269
+ tts_models[lang_code].to(device)
270
+
271
+ logger.info(f"TTS model for {lang_name} loaded successfully")
272
+ model_status["tts"][lang_code] = "loaded"
273
+ except Exception as e:
274
+ logger.error(f"Failed to load TTS model for {lang_name}: {str(e)}")
275
+ model_status["tts"][lang_code] = "failed"
276
+
277
+ # Try to load English as fallback if this is not English
278
+ if lang_code != "eng":
279
+ try:
280
+ logger.info(f"Trying to load English TTS model as fallback for {lang_name}...")
281
+ # Only load English model once if not already loaded
282
+ if "eng" not in tts_models or tts_models["eng"] is None:
283
+ tts_models["eng"] = VitsModel.from_pretrained("facebook/mms-tts-eng")
284
+ tts_tokenizers["eng"] = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
285
+ tts_models["eng"].to(device)
286
+ model_status["tts"]["eng"] = "loaded"
287
+
288
+ # Point this language to use English model
289
+ tts_models[lang_code] = tts_models["eng"]
290
+ tts_tokenizers[lang_code] = tts_tokenizers["eng"]
291
+ model_status["tts"][lang_code] = "loaded (fallback to eng)"
292
+ except Exception as e2:
293
+ logger.error(f"Failed to load English fallback TTS model: {str(e2)}")
294
+ model_status["tts"][lang_code] = "failed (with fallback)"
295
 
296
+ # Set models_loaded flag based on which critical models are loaded
297
+ # Consider the system usable if we have at least one STT model, the MT model, and at least one TTS model
298
+ stt_loaded = model_status["stt_mms"] == "loaded" or model_status["stt_whisper_small"] == "loaded"
299
+ mt_loaded = model_status["mt"] == "loaded"
300
+ any_tts_loaded = any(status == "loaded" or status.startswith("loaded (fallback")
301
+ for status in model_status["tts"].values())
302
 
303
+ models_loaded = stt_loaded and mt_loaded and any_tts_loaded
304
+
305
+ if models_loaded:
306
+ logger.info("Critical models loaded successfully - system is ready")
307
+ else:
308
+ logger.warning("Some critical models failed to load - system may have limited functionality")
309
+
310
  except Exception as e:
311
  error_message = str(e)
312
  logger.error(f"Error in model loading task: {str(e)}")
 
316
  # Start loading models in background
317
  def start_model_loading():
318
  global loading_thread, loading_in_progress
319
+ if not loading_in_progress:
320
  loading_in_progress = True
321
  loading_thread = threading.Thread(target=load_models_task)
322
  loading_thread.daemon = True
 
354
  "error": error_message
355
  }
356
 
357
+ @app.post("/synthesize-speech")
358
+ async def synthesize_speech(text: str = Form(...), language: str = Form(...)):
359
+ """Endpoint to synthesize speech from text without translation"""
360
+ if language not in LANGUAGE_MAPPING:
 
361
  raise HTTPException(status_code=400, detail="Invalid language selected")
362
 
363
+ logger.info(f"Speech synthesis requested for text in {language}")
364
+ request_id = str(uuid.uuid4())
365
+ language_code = LANGUAGE_MAPPING[language]
366
 
367
+ # Check if the TTS model is loaded
368
+ if language_code not in tts_models or tts_models[language_code] is None:
369
+ return {
370
+ "request_id": request_id,
371
+ "status": "failed",
372
+ "message": f"TTS model for {language} not loaded yet",
373
+ "output_audio": None,
374
+ "is_inappropriate": False
375
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
+ # Check for inappropriate content
378
+ is_inappropriate = check_inappropriate_content(text)
379
+
380
+ # Generate speech
381
+ speech, sample_rate, error = text_to_speech(text, language_code)
382
+
383
+ if error:
384
+ return {
385
+ "request_id": request_id,
386
+ "status": "failed",
387
+ "message": error,
388
+ "output_audio": None,
389
+ "is_inappropriate": is_inappropriate
390
+ }
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ # Save the synthesized audio
393
+ output_filename = f"{request_id}.wav"
394
+ output_path = os.path.join(AUDIO_DIR, output_filename)
395
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
396
+
397
+ # Generate URL to the WAV file
398
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
399
+
400
+ return {
401
+ "request_id": request_id,
402
+ "status": "completed",
403
+ "message": "Speech synthesis completed",
404
+ "output_audio": output_audio_url,
405
+ "is_inappropriate": is_inappropriate
406
+ }
407
 
408
  @app.post("/translate-text")
409
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
410
  """Endpoint to translate text and convert to speech"""
411
+ global mt_model, mt_tokenizer
412
 
413
  if not text:
414
  raise HTTPException(status_code=400, detail="No text provided")
 
443
  translated_text = f"Translation failed: {str(e)}"
444
  else:
445
  logger.warning("MT model not loaded, skipping translation")
446
+
447
+ # Check for inappropriate content in the translation
448
+ is_inappropriate = check_inappropriate_content(translated_text)
449
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  # Convert translated text to speech
451
+ speech, sample_rate, error = text_to_speech(translated_text, target_code)
452
+
453
  output_audio_url = None
454
+ if speech is not None and sample_rate is not None:
455
+ # Save the audio as a WAV file
456
+ output_filename = f"{request_id}.wav"
457
+ output_path = os.path.join(AUDIO_DIR, output_filename)
458
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
459
+
460
+ # Generate a URL to the WAV file
461
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
462
+ logger.info("TTS conversion completed")
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
  return {
465
  "request_id": request_id,
 
467
  "message": "Translation and TTS completed (or partially completed).",
468
  "source_text": text,
469
  "translated_text": translated_text,
470
+ "output_audio": output_audio_url,
471
+ "is_inappropriate": is_inappropriate
472
  }
473
 
474
  @app.post("/translate-audio")
475
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
476
  """Endpoint to transcribe, translate, and convert audio to speech"""
477
+ global stt_models, mt_model, mt_tokenizer
478
 
479
  if not audio:
480
  raise HTTPException(status_code=400, detail="No audio file provided")
 
484
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
485
  request_id = str(uuid.uuid4())
486
 
487
+ # Check if appropriate STT model is loaded
488
+ source_code = LANGUAGE_MAPPING[source_lang]
489
+ use_whisper = source_code in ["eng", "tgl"] # Use Whisper for English or Tagalog
490
+
491
+ if use_whisper and (model_status["stt_whisper_small"] != "loaded" or stt_models["whisper_small"] is None):
492
+ logger.warning("Whisper Small STT model not loaded for English/Tagalog, checking MMS")
493
+ if model_status["stt_mms"] != "loaded" or stt_models["mms"] is None:
494
+ logger.warning("MMS STT model not loaded either, returning placeholder response")
495
+ return {
496
+ "request_id": request_id,
497
+ "status": "processing",
498
+ "message": "STT models not loaded yet. Please try again later.",
499
+ "source_text": "Transcription not available",
500
+ "translated_text": "Translation not available",
501
+ "output_audio": None,
502
+ "is_inappropriate": False
503
+ }
504
+ use_whisper = False # Fall back to MMS
505
+ elif not use_whisper and (model_status["stt_mms"] != "loaded" or stt_models["mms"] is None):
506
+ logger.warning("MMS STT model not loaded for non-English/Tagalog, checking Whisper")
507
+ if model_status["stt_whisper_small"] != "loaded" or stt_models["whisper_small"] is None:
508
+ logger.warning("Whisper Small STT model not loaded either, returning placeholder response")
509
+ return {
510
+ "request_id": request_id,
511
+ "status": "processing",
512
+ "message": "STT models not loaded yet. Please try again later.",
513
+ "source_text": "Transcription not available",
514
+ "translated_text": "Translation not available",
515
+ "output_audio": None,
516
+ "is_inappropriate": False
517
+ }
518
+ use_whisper = True # Fall back to Whisper
519
 
520
  # Save the uploaded audio to a temporary file
521
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
 
525
  transcription = "Transcription not available"
526
  translated_text = "Translation not available"
527
  output_audio_url = None
528
+ is_inappropriate = False
529
 
530
  try:
531
  # Step 1: Load and resample the audio using torchaudio
 
548
  "message": "No speech detected in the audio.",
549
  "source_text": "No speech detected",
550
  "translated_text": "No translation available",
551
+ "output_audio": None,
552
+ "is_inappropriate": False
553
  }
554
 
555
  # Step 3: Transcribe the audio (STT)
556
  device = "cuda" if torch.cuda.is_available() else "cpu"
557
+ logger.info(f"Using device: {device} with {'Whisper' if use_whisper else 'MMS'} model")
 
 
558
 
559
+ if use_whisper:
560
+ # Use Whisper Small for English or Tagalog
561
+ logger.info("Using Whisper Small for transcription")
562
+ processor = stt_models["whisper_small_processor"]
563
+ model = stt_models["whisper_small"]
564
+
565
+ inputs = processor(waveform.numpy()[0], sampling_rate=16000, return_tensors="pt").to(device)
566
+ with torch.no_grad():
567
+ # Use the language code for forced decoding if source is English or Tagalog
568
+ language = "en" if source_code == "eng" else "tl" if source_code == "tgl" else None
569
+ generated_ids = model.generate(
570
+ **inputs,
571
+ language=language,
572
+ task="transcribe"
573
+ )
574
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
575
+ else:
576
+ # Use MMS for other languages
577
+ logger.info("Using MMS for transcription")
578
+ processor = stt_models["mms_processor"]
579
+ model = stt_models["mms"]
580
+
581
+ if source_code in processor.tokenizer.vocab.keys():
582
+ processor.tokenizer.set_target_lang(source_code)
583
+ model.load_adapter(source_code)
584
+
585
+ inputs = processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
586
+ with torch.no_grad():
587
+ logits = model(**inputs).logits
588
  predicted_ids = torch.argmax(logits, dim=-1)
589
+ transcription = processor.batch_decode(predicted_ids)[0]
590
+
591
  logger.info(f"Transcription completed: {transcription}")
592
 
593
  # Step 4: Translate the transcribed text (MT)
 
594
  target_code = LANGUAGE_MAPPING[target_lang]
595
 
596
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
 
613
  else:
614
  logger.warning("MT model not loaded, skipping translation")
615
 
616
+ # Step 5: Check for inappropriate content in the translation
617
+ is_inappropriate = check_inappropriate_content(translated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
  # Step 6: Convert translated text to speech (TTS)
620
+ speech, sample_rate, error = text_to_speech(translated_text, target_code)
621
+
622
+ if speech is not None and sample_rate is not None:
623
+ # Save the audio as a WAV file
624
+ output_filename = f"{request_id}.wav"
625
+ output_path = os.path.join(AUDIO_DIR, output_filename)
626
+ save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
627
+
628
+ # Generate a URL to the WAV file
629
+ output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
630
+ logger.info("TTS conversion completed")
 
 
 
 
 
 
 
 
 
 
631
 
632
  return {
633
  "request_id": request_id,
 
635
  "message": "Transcription, translation, and TTS completed (or partially completed).",
636
  "source_text": transcription,
637
  "translated_text": translated_text,
638
+ "output_audio": output_audio_url,
639
+ "is_inappropriate": is_inappropriate
640
  }
641
  except Exception as e:
642
  logger.error(f"Error during processing: {str(e)}")
 
646
  "message": f"Processing failed: {str(e)}",
647
  "source_text": transcription,
648
  "translated_text": translated_text,
649
+ "output_audio": output_audio_url,
650
+ "is_inappropriate": is_inappropriate
651
  }
652
  finally:
653
  logger.info(f"Cleaning up temporary file: {temp_path}")