Jerich commited on
Commit
06e8ec6
·
verified ·
1 Parent(s): 28e1a88

Updated the code to use the Whisper model if the source language is English or Tagalog; otherwise, it will use MMS. Additionally, the link to the synthesized speech has been updated to match the current space.

Browse files
Files changed (1) hide show
  1. app.py +70 -138
app.py CHANGED
@@ -43,8 +43,10 @@ error_message = None
43
  current_tts_language = "tgl" # Track the current TTS language
44
 
45
  # Model instances
46
- stt_processor = None
47
- stt_model = None
 
 
48
  mt_model = None
49
  mt_tokenizer = None
50
  tts_model = None
@@ -85,60 +87,39 @@ def check_inappropriate_content(text: str) -> bool:
85
  Check if the text contains inappropriate content.
86
  Returns True if inappropriate content is detected, False otherwise.
87
  """
88
- # Convert to lowercase for case-insensitive matching
89
  text_lower = text.lower()
90
-
91
- # Check for inappropriate words
92
  for word in INAPPROPRIATE_WORDS:
93
- # Use word boundary matching to avoid false positives
94
  pattern = r'\b' + re.escape(word) + r'\b'
95
  if re.search(pattern, text_lower):
96
  logger.warning(f"Inappropriate content detected: {word}")
97
  return True
98
-
99
  return False
100
 
101
  # Function to save PCM data as a WAV file
102
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
103
- # Convert pcm_data to a NumPy array of 16-bit integers
104
  pcm_array = np.array(pcm_data, dtype=np.int16)
105
-
106
  with wave.open(output_path, 'wb') as wav_file:
107
- # Set WAV parameters: 1 channel (mono), 2 bytes per sample (16-bit), sample rate
108
  wav_file.setnchannels(1)
109
- wav_file.setsampwidth(2) # 16-bit audio
110
  wav_file.setframerate(sample_rate)
111
- # Write the 16-bit PCM data as bytes (little-endian)
112
  wav_file.writeframes(pcm_array.tobytes())
113
 
114
  # Function to detect speech using an energy-based approach
115
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
116
- """
117
- Detects if the audio contains speech using an energy-based approach.
118
- Returns True if speech is detected, False otherwise.
119
- """
120
- # Convert waveform to numpy array
121
  waveform_np = waveform.numpy()
122
  if waveform_np.ndim > 1:
123
- waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono
124
-
125
- # Compute RMS energy
126
  rms = np.sqrt(np.mean(waveform_np**2))
127
  logger.info(f"RMS energy: {rms}")
128
-
129
- # Check if RMS energy exceeds the threshold
130
  if rms < threshold:
131
  logger.info("No speech detected: RMS energy below threshold")
132
  return False
133
-
134
- # Optionally, check for minimum speech duration (requires more sophisticated VAD)
135
- # For now, we assume if RMS is above threshold, there is speech
136
  return True
137
 
138
  # Function to clean up old audio files
139
  def cleanup_old_audio_files():
140
  logger.info("Starting cleanup of old audio files...")
141
- expiration_time = datetime.now() - timedelta(minutes=10) # Files older than 10 minutes
142
  for filename in os.listdir(AUDIO_DIR):
143
  file_path = os.path.join(AUDIO_DIR, filename)
144
  if os.path.isfile(file_path):
@@ -154,42 +135,48 @@ def cleanup_old_audio_files():
154
  def schedule_cleanup():
155
  while True:
156
  cleanup_old_audio_files()
157
- time.sleep(300) # Run every 5 minutes (300 seconds)
158
 
159
  # Function to load models in background
160
  def load_models_task():
161
  global models_loaded, loading_in_progress, model_status, error_message
162
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer
 
163
 
164
  try:
165
  loading_in_progress = True
166
 
167
- # Load STT model (MMS with fallback to Whisper)
168
- logger.info("Starting to load STT model...")
169
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
170
 
171
  try:
172
- logger.info("Loading MMS STT model...")
173
  model_status["stt"] = "loading"
174
- stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
175
- stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
176
  device = "cuda" if torch.cuda.is_available() else "cpu"
177
- stt_model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  logger.info("MMS STT model loaded successfully")
179
- model_status["stt"] = "loaded_mms"
180
- except Exception as mms_error:
181
- logger.error(f"Failed to load MMS STT model: {str(mms_error)}")
182
- logger.info("Falling back to Whisper STT model...")
183
- try:
184
- stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
185
- stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
186
- stt_model.to(device)
187
- logger.info("Whisper STT model loaded successfully as fallback")
188
- model_status["stt"] = "loaded_whisper"
189
- except Exception as whisper_error:
190
- logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}")
191
  model_status["stt"] = "failed"
192
- error_message = f"STT model loading failed: MMS error: {str(mms_error)}, Whisper error: {str(whisper_error)}"
193
  return
194
 
195
  # Load MT model
@@ -210,7 +197,7 @@ def load_models_task():
210
  error_message = f"MT model loading failed: {str(e)}"
211
  return
212
 
213
- # Load TTS model (default to Tagalog, will be updated dynamically)
214
  logger.info("Starting to load TTS model...")
215
  from transformers import VitsModel, AutoTokenizer
216
 
@@ -224,7 +211,6 @@ def load_models_task():
224
  model_status["tts"] = "loaded"
225
  except Exception as e:
226
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
227
- # Fallback to English TTS if the target language fails
228
  try:
229
  logger.info("Falling back to MMS-TTS English model...")
230
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
@@ -265,21 +251,13 @@ def start_cleanup_task():
265
 
266
  # Function to load or update TTS model for a specific language
267
  def load_tts_model_for_language(target_code: str) -> bool:
268
- """
269
- Load or update the TTS model for the specified language.
270
- Returns True if successful, False otherwise.
271
- """
272
  global tts_model, tts_tokenizer, current_tts_language, model_status
273
-
274
  if target_code not in LANGUAGE_MAPPING.values():
275
  logger.error(f"Invalid language code: {target_code}")
276
  return False
277
-
278
- # Skip if the model is already loaded for the target language
279
  if current_tts_language == target_code and model_status["tts"].startswith("loaded"):
280
  logger.info(f"TTS model for {target_code} is already loaded.")
281
  return True
282
-
283
  device = "cuda" if torch.cuda.is_available() else "cpu"
284
  try:
285
  logger.info(f"Loading MMS-TTS model for {target_code}...")
@@ -309,19 +287,11 @@ def load_tts_model_for_language(target_code: str) -> bool:
309
 
310
  # Function to synthesize speech from text
311
  def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]:
312
- """
313
- Convert text to speech for the specified language.
314
- Returns a tuple of (output_path, error_message).
315
- """
316
  global tts_model, tts_tokenizer
317
-
318
  request_id = str(uuid.uuid4())
319
  output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav")
320
-
321
- # Make sure the TTS model is loaded for the target language
322
  if not load_tts_model_for_language(target_code):
323
  return None, "Failed to load TTS model for the target language"
324
-
325
  device = "cuda" if torch.cuda.is_available() else "cpu"
326
  try:
327
  inputs = tts_tokenizer(text, return_tensors="pt").to(device)
@@ -330,11 +300,8 @@ def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optio
330
  speech = output.waveform.cpu().numpy().squeeze()
331
  speech = (speech * 32767).astype(np.int16)
332
  sample_rate = tts_model.config.sampling_rate
333
-
334
- # Save the audio as a WAV file
335
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
336
  logger.info(f"Saved synthesized audio to {output_path}")
337
-
338
  return output_path, None
339
  except Exception as e:
340
  error_msg = f"Error during TTS conversion: {str(e)}"
@@ -350,14 +317,11 @@ async def startup_event():
350
 
351
  @app.get("/")
352
  async def root():
353
- """Root endpoint for default health check"""
354
  logger.info("Root endpoint requested")
355
  return {"status": "healthy"}
356
 
357
  @app.get("/health")
358
  async def health_check():
359
- """Health check endpoint that always returns successfully"""
360
- global models_loaded, loading_in_progress, model_status, error_message
361
  logger.info("Health check requested")
362
  return {
363
  "status": "healthy",
@@ -369,22 +333,16 @@ async def health_check():
369
 
370
  @app.post("/translate-text")
371
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
372
- """Endpoint to translate text and convert to speech"""
373
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
374
-
375
  if not text:
376
  raise HTTPException(status_code=400, detail="No text provided")
377
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
378
  raise HTTPException(status_code=400, detail="Invalid language selected")
379
-
380
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
381
  request_id = str(uuid.uuid4())
382
-
383
- # Translate the text
384
  source_code = LANGUAGE_MAPPING[source_lang]
385
  target_code = LANGUAGE_MAPPING[target_lang]
386
  translated_text = "Translation not available"
387
-
388
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
389
  try:
390
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
@@ -405,26 +363,20 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
405
  translated_text = f"Translation failed: {str(e)}"
406
  else:
407
  logger.warning("MT model not loaded, skipping translation")
408
-
409
- # Check for inappropriate content in the source text and translated text
410
  is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text)
411
  if is_inappropriate:
412
  logger.warning("Inappropriate content detected in translation request")
413
-
414
- # Convert translated text to speech
415
  output_audio_url = None
416
  if model_status["tts"].startswith("loaded"):
417
- # Load or update TTS model for the target language
418
  if load_tts_model_for_language(target_code):
419
  try:
420
  output_path, error = synthesize_speech(translated_text, target_code)
421
  if output_path:
422
  output_filename = os.path.basename(output_path)
423
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
424
  logger.info("TTS conversion completed")
425
  except Exception as e:
426
  logger.error(f"Error during TTS conversion: {str(e)}")
427
-
428
  return {
429
  "request_id": request_id,
430
  "status": "completed",
@@ -437,8 +389,8 @@ async def translate_text(text: str = Form(...), source_lang: str = Form(...), ta
437
 
438
  @app.post("/translate-audio")
439
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
440
- """Endpoint to transcribe, translate, and convert audio to speech"""
441
- global stt_processor, stt_model, mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
442
 
443
  if not audio:
444
  raise HTTPException(status_code=400, detail="No audio file provided")
@@ -448,20 +400,33 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
448
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
449
  request_id = str(uuid.uuid4())
450
 
451
- # Check if STT models are loaded
452
- if model_status["stt"] not in ["loaded_mms", "loaded_mms_default", "loaded_whisper"] or stt_processor is None or stt_model is None:
453
- logger.warning("STT model not loaded, returning placeholder response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  return {
455
  "request_id": request_id,
456
  "status": "processing",
457
- "message": "STT model not loaded yet. Please try again later.",
458
  "source_text": "Transcription not available",
459
  "translated_text": "Translation not available",
460
  "output_audio": None,
461
  "is_inappropriate": False
462
  }
463
 
464
- # Save the uploaded audio to a temporary file
465
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
466
  temp_file.write(await audio.read())
467
  temp_path = temp_file.name
@@ -472,19 +437,16 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
472
  is_inappropriate = False
473
 
474
  try:
475
- # Step 1: Load and resample the audio using torchaudio
476
  logger.info(f"Reading audio file: {temp_path}")
477
  waveform, sample_rate = torchaudio.load(temp_path)
478
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
479
 
480
- # Resample to 16 kHz if needed (required by Whisper and MMS models)
481
  if sample_rate != 16000:
482
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
483
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
484
  waveform = resampler(waveform)
485
  sample_rate = 16000
486
 
487
- # Step 2: Detect speech
488
  if not detect_speech(waveform, sample_rate):
489
  return {
490
  "request_id": request_id,
@@ -496,49 +458,25 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
496
  "is_inappropriate": False
497
  }
498
 
499
- # Step 3: Transcribe the audio (STT)
500
  device = "cuda" if torch.cuda.is_available() else "cpu"
501
  logger.info(f"Using device: {device}")
502
 
503
- # Determine which model to use based on source language
504
- source_code = LANGUAGE_MAPPING[source_lang]
505
- use_whisper = source_code in ["eng", "tgl"] # Use Whisper for English and Tagalog
506
- use_mms = not use_whisper # Use MMS for other Philippine languages
507
-
508
- logger.info(f"Source language: {source_lang} ({source_code}), Using Whisper: {use_whisper}, Using MMS: {use_mms}")
509
-
510
- # Process with appropriate model
511
- inputs = stt_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
512
- logger.info("Audio processed, generating transcription...")
513
-
514
- with torch.no_grad():
515
- if use_whisper and model_status["stt"] == "loaded_whisper":
516
- # Whisper model for English and Tagalog
517
- logger.info(f"Using Whisper model for {source_lang}")
518
- generated_ids = stt_model.generate(**inputs, language="en" if source_code == "eng" else "tl")
519
- transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
520
- elif model_status["stt"] in ["loaded_mms", "loaded_mms_default"]:
521
- # MMS model for other Philippine languages
522
- logger.info(f"Using MMS model for {source_lang}")
523
- logits = stt_model(**inputs).logits
524
  predicted_ids = torch.argmax(logits, dim=-1)
525
- transcription = stt_processor.batch_decode(predicted_ids)[0]
526
- else:
527
- # Fallback to any available model
528
- logger.info(f"Preferred model not available, using fallback model")
529
- if model_status["stt"] == "loaded_whisper":
530
- generated_ids = stt_model.generate(**inputs, language="en")
531
- transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
532
- else:
533
- logits = stt_model(**inputs).logits
534
- predicted_ids = torch.argmax(logits, dim=-1)
535
- transcription = stt_processor.batch_decode(predicted_ids)[0]
536
-
537
  logger.info(f"Transcription completed: {transcription}")
538
 
539
- # Step 4: Translate the transcribed text (MT)
540
  target_code = LANGUAGE_MAPPING[target_lang]
541
-
542
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
543
  try:
544
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
@@ -559,18 +497,16 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
559
  else:
560
  logger.warning("MT model not loaded, skipping translation")
561
 
562
- # Step 5: Check for inappropriate content
563
  is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text)
564
  if is_inappropriate:
565
  logger.warning("Inappropriate content detected in audio transcription or translation")
566
 
567
- # Step 6: Convert translated text to speech (TTS)
568
  if load_tts_model_for_language(target_code):
569
  try:
570
  output_path, error = synthesize_speech(translated_text, target_code)
571
  if output_path:
572
  output_filename = os.path.basename(output_path)
573
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
574
  logger.info("TTS conversion completed")
575
  except Exception as e:
576
  logger.error(f"Error during TTS conversion: {str(e)}")
@@ -601,7 +537,6 @@ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form
601
 
602
  @app.post("/text-to-speech")
603
  async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
604
- """Endpoint to convert text to speech in the specified language"""
605
  if not text:
606
  raise HTTPException(status_code=400, detail="No text provided")
607
  if target_lang not in LANGUAGE_MAPPING:
@@ -611,20 +546,17 @@ async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
611
  request_id = str(uuid.uuid4())
612
 
613
  target_code = LANGUAGE_MAPPING[target_lang]
614
-
615
- # Check for inappropriate content
616
  is_inappropriate = check_inappropriate_content(text)
617
  if is_inappropriate:
618
  logger.warning("Inappropriate content detected in text-to-speech request")
619
 
620
- # Synthesize speech
621
  output_audio_url = None
622
  if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code):
623
  try:
624
  output_path, error = synthesize_speech(text, target_code)
625
  if output_path:
626
  output_filename = os.path.basename(output_path)
627
- output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}"
628
  logger.info("TTS conversion completed")
629
  else:
630
  logger.error(f"TTS conversion failed: {error}")
 
43
  current_tts_language = "tgl" # Track the current TTS language
44
 
45
  # Model instances
46
+ stt_processor_whisper = None
47
+ stt_model_whisper = None
48
+ stt_processor_mms = None
49
+ stt_model_mms = None
50
  mt_model = None
51
  mt_tokenizer = None
52
  tts_model = None
 
87
  Check if the text contains inappropriate content.
88
  Returns True if inappropriate content is detected, False otherwise.
89
  """
 
