GUVI-AI-Synapses-VoiceAuthAPI / app /ml /model_loader.py
itssKarthiii's picture
Upload 70 files
6b408d7 verified
"""
Model loader for Wav2Vec2 voice classification model.
Handles loading, caching, and management of the ML model.
"""
import gc
import threading
from typing import TYPE_CHECKING
import torch
from transformers import Wav2Vec2ForSequenceClassification
from transformers import Wav2Vec2Processor
from app.config import get_settings
from app.utils.constants import ID_TO_LABEL
from app.utils.constants import LABEL_TO_ID
from app.utils.exceptions import ModelNotLoadedError
from app.utils.logger import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
logger = get_logger(__name__)
class ModelLoader:
"""
Singleton model loader for Wav2Vec2 classification model.
Handles lazy loading, caching, and memory management of the ML model.
Thread-safe for production use.
"""
_instance: "ModelLoader | None" = None
_lock: threading.Lock = threading.Lock()
def __new__(cls) -> "ModelLoader":
"""Ensure only one instance exists (Singleton pattern)."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
"""Initialize ModelLoader if not already initialized."""
if getattr(self, "_initialized", False):
return
self.settings = get_settings()
self.model: Wav2Vec2ForSequenceClassification | None = None
self.processor: Wav2Vec2Processor | None = None
self.device: str = self.settings.torch_device
self._model_lock = threading.Lock()
self._initialized = True
logger.info(
"ModelLoader initialized",
device=self.device,
model_identifier=self.settings.model_identifier,
)
@property
def is_loaded(self) -> bool:
"""Check if model is loaded and ready for inference."""
return self.model is not None and self.processor is not None
def load_model(self) -> None:
"""
Load the Wav2Vec2 model and processor.
Thread-safe loading with proper error handling.
Raises:
Exception: If model loading fails
"""
with self._model_lock:
if self.is_loaded:
logger.debug("Model already loaded, skipping")
return
model_identifier = self.settings.model_identifier
logger.info("Loading Wav2Vec2 model", model=model_identifier, device=self.device)
try:
# Load processor - try model first, fallback to base wav2vec2
try:
self.processor = Wav2Vec2Processor.from_pretrained(
model_identifier,
trust_remote_code=False,
)
except Exception:
# Fine-tuned models often don't have processor, use base
logger.info("Using base wav2vec2 processor")
self.processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-base",
trust_remote_code=False,
)
# Load model with classification head
# For pretrained deepfake models, use their existing configuration
self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
model_identifier,
trust_remote_code=False,
ignore_mismatched_sizes=True, # Allow different classifier sizes
)
# Move model to device
self.model = self.model.to(self.device)
# Set to evaluation mode
self.model.eval()
# Log memory usage
if self.device.startswith("cuda"):
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
logger.info(
"Model loaded successfully",
device=self.device,
gpu_memory_gb=round(memory_allocated, 2),
)
else:
logger.info("Model loaded successfully", device=self.device)
except Exception as e:
self.model = None
self.processor = None
logger.error("Failed to load model", error=str(e))
raise
async def load_model_async(self) -> None:
"""
Async wrapper for model loading.
Useful for FastAPI lifespan context.
"""
# Run in thread pool to avoid blocking
import asyncio
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.load_model)
def get_model(self) -> tuple[Wav2Vec2ForSequenceClassification, Wav2Vec2Processor]:
"""
Get the loaded model and processor.
Returns:
Tuple of (model, processor)
Raises:
ModelNotLoadedError: If model is not loaded
"""
if not self.is_loaded:
raise ModelNotLoadedError(
"Model not loaded. Call load_model() first.",
details={"model_identifier": self.settings.model_identifier},
)
return self.model, self.processor # type: ignore
def unload_model(self) -> None:
"""
Unload model and free memory.
Useful for memory management in constrained environments.
"""
with self._model_lock:
if self.model is not None:
del self.model
self.model = None
if self.processor is not None:
del self.processor
self.processor = None
# Force garbage collection
gc.collect()
# Clear CUDA cache if using GPU
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Model unloaded, memory freed")
def warmup(self) -> None:
"""
Run a warmup inference to initialize CUDA kernels.
This reduces latency on the first real inference.
"""
if not self.is_loaded:
logger.warning("Cannot warmup - model not loaded")
return
logger.info("Running model warmup...")
try:
# Create dummy input
dummy_audio = torch.randn(1, 16000) # 1 second of audio
model, processor = self.get_model()
# Preprocess dummy audio
inputs = processor(
dummy_audio.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Run warmup inference
with torch.no_grad():
_ = model(**inputs)
logger.info("Model warmup complete")
except Exception as e:
logger.warning("Warmup failed (non-critical)", error=str(e))
def health_check(self) -> dict:
"""
Get model health status.
Returns:
Dictionary with health information
"""
status = {
"model_loaded": self.is_loaded,
"device": self.device,
"model_identifier": self.settings.model_identifier,
}
if self.device.startswith("cuda") and torch.cuda.is_available():
status["gpu_memory_allocated_gb"] = round(
torch.cuda.memory_allocated() / (1024**3), 2
)
status["gpu_memory_reserved_gb"] = round(
torch.cuda.memory_reserved() / (1024**3), 2
)
return status