Spaces:
Running on Zero
Running on Zero
File size: 7,657 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
Model Persistence Manager for LightDiffusion
Keeps models loaded in VRAM for instant reuse between generations
"""
from typing import Dict, Optional, Any, Tuple, List
import logging
from src.Device import Device
class ModelCache:
"""Global model cache to keep models loaded in VRAM"""
def __init__(self):
self._cached_checkpoints: Dict[str, Tuple[Any, Any, Any]] = {}
self._cached_taesd: Dict[Tuple[int, bool], Any] = {}
self._cached_conditions: Dict[str, Any] = {}
self._last_checkpoint_path: Optional[str] = None
self._keep_models_loaded: bool = True
self._loaded_models_list: List[Any] = []
self._max_cached_checkpoints: int = 3
# Prefetching support
self._prefetched_state_dict: Optional[dict] = None
self._prefetched_path: Optional[str] = None
def cache_taesd(self, channels: int, flux: bool, model: Any) -> None:
"""Cache a TAESD model instance"""
self._cached_taesd[(channels, flux)] = model
def get_taesd(self, channels: int, flux: bool) -> Optional[Any]:
"""Get a cached TAESD model instance"""
return self._cached_taesd.get((channels, flux))
def set_prefetched_model(self, path: str, state_dict: dict) -> None:
"""Store a prefetched state dict in CPU RAM"""
self._prefetched_path = path
self._prefetched_state_dict = state_dict
logging.info(f"ModelCache: Stored prefetched model: {path}")
def get_prefetched_model(self, path: str) -> Optional[dict]:
"""Get prefetched state dict if path matches"""
if self._prefetched_path == path:
logging.info(f"ModelCache: Using prefetched state dict for {path}")
return self._prefetched_state_dict
return None
def clear_prefetch(self) -> None:
"""Clear prefetched data from RAM"""
self._prefetched_state_dict = None
self._prefetched_path = None
def set_keep_models_loaded(self, keep_loaded: bool) -> None:
"""Enable or disable keeping models loaded in VRAM"""
self._keep_models_loaded = keep_loaded
if not keep_loaded:
self.clear_cache()
def get_keep_models_loaded(self) -> bool:
"""Check if models should be kept loaded"""
return self._keep_models_loaded
def cache_checkpoint(
self, checkpoint_path: str, model_patcher: Any, clip: Any, vae: Any
) -> None:
"""Cache a loaded checkpoint"""
if not self._keep_models_loaded:
return
# Limit cache size
if len(self._cached_checkpoints) >= self._max_cached_checkpoints and checkpoint_path not in self._cached_checkpoints:
# Remove oldest (first) entry
oldest_path = next(iter(self._cached_checkpoints))
old_patcher, _, _ = self._cached_checkpoints.pop(oldest_path)
try:
if oldest_path != checkpoint_path:
logging.info(f"ModelCache: Evicting {oldest_path} to make room")
if hasattr(old_patcher, "model_unload"):
old_patcher.model_unload()
except Exception:
pass
self._last_checkpoint_path = checkpoint_path
self._cached_checkpoints[checkpoint_path] = (model_patcher, clip, vae)
logging.info(f"Cached checkpoint: {checkpoint_path} (Total cached: {len(self._cached_checkpoints)})")
def get_cached_checkpoint(
self, checkpoint_path: str
) -> Optional[Tuple[Any, Any, Any]]:
"""Get cached checkpoint if available"""
if not self._keep_models_loaded:
return None
if checkpoint_path in self._cached_checkpoints:
logging.info(f"Using cached checkpoint: {checkpoint_path}")
self._last_checkpoint_path = checkpoint_path
return self._cached_checkpoints[checkpoint_path]
return None
def cache_sampling_models(self, models: List[Any]) -> None:
"""Cache models used during sampling"""
if not self._keep_models_loaded:
return
self._loaded_models_list = models.copy()
def get_cached_sampling_models(self) -> List[Any]:
"""Get cached sampling models"""
if not self._keep_models_loaded:
return []
return self._loaded_models_list
def prevent_model_cleanup(self, conds: Dict[str, Any], models: List[Any]) -> None:
"""Prevent models from being cleaned up if caching is enabled"""
if not self._keep_models_loaded:
# Original cleanup behavior
from src.cond import cond_util
cond_util.cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += cond_util.get_models_from_cond(conds[k], "control")
cond_util.cleanup_additional_models(set(control_cleanup))
else:
# Keep models loaded - only cleanup control models that aren't main models
control_cleanup = []
for k in conds:
from src.cond import cond_util
control_cleanup += cond_util.get_models_from_cond(conds[k], "control")
# Only cleanup control models, not the main models
from src.cond import cond_util
cond_util.cleanup_additional_models(set(control_cleanup))
logging.info("Kept main models loaded in VRAM for reuse")
def clear_cache(self) -> None:
"""Clear all cached models"""
for path, (model_patcher, _, _) in self._cached_checkpoints.items():
try:
if hasattr(model_patcher, "model_unload"):
model_patcher.model_unload()
except Exception as e:
logging.warning(f"Error unloading cached model {path}: {e}")
self._cached_checkpoints.clear()
self._cached_taesd.clear()
self._cached_conditions.clear()
self._last_checkpoint_path = None
self._loaded_models_list.clear()
# Force cleanup
Device.cleanup_models(keep_clone_weights_loaded=False)
Device.soft_empty_cache(force=True)
logging.info("Cleared model cache and freed VRAM")
def get_memory_info(self) -> Dict[str, Any]:
"""Get memory usage information"""
device = Device.get_torch_device()
total_mem = Device.get_total_memory(device)
free_mem = Device.get_free_memory(device)
used_mem = total_mem - free_mem
return {
"total_vram": total_mem / (1024 * 1024 * 1024), # GB
"used_vram": used_mem / (1024 * 1024 * 1024), # GB
"free_vram": free_mem / (1024 * 1024 * 1024), # GB
"cached_models": len(self._cached_checkpoints),
"keep_loaded": self._keep_models_loaded,
"has_cached_checkpoint": len(self._cached_checkpoints) > 0,
}
# Global model cache instance
model_cache = ModelCache()
def get_model_cache() -> ModelCache:
"""Get the global model cache instance"""
return model_cache
def set_keep_models_loaded(keep_loaded: bool) -> None:
"""Global function to enable/disable model persistence"""
model_cache.set_keep_models_loaded(keep_loaded)
def get_keep_models_loaded() -> bool:
"""Global function to check if models should be kept loaded"""
return model_cache.get_keep_models_loaded()
def clear_model_cache() -> None:
"""Global function to clear model cache"""
model_cache.clear_cache()
def get_memory_info() -> Dict[str, Any]:
"""Global function to get memory info"""
return model_cache.get_memory_info()
|