MorphGuard / src /model_manager.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
6.17 kB
import os
import gc
import time
import torch
import logging
import threading
from enum import Enum
from typing import Dict, Any, Optional
from dataclasses import dataclass
# Setup logging
logger = logging.getLogger("ModelManager")
class ModelType(Enum):
DETECTOR = "detector"
DEMORPH_TRANSFORMER = "demorph_transformer"
DEMORPH_GAN = "demorph_gan"
DEMORPH_SD = "demorph_sd"
DEMORPH_LDM = "demorph_ldm"
LIVENESS = "liveness"
FACE_MATCHER = "face_matcher"
@dataclass
class ModelInfo:
model_type: ModelType
instance: Any
last_used: float
size_mb: float = 0
device: str = "cpu"
class ModelManager:
"""
Singleton for managing AI model lifecycles.
Prevents CUDA OOM by unloading unused models.
"""
_instance = None
_lock = threading.RLock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self.models: Dict[ModelType, ModelInfo] = {}
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Max VRAM config (approximate safety margin)
self.max_vram_usage = 0.9 # 90%
logger.info(f"ModelManager initialized on {self.device}")
def _unload_least_used(self, required_mb: float = 0):
"""Unload the least recently used model to free space."""
if not self.models:
return
# Sort by last used (oldest first)
sorted_models = sorted(self.models.items(), key=lambda x: x[1].last_used)
# Don't unload the detector if possible, it's critical
for model_type, info in sorted_models:
if model_type == ModelType.DETECTOR and len(self.models) > 1:
continue
logger.info(f"Unloading model: {model_type.value} (Last used: {time.time() - info.last_used:.1f}s ago)")
# Explicitly delete and empty cache
del info.instance
del self.models[model_type]
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
logger.info(f"Model unloaded: {model_type.value}")
return
def _check_memory(self):
"""Check VRAM usage and unload if critical."""
if self.device.type != 'cuda':
return
try:
total_memory = torch.cuda.get_device_properties(0).total_memory
allocated = torch.cuda.memory_allocated(0)
reserved = torch.cuda.memory_reserved(0)
usage = reserved / total_memory
if usage > self.max_vram_usage:
logger.warning(f"High VRAM usage: {usage:.1%} - Unloading implicit models...")
self._unload_least_used()
except Exception as e:
logger.warning(f"Failed to check memory: {e}")
def get_model(self, model_type: ModelType, loader_func: callable, **kwargs) -> Any:
"""
Get a model instance, loading it if necessary.
Args:
model_type: Type of model
loader_func: Function that returns the model instance
**kwargs: Arguments for the loader function
"""
with self._lock:
# Update last used if exists
if model_type in self.models:
self.models[model_type].last_used = time.time()
return self.models[model_type].instance
# Check memory before loading
self._check_memory()
# If we need to load a heavy model (like SD), unload others first aggressively
if model_type in [ModelType.DEMORPH_SD, ModelType.DEMORPH_GAN]:
# these are heavy, unload everything else except maybe detector
for m_type in list(self.models.keys()):
if m_type != ModelType.DETECTOR:
self.unload_model(m_type)
logger.info(f"Loading model: {model_type.value}...")
start_time = time.time()
try:
# Load the model
instance = loader_func(**kwargs)
# Register
self.models[model_type] = ModelInfo(
model_type=model_type,
instance=instance,
last_used=time.time(),
device=str(self.device)
)
logger.info(f"Model loaded: {model_type.value} in {time.time() - start_time:.2f}s")
return instance
except RuntimeError as e:
if "out of memory" in str(e):
logger.error("CUDA OOM during load! Attempting emergency cleanup...")
self.unload_all()
torch.cuda.empty_cache()
# Try once more? Or just fail
raise e
raise e
def unload_model(self, model_type: ModelType):
"""Explicitly unload a model."""
with self._lock:
if model_type in self.models:
logger.info(f"Unloading model: {model_type.value}")
del self.models[model_type].instance
del self.models[model_type]
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
def unload_all(self):
"""Unload all models."""
with self._lock:
logger.info("Unloading ALL models")
self.models.clear()
gc.collect()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
# Global accessor
def get_model_manager():
return ModelManager()