90
  text_lower = text.lower()
 
 
91
  for word in INAPPROPRIATE_WORDS:
 
92
  pattern = r'\b' + re.escape(word) + r'\b'
93
  if re.search(pattern, text_lower):
94
  logger.warning(f"Inappropriate content detected: {word}")
95
  return True
 
96
  return False
97
 
98
  # Function to save PCM data as a WAV file
99
  def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str):
 
100
  pcm_array = np.array(pcm_data, dtype=np.int16)
 
101
  with wave.open(output_path, 'wb') as wav_file:
 
102
  wav_file.setnchannels(1)
103
+ wav_file.setsampwidth(2)
104
  wav_file.setframerate(sample_rate)
 
105
  wav_file.writeframes(pcm_array.tobytes())
106
 
107
  # Function to detect speech using an energy-based approach
108
  def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool:
 
 
 
 
 
109
  waveform_np = waveform.numpy()
110
  if waveform_np.ndim > 1:
111
+ waveform_np = waveform_np.mean(axis=0)
 
 
112
  rms = np.sqrt(np.mean(waveform_np**2))
113
  logger.info(f"RMS energy: {rms}")
 
 
114
  if rms < threshold:
115
  logger.info("No speech detected: RMS energy below threshold")
116
  return False
 
 
 
