Spaces:
Running
Running
| import os | |
| import logging | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| import threading | |
| import gc | |
| logger = logging.getLogger(__name__) | |
| NLLB_LANG_CODES = { | |
| "Marathi": "mar_Deva", | |
| "Hindi": "hin_Deva", | |
| "English": "eng_Latn", | |
| } | |
| # Optimize Torch for CPU-only environments like HF Spaces | |
| if not torch.cuda.is_available(): | |
| torch.set_num_threads(int(os.cpu_count() or 1)) | |
| class ModelManager: | |
| def __init__(self, cache_dir="./models"): | |
| self.cache_dir = os.path.abspath(cache_dir) | |
| os.environ["HF_HOME"] = os.path.join(self.cache_dir, "hf_cache") | |
| self.lock = threading.RLock() # Reentrant lock for concurrent requests safety | |
| self.ci_mode = os.environ.get("CI_MODE", "false").lower() == "true" | |
| # Select device automatically | |
| if self.ci_mode: | |
| self.device = "cpu" | |
| elif torch.cuda.is_available(): | |
| self.device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| else: | |
| self.device = "cpu" | |
| logger.info("ModelManager initialized using device: %s (CI_MODE=%s)", self.device, self.ci_mode) | |
| logger.info("Cache directory: %s", self.cache_dir) | |
| # Lazy load containers | |
| self.whisper_pipe = {} | |
| self.nllb_model = None | |
| self.nllb_tokenizer = None | |
| self.tts_models = {} | |
| self.tts_tokenizers = {} | |
| def _clear_memory(self): | |
| """Force garbage collection and clear torch cache if on GPU""" | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif self.device == "mps": | |
| torch.mps.empty_cache() | |
| def get_whisper(self, size="base"): | |
| with self.lock: | |
| if size not in self.whisper_pipe: | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline | |
| model_id = f"openai/whisper-{size}" | |
| logger.info("Loading STT model %s from %s on %s...", model_id, self.cache_dir, self.device) | |
| try: | |
| processor = WhisperProcessor.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) | |
| self.whisper_pipe[size] = pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| chunk_length_s=30, | |
| device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps") | |
| ) | |
| logger.info("Whisper-%s loaded successfully.", size) | |
| except Exception as e: | |
| logger.error("Error loading Whisper-%s: %s", size, e) | |
| self.whisper_pipe[size] = pipeline( | |
| "automatic-speech-recognition", | |
| model=model_id, | |
| cache_dir=self.cache_dir, | |
| chunk_length_s=30, | |
| device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps") | |
| ) | |
| return self.whisper_pipe[size] | |
| def get_nllb(self): | |
| with self.lock: | |
| if self.nllb_model is None: | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| model_id = "facebook/nllb-200-distilled-600M" | |
| logger.info("Loading NLLB-200 translation model from %s on %s...", self.cache_dir, self.device) | |
| try: | |
| self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) | |
| self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device) | |
| logger.info("NLLB-200 loaded successfully.") | |
| except Exception as e: | |
| logger.error("Error loading NLLB-200: %s", e) | |
| self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir) | |
| self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device) | |
| return self.nllb_model, self.nllb_tokenizer | |
| def get_tts(self, lang): | |
| with self.lock: | |
| if lang not in self.tts_models: | |
| from transformers import AutoTokenizer, VitsModel | |
| model_id = { | |
| "Hindi": "facebook/mms-tts-hin", | |
| "Marathi": "facebook/mms-tts-mar", | |
| "English": "facebook/mms-tts-eng" | |
| }.get(lang) | |
| if not model_id: | |
| raise ValueError(f"Unsupported TTS language: {lang}") | |
| logger.info("Loading TTS model for %s (%s) on %s...", lang, model_id, self.device) | |
| try: | |
| self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True) | |
| self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device) | |
| logger.info("TTS model for %s loaded successfully.", lang) | |
| except Exception as e: | |
| logger.error("Error loading TTS for %s: %s", lang, e) | |
| self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir) | |
| self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device) | |
| return self.tts_models[lang], self.tts_tokenizers[lang] | |
| def transcribe(self, audio_path, size="base", language="English"): | |
| if self.ci_mode: | |
| return { | |
| "text": "This is a mock transcription for CI mode.", | |
| "segments": [ | |
| {"start": 0.0, "end": 2.0, "text": "This is a mock"}, | |
| {"start": 2.0, "end": 4.0, "text": "transcription for CI mode."} | |
| ] | |
| } | |
| with self.lock: | |
| # Map human language name to Whisper code | |
| # If "auto", we pass None to let Whisper detect | |
| lang_code = { | |
| "Marathi": "mr", | |
| "Hindi": "hi", | |
| "English": "en" | |
| }.get(language) | |
| pipe = self.get_whisper(size) | |
| lang_label = "auto" if lang_code is None else lang_code | |
| logger.info("Transcribing %s using Whisper-%s (language=%s)...", audio_path, size, lang_label) | |
| # Run Whisper ASR pipeline with timestamps | |
| gen_kwargs = {"return_timestamps": True} | |
| if lang_code: | |
| gen_kwargs["language"] = lang_code | |
| # Using the pipeline's underlying model to get detected language if needed | |
| # but for simplicity, we can also look at the pipe.model.config.lang | |
| # Actually, the pipeline result might have it if we use return_timestamps=True and return_language=True | |
| result = pipe( | |
| audio_path, | |
| chunk_length_s=30, | |
| stride_length_s=5, | |
| generate_kwargs=gen_kwargs | |
| ) | |
| self._clear_memory() | |
| # Extract segments from chunks | |
| chunks = result.get("chunks", []) | |
| segments = [] | |
| for chunk in chunks: | |
| start, end = chunk.get("timestamp", (0, 0)) | |
| segments.append({ | |
| "start": start, | |
| "end": end, | |
| "text": chunk.get("text", "") | |
| }) | |
| # Extract detected language if auto | |
| detected_lang = language | |
| if language == "auto" and hasattr(pipe.model, "config") and hasattr(pipe.model.config, "lang"): | |
| # Map ISO code back to name | |
| iso_to_name = { | |
| "mr": "Marathi", | |
| "hi": "Hindi", | |
| "en": "English" | |
| } | |
| detected_lang = iso_to_name.get(pipe.model.config.lang, "English") | |
| return { | |
| "text": result.get("text", ""), | |
| "segments": segments, | |
| "detected_language": detected_lang | |
| } | |
| def translate(self, text, src_lang, tgt_lang): | |
| if not text.strip(): | |
| return "" | |
| if src_lang == tgt_lang: | |
| return text | |
| # Map languages to NLLB-200 code | |
| lang_codes = { | |
| "Marathi": "mar_Deva", | |
| "Hindi": "hin_Deva", | |
| "English": "eng_Latn" | |
| } | |
| src_code = lang_codes.get(src_lang) | |
| tgt_code = lang_codes.get(tgt_lang) | |
| if not src_code or not tgt_code: | |
| raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}") | |
| if self.ci_mode: | |
| return f"[CI MOCK] {tgt_lang}: {text}" | |
| with self.lock: | |
| model, tokenizer = self.get_nllb() | |
| logger.info("Translating text using NLLB-200 (%s -> %s)...", src_code, tgt_code) | |
| # Tokenize and force generation in target language | |
| tokenizer.src_lang = src_code | |
| inputs = tokenizer(text, return_tensors="pt").to(self.device) | |
| forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) | |
| with torch.no_grad(): | |
| translated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=forced_bos_token_id, | |
| max_length=256, | |
| num_beams=4, | |
| no_repeat_ngram_size=3 | |
| ) | |
| translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| self._clear_memory() | |
| return translated_text | |
| def translate_batch(self, texts, src_lang, tgt_lang): | |
| if not texts: | |
| return [] | |
| if self.ci_mode: | |
| return [f"[CI MOCK] {tgt_lang}: {t}" if t.strip() else t for t in texts] | |
| if src_lang == tgt_lang: | |
| return texts | |
| # Map languages to NLLB-200 code | |
| src_code = NLLB_LANG_CODES.get(src_lang) | |
| tgt_code = NLLB_LANG_CODES.get(tgt_lang) | |
| if not src_code or not tgt_code: | |
| raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}") | |
| # Filter out empty/whitespace-only texts but keep track of indices to restore them | |
| non_empty_indices = [] | |
| non_empty_texts = [] | |
| for idx, text in enumerate(texts): | |
| if text.strip(): | |
| non_empty_indices.append(idx) | |
| non_empty_texts.append(text) | |
| results = [""] * len(texts) | |
| if not non_empty_texts: | |
| # All texts were empty | |
| for idx, text in enumerate(texts): | |
| if not text.strip(): | |
| results[idx] = text # Preserve original whitespace if any | |
| return results | |
| with self.lock: | |
| model, tokenizer = self.get_nllb() | |
| logger.info("Batch translating %d items using NLLB-200 (%s -> %s)...", len(non_empty_texts), src_code, tgt_code) | |
| tokenizer.src_lang = src_code | |
| inputs = tokenizer(non_empty_texts, return_tensors="pt", padding=True).to(self.device) | |
| forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code) | |
| with torch.no_grad(): | |
| translated_tokens = model.generate( | |
| **inputs, | |
| forced_bos_token_id=forced_bos_token_id, | |
| max_length=256, | |
| num_beams=4, | |
| no_repeat_ngram_size=3 | |
| ) | |
| translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) | |
| self._clear_memory() | |
| # Map back to full results list | |
| for i, idx in enumerate(non_empty_indices): | |
| results[idx] = translated_texts[i] | |
| # For empty texts, just preserve them | |
| for idx, text in enumerate(texts): | |
| if not text.strip(): | |
| results[idx] = "" | |
| return results | |
| def text_to_speech(self, text, lang, output_path, speed=1.0): | |
| if not text.strip(): | |
| raise ValueError("Empty text provided for TTS") | |
| supported_langs = ["Hindi", "Marathi", "English"] | |
| if lang not in supported_langs: | |
| raise ValueError(f"Unsupported TTS language: {lang}") | |
| if self.ci_mode: | |
| logger.debug("[CI MOCK] Generating dummy TTS for %s...", lang) | |
| dummy_data = np.zeros(16000) | |
| sf.write(output_path, dummy_data, 16000) | |
| return output_path | |
| with self.lock: | |
| model, tokenizer = self.get_tts(lang) | |
| logger.info("Synthesizing speech for text in %s (speed=%s)...", lang, speed) | |
| inputs = tokenizer(text, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| output = model(**inputs).waveform | |
| # Convert PyTorch tensor to numpy array (1D) | |
| waveform_numpy = output.cpu().numpy().squeeze() | |
| # Apply speed adjustment if not 1.0 | |
| # We use simple linear resampling for speed control | |
| if speed != 1.0: | |
| import scipy.interpolate | |
| x = np.arange(len(waveform_numpy)) | |
| new_x = np.linspace(0, len(waveform_numpy)-1, int(len(waveform_numpy) / speed)) | |
| f = scipy.interpolate.interp1d(x, waveform_numpy) | |
| waveform_numpy = f(new_x).astype(np.float32) | |
| # MMS-TTS models output sample rate is 16000Hz | |
| sf.write(output_path, waveform_numpy, samplerate=16000) | |
| self._clear_memory() | |
| logger.info("TTS audio written to: %s", output_path) | |
| return output_path | |