| """
|
| Model loader for Wav2Vec2 voice classification model.
|
|
|
| Handles loading, caching, and management of the ML model.
|
| """
|
|
|
| import gc
|
| import threading
|
| from typing import TYPE_CHECKING
|
|
|
| import torch
|
| from transformers import Wav2Vec2ForSequenceClassification
|
| from transformers import Wav2Vec2Processor
|
|
|
| from app.config import get_settings
|
| from app.utils.constants import ID_TO_LABEL
|
| from app.utils.constants import LABEL_TO_ID
|
| from app.utils.exceptions import ModelNotLoadedError
|
| from app.utils.logger import get_logger
|
|
|
| if TYPE_CHECKING:
|
| from transformers import PreTrainedModel
|
|
|
| logger = get_logger(__name__)
|
|
|
|
|
| class ModelLoader:
|
| """
|
| Singleton model loader for Wav2Vec2 classification model.
|
|
|
| Handles lazy loading, caching, and memory management of the ML model.
|
| Thread-safe for production use.
|
| """
|
|
|
| _instance: "ModelLoader | None" = None
|
| _lock: threading.Lock = threading.Lock()
|
|
|
| def __new__(cls) -> "ModelLoader":
|
| """Ensure only one instance exists (Singleton pattern)."""
|
| if cls._instance is None:
|
| with cls._lock:
|
| if cls._instance is None:
|
| cls._instance = super().__new__(cls)
|
| cls._instance._initialized = False
|
| return cls._instance
|
|
|
| def __init__(self) -> None:
|
| """Initialize ModelLoader if not already initialized."""
|
| if getattr(self, "_initialized", False):
|
| return
|
|
|
| self.settings = get_settings()
|
| self.model: Wav2Vec2ForSequenceClassification | None = None
|
| self.processor: Wav2Vec2Processor | None = None
|
| self.device: str = self.settings.torch_device
|
| self._model_lock = threading.Lock()
|
| self._initialized = True
|
|
|
| logger.info(
|
| "ModelLoader initialized",
|
| device=self.device,
|
| model_identifier=self.settings.model_identifier,
|
| )
|
|
|
| @property
|
| def is_loaded(self) -> bool:
|
| """Check if model is loaded and ready for inference."""
|
| return self.model is not None and self.processor is not None
|
|
|
| def load_model(self) -> None:
|
| """
|
| Load the Wav2Vec2 model and processor.
|
|
|
| Thread-safe loading with proper error handling.
|
|
|
| Raises:
|
| Exception: If model loading fails
|
| """
|
| with self._model_lock:
|
| if self.is_loaded:
|
| logger.debug("Model already loaded, skipping")
|
| return
|
|
|
| model_identifier = self.settings.model_identifier
|
|
|
| logger.info("Loading Wav2Vec2 model", model=model_identifier, device=self.device)
|
|
|
| try:
|
|
|
| try:
|
| self.processor = Wav2Vec2Processor.from_pretrained(
|
| model_identifier,
|
| trust_remote_code=False,
|
| )
|
| except Exception:
|
|
|
| logger.info("Using base wav2vec2 processor")
|
| self.processor = Wav2Vec2Processor.from_pretrained(
|
| "facebook/wav2vec2-base",
|
| trust_remote_code=False,
|
| )
|
|
|
|
|
|
|
| self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
|
| model_identifier,
|
| trust_remote_code=False,
|
| ignore_mismatched_sizes=True,
|
| )
|
|
|
|
|
| self.model = self.model.to(self.device)
|
|
|
|
|
| self.model.eval()
|
|
|
|
|
| if self.device.startswith("cuda"):
|
| memory_allocated = torch.cuda.memory_allocated() / (1024**3)
|
| logger.info(
|
| "Model loaded successfully",
|
| device=self.device,
|
| gpu_memory_gb=round(memory_allocated, 2),
|
| )
|
| else:
|
| logger.info("Model loaded successfully", device=self.device)
|
|
|
| except Exception as e:
|
| self.model = None
|
| self.processor = None
|
| logger.error("Failed to load model", error=str(e))
|
| raise
|
|
|
| async def load_model_async(self) -> None:
|
| """
|
| Async wrapper for model loading.
|
|
|
| Useful for FastAPI lifespan context.
|
| """
|
|
|
| import asyncio
|
|
|
| loop = asyncio.get_event_loop()
|
| await loop.run_in_executor(None, self.load_model)
|
|
|
| def get_model(self) -> tuple[Wav2Vec2ForSequenceClassification, Wav2Vec2Processor]:
|
| """
|
| Get the loaded model and processor.
|
|
|
| Returns:
|
| Tuple of (model, processor)
|
|
|
| Raises:
|
| ModelNotLoadedError: If model is not loaded
|
| """
|
| if not self.is_loaded:
|
| raise ModelNotLoadedError(
|
| "Model not loaded. Call load_model() first.",
|
| details={"model_identifier": self.settings.model_identifier},
|
| )
|
|
|
| return self.model, self.processor
|
|
|
| def unload_model(self) -> None:
|
| """
|
| Unload model and free memory.
|
|
|
| Useful for memory management in constrained environments.
|
| """
|
| with self._model_lock:
|
| if self.model is not None:
|
| del self.model
|
| self.model = None
|
|
|
| if self.processor is not None:
|
| del self.processor
|
| self.processor = None
|
|
|
|
|
| gc.collect()
|
|
|
|
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
| logger.info("Model unloaded, memory freed")
|
|
|
| def warmup(self) -> None:
|
| """
|
| Run a warmup inference to initialize CUDA kernels.
|
|
|
| This reduces latency on the first real inference.
|
| """
|
| if not self.is_loaded:
|
| logger.warning("Cannot warmup - model not loaded")
|
| return
|
|
|
| logger.info("Running model warmup...")
|
|
|
| try:
|
|
|
| dummy_audio = torch.randn(1, 16000)
|
|
|
| model, processor = self.get_model()
|
|
|
|
|
| inputs = processor(
|
| dummy_audio.squeeze().numpy(),
|
| sampling_rate=16000,
|
| return_tensors="pt",
|
| padding=True,
|
| )
|
|
|
| inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
| with torch.no_grad():
|
| _ = model(**inputs)
|
|
|
| logger.info("Model warmup complete")
|
|
|
| except Exception as e:
|
| logger.warning("Warmup failed (non-critical)", error=str(e))
|
|
|
| def health_check(self) -> dict:
|
| """
|
| Get model health status.
|
|
|
| Returns:
|
| Dictionary with health information
|
| """
|
| status = {
|
| "model_loaded": self.is_loaded,
|
| "device": self.device,
|
| "model_identifier": self.settings.model_identifier,
|
| }
|
|
|
| if self.device.startswith("cuda") and torch.cuda.is_available():
|
| status["gpu_memory_allocated_gb"] = round(
|
| torch.cuda.memory_allocated() / (1024**3), 2
|
| )
|
| status["gpu_memory_reserved_gb"] = round(
|
| torch.cuda.memory_reserved() / (1024**3), 2
|
| )
|
|
|
| return status
|
|
|