117
  return True
118
 
119
  # Function to clean up old audio files
120
  def cleanup_old_audio_files():
121
  logger.info("Starting cleanup of old audio files...")
122
+ expiration_time = datetime.now() - timedelta(minutes=10)
123
  for filename in os.listdir(AUDIO_DIR):
124
  file_path = os.path.join(AUDIO_DIR, filename)
125
  if os.path.isfile(file_path):
 
135
  def schedule_cleanup():
136
  while True:
137
  cleanup_old_audio_files()
138
+ time.sleep(300)
139
 
140
  # Function to load models in background
141
  def load_models_task():
142
  global models_loaded, loading_in_progress, model_status, error_message
143
+ global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms
144
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer
145
 
146
  try:
147
  loading_in_progress = True
148
 
149
+ # Load STT models
150
+ logger.info("Starting to load STT models...")
151
  from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration
152
 
153
  try:
154
+ logger.info("Loading Whisper STT model...")
155
  model_status["stt"] = "loading"
156
+ stt_processor_whisper = WhisperProcessor.from_pretrained("openai/whisper-tiny")
157
+ stt_model_whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
158
  device = "cuda" if torch.cuda.is_available() else "cpu"
159
+ stt_model_whisper.to(device)
160
+ logger.info("Whisper STT model loaded successfully")
161
+ model_status["stt"] = "loaded_whisper"
162
+ except Exception as e:
163
+ logger.error(f"Failed to load Whisper STT model: {str(e)}")
164
+ model_status["stt"] = "failed"
165
+ error_message = f"Whisper STT model loading failed: {str(e)}"
166
+ return
167
+
168
+ try:
169
+ logger.info("Loading MMS STT model...")
170
+ stt_processor_mms = AutoProcessor.from_pretrained("facebook/mms-1b-all")
171
+ stt_model_mms = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
172
+ stt_model_mms.to(device)
173
  logger.info("MMS STT model loaded successfully")
