Spaces:
Running
Running
| import os | |
| import gc | |
| import time | |
| import torch | |
| import logging | |
| import threading | |
| from enum import Enum | |
| from typing import Dict, Any, Optional | |
| from dataclasses import dataclass | |
| # Setup logging | |
| logger = logging.getLogger("ModelManager") | |
| class ModelType(Enum): | |
| DETECTOR = "detector" | |
| DEMORPH_TRANSFORMER = "demorph_transformer" | |
| DEMORPH_GAN = "demorph_gan" | |
| DEMORPH_SD = "demorph_sd" | |
| DEMORPH_LDM = "demorph_ldm" | |
| LIVENESS = "liveness" | |
| FACE_MATCHER = "face_matcher" | |
| class ModelInfo: | |
| model_type: ModelType | |
| instance: Any | |
| last_used: float | |
| size_mb: float = 0 | |
| device: str = "cpu" | |
| class ModelManager: | |
| """ | |
| Singleton for managing AI model lifecycles. | |
| Prevents CUDA OOM by unloading unused models. | |
| """ | |
| _instance = None | |
| _lock = threading.RLock() | |
| def __new__(cls): | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super(ModelManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._initialized = True | |
| self.models: Dict[ModelType, ModelInfo] = {} | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Max VRAM config (approximate safety margin) | |
| self.max_vram_usage = 0.9 # 90% | |
| logger.info(f"ModelManager initialized on {self.device}") | |
| def _unload_least_used(self, required_mb: float = 0): | |
| """Unload the least recently used model to free space.""" | |
| if not self.models: | |
| return | |
| # Sort by last used (oldest first) | |
| sorted_models = sorted(self.models.items(), key=lambda x: x[1].last_used) | |
| # Don't unload the detector if possible, it's critical | |
| for model_type, info in sorted_models: | |
| if model_type == ModelType.DETECTOR and len(self.models) > 1: | |
| continue | |
| logger.info(f"Unloading model: {model_type.value} (Last used: {time.time() - info.last_used:.1f}s ago)") | |
| # Explicitly delete and empty cache | |
| del info.instance | |
| del self.models[model_type] | |
| gc.collect() | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| logger.info(f"Model unloaded: {model_type.value}") | |
| return | |
| def _check_memory(self): | |
| """Check VRAM usage and unload if critical.""" | |
| if self.device.type != 'cuda': | |
| return | |
| try: | |
| total_memory = torch.cuda.get_device_properties(0).total_memory | |
| allocated = torch.cuda.memory_allocated(0) | |
| reserved = torch.cuda.memory_reserved(0) | |
| usage = reserved / total_memory | |
| if usage > self.max_vram_usage: | |
| logger.warning(f"High VRAM usage: {usage:.1%} - Unloading implicit models...") | |
| self._unload_least_used() | |
| except Exception as e: | |
| logger.warning(f"Failed to check memory: {e}") | |
| def get_model(self, model_type: ModelType, loader_func: callable, **kwargs) -> Any: | |
| """ | |
| Get a model instance, loading it if necessary. | |
| Args: | |
| model_type: Type of model | |
| loader_func: Function that returns the model instance | |
| **kwargs: Arguments for the loader function | |
| """ | |
| with self._lock: | |
| # Update last used if exists | |
| if model_type in self.models: | |
| self.models[model_type].last_used = time.time() | |
| return self.models[model_type].instance | |
| # Check memory before loading | |
| self._check_memory() | |
| # If we need to load a heavy model (like SD), unload others first aggressively | |
| if model_type in [ModelType.DEMORPH_SD, ModelType.DEMORPH_GAN]: | |
| # these are heavy, unload everything else except maybe detector | |
| for m_type in list(self.models.keys()): | |
| if m_type != ModelType.DETECTOR: | |
| self.unload_model(m_type) | |
| logger.info(f"Loading model: {model_type.value}...") | |
| start_time = time.time() | |
| try: | |
| # Load the model | |
| instance = loader_func(**kwargs) | |
| # Register | |
| self.models[model_type] = ModelInfo( | |
| model_type=model_type, | |
| instance=instance, | |
| last_used=time.time(), | |
| device=str(self.device) | |
| ) | |
| logger.info(f"Model loaded: {model_type.value} in {time.time() - start_time:.2f}s") | |
| return instance | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| logger.error("CUDA OOM during load! Attempting emergency cleanup...") | |
| self.unload_all() | |
| torch.cuda.empty_cache() | |
| # Try once more? Or just fail | |
| raise e | |
| raise e | |
| def unload_model(self, model_type: ModelType): | |
| """Explicitly unload a model.""" | |
| with self._lock: | |
| if model_type in self.models: | |
| logger.info(f"Unloading model: {model_type.value}") | |
| del self.models[model_type].instance | |
| del self.models[model_type] | |
| gc.collect() | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| def unload_all(self): | |
| """Unload all models.""" | |
| with self._lock: | |
| logger.info("Unloading ALL models") | |
| self.models.clear() | |
| gc.collect() | |
| if self.device.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| # Global accessor | |
| def get_model_manager(): | |
| return ModelManager() | |