Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,483 Bytes
7139ce5 e8c4fd0 7139ce5 e8c4fd0 7139ce5 e8c4fd0 7139ce5 e8c4fd0 7139ce5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import gc
from typing import Dict, List, Any, Set
import torch
import gradio as gr
from comfy import model_management
from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR
from comfy_integration.nodes import LoraLoader
from nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index, _ensure_model_downloaded
class ModelManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if hasattr(self, 'initialized'):
return
self.loaded_models: Dict[str, Any] = {}
self.last_active_loras: List[Dict[str, Any]] = []
self.initialized = True
print("✅ ModelManager initialized.")
def get_loaded_model_names(self) -> Set[str]:
return set(self.loaded_models.keys())
def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
print(f"--- [ModelManager] Loading model combo: '{display_name}' ---")
if display_name not in ALL_MODEL_MAP:
raise ValueError(f"Model '{display_name}' not found in configuration.")
with torch.no_grad():
_, components, _, _ = ALL_MODEL_MAP[display_name]
unet_filename = components.get('unet')
clip_filename = components.get('clip')
vae_filename = components.get('vae')
if not all([unet_filename, clip_filename, vae_filename]):
raise ValueError(f"Model '{display_name}' is missing required components (unet, clip, or vae) in model_list.yaml.")
unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]()
clip_loader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]()
print(" - Loading UNET...")
unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default")
print(" - Loading CLIP...")
clip_tuple = clip_loader.load_clip(clip_name=clip_filename, type="qwen_image", device="default")
print(" - Loading VAE...")
vae_tuple = vae_loader.load_vae(vae_name=vae_filename)
unet_object = get_value_at_index(unet_tuple, 0)
clip_object = get_value_at_index(clip_tuple, 0)
lora_loader = LoraLoader()
base_lora_name = components.get('lora')
if base_lora_name:
print(f"--- [ModelManager] Applying base model LoRA: {base_lora_name} ---")
_ensure_model_downloaded(base_lora_name, progress)
unet_object, clip_object = lora_loader.load_lora(
model=unet_object,
clip=clip_object,
lora_name=base_lora_name,
strength_model=1.0,
strength_clip=1.0
)
print(f"--- [ModelManager] ✅ Base LoRA merged into the model on CPU. ---")
if active_loras:
print(f"--- [ModelManager] Applying {len(active_loras)} custom LoRAs on CPU... ---")
patched_unet, patched_clip = unet_object, clip_object
for lora_info in active_loras:
patched_unet, patched_clip = lora_loader.load_lora(
model=patched_unet,
clip=patched_clip,
lora_name=lora_info["lora_name"],
strength_model=lora_info["strength_model"],
strength_clip=lora_info["strength_clip"]
)
unet_object = patched_unet
clip_object = patched_clip
print(f"--- [ModelManager] ✅ All custom LoRAs merged into the model on CPU. ---")
loaded_combo = {
"unet": (unet_object,),
"clip": (clip_object,),
"vae": vae_tuple,
}
print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---")
return loaded_combo
def move_models_to_gpu(self, required_models: List[str]):
print(f"--- [ModelManager] Moving models to GPU: {required_models} ---")
models_to_load_gpu = []
for name in required_models:
if name in self.loaded_models:
model_combo = self.loaded_models[name]
models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0))
if models_to_load_gpu:
with torch.no_grad():
model_management.load_models_gpu(models_to_load_gpu)
print("--- [ModelManager] ✅ Models successfully moved to GPU. ---")
else:
print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---")
def ensure_models_downloaded(self, required_models: List[str], progress):
print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---")
files_to_download = set()
for display_name in required_models:
if display_name in ALL_MODEL_MAP:
_, components, _, _ = ALL_MODEL_MAP[display_name]
for component_key, component_file in components.items():
if component_key in ['unet', 'clip', 'vae', 'lora']:
files_to_download.add(component_file)
files_to_download = list(files_to_download)
total_files = len(files_to_download)
for i, filename in enumerate(files_to_download):
if progress and hasattr(progress, '__call__'):
progress(i / total_files, desc=f"Checking file: {filename}")
try:
_ensure_model_downloaded(filename, progress)
except Exception as e:
raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}")
print(f"--- [ModelManager] ✅ All required models are present on disk. ---")
def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
required_set = set(required_models)
current_set = set(self.loaded_models.keys())
loras_changed = self.last_active_loras != active_loras
models_to_unload = current_set - required_set
if models_to_unload or loras_changed:
if models_to_unload:
print(f"--- [ModelManager] Models to unload: {models_to_unload} ---")
if loras_changed and not models_to_unload:
models_to_unload = current_set.intersection(required_set)
print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---")
model_management.unload_all_models()
self.loaded_models.clear()
gc.collect()
torch.cuda.empty_cache()
print("--- [ModelManager] All models unloaded to free RAM. ---")
models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set
if models_to_load:
print(f"--- [ModelManager] Models to load: {models_to_load} ---")
for i, display_name in enumerate(models_to_load):
progress(i / len(models_to_load), desc=f"Loading model: {display_name}")
try:
loaded_model_data = self._load_model_combo(display_name, active_loras, progress)
self.loaded_models[display_name] = loaded_model_data
except Exception as e:
self.last_active_loras = []
raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}")
self.last_active_loras = active_loras
else:
print(f"--- [ModelManager] All required models are already loaded. ---")
return {name: self.loaded_models[name] for name in required_models}
model_manager = ModelManager() |