174
+ model_status["stt"] = "loaded_both" if model_status["stt"] == "loaded_whisper" else "loaded_mms"
175
+ except Exception as e:
176
+ logger.error(f"Failed to load MMS STT model: {str(e)}")
177
+ if model_status["stt"] != "loaded_whisper":
 
 
 
 
 
 
 
 
178
  model_status["stt"] = "failed"
179
+ error_message = f"MMS STT model loading failed: {str(e)}"
180
  return
181
 
182
  # Load MT model
 
197
  error_message = f"MT model loading failed: {str(e)}"
198
  return
199
 
200
+ # Load TTS model (default to Tagalog)
201
  logger.info("Starting to load TTS model...")
202
  from transformers import VitsModel, AutoTokenizer
203
 
 
211
  model_status["tts"] = "loaded"
212
  except Exception as e:
213
  logger.error(f"Failed to load TTS model for Tagalog: {str(e)}")
 
214
  try:
215
  logger.info("Falling back to MMS-TTS English model...")
216
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
 
251
 
252
  # Function to load or update TTS model for a specific language
253
  def load_tts_model_for_language(target_code: str) -> bool:
 
 
 
 
254
  global tts_model, tts_tokenizer, current_tts_language, model_status
 
255
  if target_code not in LANGUAGE_MAPPING.values():
