File size: 9,024 Bytes
2dddf31
6d5987b
 
 
 
 
2dddf31
 
c19d329
2dddf31
cc0ae1f
61380fb
acd970e
 
 
a92aea4
2dddf31
c19d329
2dddf31
 
 
 
c19d329
4a72459
 
c19d329
6d5987b
c19d329
6d5987b
 
 
 
948869c
6d5987b
 
2dddf31
948869c
 
 
 
6d5987b
 
e01d167
c19d329
6d5987b
 
 
 
 
 
 
948869c
 
3ae9ca7
6d5987b
 
 
 
 
 
948869c
6d5987b
 
c19d329
 
 
 
 
 
 
 
 
6d5987b
 
2dddf31
c19d329
 
aecb45b
65a7aea
 
2dddf31
 
c19d329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8333ca9
61380fb
 
76e0564
c19d329
6d5987b
acd970e
4089031
 
 
 
 
 
 
94327de
4089031
 
 
acd970e
65a7aea
 
2dddf31
6d5987b
e01d167
4089031
e01d167
2dddf31
16dc50a
 
0117fa7
16dc50a
 
de66e8c
c19d329
de66e8c
9879887
c19d329
5bb4ff9
fa327ca
de66e8c
 
c19d329
54953c7
 
 
 
 
 
 
 
 
 
 
 
 
c19d329
5e35e8b
 
 
fa327ca
5e35e8b
 
fa327ca
 
 
5e35e8b
c19d329
fa327ca
c19d329
 
4a72459
c19d329
 
 
 
 
 
 
 
2dddf31
6d5987b
 
e4dd0ff
 
6d5987b
 
 
76e0564
6d5987b
 
 
c19d329
 
 
 
 
e4dd0ff
6d5987b
 
c19d329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import torch
import os
import cv2
import numpy as np
from config import Config

