Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import gc | |
| import time | |
| from enum import IntEnum | |
| from typing import Dict, Any, Optional, Callable, List | |
| from dataclasses import dataclass, field | |
| from threading import Lock | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| class ModelPriority(IntEnum): | |
| """ | |
| Model priority levels for memory management. | |
| Higher priority models are kept loaded longer under memory pressure. | |
| """ | |
| CRITICAL = 100 # Never unload (e.g., OpenCLIP for analysis) | |
| HIGH = 80 # Currently active pipeline | |
| MEDIUM = 50 # Recently used models | |
| LOW = 20 # Inactive pipelines, can be evicted | |
| DISPOSABLE = 0 # Temporary models, evict first | |
| class ModelInfo: | |
| """ | |
| Information about a registered model. | |
| Attributes: | |
| name: Unique model identifier | |
| loader: Callable that returns the loaded model | |
| is_critical: If True, model won't be unloaded under memory pressure | |
| priority: ModelPriority level for eviction decisions | |
| estimated_memory_gb: Estimated GPU memory usage | |
| model_group: Group name for mutual exclusion (e.g., "pipeline") | |
| is_loaded: Whether model is currently loaded | |
| last_used: Timestamp of last use | |
| model_instance: The actual model object | |
| """ | |
| name: str | |
| loader: Callable[[], Any] | |
| is_critical: bool = False | |
| priority: int = ModelPriority.MEDIUM | |
| estimated_memory_gb: float = 0.0 | |
| model_group: str = "" # For mutual exclusion (e.g., "pipeline") | |
| is_loaded: bool = False | |
| last_used: float = 0.0 | |
| model_instance: Any = None | |
| class ModelManager: | |
| """ | |
| Singleton model manager for unified model lifecycle management. | |
| Handles lazy loading, caching, priority-based eviction, and mutual | |
| exclusion for pipeline models. Designed for memory-constrained | |
| environments like Google Colab and HuggingFace Spaces. | |
| Features: | |
| - Priority-based model eviction under memory pressure | |
| - Mutual exclusion for pipeline models (only one active at a time) | |
| - Automatic memory monitoring and cleanup | |
| - Support for model groups and dependencies | |
| Example: | |
| >>> manager = get_model_manager() | |
| >>> manager.register_model( | |
| ... name="sdxl_pipeline", | |
| ... loader=load_sdxl, | |
| ... priority=ModelPriority.HIGH, | |
| ... model_group="pipeline" | |
| ... ) | |
| >>> pipeline = manager.load_model("sdxl_pipeline") | |
| """ | |
| _instance = None | |
| _lock = Lock() | |
| # Known model groups for mutual exclusion | |
| PIPELINE_GROUP = "pipeline" # Only one pipeline can be loaded at a time | |
| def __new__(cls): | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if self._initialized: | |
| return | |
| self._models: Dict[str, ModelInfo] = {} | |
| self._memory_threshold = 0.80 # Trigger cleanup at 80% GPU memory usage | |
| self._high_memory_threshold = 0.90 # Critical threshold for aggressive cleanup | |
| self._device = self._detect_device() | |
| self._active_pipeline: Optional[str] = None # Track currently active pipeline | |
| logger.info(f"ModelManager initialized on {self._device}") | |
| self._initialized = True | |
| def _detect_device(self) -> str: | |
| """Detect best available device.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def register_model( | |
| self, | |
| name: str, | |
| loader: Callable[[], Any], | |
| is_critical: bool = False, | |
| priority: int = ModelPriority.MEDIUM, | |
| estimated_memory_gb: float = 0.0, | |
| model_group: str = "" | |
| ): | |
| """ | |
| Register a model for managed loading. | |
| Parameters | |
| ---------- | |
| name : str | |
| Unique model identifier | |
| loader : callable | |
| Function that returns the loaded model | |
| is_critical : bool | |
| If True, model won't be unloaded under memory pressure | |
| priority : int | |
| ModelPriority level for eviction decisions | |
| estimated_memory_gb : float | |
| Estimated GPU memory usage in GB | |
| model_group : str | |
| Group name for mutual exclusion (e.g., "pipeline") | |
| """ | |
| if name in self._models: | |
| logger.warning(f"Model '{name}' already registered, updating") | |
| # Critical models always have highest priority | |
| if is_critical: | |
| priority = ModelPriority.CRITICAL | |
| self._models[name] = ModelInfo( | |
| name=name, | |
| loader=loader, | |
| is_critical=is_critical, | |
| priority=priority, | |
| estimated_memory_gb=estimated_memory_gb, | |
| model_group=model_group, | |
| is_loaded=False, | |
| last_used=0.0, | |
| model_instance=None | |
| ) | |
| logger.info(f"Registered model: {name} (priority={priority}, group={model_group}, ~{estimated_memory_gb:.1f}GB)") | |
| def load_model(self, name: str, update_priority: Optional[int] = None) -> Any: | |
| """ | |
| Load a model by name. Returns cached instance if already loaded. | |
| Implements mutual exclusion for pipeline models - loading a new | |
| pipeline will unload any existing pipeline first. | |
| Parameters | |
| ---------- | |
| name : str | |
| Model identifier | |
| update_priority : int, optional | |
| If provided, update the model's priority after loading | |
| Returns | |
| ------- | |
| Any | |
| Loaded model instance | |
| Raises | |
| ------ | |
| KeyError | |
| If model not registered | |
| RuntimeError | |
| If loading fails | |
| """ | |
| if name not in self._models: | |
| raise KeyError(f"Model '{name}' not registered") | |
| model_info = self._models[name] | |
| # Return cached instance | |
| if model_info.is_loaded and model_info.model_instance is not None: | |
| model_info.last_used = time.time() | |
| if update_priority is not None: | |
| model_info.priority = update_priority | |
| logger.debug(f"Using cached model: {name}") | |
| return model_info.model_instance | |
| # Handle mutual exclusion for pipeline group | |
| if model_info.model_group == self.PIPELINE_GROUP: | |
| self._ensure_pipeline_exclusion(name) | |
| # Check memory pressure before loading | |
| self.check_memory_pressure() | |
| # Load the model | |
| try: | |
| logger.info(f"Loading model: {name}") | |
| start_time = time.time() | |
| model_instance = model_info.loader() | |
| model_info.model_instance = model_instance | |
| model_info.is_loaded = True | |
| model_info.last_used = time.time() | |
| if update_priority is not None: | |
| model_info.priority = update_priority | |
| # Track active pipeline | |
| if model_info.model_group == self.PIPELINE_GROUP: | |
| self._active_pipeline = name | |
| load_time = time.time() - start_time | |
| logger.info(f"Model '{name}' loaded in {load_time:.1f}s") | |
| return model_instance | |
| except Exception as e: | |
| logger.error(f"Failed to load model '{name}': {e}") | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| def _ensure_pipeline_exclusion(self, new_pipeline: str) -> None: | |
| """ | |
| Ensure only one pipeline is loaded at a time. | |
| Unloads any existing pipeline before loading a new one. | |
| Parameters | |
| ---------- | |
| new_pipeline : str | |
| Name of the pipeline about to be loaded | |
| """ | |
| for name, info in self._models.items(): | |
| if (info.model_group == self.PIPELINE_GROUP and | |
| info.is_loaded and | |
| name != new_pipeline): | |
| logger.info(f"Unloading {name} to make room for {new_pipeline}") | |
| self.unload_model(name) | |
| def unload_model(self, name: str) -> bool: | |
| """ | |
| Unload a specific model to free memory. | |
| Parameters | |
| ---------- | |
| name : str | |
| Model identifier | |
| Returns | |
| ------- | |
| bool | |
| True if model was unloaded successfully | |
| """ | |
| if name not in self._models: | |
| return False | |
| model_info = self._models[name] | |
| if not model_info.is_loaded: | |
| return True | |
| try: | |
| logger.info(f"Unloading model: {name}") | |
| # Delete model instance | |
| if model_info.model_instance is not None: | |
| del model_info.model_instance | |
| model_info.model_instance = None | |
| model_info.is_loaded = False | |
| # Update active pipeline tracking | |
| if self._active_pipeline == name: | |
| self._active_pipeline = None | |
| # Cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| logger.info(f"Model '{name}' unloaded") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error unloading model '{name}': {e}") | |
| return False | |
| def check_memory_pressure(self) -> bool: | |
| """ | |
| Check GPU memory usage and unload low-priority models if needed. | |
| Uses priority-based eviction: lower priority models are unloaded first, | |
| then falls back to least-recently-used within same priority tier. | |
| Returns | |
| ------- | |
| bool | |
| True if cleanup was performed | |
| """ | |
| if not torch.cuda.is_available(): | |
| return False | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| usage_ratio = allocated / total | |
| if usage_ratio < self._memory_threshold: | |
| return False | |
| logger.warning(f"Memory pressure detected: {usage_ratio:.1%} used") | |
| # Find evictable models (not critical, loaded) | |
| # Sort by priority (ascending) then by last_used (ascending) | |
| evictable = [ | |
| (name, info) for name, info in self._models.items() | |
| if info.is_loaded and info.priority < ModelPriority.CRITICAL | |
| ] | |
| evictable.sort(key=lambda x: (x[1].priority, x[1].last_used)) | |
| # Unload models starting from lowest priority | |
| cleaned = False | |
| for name, info in evictable: | |
| self.unload_model(name) | |
| cleaned = True | |
| # Re-check memory | |
| new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory | |
| if new_ratio < self._memory_threshold * 0.7: # Target 70% of threshold | |
| break | |
| return cleaned | |
| def force_cleanup(self, keep_critical_only: bool = True): | |
| """ | |
| Force cleanup models and clear caches. | |
| Parameters | |
| ---------- | |
| keep_critical_only : bool | |
| If True, only keep CRITICAL priority models loaded | |
| """ | |
| logger.info("Force cleanup initiated") | |
| # Unload models based on priority | |
| threshold = ModelPriority.CRITICAL if keep_critical_only else ModelPriority.HIGH | |
| for name, info in list(self._models.items()): | |
| if info.is_loaded and info.priority < threshold: | |
| self.unload_model(name) | |
| # Aggressive garbage collection | |
| for _ in range(5): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| torch.cuda.synchronize() | |
| logger.info("Force cleanup completed") | |
| def update_priority(self, name: str, priority: int) -> bool: | |
| """ | |
| Update a model's priority level. | |
| Parameters | |
| ---------- | |
| name : str | |
| Model identifier | |
| priority : int | |
| New priority level | |
| Returns | |
| ------- | |
| bool | |
| True if priority was updated | |
| """ | |
| if name not in self._models: | |
| return False | |
| self._models[name].priority = priority | |
| logger.debug(f"Updated priority for {name} to {priority}") | |
| return True | |
| def get_active_pipeline(self) -> Optional[str]: | |
| """ | |
| Get the name of currently active pipeline. | |
| Returns | |
| ------- | |
| str or None | |
| Name of active pipeline, or None if no pipeline is loaded | |
| """ | |
| return self._active_pipeline | |
| def switch_to_pipeline( | |
| self, | |
| name: str, | |
| loader: Optional[Callable[[], Any]] = None | |
| ) -> Any: | |
| """ | |
| Switch to a different pipeline, unloading current one. | |
| This is a convenience method for pipeline switching that handles | |
| mutual exclusion automatically. | |
| Parameters | |
| ---------- | |
| name : str | |
| Pipeline name to switch to | |
| loader : callable, optional | |
| Loader function if pipeline not already registered | |
| Returns | |
| ------- | |
| Any | |
| The loaded pipeline instance | |
| Raises | |
| ------ | |
| KeyError | |
| If pipeline not registered and no loader provided | |
| """ | |
| # Register if needed | |
| if name not in self._models and loader is not None: | |
| self.register_model( | |
| name=name, | |
| loader=loader, | |
| priority=ModelPriority.HIGH, | |
| model_group=self.PIPELINE_GROUP | |
| ) | |
| # Load will handle unloading of current pipeline | |
| return self.load_model(name, update_priority=ModelPriority.HIGH) | |
| def get_memory_status(self) -> Dict[str, Any]: | |
| """ | |
| Get detailed memory status. | |
| Returns: | |
| Dictionary with memory statistics | |
| """ | |
| status = { | |
| "device": self._device, | |
| "models": {}, | |
| "total_estimated_gb": 0.0 | |
| } | |
| # Model status | |
| for name, info in self._models.items(): | |
| status["models"][name] = { | |
| "loaded": info.is_loaded, | |
| "critical": info.is_critical, | |
| "estimated_gb": info.estimated_memory_gb, | |
| "last_used": info.last_used | |
| } | |
| if info.is_loaded: | |
| status["total_estimated_gb"] += info.estimated_memory_gb | |
| # GPU memory | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| cached = torch.cuda.memory_reserved() / 1024**3 | |
| status["gpu"] = { | |
| "allocated_gb": round(allocated, 2), | |
| "total_gb": round(total, 2), | |
| "cached_gb": round(cached, 2), | |
| "free_gb": round(total - allocated, 2), | |
| "usage_percent": round((allocated / total) * 100, 1) | |
| } | |
| return status | |
| def get_loaded_models(self) -> list: | |
| """Get list of currently loaded model names.""" | |
| return [name for name, info in self._models.items() if info.is_loaded] | |
| def is_model_loaded(self, name: str) -> bool: | |
| """Check if a specific model is loaded.""" | |
| if name not in self._models: | |
| return False | |
| return self._models[name].is_loaded | |
| # Global singleton instance | |
| _model_manager = None | |
| def get_model_manager() -> ModelManager: | |
| """Get the global ModelManager singleton instance.""" | |
| global _model_manager | |
| if _model_manager is None: | |
| _model_manager = ModelManager() | |
| return _model_manager | |