import gc from typing import Dict, List, Any, Set import torch import gradio as gr from comfy import model_management from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR from comfy_integration.nodes import LoraLoader from nodes import NODE_CLASS_MAPPINGS from utils.app_utils import get_value_at_index, _ensure_model_downloaded class ModelManager: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self): if hasattr(self, 'initialized'): return self.loaded_models: Dict[str, Any] = {} self.last_active_loras: List[Dict[str, Any]] = [] self.initialized = True print("✅ ModelManager initialized.") def get_loaded_model_names(self) -> Set[str]: return set(self.loaded_models.keys()) def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: print(f"--- [ModelManager] Loading model combo: '{display_name}' ---") if display_name not in ALL_MODEL_MAP: raise ValueError(f"Model '{display_name}' not found in configuration.") _, components, _, _ = ALL_MODEL_MAP[display_name] unet_filename = components.get('unet') clip_filename = components.get('clip') vae_filename = components.get('vae') if not all([unet_filename, clip_filename, vae_filename]): raise ValueError(f"Model '{display_name}' is missing required components (unet, clip, or vae) in model_list.yaml.") unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]() clip_loader = NODE_CLASS_MAPPINGS["CLIPLoader"]() vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]() print(" - Loading UNET...") unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default") print(" - Loading CLIP...") clip_tuple = clip_loader.load_clip(clip_name=clip_filename, type="lumina2", device="default") print(" - Loading VAE...") vae_tuple = vae_loader.load_vae(vae_name=vae_filename) unet_object = get_value_at_index(unet_tuple, 0) clip_object = get_value_at_index(clip_tuple, 0) if active_loras: print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---") lora_loader = LoraLoader() patched_unet, patched_clip = unet_object, clip_object for lora_info in active_loras: patched_unet, patched_clip = lora_loader.load_lora( model=patched_unet, clip=patched_clip, lora_name=lora_info["lora_name"], strength_model=lora_info["strength_model"], strength_clip=lora_info["strength_clip"] ) unet_object = patched_unet clip_object = patched_clip print(f"--- [ModelManager] ✅ All LoRAs merged into the model on CPU. ---") loaded_combo = { "unet": (unet_object,), "clip": (clip_object,), "vae": vae_tuple, } print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---") return loaded_combo def move_models_to_gpu(self, required_models: List[str]): print(f"--- [ModelManager] Moving models to GPU: {required_models} ---") models_to_load_gpu = [] for name in required_models: if name in self.loaded_models: model_combo = self.loaded_models[name] models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0)) if models_to_load_gpu: model_management.load_models_gpu(models_to_load_gpu) print("--- [ModelManager] ✅ Models successfully moved to GPU. ---") else: print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---") def ensure_models_downloaded(self, required_models: List[str], progress): print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---") files_to_download = set() for display_name in required_models: if display_name in ALL_MODEL_MAP: _, components, _, _ = ALL_MODEL_MAP[display_name] for component_file in components.values(): files_to_download.add(component_file) files_to_download = list(files_to_download) total_files = len(files_to_download) for i, filename in enumerate(files_to_download): if progress and hasattr(progress, '__call__'): progress(i / total_files, desc=f"Checking file: {filename}") try: _ensure_model_downloaded(filename, progress) except Exception as e: raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}") print(f"--- [ModelManager] ✅ All required models are present on disk. ---") def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: required_set = set(required_models) current_set = set(self.loaded_models.keys()) loras_changed = active_loras != self.last_active_loras models_to_unload = current_set - required_set if models_to_unload or loras_changed: if models_to_unload: print(f"--- [ModelManager] Models to unload: {models_to_unload} ---") if loras_changed and not models_to_unload: models_to_unload = current_set.intersection(required_set) if active_loras: print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---") else: print(f"--- [ModelManager] LoRAs removed. Reloading base model(s) to clear patches: {models_to_unload} ---") model_management.unload_all_models() self.loaded_models.clear() gc.collect() torch.cuda.empty_cache() print("--- [ModelManager] All models unloaded to free RAM. ---") models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set if models_to_load: print(f"--- [ModelManager] Models to load: {models_to_load} ---") for i, display_name in enumerate(models_to_load): progress(i / len(models_to_load), desc=f"Loading model: {display_name}") try: loaded_model_data = self._load_model_combo(display_name, active_loras, progress) self.loaded_models[display_name] = loaded_model_data except Exception as e: raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}") else: print(f"--- [ModelManager] All required models are already loaded. ---") self.last_active_loras = active_loras return {name: self.loaded_models[name] for name in required_models} model_manager = ModelManager()