from diffusers import (
    ControlNetModel, 
    TCDScheduler, 
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel

# Import the custom pipeline from your local file
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline

from huggingface_hub import snapshot_download, hf_hub_download
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector, LineartAnimeDetector, CannyDetector

class ModelHandler:
    def __init__(self):
        self.pipeline = None
        self.app = None  # InsightFace
        self.leres_detector = None
        self.lineart_anime_detector = None
        self.canny_detector = None
        self.face_analysis_loaded = False
        self.edge_type = Config.DEFAULT_EDGE_TYPE

    def load_face_analysis(self):
        """
        Load face analysis model. 
        Downloads from HF Hub to the path insightface expects.
        """
        print("Loading face analysis model...")
        
        model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)

        if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
            print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
            try:
                snapshot_download(
                    repo_id=Config.ANTELOPEV2_REPO,
                    local_dir=model_path,
                )
            except Exception as e:
                print(f"  [ERROR] Failed to download AntelopeV2 models: {e}")
                return False

        try:
            self.app = FaceAnalysis(
                name=Config.ANTELOPEV2_NAME,
                root=Config.ANTELOPEV2_ROOT, 
                providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 
            )
            self.app.prepare(ctx_id=0, det_size=(640, 640))
            print(f"  [OK] Face analysis model loaded successfully.")
            return True
            
        except Exception as e:
            print(f"  [WARNING] Face detection system failed to initialize: {e}") 
            return False

    def load_models(self, edge_type="canny"):
        """
        Load all models with support for different edge detection types.
        
        Args:
            edge_type: "canny", "lineart", or "both"
        """
        self.edge_type = edge_type
        
        # 1. Load Face Analysis
        self.face_analysis_loaded = self.load_face_analysis()

        # 2. Load ControlNets based on edge_type
        print(f"Loading ControlNets (InstantID, Zoe, {edge_type.upper()})...")
        cn_instantid = ControlNetModel.from_pretrained(
            Config.INSTANTID_REPO,
            subfolder="ControlNetModel",
            torch_dtype=Config.DTYPE
        )
        cn_zoe = ControlNetModel.from_pretrained(
            Config.CN_ZOE_REPO, 
            torch_dtype=Config.DTYPE
        )
        
        # Load edge ControlNet(s)
        controlnet_list = [cn_instantid, cn_zoe]
        
        if edge_type == "canny":
            cn_canny = ControlNetModel.from_pretrained(
                Config.CN_CANNY_REPO, 
                torch_dtype=Config.DTYPE
            )
            controlnet_list.append(cn_canny)
            print("  [OK] Loaded Canny ControlNet")
            
        elif edge_type == "lineart":
            cn_lineart = ControlNetModel.from_pretrained(
                Config.CN_LINEART_REPO, 
                torch_dtype=Config.DTYPE
            )
            controlnet_list.append(cn_lineart)
            print("  [OK] Loaded LineArt ControlNet")
            
        elif edge_type == "both":
            cn_canny = ControlNetModel.from_pretrained(
                Config.CN_CANNY_REPO, 
                torch_dtype=Config.DTYPE
            )
            cn_lineart = ControlNetModel.from_pretrained(
                Config.CN_LINEART_REPO, 
                torch_dtype=Config.DTYPE
            )
            controlnet_list.extend([cn_canny, cn_lineart])
            print("  [OK] Loaded both Canny and LineArt ControlNets")

        print("Wrapping ControlNets in MultiControlNetModel...")
        controlnet = MultiControlNetModel(controlnet_list)
        
        # 3. Load SDXL Pipeline
        print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
        
        checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
        if not os.path.exists(checkpoint_local_path):
            print(f"Downloading checkpoint to {checkpoint_local_path}...")
            hf_hub_download(
                repo_id=Config.REPO_ID,
                filename=Config.CHECKPOINT_FILENAME,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        
        print(f"Loading pipeline from local file: {checkpoint_local_path}")
        self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
            checkpoint_local_path,
            controlnet=controlnet,
            torch_dtype=Config.DTYPE,
            use_safetensors=True
        )
        
        self.pipeline.to(Config.DEVICE)

        try:
            self.pipeline.enable_xformers_memory_efficient_attention()
            print("  [OK] xFormers memory efficient attention enabled.")
        except Exception as e:
            print(f"  [WARNING] Failed to enable xFormers: {e}")

        # 4. Set TCD Scheduler
        print("Configuring TCDScheduler...")
        self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
        print("  [OK] TCDScheduler loaded.")

        # 5. Load Adapters
        print("Loading Adapters...")

        # 5a. Load and Fuse Style LoRA
        print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
        style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
        if not os.path.exists(style_lora_path):
            hf_hub_download(
                repo_id=Config.REPO_ID,
                filename=Config.LORA_FILENAME,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        self.pipeline.load_lora_weights("./models", weight_name=Config.LORA_FILENAME)
        self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
        print("  [OK] Style LoRA fused.")

        # 5b. Load IP-Adapter for InstantID
        ip_adapter_filename = "ip-adapter.bin"
        ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
        if not os.path.exists(ip_adapter_local_path):
            hf_hub_download(
                repo_id=Config.INSTANTID_REPO,
                filename=ip_adapter_filename,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
        print("  [OK] InstantID IP-Adapter loaded.")

        # 6. Load Preprocessors
        print("Loading Preprocessors...")
        self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
        
        if edge_type in ["canny", "both"]:
            self.canny_detector = CannyDetector()
            print("  [OK] Canny detector loaded")
            
        if edge_type in ["lineart", "both"]:
            self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
            print("  [OK] LineArt detector loaded")
        
        print("--- All models loaded successfully ---")

    def get_face_info(self, image):
        """Extracts the largest face, returns insightface result object."""
        if not self.face_analysis_loaded:
            return None
        try:
            cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 
            faces = self.app.get(cv2_img)
            if len(faces) == 0:
                return None
            faces = sorted(
                faces, 
                key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), 
                reverse=True
            )
            return faces[0]
        except Exception as e:
            print(f"Face embedding extraction failed: {e}")
            return None
    
    def extract_depth(self, image):
        """Extract depth map using LeReS detector"""
        return self.leres_detector(image)
    
    def extract_canny(self, image, low_threshold=100, high_threshold=200):
        """Extract Canny edges"""
        if self.canny_detector is None:
            raise ValueError("Canny detector not loaded. Initialize with edge_type='canny' or 'both'")
        return self.canny_detector(image, low_threshold=low_threshold, high_threshold=high_threshold)
    
    def extract_lineart(self, image):
        """Extract LineArt edges"""
        if self.lineart_anime_detector is None:
            raise ValueError("LineArt detector not loaded. Initialize with edge_type='lineart' or 'both'")
        return self.lineart_anime_detector(image)