Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| from typing import Dict, List, Any, Set | |
| import torch | |
| import gradio as gr | |
| from comfy import model_management | |
| from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR | |
| from comfy_integration.nodes import LoraLoader | |
| from nodes import NODE_CLASS_MAPPINGS | |
| from utils.app_utils import get_value_at_index, _ensure_model_downloaded | |
| class ModelManager: | |
| _instance = None | |
| def __new__(cls, *args, **kwargs): | |
| if not cls._instance: | |
| cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) | |
| return cls._instance | |
| def __init__(self): | |
| if hasattr(self, 'initialized'): | |
| return | |
| self.loaded_models: Dict[str, Any] = {} | |
| self.last_active_loras: List[Dict[str, Any]] = [] | |
| self.initialized = True | |
| print("✅ ModelManager initialized.") | |
| def get_loaded_model_names(self) -> Set[str]: | |
| return set(self.loaded_models.keys()) | |
| def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: | |
| print(f"--- [ModelManager] Loading model combo: '{display_name}' ---") | |
| if display_name not in ALL_MODEL_MAP: | |
| raise ValueError(f"Model '{display_name}' not found in configuration.") | |
| _, components, _, _ = ALL_MODEL_MAP[display_name] | |
| unet_filename = components.get('unet') | |
| clip_filename = components.get('clip') | |
| vae_filename = components.get('vae') | |
| if not all([unet_filename, clip_filename, vae_filename]): | |
| raise ValueError(f"Model '{display_name}' is missing required components (unet, clip, or vae) in model_list.yaml.") | |
| unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
| clip_loader = NODE_CLASS_MAPPINGS["CLIPLoader"]() | |
| vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
| print(" - Loading UNET...") | |
| unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default") | |
| print(" - Loading CLIP...") | |
| clip_tuple = clip_loader.load_clip(clip_name=clip_filename, type="lumina2", device="default") | |
| print(" - Loading VAE...") | |
| vae_tuple = vae_loader.load_vae(vae_name=vae_filename) | |
| unet_object = get_value_at_index(unet_tuple, 0) | |
| clip_object = get_value_at_index(clip_tuple, 0) | |
| if active_loras: | |
| print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---") | |
| lora_loader = LoraLoader() | |
| patched_unet, patched_clip = unet_object, clip_object | |
| for lora_info in active_loras: | |
| patched_unet, patched_clip = lora_loader.load_lora( | |
| model=patched_unet, | |
| clip=patched_clip, | |
| lora_name=lora_info["lora_name"], | |
| strength_model=lora_info["strength_model"], | |
| strength_clip=lora_info["strength_clip"] | |
| ) | |
| unet_object = patched_unet | |
| clip_object = patched_clip | |
| print(f"--- [ModelManager] ✅ All LoRAs merged into the model on CPU. ---") | |
| loaded_combo = { | |
| "unet": (unet_object,), | |
| "clip": (clip_object,), | |
| "vae": vae_tuple, | |
| } | |
| print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---") | |
| return loaded_combo | |
| def move_models_to_gpu(self, required_models: List[str]): | |
| print(f"--- [ModelManager] Moving models to GPU: {required_models} ---") | |
| models_to_load_gpu = [] | |
| for name in required_models: | |
| if name in self.loaded_models: | |
| model_combo = self.loaded_models[name] | |
| models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0)) | |
| if models_to_load_gpu: | |
| model_management.load_models_gpu(models_to_load_gpu) | |
| print("--- [ModelManager] ✅ Models successfully moved to GPU. ---") | |
| else: | |
| print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---") | |
| def ensure_models_downloaded(self, required_models: List[str], progress): | |
| print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---") | |
| files_to_download = set() | |
| for display_name in required_models: | |
| if display_name in ALL_MODEL_MAP: | |
| _, components, _, _ = ALL_MODEL_MAP[display_name] | |
| for component_file in components.values(): | |
| files_to_download.add(component_file) | |
| files_to_download = list(files_to_download) | |
| total_files = len(files_to_download) | |
| for i, filename in enumerate(files_to_download): | |
| if progress and hasattr(progress, '__call__'): | |
| progress(i / total_files, desc=f"Checking file: {filename}") | |
| try: | |
| _ensure_model_downloaded(filename, progress) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}") | |
| print(f"--- [ModelManager] ✅ All required models are present on disk. ---") | |
| def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: | |
| required_set = set(required_models) | |
| current_set = set(self.loaded_models.keys()) | |
| loras_changed = active_loras != self.last_active_loras | |
| models_to_unload = current_set - required_set | |
| if models_to_unload or loras_changed: | |
| if models_to_unload: | |
| print(f"--- [ModelManager] Models to unload: {models_to_unload} ---") | |
| if loras_changed and not models_to_unload: | |
| models_to_unload = current_set.intersection(required_set) | |
| if active_loras: | |
| print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---") | |
| else: | |
| print(f"--- [ModelManager] LoRAs removed. Reloading base model(s) to clear patches: {models_to_unload} ---") | |
| model_management.unload_all_models() | |
| self.loaded_models.clear() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print("--- [ModelManager] All models unloaded to free RAM. ---") | |
| models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set | |
| if models_to_load: | |
| print(f"--- [ModelManager] Models to load: {models_to_load} ---") | |
| for i, display_name in enumerate(models_to_load): | |
| progress(i / len(models_to_load), desc=f"Loading model: {display_name}") | |
| try: | |
| loaded_model_data = self._load_model_combo(display_name, active_loras, progress) | |
| self.loaded_models[display_name] = loaded_model_data | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}") | |
| else: | |
| print(f"--- [ModelManager] All required models are already loaded. ---") | |
| self.last_active_loras = active_loras | |
| return {name: self.loaded_models[name] for name in required_models} | |
| model_manager = ModelManager() |