import os import gc import time import torch import logging import threading from enum import Enum from typing import Dict, Any, Optional from dataclasses import dataclass # Setup logging logger = logging.getLogger("ModelManager") class ModelType(Enum): DETECTOR = "detector" DEMORPH_TRANSFORMER = "demorph_transformer" DEMORPH_GAN = "demorph_gan" DEMORPH_SD = "demorph_sd" DEMORPH_LDM = "demorph_ldm" LIVENESS = "liveness" FACE_MATCHER = "face_matcher" @dataclass class ModelInfo: model_type: ModelType instance: Any last_used: float size_mb: float = 0 device: str = "cpu" class ModelManager: """ Singleton for managing AI model lifecycles. Prevents CUDA OOM by unloading unused models. """ _instance = None _lock = threading.RLock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super(ModelManager, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._initialized = True self.models: Dict[ModelType, ModelInfo] = {} self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Max VRAM config (approximate safety margin) self.max_vram_usage = 0.9 # 90% logger.info(f"ModelManager initialized on {self.device}") def _unload_least_used(self, required_mb: float = 0): """Unload the least recently used model to free space.""" if not self.models: return # Sort by last used (oldest first) sorted_models = sorted(self.models.items(), key=lambda x: x[1].last_used) # Don't unload the detector if possible, it's critical for model_type, info in sorted_models: if model_type == ModelType.DETECTOR and len(self.models) > 1: continue logger.info(f"Unloading model: {model_type.value} (Last used: {time.time() - info.last_used:.1f}s ago)") # Explicitly delete and empty cache del info.instance del self.models[model_type] gc.collect() if self.device.type == 'cuda': torch.cuda.empty_cache() logger.info(f"Model unloaded: {model_type.value}") return def _check_memory(self): """Check VRAM usage and unload if critical.""" if self.device.type != 'cuda': return try: total_memory = torch.cuda.get_device_properties(0).total_memory allocated = torch.cuda.memory_allocated(0) reserved = torch.cuda.memory_reserved(0) usage = reserved / total_memory if usage > self.max_vram_usage: logger.warning(f"High VRAM usage: {usage:.1%} - Unloading implicit models...") self._unload_least_used() except Exception as e: logger.warning(f"Failed to check memory: {e}") def get_model(self, model_type: ModelType, loader_func: callable, **kwargs) -> Any: """ Get a model instance, loading it if necessary. Args: model_type: Type of model loader_func: Function that returns the model instance **kwargs: Arguments for the loader function """ with self._lock: # Update last used if exists if model_type in self.models: self.models[model_type].last_used = time.time() return self.models[model_type].instance # Check memory before loading self._check_memory() # If we need to load a heavy model (like SD), unload others first aggressively if model_type in [ModelType.DEMORPH_SD, ModelType.DEMORPH_GAN]: # these are heavy, unload everything else except maybe detector for m_type in list(self.models.keys()): if m_type != ModelType.DETECTOR: self.unload_model(m_type) logger.info(f"Loading model: {model_type.value}...") start_time = time.time() try: # Load the model instance = loader_func(**kwargs) # Register self.models[model_type] = ModelInfo( model_type=model_type, instance=instance, last_used=time.time(), device=str(self.device) ) logger.info(f"Model loaded: {model_type.value} in {time.time() - start_time:.2f}s") return instance except RuntimeError as e: if "out of memory" in str(e): logger.error("CUDA OOM during load! Attempting emergency cleanup...") self.unload_all() torch.cuda.empty_cache() # Try once more? Or just fail raise e raise e def unload_model(self, model_type: ModelType): """Explicitly unload a model.""" with self._lock: if model_type in self.models: logger.info(f"Unloading model: {model_type.value}") del self.models[model_type].instance del self.models[model_type] gc.collect() if self.device.type == 'cuda': torch.cuda.empty_cache() def unload_all(self): """Unload all models.""" with self._lock: logger.info("Unloading ALL models") self.models.clear() gc.collect() if self.device.type == 'cuda': torch.cuda.empty_cache() # Global accessor def get_model_manager(): return ModelManager()