File size: 7,927 Bytes
2dddf31
6d5987b
 
 
 
 
2dddf31
 
f5bcb07
2dddf31
cc0ae1f
61380fb
acd970e
 
 
a92aea4
2dddf31
91e168e
2dddf31
 
 
 
 
4a72459
 
6d5987b
 
 
 
 
948869c
6d5987b
 
2dddf31
948869c
 
 
 
6d5987b
 
e01d167
4089031
6d5987b
 
 
 
 
 
 
948869c
 
3ae9ca7
6d5987b
 
 
 
 
 
948869c
6d5987b
 
2dddf31
6d5987b
 
2dddf31
6d5987b
f6d97e6
e01d167
064a79e
 
aecb45b
65a7aea
 
2dddf31
 
aecb45b
e01d167
4a72459
f6d97e6
4a72459
 
8333ca9
cc0ae1f
61380fb
f6d97e6
61380fb
cc0ae1f
76e0564
6d5987b
 
acd970e
4089031
 
 
 
 
 
 
94327de
4089031
 
 
acd970e
65a7aea
 
2dddf31
6d5987b
e01d167
4089031
e01d167
2dddf31
545e006
16dc50a
 
0117fa7
16dc50a
 
545e006
 
 
2d5b51b
b96a00f
 
2d5b51b
 
b96a00f
2d5b51b
545e006
 
 
94327de
545e006
94327de
 
 
 
 
 
 
 
 
 
 
 
 
65a7aea
94327de
2d5b51b
545e006
b96a00f
f5bcb07
 
545e006
 
 
f5bcb07
545e006
 
 
f5bcb07
 
545e006
 
b96a00f
545e006
2dddf31
ff641c2
545e006
d8780fa
545e006
d8780fa
6d5987b
91e168e
4a72459
 
2dddf31
6d5987b
 
e4dd0ff
 
6d5987b
 
2dddf31
6d5987b
76e0564
6d5987b
 
 
 
 
e4dd0ff
0117fa7
6d5987b
e4dd0ff
 
6d5987b
 
0a54687
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import torch
import os
import cv2
import numpy as np
from config import Config

from diffusers import (
    ControlNetModel, 
    TCDScheduler, 
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel

# Import the custom pipeline from your local file
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline

from huggingface_hub import snapshot_download, hf_hub_download
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector, LineartAnimeDetector

class ModelHandler:
    def __init__(self):
        self.pipeline = None
        self.app = None # InsightFace
        self.leres_detector = None
        self.lineart_anime_detector = None
        self.face_analysis_loaded = False

    def load_face_analysis(self):
        """
        Load face analysis model. 
        Downloads from HF Hub to the path insightface expects.
        """
        print("Loading face analysis model...")
        
        model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)

        if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
            print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
            try:
                snapshot_download(
                    repo_id=Config.ANTELOPEV2_REPO,
                    local_dir=model_path, # Download to the correct expected path
                )
            except Exception as e:
                print(f"  [ERROR] Failed to download AntelopeV2 models: {e}")
                return False

        try:
            self.app = FaceAnalysis(
                name=Config.ANTELOPEV2_NAME,
                root=Config.ANTELOPEV2_ROOT, 
                providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 
            )
            self.app.prepare(ctx_id=0, det_size=(640, 640))
            print(f"  [OK] Face analysis model loaded successfully.")
            return True
            
        except Exception as e:
            print(f"  [WARNING] Face detection system failed to initialize: {e}") 
            return False

    def load_models(self):
        # 1. Load Face Analysis
        self.face_analysis_loaded = self.load_face_analysis()

        # 2. Load ControlNets
        print("Loading ControlNets (InstantID, Zoe, LineArt)...")

        # Load the InstantID ControlNet from the correct subfolder
        print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
        cn_instantid = ControlNetModel.from_pretrained(
            Config.INSTANTID_REPO,
            subfolder="ControlNetModel",
            torch_dtype=Config.DTYPE
        )
        print("  [OK] Loaded InstantID ControlNet.")
        
        # Load other ControlNets normally
        print("Loading Zoe and LineArt ControlNets...")
        cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
        cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)

        # --- Manually wrap the list of models in a MultiControlNetModel ---
        print("Wrapping ControlNets in MultiControlNetModel...")
        controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
        controlnet = MultiControlNetModel(controlnet_list)
        # --- End wrapping ---
        
        # 3. Load SDXL Pipeline
        print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
        
        checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
        if not os.path.exists(checkpoint_local_path):
            print(f"Downloading checkpoint to {checkpoint_local_path}...")
            hf_hub_download(
                repo_id=Config.REPO_ID,
                filename=Config.CHECKPOINT_FILENAME,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        
        print(f"Loading pipeline from local file: {checkpoint_local_path}")
        self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
            checkpoint_local_path,
            controlnet=controlnet,
            torch_dtype=Config.DTYPE,
            use_safetensors=True
        )
        
        self.pipeline.to(Config.DEVICE)

        # Enable xFormers
        try:
            self.pipeline.enable_xformers_memory_efficient_attention()
            print("  [OK] xFormers memory efficient attention enabled.")
        except Exception as e:
            print(f"  [WARNING] Failed to enable xFormers: {e}")

        # 4. Set TCD Scheduler
        print("Configuring TCDScheduler...")
        # --- FIX: Set timestep_spacing="trailing" for proper distilled sampling ---
        self.pipeline.scheduler = TCDScheduler.from_config(
            self.pipeline.scheduler.config,
            use_karras_sigmas=True,
            timestep_spacing="trailing" 
        )
        print("  [OK] TCDScheduler loaded (Karras + Trailing Spacing).")

        # 5. Load Adapters (IP-Adapter, TCD-LoRA & Style LoRA)
        print("Loading Adapters...")
        
        # 5a. IP-Adapter
        ip_adapter_filename = "ip-adapter.bin"
        ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)

        if not os.path.exists(ip_adapter_local_path):
            print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
            hf_hub_download(
                repo_id=Config.INSTANTID_REPO,
                filename=ip_adapter_filename,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        
        print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
        self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
        
        # 5b. Load TCD LoRA (Correct filename)
        print("Loading TCD-SDXL-LoRA...")
        tcd_lora_filename = "pytorch_lora_weights.safetensors"
        tcd_lora_path = os.path.join("./models", tcd_lora_filename)
        
        if not os.path.exists(tcd_lora_path):
             hf_hub_download(
                repo_id="h1t/TCD-SDXL-LoRA",
                filename=tcd_lora_filename,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
        self.pipeline.fuse_lora(lora_scale=1.0) 
        print("  [OK] TCD LoRA fused.")

        # 5c. Load Style LoRA
        print("Loading Style LoRA weights...")
        self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
        
        print(f"Fusing Style LoRA with scale {Config.LORA_STRENGTH}...")
        self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
        print("  [OK] Style LoRA fused.")
        
        # 6. Load Preprocessors
        print("Loading Preprocessors (LeReS, LineArtAnime)...")
        self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
        self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
        
        print("--- All models loaded successfully ---")

    def get_face_info(self, image):
        """Extracts the largest face, returns insightface result object."""
        if not self.face_analysis_loaded:
            return None
            
        try:
            cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 
            faces = self.app.get(cv2_img)
            
            if len(faces) == 0:
                return None
                
            # Sort by size (width * height) to find the main character
            faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
            
            # Return the largest face info
            return faces[0]
        except Exception as e:
            print(f"Face embedding extraction failed: {e}")
            return None