File size: 15,975 Bytes
ca80d1d
 
 
7d583e3
 
ca80d1d
 
 
 
 
 
 
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
7d583e3
 
ca80d1d
7d583e3
ca80d1d
 
 
 
 
 
 
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
 
 
 
7d583e3
 
 
ca80d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
7d583e3
ca80d1d
7d583e3
ca80d1d
7d583e3
ca80d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d583e3
 
 
ca80d1d
 
 
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
7d583e3
 
 
 
 
ca80d1d
 
 
 
 
7d583e3
ca80d1d
7d583e3
ca80d1d
 
 
 
7d583e3
ca80d1d
7d583e3
ca80d1d
 
 
7d583e3
 
ca80d1d
7d583e3
 
 
 
 
 
 
 
 
 
ca80d1d
 
7d583e3
 
 
 
 
 
ca80d1d
 
 
 
 
 
 
 
 
7d583e3
 
 
ca80d1d
 
7d583e3
 
 
 
ca80d1d
 
 
 
 
7d583e3
ca80d1d
 
 
 
 
 
 
 
7d583e3
 
 
 
 
 
 
ca80d1d
7d583e3
ca80d1d
 
 
 
7d583e3
ca80d1d
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
 
7d583e3
 
 
 
 
 
 
 
 
ca80d1d
 
7d583e3
ca80d1d
 
 
 
7d583e3
ca80d1d
 
7d583e3
ca80d1d
 
 
 
 
 
 
 
7d583e3
 
 
 
ca80d1d
 
 
 
7d583e3
ca80d1d
7d583e3
 
ca80d1d
 
7d583e3
 
ca80d1d
 
 
7d583e3
ca80d1d
7d583e3
 
 
 
 
 
ca80d1d
 
 
 
 
 
 
 
 
 
 
 
7d583e3
ca80d1d
7d583e3
 
 
ca80d1d
7d583e3
ca80d1d
7d583e3
ca80d1d
7d583e3
ca80d1d
7d583e3
ca80d1d
 
 
 
 
 
 
 
 
 
7d583e3
 
 
 
 
 
 
 
 
 
ca80d1d
7d583e3
 
 
 
ca80d1d
 
 
 
 
 
 
 
 
 
 
7d583e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca80d1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
import logging
import gc
import time
from enum import IntEnum
from typing import Dict, Any, Optional, Callable, List
from dataclasses import dataclass, field
from threading import Lock
import torch

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class ModelPriority(IntEnum):
    """
    Model priority levels for memory management.

    Higher priority models are kept loaded longer under memory pressure.
    """
    CRITICAL = 100      # Never unload (e.g., OpenCLIP for analysis)
    HIGH = 80           # Currently active pipeline
    MEDIUM = 50         # Recently used models
    LOW = 20            # Inactive pipelines, can be evicted
    DISPOSABLE = 0      # Temporary models, evict first


@dataclass
class ModelInfo:
    """
    Information about a registered model.

    Attributes:
        name: Unique model identifier
        loader: Callable that returns the loaded model
        is_critical: If True, model won't be unloaded under memory pressure
        priority: ModelPriority level for eviction decisions
        estimated_memory_gb: Estimated GPU memory usage
        model_group: Group name for mutual exclusion (e.g., "pipeline")
        is_loaded: Whether model is currently loaded
        last_used: Timestamp of last use
        model_instance: The actual model object
    """
    name: str
    loader: Callable[[], Any]
    is_critical: bool = False
    priority: int = ModelPriority.MEDIUM
    estimated_memory_gb: float = 0.0
    model_group: str = ""  # For mutual exclusion (e.g., "pipeline")
    is_loaded: bool = False
    last_used: float = 0.0
    model_instance: Any = None


