ImageGen-Z-Image / core /model_manager.py
RioShiina's picture
Upload folder using huggingface_hub
a80b248 verified
import gc
from typing import Dict, List, Any, Set
import torch
import gradio as gr
from comfy import model_management
from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR
from comfy_integration.nodes import LoraLoader
from nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index, _ensure_model_downloaded
class ModelManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if hasattr(self, 'initialized'):
return
self.loaded_models: Dict[str, Any] = {}
self.last_active_loras: List[Dict[str, Any]] = []
self.initialized = True
print("✅ ModelManager initialized.")
def get_loaded_model_names(self) -> Set[str]:
return set(self.loaded_models.keys())
def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
print(f"--- [ModelManager] Loading model combo: '{display_name}' ---")
if display_name not in ALL_MODEL_MAP:
raise ValueError(f"Model '{display_name}' not found in configuration.")
_, components, _, _ = ALL_MODEL_MAP[display_name]
unet_filename = components.get('unet')
clip_filename = components.get('clip')
vae_filename = components.get('vae')
if not all([unet_filename, clip_filename, vae_filename]):
raise ValueError(f"Model '{display_name}' is missing required components (unet, clip, or vae) in model_list.yaml.")
unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]()
clip_loader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]()
print(" - Loading UNET...")
unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default")
print(" - Loading CLIP...")
clip_tuple = clip_loader.load_clip(clip_name=clip_filename, type="lumina2", device="default")
print(" - Loading VAE...")
vae_tuple = vae_loader.load_vae(vae_name=vae_filename)
unet_object = get_value_at_index(unet_tuple, 0)
clip_object = get_value_at_index(clip_tuple, 0)
if active_loras:
print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---")
lora_loader = LoraLoader()
patched_unet, patched_clip = unet_object, clip_object
for lora_info in active_loras:
patched_unet, patched_clip = lora_loader.load_lora(
model=patched_unet,
clip=patched_clip,
lora_name=lora_info["lora_name"],
strength_model=lora_info["strength_model"],
strength_clip=lora_info["strength_clip"]
)
unet_object = patched_unet
clip_object = patched_clip
print(f"--- [ModelManager] ✅ All LoRAs merged into the model on CPU. ---")
loaded_combo = {
"unet": (unet_object,),
"clip": (clip_object,),
"vae": vae_tuple,
}
print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---")
return loaded_combo
def move_models_to_gpu(self, required_models: List[str]):
print(f"--- [ModelManager] Moving models to GPU: {required_models} ---")
models_to_load_gpu = []
for name in required_models:
if name in self.loaded_models:
model_combo = self.loaded_models[name]
models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0))
if models_to_load_gpu:
model_management.load_models_gpu(models_to_load_gpu)
print("--- [ModelManager] ✅ Models successfully moved to GPU. ---")
else:
print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---")
def ensure_models_downloaded(self, required_models: List[str], progress):
print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---")
files_to_download = set()
for display_name in required_models:
if display_name in ALL_MODEL_MAP:
_, components, _, _ = ALL_MODEL_MAP[display_name]
for component_file in components.values():
files_to_download.add(component_file)
files_to_download = list(files_to_download)
total_files = len(files_to_download)
for i, filename in enumerate(files_to_download):
if progress and hasattr(progress, '__call__'):
progress(i / total_files, desc=f"Checking file: {filename}")
try:
_ensure_model_downloaded(filename, progress)
except Exception as e:
raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}")
print(f"--- [ModelManager] ✅ All required models are present on disk. ---")
def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
required_set = set(required_models)
current_set = set(self.loaded_models.keys())
loras_changed = active_loras != self.last_active_loras
models_to_unload = current_set - required_set
if models_to_unload or loras_changed:
if models_to_unload:
print(f"--- [ModelManager] Models to unload: {models_to_unload} ---")
if loras_changed and not models_to_unload:
models_to_unload = current_set.intersection(required_set)
if active_loras:
print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---")
else:
print(f"--- [ModelManager] LoRAs removed. Reloading base model(s) to clear patches: {models_to_unload} ---")
model_management.unload_all_models()
self.loaded_models.clear()
gc.collect()
torch.cuda.empty_cache()
print("--- [ModelManager] All models unloaded to free RAM. ---")
models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set
if models_to_load:
print(f"--- [ModelManager] Models to load: {models_to_load} ---")
for i, display_name in enumerate(models_to_load):
progress(i / len(models_to_load), desc=f"Loading model: {display_name}")
try:
loaded_model_data = self._load_model_combo(display_name, active_loras, progress)
self.loaded_models[display_name] = loaded_model_data
except Exception as e:
raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}")
else:
print(f"--- [ModelManager] All required models are already loaded. ---")
self.last_active_loras = active_loras
return {name: self.loaded_models[name] for name in required_models}
model_manager = ModelManager()