File size: 7,657 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Model Persistence Manager for LightDiffusion
Keeps models loaded in VRAM for instant reuse between generations
"""

from typing import Dict, Optional, Any, Tuple, List
import logging
from src.Device import Device


class ModelCache:
    """Global model cache to keep models loaded in VRAM"""

    def __init__(self):
        self._cached_checkpoints: Dict[str, Tuple[Any, Any, Any]] = {}
        self._cached_taesd: Dict[Tuple[int, bool], Any] = {}
        self._cached_conditions: Dict[str, Any] = {}
        self._last_checkpoint_path: Optional[str] = None
        self._keep_models_loaded: bool = True
        self._loaded_models_list: List[Any] = []
        self._max_cached_checkpoints: int = 3
        
        # Prefetching support
        self._prefetched_state_dict: Optional[dict] = None
        self._prefetched_path: Optional[str] = None

    def cache_taesd(self, channels: int, flux: bool, model: Any) -> None:
        """Cache a TAESD model instance"""
        self._cached_taesd[(channels, flux)] = model

    def get_taesd(self, channels: int, flux: bool) -> Optional[Any]:
        """Get a cached TAESD model instance"""
        return self._cached_taesd.get((channels, flux))

    def set_prefetched_model(self, path: str, state_dict: dict) -> None:
        """Store a prefetched state dict in CPU RAM"""
        self._prefetched_path = path
        self._prefetched_state_dict = state_dict
        logging.info(f"ModelCache: Stored prefetched model: {path}")

    def get_prefetched_model(self, path: str) -> Optional[dict]:
        """Get prefetched state dict if path matches"""
        if self._prefetched_path == path:
            logging.info(f"ModelCache: Using prefetched state dict for {path}")
            return self._prefetched_state_dict
        return None

    def clear_prefetch(self) -> None:
        """Clear prefetched data from RAM"""
        self._prefetched_state_dict = None
        self._prefetched_path = None

    def set_keep_models_loaded(self, keep_loaded: bool) -> None:
        """Enable or disable keeping models loaded in VRAM"""
        self._keep_models_loaded = keep_loaded
        if not keep_loaded:
            self.clear_cache()

    def get_keep_models_loaded(self) -> bool:
        """Check if models should be kept loaded"""
        return self._keep_models_loaded

    def cache_checkpoint(
        self, checkpoint_path: str, model_patcher: Any, clip: Any, vae: Any
    ) -> None:
        """Cache a loaded checkpoint"""
        if not self._keep_models_loaded:
            return

        # Limit cache size
        if len(self._cached_checkpoints) >= self._max_cached_checkpoints and checkpoint_path not in self._cached_checkpoints:
            # Remove oldest (first) entry
            oldest_path = next(iter(self._cached_checkpoints))
            old_patcher, _, _ = self._cached_checkpoints.pop(oldest_path)
            try:
                if oldest_path != checkpoint_path:
                    logging.info(f"ModelCache: Evicting {oldest_path} to make room")
                    if hasattr(old_patcher, "model_unload"):
                        old_patcher.model_unload()
            except Exception:
                pass

        self._last_checkpoint_path = checkpoint_path
        self._cached_checkpoints[checkpoint_path] = (model_patcher, clip, vae)
        logging.info(f"Cached checkpoint: {checkpoint_path} (Total cached: {len(self._cached_checkpoints)})")

    def get_cached_checkpoint(
        self, checkpoint_path: str
    ) -> Optional[Tuple[Any, Any, Any]]:
        """Get cached checkpoint if available"""
        if not self._keep_models_loaded:
            return None

        if checkpoint_path in self._cached_checkpoints:
            logging.info(f"Using cached checkpoint: {checkpoint_path}")
            self._last_checkpoint_path = checkpoint_path
            return self._cached_checkpoints[checkpoint_path]
        return None

    def cache_sampling_models(self, models: List[Any]) -> None:
        """Cache models used during sampling"""
        if not self._keep_models_loaded:
            return

        self._loaded_models_list = models.copy()

    def get_cached_sampling_models(self) -> List[Any]:
        """Get cached sampling models"""
        if not self._keep_models_loaded:
            return []
        return self._loaded_models_list

    def prevent_model_cleanup(self, conds: Dict[str, Any], models: List[Any]) -> None:
        """Prevent models from being cleaned up if caching is enabled"""
        if not self._keep_models_loaded:
            # Original cleanup behavior
            from src.cond import cond_util

            cond_util.cleanup_additional_models(models)

            control_cleanup = []
            for k in conds:
                control_cleanup += cond_util.get_models_from_cond(conds[k], "control")
            cond_util.cleanup_additional_models(set(control_cleanup))
        else:
            # Keep models loaded - only cleanup control models that aren't main models
            control_cleanup = []
            for k in conds:
                from src.cond import cond_util

                control_cleanup += cond_util.get_models_from_cond(conds[k], "control")

            # Only cleanup control models, not the main models
            from src.cond import cond_util

            cond_util.cleanup_additional_models(set(control_cleanup))
            logging.info("Kept main models loaded in VRAM for reuse")

    def clear_cache(self) -> None:
        """Clear all cached models"""
        for path, (model_patcher, _, _) in self._cached_checkpoints.items():
            try:
                if hasattr(model_patcher, "model_unload"):
                    model_patcher.model_unload()
            except Exception as e:
                logging.warning(f"Error unloading cached model {path}: {e}")

        self._cached_checkpoints.clear()
        self._cached_taesd.clear()
        self._cached_conditions.clear()
        self._last_checkpoint_path = None
        self._loaded_models_list.clear()

        # Force cleanup
        Device.cleanup_models(keep_clone_weights_loaded=False)
        Device.soft_empty_cache(force=True)
        logging.info("Cleared model cache and freed VRAM")

    def get_memory_info(self) -> Dict[str, Any]:
        """Get memory usage information"""
        device = Device.get_torch_device()
        total_mem = Device.get_total_memory(device)
        free_mem = Device.get_free_memory(device)
        used_mem = total_mem - free_mem

        return {
            "total_vram": total_mem / (1024 * 1024 * 1024),  # GB
            "used_vram": used_mem / (1024 * 1024 * 1024),  # GB
            "free_vram": free_mem / (1024 * 1024 * 1024),  # GB
            "cached_models": len(self._cached_checkpoints),
            "keep_loaded": self._keep_models_loaded,
            "has_cached_checkpoint": len(self._cached_checkpoints) > 0,
        }


# Global model cache instance
model_cache = ModelCache()


def get_model_cache() -> ModelCache:
    """Get the global model cache instance"""
    return model_cache


def set_keep_models_loaded(keep_loaded: bool) -> None:
    """Global function to enable/disable model persistence"""
    model_cache.set_keep_models_loaded(keep_loaded)


def get_keep_models_loaded() -> bool:
    """Global function to check if models should be kept loaded"""
    return model_cache.get_keep_models_loaded()


def clear_model_cache() -> None:
    """Global function to clear model cache"""
    model_cache.clear_cache()


def get_memory_info() -> Dict[str, Any]:
    """Global function to get memory info"""
    return model_cache.get_memory_info()