256
  logger.error(f"Invalid language code: {target_code}")
257
  return False
 
 
258
  if current_tts_language == target_code and model_status["tts"].startswith("loaded"):
259
  logger.info(f"TTS model for {target_code} is already loaded.")
260
  return True
 
261
  device = "cuda" if torch.cuda.is_available() else "cpu"
262
  try:
263
  logger.info(f"Loading MMS-TTS model for {target_code}...")
 
287
 
288
  # Function to synthesize speech from text
289
  def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]:
 
 
 
 
290
  global tts_model, tts_tokenizer
 
291
  request_id = str(uuid.uuid4())
292
  output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav")
 
 
293
  if not load_tts_model_for_language(target_code):
294
  return None, "Failed to load TTS model for the target language"
 
295
  device = "cuda" if torch.cuda.is_available() else "cpu"
296
  try:
297
  inputs = tts_tokenizer(text, return_tensors="pt").to(device)
 
300
  speech = output.waveform.cpu().numpy().squeeze()
301
  speech = (speech * 32767).astype(np.int16)
302
  sample_rate = tts_model.config.sampling_rate
 
 
303
  save_pcm_to_wav(speech.tolist(), sample_rate, output_path)
304
  logger.info(f"Saved synthesized audio to {output_path}")
 
305
  return output_path, None
