BIAF-offASR / backend /models.py
froster02's picture
refactor: code cleanup and App.jsx component split
3496381
import os
import logging
import torch
import numpy as np
import soundfile as sf
import threading
import gc
logger = logging.getLogger(__name__)
NLLB_LANG_CODES = {
"Marathi": "mar_Deva",
"Hindi": "hin_Deva",
"English": "eng_Latn",
}
# Optimize Torch for CPU-only environments like HF Spaces
if not torch.cuda.is_available():
torch.set_num_threads(int(os.cpu_count() or 1))
class ModelManager:
def __init__(self, cache_dir="./models"):
self.cache_dir = os.path.abspath(cache_dir)
os.environ["HF_HOME"] = os.path.join(self.cache_dir, "hf_cache")
self.lock = threading.RLock() # Reentrant lock for concurrent requests safety
self.ci_mode = os.environ.get("CI_MODE", "false").lower() == "true"
# Select device automatically
if self.ci_mode:
self.device = "cpu"
elif torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
logger.info("ModelManager initialized using device: %s (CI_MODE=%s)", self.device, self.ci_mode)
logger.info("Cache directory: %s", self.cache_dir)
# Lazy load containers
self.whisper_pipe = {}
self.nllb_model = None
self.nllb_tokenizer = None
self.tts_models = {}
self.tts_tokenizers = {}
def _clear_memory(self):
"""Force garbage collection and clear torch cache if on GPU"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif self.device == "mps":
torch.mps.empty_cache()
def get_whisper(self, size="base"):
with self.lock:
if size not in self.whisper_pipe:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
model_id = f"openai/whisper-{size}"
logger.info("Loading STT model %s from %s on %s...", model_id, self.cache_dir, self.device)
try:
processor = WhisperProcessor.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
self.whisper_pipe[size] = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps")
)
logger.info("Whisper-%s loaded successfully.", size)
except Exception as e:
logger.error("Error loading Whisper-%s: %s", size, e)
self.whisper_pipe[size] = pipeline(
"automatic-speech-recognition",
model=model_id,
cache_dir=self.cache_dir,
chunk_length_s=30,
device=0 if self.device == "cuda" else (-1 if self.device == "cpu" else "mps")
)
return self.whisper_pipe[size]
def get_nllb(self):
with self.lock:
if self.nllb_model is None:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_id = "facebook/nllb-200-distilled-600M"
logger.info("Loading NLLB-200 translation model from %s on %s...", self.cache_dir, self.device)
try:
self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device)
logger.info("NLLB-200 loaded successfully.")
except Exception as e:
logger.error("Error loading NLLB-200: %s", e)
self.nllb_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
return self.nllb_model, self.nllb_tokenizer
def get_tts(self, lang):
with self.lock:
if lang not in self.tts_models:
from transformers import AutoTokenizer, VitsModel
model_id = {
"Hindi": "facebook/mms-tts-hin",
"Marathi": "facebook/mms-tts-mar",
"English": "facebook/mms-tts-eng"
}.get(lang)
if not model_id:
raise ValueError(f"Unsupported TTS language: {lang}")
logger.info("Loading TTS model for %s (%s) on %s...", lang, model_id, self.device)
try:
self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True)
self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir, local_files_only=True).to(self.device)
logger.info("TTS model for %s loaded successfully.", lang)
except Exception as e:
logger.error("Error loading TTS for %s: %s", lang, e)
self.tts_tokenizers[lang] = AutoTokenizer.from_pretrained(model_id, cache_dir=self.cache_dir)
self.tts_models[lang] = VitsModel.from_pretrained(model_id, cache_dir=self.cache_dir).to(self.device)
return self.tts_models[lang], self.tts_tokenizers[lang]
def transcribe(self, audio_path, size="base", language="English"):
if self.ci_mode:
return {
"text": "This is a mock transcription for CI mode.",
"segments": [
{"start": 0.0, "end": 2.0, "text": "This is a mock"},
{"start": 2.0, "end": 4.0, "text": "transcription for CI mode."}
]
}
with self.lock:
# Map human language name to Whisper code
# If "auto", we pass None to let Whisper detect
lang_code = {
"Marathi": "mr",
"Hindi": "hi",
"English": "en"
}.get(language)
pipe = self.get_whisper(size)
lang_label = "auto" if lang_code is None else lang_code
logger.info("Transcribing %s using Whisper-%s (language=%s)...", audio_path, size, lang_label)
# Run Whisper ASR pipeline with timestamps
gen_kwargs = {"return_timestamps": True}
if lang_code:
gen_kwargs["language"] = lang_code
# Using the pipeline's underlying model to get detected language if needed
# but for simplicity, we can also look at the pipe.model.config.lang
# Actually, the pipeline result might have it if we use return_timestamps=True and return_language=True
result = pipe(
audio_path,
chunk_length_s=30,
stride_length_s=5,
generate_kwargs=gen_kwargs
)
self._clear_memory()
# Extract segments from chunks
chunks = result.get("chunks", [])
segments = []
for chunk in chunks:
start, end = chunk.get("timestamp", (0, 0))
segments.append({
"start": start,
"end": end,
"text": chunk.get("text", "")
})
# Extract detected language if auto
detected_lang = language
if language == "auto" and hasattr(pipe.model, "config") and hasattr(pipe.model.config, "lang"):
# Map ISO code back to name
iso_to_name = {
"mr": "Marathi",
"hi": "Hindi",
"en": "English"
}
detected_lang = iso_to_name.get(pipe.model.config.lang, "English")
return {
"text": result.get("text", ""),
"segments": segments,
"detected_language": detected_lang
}
def translate(self, text, src_lang, tgt_lang):
if not text.strip():
return ""
if src_lang == tgt_lang:
return text
# Map languages to NLLB-200 code
lang_codes = {
"Marathi": "mar_Deva",
"Hindi": "hin_Deva",
"English": "eng_Latn"
}
src_code = lang_codes.get(src_lang)
tgt_code = lang_codes.get(tgt_lang)
if not src_code or not tgt_code:
raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}")
if self.ci_mode:
return f"[CI MOCK] {tgt_lang}: {text}"
with self.lock:
model, tokenizer = self.get_nllb()
logger.info("Translating text using NLLB-200 (%s -> %s)...", src_code, tgt_code)
# Tokenize and force generation in target language
tokenizer.src_lang = src_code
inputs = tokenizer(text, return_tensors="pt").to(self.device)
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
with torch.no_grad():
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=256,
num_beams=4,
no_repeat_ngram_size=3
)
translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
self._clear_memory()
return translated_text
def translate_batch(self, texts, src_lang, tgt_lang):
if not texts:
return []
if self.ci_mode:
return [f"[CI MOCK] {tgt_lang}: {t}" if t.strip() else t for t in texts]
if src_lang == tgt_lang:
return texts
# Map languages to NLLB-200 code
src_code = NLLB_LANG_CODES.get(src_lang)
tgt_code = NLLB_LANG_CODES.get(tgt_lang)
if not src_code or not tgt_code:
raise ValueError(f"Unsupported translation languages: {src_lang} -> {tgt_lang}")
# Filter out empty/whitespace-only texts but keep track of indices to restore them
non_empty_indices = []
non_empty_texts = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_indices.append(idx)
non_empty_texts.append(text)
results = [""] * len(texts)
if not non_empty_texts:
# All texts were empty
for idx, text in enumerate(texts):
if not text.strip():
results[idx] = text # Preserve original whitespace if any
return results
with self.lock:
model, tokenizer = self.get_nllb()
logger.info("Batch translating %d items using NLLB-200 (%s -> %s)...", len(non_empty_texts), src_code, tgt_code)
tokenizer.src_lang = src_code
inputs = tokenizer(non_empty_texts, return_tensors="pt", padding=True).to(self.device)
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_code)
with torch.no_grad():
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=256,
num_beams=4,
no_repeat_ngram_size=3
)
translated_texts = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self._clear_memory()
# Map back to full results list
for i, idx in enumerate(non_empty_indices):
results[idx] = translated_texts[i]
# For empty texts, just preserve them
for idx, text in enumerate(texts):
if not text.strip():
results[idx] = ""
return results
def text_to_speech(self, text, lang, output_path, speed=1.0):
if not text.strip():
raise ValueError("Empty text provided for TTS")
supported_langs = ["Hindi", "Marathi", "English"]
if lang not in supported_langs:
raise ValueError(f"Unsupported TTS language: {lang}")
if self.ci_mode:
logger.debug("[CI MOCK] Generating dummy TTS for %s...", lang)
dummy_data = np.zeros(16000)
sf.write(output_path, dummy_data, 16000)
return output_path
with self.lock:
model, tokenizer = self.get_tts(lang)
logger.info("Synthesizing speech for text in %s (speed=%s)...", lang, speed)
inputs = tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
output = model(**inputs).waveform
# Convert PyTorch tensor to numpy array (1D)
waveform_numpy = output.cpu().numpy().squeeze()
# Apply speed adjustment if not 1.0
# We use simple linear resampling for speed control
if speed != 1.0:
import scipy.interpolate
x = np.arange(len(waveform_numpy))
new_x = np.linspace(0, len(waveform_numpy)-1, int(len(waveform_numpy) / speed))
f = scipy.interpolate.interp1d(x, waveform_numpy)
waveform_numpy = f(new_x).astype(np.float32)
# MMS-TTS models output sample rate is 16000Hz
sf.write(output_path, waveform_numpy, samplerate=16000)
self._clear_memory()
logger.info("TTS audio written to: %s", output_path)
return output_path