""" Model Manager - Handles loading and caching of AI models Manages Whisper and NLLB-200 models with GPU optimization """ import torch import logging from typing import Optional import whisper from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline logger = logging.getLogger(__name__) class ModelManager: """Singleton class to manage model instances and caching""" _instance = None _models = {} def __new__(cls): if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") self.whisper_model = None self.nllb_tokenizer = None self.nllb_model = None self._initialized = True def get_whisper_model(self, model_size: str = "large") -> whisper.Whisper: """Load Whisper transcription model""" if self.whisper_model is None: logger.info(f"Loading Whisper {model_size} model...") self.whisper_model = whisper.load_model(model_size, device=self.device) return self.whisper_model def get_nllb_model(self, model_name: str = "facebook/nllb-200-distilled-600M"): """Load NLLB-200 translation model""" if self.nllb_model is None: logger.info(f"Loading NLLB-200 model: {model_name}") self.nllb_tokenizer = AutoTokenizer.from_pretrained( model_name, use_auth_token=True, trust_remote_code=True ) self.nllb_model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None ) if self.device == "cpu": self.nllb_model = self.nllb_model.to(self.device) return self.nllb_model, self.nllb_tokenizer def get_device(self) -> str: """Get current device (cuda or cpu)""" return self.device def unload_all(self): """Unload all models to free memory""" logger.info("Unloading all models...") self.whisper_model = None self.nllb_model = None self.nllb_tokenizer = None if torch.cuda.is_available(): torch.cuda.empty_cache()