306
  except Exception as e:
307
  error_msg = f"Error during TTS conversion: {str(e)}"
 
317
 
318
  @app.get("/")
319
  async def root():
 
320
  logger.info("Root endpoint requested")
321
  return {"status": "healthy"}
322
 
323
  @app.get("/health")
324
  async def health_check():
 
 
325
  logger.info("Health check requested")
326
  return {
327
  "status": "healthy",
 
333
 
334
  @app.post("/translate-text")
335
  async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
 
336
  global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
 
337
  if not text:
338
  raise HTTPException(status_code=400, detail="No text provided")
339
  if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
340
  raise HTTPException(status_code=400, detail="Invalid language selected")
 
341
  logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
342
  request_id = str(uuid.uuid4())
 
 
343
  source_code = LANGUAGE_MAPPING[source_lang]
344
  target_code = LANGUAGE_MAPPING[target_lang]
345
  translated_text = "Translation not available"
 
346
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
347
  try:
348
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
 
363
  translated_text = f"Translation failed: {str(e)}"
364
  else:
365
  logger.warning("MT model not loaded, skipping translation")
 
 
366
  is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text)
367
  if is_inappropriate:
368
  logger.warning("Inappropriate content detected in translation request")
 
 
369
  output_audio_url = None
370
  if model_status["tts"].startswith("loaded"):
 
371
  if load_tts_model_for_language(target_code):
372
  try:
373
  output_path, error = synthesize_speech(translated_text, target_code)
374
  if output_path:
375
  output_filename = os.path.basename(output_path)
376
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
377
  logger.info("TTS conversion completed")
378
  except Exception as e:
379
  logger.error(f"Error during TTS conversion: {str(e)}")
 