class ModelManager:
    """
    Singleton model manager for unified model lifecycle management.

    Handles lazy loading, caching, priority-based eviction, and mutual
    exclusion for pipeline models. Designed for memory-constrained
    environments like Google Colab and HuggingFace Spaces.

    Features:
        - Priority-based model eviction under memory pressure
        - Mutual exclusion for pipeline models (only one active at a time)
        - Automatic memory monitoring and cleanup
        - Support for model groups and dependencies

    Example:
        >>> manager = get_model_manager()
        >>> manager.register_model(
        ...     name="sdxl_pipeline",
        ...     loader=load_sdxl,
        ...     priority=ModelPriority.HIGH,
        ...     model_group="pipeline"
        ... )
        >>> pipeline = manager.load_model("sdxl_pipeline")
    """

    _instance = None
    _lock = Lock()

    # Known model groups for mutual exclusion
    PIPELINE_GROUP = "pipeline"  # Only one pipeline can be loaded at a time

    def __new__(cls):
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        if self._initialized:
            return

        self._models: Dict[str, ModelInfo] = {}
        self._memory_threshold = 0.80  # Trigger cleanup at 80% GPU memory usage
        self._high_memory_threshold = 0.90  # Critical threshold for aggressive cleanup
        self._device = self._detect_device()
        self._active_pipeline: Optional[str] = None  # Track currently active pipeline

        logger.info(f"ModelManager initialized on {self._device}")
        self._initialized = True

    def _detect_device(self) -> str:
        """Detect best available device."""
        if torch.cuda.is_available():
            return "cuda"
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            return "mps"
        return "cpu"

    def register_model(
        self,
        name: str,
        loader: Callable[[], Any],
        is_critical: bool = False,
        priority: int = ModelPriority.MEDIUM,
        estimated_memory_gb: float = 0.0,
        model_group: str = ""
    ):
        """
        Register a model for managed loading.

        Parameters
        ----------
        name : str
            Unique model identifier
        loader : callable
            Function that returns the loaded model
        is_critical : bool
            If True, model won't be unloaded under memory pressure
        priority : int
            ModelPriority level for eviction decisions
        estimated_memory_gb : float
            Estimated GPU memory usage in GB
        model_group : str
            Group name for mutual exclusion (e.g., "pipeline")
        """
        if name in self._models:
            logger.warning(f"Model '{name}' already registered, updating")

        # Critical models always have highest priority
        if is_critical:
            priority = ModelPriority.CRITICAL

        self._models[name] = ModelInfo(
            name=name,
            loader=loader,
            is_critical=is_critical,
            priority=priority,
            estimated_memory_gb=estimated_memory_gb,
            model_group=model_group,
            is_loaded=False,
            last_used=0.0,
            model_instance=None
        )
        logger.info(f"Registered model: {name} (priority={priority}, group={model_group}, ~{estimated_memory_gb:.1f}GB)")

    def load_model(self, name: str, update_priority: Optional[int] = None) -> Any:
        """
        Load a model by name. Returns cached instance if already loaded.

        Implements mutual exclusion for pipeline models - loading a new
        pipeline will unload any existing pipeline first.

        Parameters
        ----------
        name : str
            Model identifier
        update_priority : int, optional
            If provided, update the model's priority after loading

        Returns
        -------
        Any
            Loaded model instance

        Raises
        ------
        KeyError
            If model not registered
        RuntimeError
            If loading fails
        """
        if name not in self._models:
            raise KeyError(f"Model '{name}' not registered")

        model_info = self._models[name]

        # Return cached instance
        if model_info.is_loaded and model_info.model_instance is not None:
            model_info.last_used = time.time()
            if update_priority is not None:
                model_info.priority = update_priority
            logger.debug(f"Using cached model: {name}")
            return model_info.model_instance

        # Handle mutual exclusion for pipeline group
        if model_info.model_group == self.PIPELINE_GROUP:
            self._ensure_pipeline_exclusion(name)

        # Check memory pressure before loading
        self.check_memory_pressure()

        # Load the model
        try:
            logger.info(f"Loading model: {name}")
            start_time = time.time()

            model_instance = model_info.loader()

            model_info.model_instance = model_instance
            model_info.is_loaded = True
            model_info.last_used = time.time()

            if update_priority is not None:
                model_info.priority = update_priority

            # Track active pipeline
            if model_info.model_group == self.PIPELINE_GROUP:
                self._active_pipeline = name

            load_time = time.time() - start_time
            logger.info(f"Model '{name}' loaded in {load_time:.1f}s")

            return model_instance

        except Exception as e:
            logger.error(f"Failed to load model '{name}': {e}")
            raise RuntimeError(f"Model loading failed: {e}")

    def _ensure_pipeline_exclusion(self, new_pipeline: str) -> None:
        """
        Ensure only one pipeline is loaded at a time.

        Unloads any existing pipeline before loading a new one.

        Parameters
        ----------
        new_pipeline : str
            Name of the pipeline about to be loaded
        """
        for name, info in self._models.items():
            if (info.model_group == self.PIPELINE_GROUP and
                info.is_loaded and
                name != new_pipeline):
                logger.info(f"Unloading {name} to make room for {new_pipeline}")
                self.unload_model(name)

    def unload_model(self, name: str) -> bool:
        """
        Unload a specific model to free memory.

        Parameters
        ----------
        name : str
            Model identifier

        Returns
        -------
        bool
            True if model was unloaded successfully
        """
        if name not in self._models:
            return False

        model_info = self._models[name]

        if not model_info.is_loaded:
            return True

        try:
            logger.info(f"Unloading model: {name}")

            # Delete model instance
            if model_info.model_instance is not None:
                del model_info.model_instance

            model_info.model_instance = None
            model_info.is_loaded = False

            # Update active pipeline tracking
            if self._active_pipeline == name:
                self._active_pipeline = None

            # Cleanup
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()

            logger.info(f"Model '{name}' unloaded")
            return True

        except Exception as e:
            logger.error(f"Error unloading model '{name}': {e}")
            return False

    def check_memory_pressure(self) -> bool:
        """
        Check GPU memory usage and unload low-priority models if needed.

        Uses priority-based eviction: lower priority models are unloaded first,
        then falls back to least-recently-used within same priority tier.

        Returns
        -------
        bool
            True if cleanup was performed
        """
        if not torch.cuda.is_available():
            return False

        allocated = torch.cuda.memory_allocated() / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        usage_ratio = allocated / total

        if usage_ratio < self._memory_threshold:
            return False

        logger.warning(f"Memory pressure detected: {usage_ratio:.1%} used")

        # Find evictable models (not critical, loaded)
        # Sort by priority (ascending) then by last_used (ascending)
        evictable = [
            (name, info) for name, info in self._models.items()
            if info.is_loaded and info.priority < ModelPriority.CRITICAL
        ]
        evictable.sort(key=lambda x: (x[1].priority, x[1].last_used))

        # Unload models starting from lowest priority
        cleaned = False
        for name, info in evictable:
            self.unload_model(name)
            cleaned = True

            # Re-check memory
            new_ratio = torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory
            if new_ratio < self._memory_threshold * 0.7:  # Target 70% of threshold
                break

        return cleaned

    def force_cleanup(self, keep_critical_only: bool = True):
        """
        Force cleanup models and clear caches.

        Parameters
        ----------
        keep_critical_only : bool
            If True, only keep CRITICAL priority models loaded
        """
        logger.info("Force cleanup initiated")

        # Unload models based on priority
        threshold = ModelPriority.CRITICAL if keep_critical_only else ModelPriority.HIGH
        for name, info in list(self._models.items()):
            if info.is_loaded and info.priority < threshold:
                self.unload_model(name)

        # Aggressive garbage collection
        for _ in range(5):
            gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.synchronize()

        logger.info("Force cleanup completed")

    def update_priority(self, name: str, priority: int) -> bool:
        """
        Update a model's priority level.

        Parameters
        ----------
        name : str
            Model identifier
        priority : int
            New priority level

        Returns
        -------
        bool
            True if priority was updated
        """
        if name not in self._models:
            return False

        self._models[name].priority = priority
        logger.debug(f"Updated priority for {name} to {priority}")
        return True

    def get_active_pipeline(self) -> Optional[str]:
        """
        Get the name of currently active pipeline.

        Returns
        -------
        str or None
            Name of active pipeline, or None if no pipeline is loaded
        """
        return self._active_pipeline

    def switch_to_pipeline(
        self,
        name: str,
        loader: Optional[Callable[[], Any]] = None
    ) -> Any:
        """
        Switch to a different pipeline, unloading current one.

        This is a convenience method for pipeline switching that handles
        mutual exclusion automatically.

        Parameters
        ----------
        name : str
            Pipeline name to switch to
        loader : callable, optional
            Loader function if pipeline not already registered

        Returns
        -------
        Any
            The loaded pipeline instance

        Raises
        ------
        KeyError
            If pipeline not registered and no loader provided
        """
        # Register if needed
        if name not in self._models and loader is not None:
            self.register_model(
                name=name,
                loader=loader,
                priority=ModelPriority.HIGH,
                model_group=self.PIPELINE_GROUP
            )

        # Load will handle unloading of current pipeline
        return self.load_model(name, update_priority=ModelPriority.HIGH)

    def get_memory_status(self) -> Dict[str, Any]:
        """
        Get detailed memory status.

        Returns:
            Dictionary with memory statistics
        """
        status = {
            "device": self._device,
            "models": {},
            "total_estimated_gb": 0.0
        }

        # Model status
        for name, info in self._models.items():
            status["models"][name] = {
                "loaded": info.is_loaded,
                "critical": info.is_critical,
                "estimated_gb": info.estimated_memory_gb,
                "last_used": info.last_used
            }
            if info.is_loaded:
                status["total_estimated_gb"] += info.estimated_memory_gb

        # GPU memory
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            cached = torch.cuda.memory_reserved() / 1024**3

            status["gpu"] = {
                "allocated_gb": round(allocated, 2),
                "total_gb": round(total, 2),
                "cached_gb": round(cached, 2),
                "free_gb": round(total - allocated, 2),
                "usage_percent": round((allocated / total) * 100, 1)
            }

        return status

    def get_loaded_models(self) -> list:
        """Get list of currently loaded model names."""
        return [name for name, info in self._models.items() if info.is_loaded]

    def is_model_loaded(self, name: str) -> bool:
        """Check if a specific model is loaded."""
        if name not in self._models:
            return False
        return self._models[name].is_loaded


# Global singleton instance
_model_manager = None


def get_model_manager() -> ModelManager:
    """Get the global ModelManager singleton instance."""
    global _model_manager
    if _model_manager is None:
        _model_manager = ModelManager()
    return _model_manager