import os import logging import torch import numpy as np import soundfile as sf import threading import gc logger = logging.getLogger(__name__) NLLB_LANG_CODES = { "Marathi": "mar_Deva", "Hindi": "hin_Deva", "English": "eng_Latn", } # Optimize Torch for CPU-only environments like HF Spaces if not torch.cuda.is_available(): torch.set_num_threads(int(os.cpu_count() or 1)) class ModelManager: def __init__(self, cache_dir="./models"): self.cache_dir = os.path.abspath(cache_dir) os.environ["HF_HOME"] = os.path.join(self.cache_dir, "hf_cache") self.lock = threading.RLock() # Reentrant lock for concurrent requests safety self.ci_mode = os.environ.get("CI_MODE", "false").lower() == "true" # Select device automatically if self.ci_mode: self.device = "cpu" elif torch.cuda.is_available(): self.device = "cuda" elif torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" logger.info("ModelManager initialized using device: %s (CI_MODE=%s)", self.device, self.ci_mode) logger.info("Cache directory: %s", self.cache_dir) # Lazy load containers self.whisper_pipe = {} self.nllb_model = None self.nllb_tokenizer = None self.tts_models = {} self.tts_tokenizers = {} def _clear_memory(self): """Force garbage collection and clear torch cache if on GPU""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() elif self.device == "mps": torch.mps.empty_cache() def get_whisper(self, size="base"): with self.lock: if size not in self.whisper_pipe: from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline model_id = f"openai/whisper-{size}" logger.info("Loading STT model %s from %s on %s...", model_id, self.cache_dir, self.device) try: processor = WhisperProcessor.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) self.whisper_pipe[size] = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, chunk_length_s=30, device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps") ) logger.info("Whisper-%s loaded successfully.", size) except Exception as e: logger.error("Error loading Whisper-%s: %s", size, e) self.whisper_pipe[size] = pipeline( "automatic-speech-recognition", model=model_id, cache_dir=self.cache_dir, chunk_length_s=30, device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps") ) return self.whisper_pipe[size] def get_nllb(self): with self.lock: if self.nllb_model is None: from transformers import AutoModelForSeq2SeqLM, AutoTokenizer model_id = "facebook/nllb-200-distilled-600M" logger.info("Loading NLLB-200 translation model from %s on %s...", self.cache_dir, self.device) try: self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device) logger.info("NLLB-200 loaded successfully.") except Exception as e: logger.error("Error loading NLLB-200: %s", e) self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir) self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device) return self.nllb_model, self.nllb_tokenizer def get_tts(self, lang): with self.lock: if lang not in self.tts_models: from transformers import AutoTokenizer, VitsModel model_id = { "Hindi": "facebook/mms-tts-hin", "Marathi": "facebook/mms-tts-mar", "English": "facebook/mms-tts-eng" }.get(lang) if not model_id: raise ValueError(f"Unsupported TTS language: {lang}") logger.info("Loading TTS model for %s (%s) on %s...", lang, model_id, self.device) try: self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device) logger.info("TTS model for %s loaded successfully.", lang) except Exception as e: logger.error("Error loading TTS for %s: %s", lang, e) self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir) self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device) return self.tts_models[lang], self.tts_tokenizers[lang] def transcribe(self, audio_path, size="base", language="English"): if self.ci_mode: return { "text": "This is a mock transcription for CI mode.", "segments": [ {"start": 0.0, "end": 2.0, "text": "This is a mock"}, {"start": 2.0, "end": 4.0, "text": "transcription for CI mode."} ] } with self.lock: # Map human language name to Whisper code # If "auto", we pass None to let Whisper detect lang_code = { "Marathi": "mr", "Hindi": "hi", "English": "en" }.get(language) pipe = self.get_whisper(size) lang_label = "auto" if lang_code is None else lang_code logger.info("Transcribing %s using Whisper-%s (language=%s)...", audio_path, size, lang_label) # Run Whisper ASR pipeline with timestamps gen_kwargs = {"return_timestamps": True} if lang_code: gen_kwargs["language"] = lang_code # Using the pipeline's underlying model to get detected language if needed # but for simplicity, we can also look at the pipe.model.config.lang # Actually, the pipeline result might have it if we use return_timestamps=True and return_language=True result = pipe( audio_path, chunk_length_s=30, stride_length_s=5, generate_kwargs=gen_kwargs ) self._clear_memory() # Extract segments from chunks chunks = result.get("chunks", []) segments = [] for chunk in chunks: start, end = chunk.get("timestamp", (0, 0)) segments.append({ "start": start, "end": end, "text": chunk.get("text", "") }) # Extract detected language if auto detected_lang = language if language == "auto" and hasattr(pipe.model, "config") and hasattr(pipe.model.config, "lang"): # Map ISO code back to name iso_to_name = { "mr": "Marathi", "hi": "Hindi", "en": "English" } detected_lang = iso_to_name.get(pipe.model.config.lang, "English") return { "text": result.get("text", ""), "segments": segments, "detected_language": detected_lang } def translate(self, text, src_lang, tgt_lang): if not text.strip(): return "" if src_lang == tgt_lang: return text # Map languages to NLLB-200 code lang_codes = { "Marathi": "mar_Deva", "Hindi": "hin_Deva", "English": "eng_Latn" } src_code = lang_codes.get(src_lang) tgt_code = lang_codes.get(tgt_lang) if not src_code or not tgt_code: raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}") if self.ci_mode: return f"[CI MOCK] {tgt_lang}: {text}" with self.lock: model, tokenizer = self.get_nllb() logger.info("Translating text using NLLB-200 (%s -> %s)...", src_code, tgt_code) # Tokenize and force generation in target language tokenizer.src_lang = src_code inputs = tokenizer(text, return_tensors="pt").to(self.device) forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) with torch.no_grad(): translated_tokens = model.generate( **inputs, forced_bos_token_id=forced_bos_token_id, max_length=256, num_beams=4, no_repeat_ngram_size=3 ) translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] self._clear_memory() return translated_text def translate_batch(self, texts, src_lang, tgt_lang): if not texts: return [] if self.ci_mode: return [f"[CI MOCK] {tgt_lang}: {t}" if t.strip() else t for t in texts] if src_lang == tgt_lang: return texts # Map languages to NLLB-200 code src_code = NLLB_LANG_CODES.get(src_lang) tgt_code = NLLB_LANG_CODES.get(tgt_lang) if not src_code or not tgt_code: raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}") # Filter out empty/whitespace-only texts but keep track of indices to restore them non_empty_indices = [] non_empty_texts = [] for idx, text in enumerate(texts): if text.strip(): non_empty_indices.append(idx) non_empty_texts.append(text) results = [""] * len(texts) if not non_empty_texts: # All texts were empty for idx, text in enumerate(texts): if not text.strip(): results[idx] = text # Preserve original whitespace if any return results with self.lock: model, tokenizer = self.get_nllb() logger.info("Batch translating %d items using NLLB-200 (%s -> %s)...", len(non_empty_texts), src_code, tgt_code) tokenizer.src_lang = src_code inputs = tokenizer(non_empty_texts, return_tensors="pt", padding=True).to(self.device) forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) with torch.no_grad(): translated_tokens = model.generate( **inputs, forced_bos_token_id=forced_bos_token_id, max_length=256, num_beams=4, no_repeat_ngram_size=3 ) translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) self._clear_memory() # Map back to full results list for i, idx in enumerate(non_empty_indices): results[idx] = translated_texts[i] # For empty texts, just preserve them for idx, text in enumerate(texts): if not text.strip(): results[idx] = "" return results def text_to_speech(self, text, lang, output_path, speed=1.0): if not text.strip(): raise ValueError("Empty text provided for TTS") supported_langs = ["Hindi", "Marathi", "English"] if lang not in supported_langs: raise ValueError(f"Unsupported TTS language: {lang}") if self.ci_mode: logger.debug("[CI MOCK] Generating dummy TTS for %s...", lang) dummy_data = np.zeros(16000) sf.write(output_path, dummy_data, 16000) return output_path with self.lock: model, tokenizer = self.get_tts(lang) logger.info("Synthesizing speech for text in %s (speed=%s)...", lang, speed) inputs = tokenizer(text, return_tensors="pt").to(self.device) with torch.no_grad(): output = model(**inputs).waveform # Convert PyTorch tensor to numpy array (1D) waveform_numpy = output.cpu().numpy().squeeze() # Apply speed adjustment if not 1.0 # We use simple linear resampling for speed control if speed != 1.0: import scipy.interpolate x = np.arange(len(waveform_numpy)) new_x = np.linspace(0, len(waveform_numpy)-1, int(len(waveform_numpy) / speed)) f = scipy.interpolate.interp1d(x, waveform_numpy) waveform_numpy = f(new_x).astype(np.float32) # MMS-TTS models output sample rate is 16000Hz sf.write(output_path, waveform_numpy, samplerate=16000) self._clear_memory() logger.info("TTS audio written to: %s", output_path) return output_path