380
  return {
381
  "request_id": request_id,
382
  "status": "completed",
 
389
 
390
  @app.post("/translate-audio")
391
  async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
392
+ global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms
393
+ global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language
394
 
395
  if not audio:
396
  raise HTTPException(status_code=400, detail="No audio file provided")
 
400
  logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
401
  request_id = str(uuid.uuid4())
402
 
403
+ source_code = LANGUAGE_MAPPING[source_lang]
404
+ use_whisper = source_code in ["eng", "tgl"]
405
+
406
+ # Check if appropriate STT model is loaded
407
+ if use_whisper and (stt_processor_whisper is None or stt_model_whisper is None):
408
+ logger.warning("Whisper STT model not loaded, returning placeholder response")
409
+ return {
410
+ "request_id": request_id,
411
+ "status": "processing",
412
+ "message": "Whisper STT model not loaded yet. Please try again later.",
413
+ "source_text": "Transcription not available",
414
+ "translated_text": "Translation not available",
415
+ "output_audio": None,
416
+ "is_inappropriate": False
417
+ }
418
+ elif not use_whisper and (stt_processor_mms is None or stt_model_mms is None):
419
+ logger.warning("MMS STT model not loaded, returning placeholder response")
420
  return {
421
  "request_id": request_id,
422
  "status": "processing",
423
+ "message": "MMS STT model not loaded yet. Please try again later.",
424
  "source_text": "Transcription not available",
425
  "translated_text": "Translation not available",
426
  "output_audio": None,
427
  "is_inappropriate": False
428
  }
429
 
 
430
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
431
  temp_file.write(await audio.read())
432
  temp_path = temp_file.name
 
437
  is_inappropriate = False
438
 
439
  try:
 
440
  logger.info(f"Reading audio file: {temp_path}")
441
  waveform, sample_rate = torchaudio.load(temp_path)
442
  logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}")
443
 
 
444
  if sample_rate != 16000:
445
  logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
446
  resampler = torchaudio.transforms.Resample(sample_rate, 16000)
447
  waveform = resampler(waveform)
448
  sample_rate = 16000
449
 
 
450
  if not detect_speech(waveform, sample_rate):
451
  return {
452
  "request_id": request_id,
 
458
  "is_inappropriate": False
459
  }
460
 
 
461
  device = "cuda" if torch.cuda.is_available() else "cpu"
462
  logger.info(f"Using device: {device}")
463
 
464
+ if use_whisper:
465
+ logger.info("Using Whisper model for transcription")
466
+ inputs = stt_processor_whisper(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
467
+ with torch.no_grad():
468
+ generated_ids = stt_model_whisper.generate(**inputs, language=source_code)
469
+ transcription = stt_processor_whisper.batch_decode(generated_ids, skip_special_tokens=True)[0]
470
+ else:
471
+ logger.info("Using MMS model for transcription")
472
+ inputs = stt_processor_mms(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device)
473
+ with torch.no_grad():
474
+ logits = stt_model_mms(**inputs).logits
 
 
 
 
 
 
 
 
 
 
475
  predicted_ids = torch.argmax(logits, dim=-1)
476
+ transcription = stt_processor_mms.batch_decode(predicted_ids)[0]
 
 
 
 
 
 
 
 
 
 
 
477
  logger.info(f"Transcription completed: {transcription}")
478
 
 
479
  target_code = LANGUAGE_MAPPING[target_lang]
 
480
  if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None:
481
  try:
482
  source_nllb_code = NLLB_LANGUAGE_CODES[source_code]
 
497
  else:
498
  logger.warning("MT model not loaded, skipping translation")
499
 
 
500
  is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text)
501
  if is_inappropriate:
502
  logger.warning("Inappropriate content detected in audio transcription or translation")
503
 
 
504
  if load_tts_model_for_language(target_code):
505
  try:
506
  output_path, error = synthesize_speech(translated_text, target_code)
507
  if output_path:
508
  output_filename = os.path.basename(output_path)
509
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
510
  logger.info("TTS conversion completed")
511
  except Exception as e:
512
  logger.error(f"Error during TTS conversion: {str(e)}")
 
537
 
538
  @app.post("/text-to-speech")
539
  async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)):
 
540
  if not text:
541
  raise HTTPException(status_code=400, detail="No text provided")
542
  if target_lang not in LANGUAGE_MAPPING:
 
546
  request_id = str(uuid.uuid4())
547
 
548
  target_code = LANGUAGE_MAPPING[target_lang]
 
 
549
  is_inappropriate = check_inappropriate_content(text)
550
  if is_inappropriate:
551
  logger.warning("Inappropriate content detected in text-to-speech request")
552
 
 
553
  output_audio_url = None
554
  if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code):
555
  try:
556
  output_path, error = synthesize_speech(text, target_code)
557
  if output_path:
558
  output_filename = os.path.basename(output_path)
559
+ output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}"
560
  logger.info("TTS conversion completed")
561
  else:
562
  logger.error(f"TTS conversion failed: {error}")