File size: 7,538 Bytes
41978ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gc
from typing import Dict, List, Any, Set

import torch
import gradio as gr
from comfy import model_management

from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR
from core.lora_utils import apply_newbie_lora_to_model
from nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index, _ensure_model_downloaded


class ModelManager:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
        return cls._instance

    def __init__(self):
        if hasattr(self, 'initialized'):
            return
        self.loaded_models: Dict[str, Any] = {}
        self.last_lora_config: List[Dict[str, Any]] = []
        self.initialized = True
        print("✅ ModelManager initialized.")

    def get_loaded_model_names(self) -> Set[str]:
        return set(self.loaded_models.keys())

    def _load_model_combo(self, display_name: str, active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
        print(f"--- [ModelManager] Loading model combo: '{display_name}' ---")
        
        if display_name not in ALL_MODEL_MAP:
            raise ValueError(f"Model '{display_name}' not found in configuration.")
            
        _, components, _, _ = ALL_MODEL_MAP[display_name]
        
        unet_filename = components.get('unet')
        clip1_filename = components.get('clip1')
        clip2_filename = components.get('clip2')
        vae_filename = components.get('vae')
        
        if not all([unet_filename, clip1_filename, clip2_filename, vae_filename]):
             raise ValueError(f"Model '{display_name}' is missing required components (unet, clip1, clip2, or vae) in model_list.yaml.")

        unet_loader = NODE_CLASS_MAPPINGS["UNETLoader"]()
        clip_loader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
        vae_loader = NODE_CLASS_MAPPINGS["VAELoader"]()

        print("  - Loading UNET...")
        unet_tuple = unet_loader.load_unet(unet_name=unet_filename, weight_dtype="default")
        
        print("  - Loading CLIPs...")
        clip_tuple = clip_loader.load_clip(
            clip_name1=clip1_filename,
            clip_name2=clip2_filename,
            type="newbie",
            device="default"
        )
        
        print("  - Loading VAE...")
        vae_tuple = vae_loader.load_vae(vae_name=vae_filename)
        
        unet_object = get_value_at_index(unet_tuple, 0)
        clip_object = get_value_at_index(clip_tuple, 0)

        if active_loras:
            print(f"--- [ModelManager] Applying {len(active_loras)} NewBie LoRAs to model on CPU... ---")
            
            patched_unet = unet_object
            for lora_info in active_loras:
                print(f"  - Applying LoRA: {lora_info['lora_name']} with strength {lora_info['strength_model']}")
                patched_unet = apply_newbie_lora_to_model(
                    patched_unet,
                    lora_name=lora_info["lora_name"],
                    strength=lora_info["strength_model"]
                )
            
            unet_object = patched_unet
            print(f"--- [ModelManager] ✅ All NewBie LoRAs merged into the model on CPU. ---")
            
        loaded_combo = {
            "unet": (unet_object,),
            "clip": (clip_object,),
            "vae": vae_tuple,
        }
        
        print(f"--- [ModelManager] ✅ Successfully loaded combo '{display_name}' to CPU/RAM ---")
        return loaded_combo

    def move_models_to_gpu(self, required_models: List[str]):
        print(f"--- [ModelManager] Moving models to GPU: {required_models} ---")
        models_to_load_gpu = []
        for name in required_models:
            if name in self.loaded_models:
                model_combo = self.loaded_models[name]
                models_to_load_gpu.append(get_value_at_index(model_combo.get("unet"), 0))

        if models_to_load_gpu:
            model_management.load_models_gpu(models_to_load_gpu)
            print("--- [ModelManager] ✅ Models successfully moved to GPU. ---")
        else:
            print("--- [ModelManager] ⚠️ No component models found to move to GPU. ---")

    def ensure_models_downloaded(self, required_models: List[str], progress):
        print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---")
        
        files_to_download = set()
        for display_name in required_models:
            if display_name in ALL_MODEL_MAP:
                _, components, _, _ = ALL_MODEL_MAP[display_name]
                for component_file in components.values():
                    files_to_download.add(component_file)
        
        files_to_download = list(files_to_download)
        total_files = len(files_to_download)
        
        for i, filename in enumerate(files_to_download):
            if progress and hasattr(progress, '__call__'):
                progress(i / total_files, desc=f"Checking file: {filename}")
            try:
                _ensure_model_downloaded(filename, progress)
            except Exception as e:
                raise gr.Error(f"Failed to download model component '{filename}'. Reason: {e}")

        print(f"--- [ModelManager] ✅ All required models are present on disk. ---")
    
    def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
        required_set = set(required_models)
        current_set = set(self.loaded_models.keys())

        loras_changed = self.last_lora_config != active_loras

        models_to_unload = current_set - required_set
        if models_to_unload or loras_changed:
            if models_to_unload:
                print(f"--- [ModelManager] Models to unload: {models_to_unload} ---")
            if loras_changed and not models_to_unload:
                models_to_unload = current_set.intersection(required_set)
                print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---")
                
            model_management.unload_all_models()
            self.loaded_models.clear()
            gc.collect()
            torch.cuda.empty_cache()
            print("--- [ModelManager] All models unloaded to free RAM. ---")
        
        models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set
        
        if models_to_load:
            print(f"--- [ModelManager] Models to load: {models_to_load} ---")
            for i, display_name in enumerate(list(models_to_load)):
                progress(i / len(models_to_load), desc=f"Loading model: {display_name}")
                try:
                    loaded_model_data = self._load_model_combo(display_name, active_loras, progress)
                    self.loaded_models[display_name] = loaded_model_data
                except Exception as e:
                    raise gr.Error(f"Failed to load model combo or apply LoRA for '{display_name}'. Reason: {e}")
        else:
             print(f"--- [ModelManager] All required models are already loaded. ---")

        self.last_lora_config = active_loras

        return {name: self.loaded_models[name] for name in required_models}

model_manager = ModelManager()