File size: 5,775 Bytes
18b382b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Model caching utilities for Depth Anything 3.

Provides model caching functionality to avoid reloading model weights on every instantiation.
This significantly reduces latency for repeated model creation (2-5s gain).
"""

from __future__ import annotations

import threading
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn

from depth_anything_3.utils.logger import logger


class ModelCache:
    """
    Thread-safe singleton cache for Depth Anything 3 models.

    Caches loaded model weights to avoid reloading from disk on every instantiation.
    Each unique combination of (model_name, device) is cached separately.

    Usage:
        cache = ModelCache()
        model = cache.get(model_name, device, loader_fn)
        # loader_fn is only called if cache miss

    Thread Safety:
        Uses threading.Lock to ensure thread-safe access to cache.

    Memory Management:
        - Models are kept in cache until explicitly cleared
        - Use clear() to free memory when needed
        - Use clear_device() to clear specific device models
    """

    _instance: Optional["ModelCache"] = None
    _lock = threading.Lock()

    def __new__(cls):
        """Singleton pattern to ensure single cache instance."""
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    cls._instance = super().__new__(cls)
                    cls._instance._initialized = False
        return cls._instance

    def __init__(self):
        """Initialize cache storage."""
        if self._initialized:
            return

        self._cache: Dict[Tuple[str, str], nn.Module] = {}
        self._cache_lock = threading.Lock()
        self._initialized = True
        logger.info("ModelCache initialized")

    def get(
        self,
        model_name: str,
        device: torch.device | str,
        loader_fn: callable,
    ) -> nn.Module:
        """
        Get cached model or load if not in cache.

        Args:
            model_name: Name of the model (e.g., "da3-large")
            device: Target device (cuda, mps, cpu)
            loader_fn: Function to load model if cache miss
                      Should return nn.Module

        Returns:
            Cached or freshly loaded model on specified device

        Example:
            >>> cache = ModelCache()
            >>> model = cache.get(
            ...     "da3-large",
            ...     "cuda",
            ...     lambda: create_model()
            ... )
        """
        device_str = str(device)
        cache_key = (model_name, device_str)

        with self._cache_lock:
            if cache_key in self._cache:
                logger.debug(f"Model cache HIT: {model_name} on {device_str}")
                return self._cache[cache_key]

            logger.info(f"Model cache MISS: {model_name} on {device_str}. Loading...")
            model = loader_fn()
            self._cache[cache_key] = model
            logger.info(f"Model cached: {model_name} on {device_str}")

            return model

    def clear(self) -> None:
        """
        Clear entire cache and free memory.

        Removes all cached models and forces garbage collection.
        Useful when switching between many different models.
        """
        with self._cache_lock:
            num_cached = len(self._cache)
            self._cache.clear()

            # Force garbage collection to free GPU memory
            import gc

            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if hasattr(torch, "mps") and torch.backends.mps.is_available():
                torch.mps.empty_cache()

            logger.info(f"Model cache cleared ({num_cached} models removed)")

    def clear_device(self, device: torch.device | str) -> None:
        """
        Clear all models on specific device.

        Args:
            device: Device to clear (e.g., "cuda", "mps", "cpu")

        Example:
            >>> cache = ModelCache()
            >>> cache.clear_device("cuda")  # Clear all CUDA models
        """
        device_str = str(device)

        with self._cache_lock:
            keys_to_remove = [key for key in self._cache if key[1] == device_str]
            for key in keys_to_remove:
                del self._cache[key]

            # Free device memory
            if "cuda" in device_str and torch.cuda.is_available():
                torch.cuda.empty_cache()
            elif "mps" in device_str and hasattr(torch, "mps") and torch.backends.mps.is_available():
                torch.mps.empty_cache()

            logger.info(f"Model cache cleared for device {device_str} ({len(keys_to_remove)} models removed)")

    def get_cache_info(self) -> Dict[str, int]:
        """
        Get cache statistics.

        Returns:
            Dictionary with cache info:
                - total: Total number of cached models
                - by_device: Number of models per device
        """
        with self._cache_lock:
            info = {
                "total": len(self._cache),
                "by_device": {},
            }

            for model_name, device_str in self._cache.keys():
                if device_str not in info["by_device"]:
                    info["by_device"][device_str] = 0
                info["by_device"][device_str] += 1

            return info


# Global singleton instance
_global_cache = ModelCache()


def get_model_cache() -> ModelCache:
    """
    Get global model cache instance.

    Returns:
        Singleton ModelCache instance

    Example:
        >>> from depth_anything_3.cache import get_model_cache
        >>> cache = get_model_cache()
        >>> cache.clear()
    """
    return _global_cache