""" Model loader for Wav2Vec2 voice classification model. Handles loading, caching, and management of the ML model. """ import gc import threading from typing import TYPE_CHECKING import torch from transformers import Wav2Vec2ForSequenceClassification from transformers import Wav2Vec2Processor from app.config import get_settings from app.utils.constants import ID_TO_LABEL from app.utils.constants import LABEL_TO_ID from app.utils.exceptions import ModelNotLoadedError from app.utils.logger import get_logger if TYPE_CHECKING: from transformers import PreTrainedModel logger = get_logger(__name__) class ModelLoader: """ Singleton model loader for Wav2Vec2 classification model. Handles lazy loading, caching, and memory management of the ML model. Thread-safe for production use. """ _instance: "ModelLoader | None" = None _lock: threading.Lock = threading.Lock() def __new__(cls) -> "ModelLoader": """Ensure only one instance exists (Singleton pattern).""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: """Initialize ModelLoader if not already initialized.""" if getattr(self, "_initialized", False): return self.settings = get_settings() self.model: Wav2Vec2ForSequenceClassification | None = None self.processor: Wav2Vec2Processor | None = None self.device: str = self.settings.torch_device self._model_lock = threading.Lock() self._initialized = True logger.info( "ModelLoader initialized", device=self.device, model_identifier=self.settings.model_identifier, ) @property def is_loaded(self) -> bool: """Check if model is loaded and ready for inference.""" return self.model is not None and self.processor is not None def load_model(self) -> None: """ Load the Wav2Vec2 model and processor. Thread-safe loading with proper error handling. Raises: Exception: If model loading fails """ with self._model_lock: if self.is_loaded: logger.debug("Model already loaded, skipping") return model_identifier = self.settings.model_identifier logger.info("Loading Wav2Vec2 model", model=model_identifier, device=self.device) try: # Load processor - try model first, fallback to base wav2vec2 try: self.processor = Wav2Vec2Processor.from_pretrained( model_identifier, trust_remote_code=False, ) except Exception: # Fine-tuned models often don't have processor, use base logger.info("Using base wav2vec2 processor") self.processor = Wav2Vec2Processor.from_pretrained( "facebook/wav2vec2-base", trust_remote_code=False, ) # Load model with classification head # For pretrained deepfake models, use their existing configuration self.model = Wav2Vec2ForSequenceClassification.from_pretrained( model_identifier, trust_remote_code=False, ignore_mismatched_sizes=True, # Allow different classifier sizes ) # Move model to device self.model = self.model.to(self.device) # Set to evaluation mode self.model.eval() # Log memory usage if self.device.startswith("cuda"): memory_allocated = torch.cuda.memory_allocated() / (1024**3) logger.info( "Model loaded successfully", device=self.device, gpu_memory_gb=round(memory_allocated, 2), ) else: logger.info("Model loaded successfully", device=self.device) except Exception as e: self.model = None self.processor = None logger.error("Failed to load model", error=str(e)) raise async def load_model_async(self) -> None: """ Async wrapper for model loading. Useful for FastAPI lifespan context. """ # Run in thread pool to avoid blocking import asyncio loop = asyncio.get_event_loop() await loop.run_in_executor(None, self.load_model) def get_model(self) -> tuple[Wav2Vec2ForSequenceClassification, Wav2Vec2Processor]: """ Get the loaded model and processor. Returns: Tuple of (model, processor) Raises: ModelNotLoadedError: If model is not loaded """ if not self.is_loaded: raise ModelNotLoadedError( "Model not loaded. Call load_model() first.", details={"model_identifier": self.settings.model_identifier}, ) return self.model, self.processor # type: ignore def unload_model(self) -> None: """ Unload model and free memory. Useful for memory management in constrained environments. """ with self._model_lock: if self.model is not None: del self.model self.model = None if self.processor is not None: del self.processor self.processor = None # Force garbage collection gc.collect() # Clear CUDA cache if using GPU if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Model unloaded, memory freed") def warmup(self) -> None: """ Run a warmup inference to initialize CUDA kernels. This reduces latency on the first real inference. """ if not self.is_loaded: logger.warning("Cannot warmup - model not loaded") return logger.info("Running model warmup...") try: # Create dummy input dummy_audio = torch.randn(1, 16000) # 1 second of audio model, processor = self.get_model() # Preprocess dummy audio inputs = processor( dummy_audio.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True, ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Run warmup inference with torch.no_grad(): _ = model(**inputs) logger.info("Model warmup complete") except Exception as e: logger.warning("Warmup failed (non-critical)", error=str(e)) def health_check(self) -> dict: """ Get model health status. Returns: Dictionary with health information """ status = { "model_loaded": self.is_loaded, "device": self.device, "model_identifier": self.settings.model_identifier, } if self.device.startswith("cuda") and torch.cuda.is_available(): status["gpu_memory_allocated_gb"] = round( torch.cuda.memory_allocated() / (1024**3), 2 ) status["gpu_memory_reserved_gb"] = round( torch.cuda.memory_reserved() / (1024**3), 2 ) return status