Spaces:
Sleeping
Sleeping
| import os | |
| os.environ["HOME"] = "/root" | |
| os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| import logging | |
| import threading | |
| import tempfile | |
| import uuid | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| import torchaudio | |
| import wave | |
| import time | |
| import re | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from typing import Dict, Any, Optional, Tuple, List | |
| from datetime import datetime, timedelta | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("talklas-api") | |
| app = FastAPI(title="Talklas API") | |
| # Mount a directory to serve audio files | |
| AUDIO_DIR = "/tmp/audio_output" # Use /tmp for temporary files | |
| os.makedirs(AUDIO_DIR, exist_ok=True) | |
| app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output") | |
| # Global variables to track application state | |
| models_loaded = False | |
| loading_in_progress = False | |
| loading_thread = None | |
| model_status = { | |
| "stt_whisper": "not_loaded", | |
| "stt_mms": "not_loaded", | |
| "mt": "not_loaded", | |
| "tts": "not_loaded" | |
| } | |
| error_message = None | |
| current_tts_language = "tgl" # Track the current TTS language | |
| # Model instances | |
| whisper_processor = None | |
| whisper_model = None | |
| mms_processor = None | |
| mms_model = None | |
| mt_model = None | |
| mt_tokenizer = None | |
| tts_model = None | |
| tts_tokenizer = None | |
| # Define the valid languages and mappings | |
| LANGUAGE_MAPPING = { | |
| "English": "eng", | |
| "Tagalog": "tgl", | |
| "Cebuano": "ceb", | |
| "Ilocano": "ilo", | |
| "Waray": "war", | |
| "Pangasinan": "pag" | |
| } | |
| # Define which languages use Whisper vs MMS for STT | |
| WHISPER_LANGUAGES = {"eng", "tgl"} # English and Tagalog use Whisper | |
| MMS_LANGUAGES = {"ceb", "ilo", "war", "pag"} # Other Philippine languages use MMS | |
| NLLB_LANGUAGE_CODES = { | |
| "eng": "eng_Latn", | |
| "tgl": "tgl_Latn", | |
| "ceb": "ceb_Latn", | |
| "ilo": "ilo_Latn", | |
| "war": "war_Latn", | |
| "pag": "pag_Latn" | |
| } | |
| # List of inappropriate words/phrases for content filtering | |
| INAPPROPRIATE_WORDS = [ | |
| # English inappropriate words | |
| "fuck", "shit", "bitch", "ass", "damn", "hell", "bastard", "cunt", "son of a bitch", "dick", "pussy", "motherfucker", | |
| # Philippine languages | |
| "agka baboy", "puta", "putang ina", "gago", "tanga", "hayop", "ulol", "lintik", "animal ka", | |
| "paki", "pakyu", "yawa", "bungol", "gingan", "yawa ka", "peste", "irig", | |
| "pakit", "ayat", "pua", "kayat mo ti agsardeng", "hinampak", "iring ka" | |
| ] | |
| # Function to check for inappropriate content | |
| def check_inappropriate_content(text: str) -> bool: | |
| """ | |
| Check if the text contains inappropriate content. | |
| Returns True if inappropriate content is detected, False otherwise. | |
| """ | |
| # Convert to lowercase for case-insensitive matching | |
| text_lower = text.lower() | |
| # Check for inappropriate words | |
| for word in INAPPROPRIATE_WORDS: | |
| # Use word boundary matching to avoid false positives | |
| pattern = r'\b' + re.escape(word) + r'\b' | |
| if re.search(pattern, text_lower): | |
| logger.warning(f"Inappropriate content detected: {word}") | |
| return True | |
| return False | |
| # Function to save PCM data as a WAV file | |
| def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str): | |
| # Convert pcm_data to a NumPy array of 16-bit integers | |
| pcm_array = np.array(pcm_data, dtype=np.int16) | |
| with wave.open(output_path, 'wb') as wav_file: | |
| # Set WAV parameters: 1 channel (mono), 2 bytes per sample (16-bit), sample rate | |
| wav_file.setnchannels(1) | |
| wav_file.setsampwidth(2) # 16-bit audio | |
| wav_file.setframerate(sample_rate) | |
| # Write the 16-bit PCM data as bytes (little-endian) | |
| wav_file.writeframes(pcm_array.tobytes()) | |
| # Function to detect speech using an energy-based approach | |
| def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool: | |
| """ | |
| Detects if the audio contains speech using an energy-based approach. | |
| Returns True if speech is detected, False otherwise. | |
| """ | |
| # Convert waveform to numpy array | |
| waveform_np = waveform.numpy() | |
| if waveform_np.ndim > 1: | |
| waveform_np = waveform_np.mean(axis=0) # Convert stereo to mono | |
| # Compute RMS energy | |
| rms = np.sqrt(np.mean(waveform_np**2)) | |
| logger.info(f"RMS energy: {rms}") | |
| # Check if RMS energy exceeds the threshold | |
| if rms < threshold: | |
| logger.info("No speech detected: RMS energy below threshold") | |
| return False | |
| # Optionally, check for minimum speech duration (requires more sophisticated VAD) | |
| # For now, we assume if RMS is above threshold, there is speech | |
| return True | |
| # Function to clean up old audio files | |
| def cleanup_old_audio_files(): | |
| logger.info("Starting cleanup of old audio files...") | |
| expiration_time = datetime.now() - timedelta(minutes=10) # Files older than 10 minutes | |
| for filename in os.listdir(AUDIO_DIR): | |
| file_path = os.path.join(AUDIO_DIR, filename) | |
| if os.path.isfile(file_path): | |
| file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
| if file_mtime < expiration_time: | |
| try: | |
| os.unlink(file_path) | |
| logger.info(f"Deleted old audio file: {file_path}") | |
| except Exception as e: | |
| logger.error(f"Error deleting file {file_path}: {str(e)}") | |
| # Background task to periodically clean up audio files | |
| def schedule_cleanup(): | |
| while True: | |
| cleanup_old_audio_files() | |
| time.sleep(300) # Run every 5 minutes (300 seconds) | |
| # Function to load models in background | |
| def load_models_task(): | |
| global models_loaded, loading_in_progress, model_status, error_message | |
| global whisper_processor, whisper_model, mms_processor, mms_model | |
| global mt_model, mt_tokenizer, tts_model, tts_tokenizer | |
| try: | |
| loading_in_progress = True | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load Whisper STT model for English and Tagalog | |
| logger.info("Starting to load Whisper STT model...") | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| try: | |
| logger.info("Loading Whisper STT model...") | |
| model_status["stt_whisper"] = "loading" | |
| whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
| whisper_model.to(device) | |
| logger.info("Whisper STT model loaded successfully") | |
| model_status["stt_whisper"] = "loaded" | |
| except Exception as whisper_error: | |
| logger.error(f"Failed to load Whisper STT model: {str(whisper_error)}") | |
| model_status["stt_whisper"] = "failed" | |
| error_message = f"Whisper STT model loading failed: {str(whisper_error)}" | |
| return | |
| # Load MMS STT model for other Philippine languages | |
| logger.info("Starting to load MMS STT model...") | |
| from transformers import AutoProcessor, AutoModelForCTC | |
| try: | |
| logger.info("Loading MMS STT model...") | |
| model_status["stt_mms"] = "loading" | |
| mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
| mms_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all") | |
| mms_model.to(device) | |
| logger.info("MMS STT model loaded successfully") | |
| model_status["stt_mms"] = "loaded" | |
| except Exception as mms_error: | |
| logger.error(f"Failed to load MMS STT model: {str(mms_error)}") | |
| model_status["stt_mms"] = "failed" | |
| error_message = f"MMS STT model loading failed: {str(mms_error)}" | |
| return | |
| # Load MT model | |
| logger.info("Starting to load MT model...") | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| try: | |
| logger.info("Loading NLLB-200-distilled-600M model...") | |
| model_status["mt"] = "loading" | |
| mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| mt_model.to(device) | |
| logger.info("MT model loaded successfully") | |
| model_status["mt"] = "loaded" | |
| except Exception as e: | |
| logger.error(f"Failed to load MT model: {str(e)}") | |
| model_status["mt"] = "failed" | |
| error_message = f"MT model loading failed: {str(e)}" | |
| return | |
| # Load TTS model (default to Tagalog, will be updated dynamically) | |
| logger.info("Starting to load TTS model...") | |
| from transformers import VitsModel, AutoTokenizer | |
| try: | |
| logger.info("Loading MMS-TTS model for Tagalog...") | |
| model_status["tts"] = "loading" | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl") | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl") | |
| tts_model.to(device) | |
| logger.info("TTS model loaded successfully") | |
| model_status["tts"] = "loaded" | |
| except Exception as e: | |
| logger.error(f"Failed to load TTS model for Tagalog: {str(e)}") | |
| # Fallback to English TTS if the target language fails | |
| try: | |
| logger.info("Falling back to MMS-TTS English model...") | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
| tts_model.to(device) | |
| logger.info("Fallback TTS model loaded successfully") | |
| model_status["tts"] = "loaded (fallback)" | |
| current_tts_language = "eng" | |
| except Exception as e2: | |
| logger.error(f"Failed to load fallback TTS model: {str(e2)}") | |
| model_status["tts"] = "failed" | |
| error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})" | |
| return | |
| models_loaded = True | |
| logger.info("Model loading completed successfully") | |
| except Exception as e: | |
| error_message = str(e) | |
| logger.error(f"Error in model loading task: {str(e)}") | |
| finally: | |
| loading_in_progress = False | |
| # Start loading models in background | |
| def start_model_loading(): | |
| global loading_thread, loading_in_progress | |
| if not loading_in_progress and not models_loaded: | |
| loading_in_progress = True | |
| loading_thread = threading.Thread(target=load_models_task) | |
| loading_thread.daemon = True | |
| loading_thread.start() | |
| # Start the background cleanup task | |
| def start_cleanup_task(): | |
| cleanup_thread = threading.Thread(target=schedule_cleanup) | |
| cleanup_thread.daemon = True | |
| cleanup_thread.start() | |
| # Function to load or update TTS model for a specific language | |
| def load_tts_model_for_language(target_code: str) -> bool: | |
| """ | |
| Load or update the TTS model for the specified language. | |
| Returns True if successful, False otherwise. | |
| """ | |
| global tts_model, tts_tokenizer, current_tts_language, model_status | |
| if target_code not in LANGUAGE_MAPPING.values(): | |
| logger.error(f"Invalid language code: {target_code}") | |
| return False | |
| # Skip if the model is already loaded for the target language | |
| if current_tts_language == target_code and model_status["tts"].startswith("loaded"): | |
| logger.info(f"TTS model for {target_code} is already loaded.") | |
| return True | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| logger.info(f"Loading MMS-TTS model for {target_code}...") | |
| from transformers import VitsModel, AutoTokenizer | |
| tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}") | |
| tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}") | |
| tts_model.to(device) | |
| current_tts_language = target_code | |
| logger.info(f"TTS model updated to {target_code}") | |
| model_status["tts"] = "loaded" | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load TTS model for {target_code}: {str(e)}") | |
| try: | |
| logger.info("Falling back to MMS-TTS English model...") | |
| tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
| tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
| tts_model.to(device) | |
| current_tts_language = "eng" | |
| logger.info("Fallback TTS model loaded successfully") | |
| model_status["tts"] = "loaded (fallback)" | |
| return True | |
| except Exception as e2: | |
| logger.error(f"Failed to load fallback TTS model: {str(e2)}") | |
| model_status["tts"] = "failed" | |
| return False | |
| # Function to synthesize speech from text | |
| def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]: | |
| """ | |
| Convert text to speech for the specified language. | |
| Returns a tuple of (output_path, error_message). | |
| """ | |
| global tts_model, tts_tokenizer | |
| request_id = str(uuid.uuid4()) | |
| output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav") | |
| # Make sure the TTS model is loaded for the target language | |
| if not load_tts_model_for_language(target_code): | |
| return None, "Failed to load TTS model for the target language" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| inputs = tts_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = tts_model(**inputs) | |
| speech = output.waveform.cpu().numpy().squeeze() | |
| speech = (speech * 32767).astype(np.int16) | |
| sample_rate = tts_model.config.sampling_rate | |
| # Save the audio as a WAV file | |
| save_pcm_to_wav(speech.tolist(), sample_rate, output_path) | |
| logger.info(f"Saved synthesized audio to {output_path}") | |
| return output_path, None | |
| except Exception as e: | |
| error_msg = f"Error during TTS conversion: {str(e)}" | |
| logger.error(error_msg) | |
| return None, error_msg | |
| # Start the background processes when the app starts | |
| async def startup_event(): | |
| logger.info("Application starting up...") | |
| start_model_loading() | |
| start_cleanup_task() | |
| async def root(): | |
| """Root endpoint for default health check""" | |
| logger.info("Root endpoint requested") | |
| return {"status": "healthy"} | |
| async def health_check(): | |
| """Health check endpoint that always returns successfully""" | |
| global models_loaded, loading_in_progress, model_status, error_message | |
| logger.info("Health check requested") | |
| return { | |
| "status": "healthy", | |
| "models_loaded": models_loaded, | |
| "loading_in_progress": loading_in_progress, | |
| "model_status": model_status, | |
| "error": error_message | |
| } | |
| async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)): | |
| """Endpoint to translate text and convert to speech""" | |
| global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language | |
| if not text: | |
| raise HTTPException(status_code=400, detail="No text provided") | |
| if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING: | |
| raise HTTPException(status_code=400, detail="Invalid language selected") | |
| logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}") | |
| request_id = str(uuid.uuid4()) | |
| # Translate the text | |
| source_code = LANGUAGE_MAPPING[source_lang] | |
| target_code = LANGUAGE_MAPPING[target_lang] | |
| translated_text = "Translation not available" | |
| if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None: | |
| try: | |
| source_nllb_code = NLLB_LANGUAGE_CODES[source_code] | |
| target_nllb_code = NLLB_LANGUAGE_CODES[target_code] | |
| mt_tokenizer.src_lang = source_nllb_code | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = mt_tokenizer(text, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_tokens = mt_model.generate( | |
| **inputs, | |
| forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code), | |
| max_length=448 | |
| ) | |
| translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| logger.info(f"Translation completed: {translated_text}") | |
| except Exception as e: | |
| logger.error(f"Error during translation: {str(e)}") | |
| translated_text = f"Translation failed: {str(e)}" | |
| else: | |
| logger.warning("MT model not loaded, skipping translation") | |
| # Check for inappropriate content in the source text and translated text | |
| is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text) | |
| if is_inappropriate: | |
| logger.warning("Inappropriate content detected in translation request") | |
| # Convert translated text to speech | |
| output_audio_url = None | |
| if model_status["tts"].startswith("loaded"): | |
| # Load or update TTS model for the target language | |
| if load_tts_model_for_language(target_code): | |
| try: | |
| output_path, error = synthesize_speech(translated_text, target_code) | |
| if output_path: | |
| output_filename = os.path.basename(output_path) | |
| output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}" | |
| logger.info("TTS conversion completed") | |
| except Exception as e: | |
| logger.error(f"Error during TTS conversion: {str(e)}") | |
| return { | |
| "request_id": request_id, | |
| "status": "completed", | |
| "message": "Translation and TTS completed (or partially completed).", | |
| "source_text": text, | |
| "translated_text": translated_text, | |
| "output_audio": output_audio_url, | |
| "is_inappropriate": is_inappropriate | |
| } | |
| async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)): | |
| """Endpoint to transcribe, translate, and convert audio to speech""" | |
| global whisper_processor, whisper_model, mms_processor, mms_model | |
| global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language | |
| if not audio: | |
| raise HTTPException(status_code=400, detail="No audio file provided") | |
| if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING: | |
| raise HTTPException(status_code=400, detail="Invalid language selected") | |
| source_code = LANGUAGE_MAPPING[source_lang] | |
| target_code = LANGUAGE_MAPPING[target_lang] | |
| logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} ({source_code}) to {target_lang} ({target_code})") | |
| request_id = str(uuid.uuid4()) | |
| # Determine which STT model to use based on source language | |
| use_whisper = source_code in WHISPER_LANGUAGES | |
| use_mms = source_code in MMS_LANGUAGES | |
| # Check if the appropriate STT model is loaded | |
| if use_whisper and (model_status["stt_whisper"] != "loaded" or whisper_processor is None or whisper_model is None): | |
| logger.warning("Whisper STT model not loaded for English/Tagalog, returning placeholder response") | |
| return { | |
| "request_id": request_id, | |
| "status": "processing", | |
| "message": "Whisper STT model not loaded yet. Please try again later.", | |
| "source_text": "Transcription not available", | |
| "translated_text": "Translation not available", | |
| "output_audio": None, | |
| "is_inappropriate": False | |
| } | |
| if use_mms and (model_status["stt_mms"] != "loaded" or mms_processor is None or mms_model is None): | |
| logger.warning("MMS STT model not loaded for Philippine languages, returning placeholder response") | |
| return { | |
| "request_id": request_id, | |
| "status": "processing", | |
| "message": "MMS STT model not loaded yet. Please try again later.", | |
| "source_text": "Transcription not available", | |
| "translated_text": "Translation not available", | |
| "output_audio": None, | |
| "is_inappropriate": False | |
| } | |
| # Save the uploaded audio to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
| temp_file.write(await audio.read()) | |
| temp_path = temp_file.name | |
| transcription = "Transcription not available" | |
| translated_text = "Translation not available" | |
| output_audio_url = None | |
| is_inappropriate = False | |
| try: | |
| # Step 1: Load and resample the audio using torchaudio | |
| logger.info(f"Reading audio file: {temp_path}") | |
| waveform, sample_rate = torchaudio.load(temp_path) | |
| logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}") | |
| # Resample to 16 kHz if needed (required by Whisper and MMS models) | |
| if sample_rate != 16000: | |
| logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz") | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| waveform = resampler(waveform) | |
| sample_rate = 16000 | |
| # Step 2: Detect speech | |
| if not detect_speech(waveform, sample_rate): | |
| return { | |
| "request_id": request_id, | |
| "status": "failed", | |
| "message": "No speech detected in the audio.", | |
| "source_text": "No speech detected", | |
| "translated_text": "No translation available", | |
| "output_audio": None, | |
| "is_inappropriate": False | |
| } | |
| # Step 3: Transcribe the audio (STT) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device} for STT") | |
| if use_whisper: | |
| # Use Whisper model for English and Tagalog | |
| logger.info(f"Using Whisper model for language: {source_code}") | |
| # Prepare audio for Whisper | |
| inputs = whisper_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device) | |
| logger.info("Audio processed for Whisper, generating transcription...") | |
| with torch.no_grad(): | |
| # For English, we can specify the language; for Tagalog we use 'tl' | |
| forced_language = "en" if source_code == "eng" else "tl" | |
| generated_ids = whisper_model.generate( | |
| **inputs, | |
| language=forced_language, | |
| task="transcribe" | |
| ) | |
| transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| else: | |
| # Use MMS model for other Philippine languages | |
| logger.info(f"Using MMS model for language: {source_code}") | |
| # Prepare audio for MMS | |
| inputs = mms_processor(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device) | |
| logger.info("Audio processed for MMS, generating transcription...") | |
| with torch.no_grad(): | |
| # Process with MMS | |
| logits = mms_model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = mms_processor.batch_decode(predicted_ids)[0] | |
| logger.info(f"Transcription completed: {transcription}") | |
| # Step 4: Translate the transcribed text (MT) | |
| if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None: | |
| try: | |
| source_nllb_code = NLLB_LANGUAGE_CODES[source_code] | |
| target_nllb_code = NLLB_LANGUAGE_CODES[target_code] | |
| mt_tokenizer.src_lang = source_nllb_code | |
| inputs = mt_tokenizer(transcription, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_tokens = mt_model.generate( | |
| **inputs, | |
| forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code), | |
| max_length=448 | |
| ) | |
| translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
| logger.info(f"Translation completed: {translated_text}") | |
| except Exception as e: | |
| logger.error(f"Error during translation: {str(e)}") | |
| translated_text = f"Translation failed: {str(e)}" | |
| else: | |
| logger.warning("MT model not loaded, skipping translation") | |
| # Step 5: Check for inappropriate content | |
| is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text) | |
| if is_inappropriate: | |
| logger.warning("Inappropriate content detected in audio transcription or translation") | |
| # Step 6: Convert translated text to speech (TTS) | |
| if load_tts_model_for_language(target_code): | |
| try: | |
| output_path, error = synthesize_speech(translated_text, target_code) | |
| if output_path: | |
| output_filename = os.path.basename(output_path) | |
| output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}" | |
| logger.info("TTS conversion completed") | |
| except Exception as e: | |
| logger.error(f"Error during TTS conversion: {str(e)}") | |
| return { | |
| "request_id": request_id, | |
| "status": "completed", | |
| "message": "Transcription, translation, and TTS completed (or partially completed).", | |
| "source_text": transcription, | |
| "translated_text": translated_text, | |
| "output_audio": output_audio_url, | |
| "is_inappropriate": is_inappropriate | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during processing: {str(e)}") | |
| return { | |
| "request_id": request_id, | |
| "status": "failed", | |
| "message": f"Processing failed: {str(e)}", | |
| "source_text": transcription, | |
| "translated_text": translated_text, | |
| "output_audio": output_audio_url, | |
| "is_inappropriate": is_inappropriate | |
| } | |
| finally: | |
| logger.info(f"Cleaning up temporary file: {temp_path}") | |
| os.unlink(temp_path) | |
| async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)): | |
| """Endpoint to convert text to speech in the specified language""" | |
| if not text: | |
| raise HTTPException(status_code=400, detail="No text provided") | |
| if target_lang not in LANGUAGE_MAPPING: | |
| raise HTTPException(status_code=400, detail="Invalid language selected") | |
| logger.info(f"Text-to-speech requested for text in {target_lang}") | |
| request_id = str(uuid.uuid4()) | |
| target_code = LANGUAGE_MAPPING[target_lang] | |
| # Check for inappropriate content | |
| is_inappropriate = check_inappropriate_content(text) | |
| if is_inappropriate: | |
| logger.warning("Inappropriate content detected in text-to-speech request") | |
| # Synthesize speech | |
| output_audio_url = None | |
| if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code): | |
| try: | |
| output_path, error = synthesize_speech(text, target_code) | |
| if output_path: | |
| output_filename = os.path.basename(output_path) | |
| output_audio_url = f"https://jerich-talklasapp.hf.space/audio_output/{output_filename}" | |
| logger.info("TTS conversion completed") | |
| else: | |
| logger.error(f"TTS conversion failed: {error}") | |
| except Exception as e: | |
| logger.error(f"Error during TTS conversion: {str(e)}") | |
| else: | |
| logger.warning("TTS model not loaded and could not be loaded") | |
| return { | |
| "request_id": request_id, | |
| "status": "completed" if output_audio_url else "failed", | |
| "message": "TTS completed" if output_audio_url else "TTS failed", | |
| "text": text, | |
| "output_audio": output_audio_url, | |
| "is_inappropriate": is_inappropriate | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("Starting Uvicorn server...") | |
| uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) |