primerz commited on
Commit
b9e8f75
·
verified ·
1 Parent(s): 050255c

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +364 -96
models.py CHANGED
@@ -1,23 +1,33 @@
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
- UPDATED VERSION with proper InstantID pipeline support
4
  """
5
  import torch
6
  import time
 
 
7
  from diffusers import (
 
8
  ControlNetModel,
9
  AutoencoderKL,
10
  LCMScheduler
11
  )
 
 
 
 
 
12
  from insightface.app import FaceAnalysis
13
- from controlnet_aux import ZoeDetector
14
- from huggingface_hub import hf_hub_download
 
 
15
  from compel import Compel, ReturnedEmbeddingsType
 
16
 
17
- # Use InstantID pipeline
18
- from pipeline_stable_diffusion_xl_instantid_img2img import (
19
- StableDiffusionXLInstantIDImg2ImgPipeline
20
- )
21
 
22
  from config import (
23
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
@@ -25,26 +35,24 @@ from config import (
25
  )
26
 
27
 
28
- def download_model_with_retry(repo_id, filename, max_retries=None):
29
  """Download model with retry logic and proper token handling."""
30
  if max_retries is None:
31
  max_retries = DOWNLOAD_CONFIG['max_retries']
32
 
 
 
 
 
33
  for attempt in range(max_retries):
34
  try:
35
  print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
36
 
37
- kwargs = {"repo_type": "model"}
38
- if HUGGINGFACE_TOKEN:
39
- kwargs["token"] = HUGGINGFACE_TOKEN
40
-
41
- path = hf_hub_download(
42
  repo_id=repo_id,
43
  filename=filename,
44
  **kwargs
45
  )
46
- print(f" [OK] Downloaded: {filename}")
47
- return path
48
 
49
  except Exception as e:
50
  print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
@@ -60,123 +68,373 @@ def download_model_with_retry(repo_id, filename, max_retries=None):
60
 
61
 
62
  def load_face_analysis():
63
- """Load face analysis model with proper error handling."""
 
 
 
64
  print("Loading face analysis model...")
 
65
  try:
66
- face_app = FaceAnalysis(
67
- name=FACE_DETECTION_CONFIG['model_name'],
68
- root='./models/insightface',
69
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
70
- )
71
- face_app.prepare(
72
- ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
73
- det_size=FACE_DETECTION_CONFIG['det_size']
74
- )
75
- print(" [OK] Face analysis model loaded successfully")
76
  return face_app, True
 
77
  except Exception as e:
78
- print(f" [WARNING] Face detection not available: {e}")
 
 
79
  return None, False
80
-
81
 
82
  def load_depth_detector():
83
- """Load Zoe Depth detector."""
84
- print("Loading Zoe Depth detector...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  try:
 
 
86
  zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
87
- zoe_depth.to(device)
88
- print(" [OK] Zoe Depth loaded successfully")
89
- return zoe_depth, True
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
- print(f" [WARNING] Zoe Depth not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return None, False
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  def load_controlnets():
96
- """
97
- Load ControlNets for InstantID pipeline.
98
- Returns both ControlNets (InstantID first, then Depth).
99
- """
100
- print("Loading InstantID ControlNet...")
101
- controlnet_instantid = ControlNetModel.from_pretrained(
102
- "InstantX/InstantID",
103
- subfolder="ControlNetModel",
104
- torch_dtype=dtype
105
- ).to(device)
106
- print(" [OK] InstantID ControlNet loaded")
107
-
108
- print("Loading Zoe Depth ControlNet...")
109
  controlnet_depth = ControlNetModel.from_pretrained(
110
- "diffusers/controlnet-zoe-depth-sdxl-1.0",
111
  torch_dtype=dtype
112
  ).to(device)
113
- print(" [OK] Zoe Depth ControlNet loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- return controlnet_instantid, controlnet_depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  def load_sdxl_pipeline(controlnets):
119
- """
120
- Load SDXL pipeline with InstantID support.
121
- controlnets MUST be a list: [identitynet, depthnet]
122
- """
123
- print("Loading SDXL checkpoint with InstantID pipeline...")
 
 
 
 
124
  try:
125
- model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
 
126
 
127
- # Use InstantID-enabled pipeline
128
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
129
- model_path,
130
- controlnet=controlnets,
131
- torch_dtype=dtype,
132
- use_safetensors=True
133
  ).to(device)
134
 
135
- # Load IP-Adapter weights for InstantID
136
- print("Loading IP-Adapter for InstantID...")
137
- ip_adapter_path = download_model_with_retry(
138
- "InstantX/InstantID",
139
- "ip-adapter.bin"
140
- )
141
- pipe.load_ip_adapter_instantid(ip_adapter_path)
142
- pipe.set_ip_adapter_scale(0.8) # Default scale
 
 
 
 
 
 
 
 
 
143
 
144
- print(" [OK] InstantID pipeline loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return pipe, True
146
 
147
  except Exception as e:
148
- print(f" [ERROR] Could not load InstantID pipeline: {e}")
149
- import traceback
150
- traceback.print_exc()
151
 
152
- # Fallback to standard pipeline
153
- print(" Falling back to standard SDXL pipeline (no InstantID)")
154
- from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
155
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
156
  "stabilityai/stable-diffusion-xl-base-1.0",
157
  controlnet=controlnets,
158
  torch_dtype=dtype,
159
  use_safetensors=True
160
- ).to(device)
161
  return pipe, False
162
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- def load_lora(pipe):
165
- """Load LORA from HuggingFace Hub."""
166
- print("Loading LORA (retroart) from HuggingFace Hub...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  try:
168
- lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
169
- pipe.load_lora_weights(lora_path, adapter_name="retroart")
170
- print(f" [OK] LORA loaded successfully")
171
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
- print(f" [WARNING] Could not load LORA: {e}")
174
- return False
 
 
175
 
176
 
 
177
  def setup_compel(pipe):
178
- """Setup Compel for better SDXL prompt handling."""
179
- print("Setting up Compel for enhanced prompt processing...")
180
  try:
181
  compel = Compel(
182
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
@@ -184,11 +442,14 @@ def setup_compel(pipe):
184
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
185
  requires_pooled=[False, True]
186
  )
187
- print(" [OK] Compel loaded successfully")
188
  return compel, True
189
  except Exception as e:
190
  print(f" [WARNING] Compel not available: {e}")
 
 
191
  return None, False
 
192
 
193
 
194
  def setup_scheduler(pipe):
@@ -200,6 +461,10 @@ def setup_scheduler(pipe):
200
 
201
  def optimize_pipeline(pipe):
202
  """Apply optimizations to pipeline."""
 
 
 
 
203
  if device == "cuda":
204
  try:
205
  pipe.enable_xformers_memory_efficient_attention()
@@ -215,17 +480,18 @@ def load_caption_model():
215
  """
216
  print("Loading caption model...")
217
 
218
- # Try GIT-Large first
219
  try:
220
  from transformers import AutoProcessor, AutoModelForCausalLM
221
 
222
  print(" Attempting GIT-Large (recommended)...")
223
  caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
 
224
  caption_model = AutoModelForCausalLM.from_pretrained(
225
  "microsoft/git-large-coco",
226
  torch_dtype=dtype
227
- ).to(device)
228
- print(" [OK] GIT-Large model loaded")
229
  return caption_processor, caption_model, True, 'git'
230
  except Exception as e1:
231
  print(f" [INFO] GIT-Large not available: {e1}")
@@ -236,14 +502,16 @@ def load_caption_model():
236
 
237
  print(" Attempting BLIP base (fallback)...")
238
  caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 
239
  caption_model = BlipForConditionalGeneration.from_pretrained(
240
  "Salesforce/blip-image-captioning-base",
241
  torch_dtype=dtype
242
- ).to(device)
243
- print(" [OK] BLIP base model loaded")
244
  return caption_processor, caption_model, True, 'blip'
245
  except Exception as e2:
246
  print(f" [WARNING] Caption models not available: {e2}")
 
247
  return None, None, False, 'none'
248
 
249
 
@@ -253,4 +521,4 @@ def set_clip_skip(pipe):
253
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
254
 
255
 
256
- print("[OK] Model loading functions ready")
 
1
  """
2
  Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ FIXED VERSION with proper IP-Adapter and BLIP-2 support
4
  """
5
  import torch
6
  import time
7
+ import os
8
+ import shutil
9
  from diffusers import (
10
+ StableDiffusionXLControlNetImg2ImgPipeline,
11
  ControlNetModel,
12
  AutoencoderKL,
13
  LCMScheduler
14
  )
15
+ from diffusers.models.attention_processor import AttnProcessor2_0
16
+ from transformers import (
17
+ CLIPVisionModelWithProjection, CLIPTokenizer,
18
+ CLIPTextModel, CLIPTextModelWithProjection
19
+ )
20
  from insightface.app import FaceAnalysis
21
+ from controlnet_aux import ZoeDetector, OpenposeDetector, LeresDetector, MidasDetector, MediapipeFaceDetector
22
+ from huggingface_hub import hf_hub_download, snapshot_download
23
+
24
+ # --- START FIX: Import Compel ---
25
  from compel import Compel, ReturnedEmbeddingsType
26
+ # --- END FIX ---
27
 
28
+ # Use reference implementation's attention processor
29
+ from attention_processor import IPAttnProcessor2_0, AttnProcessor
30
+ from resampler import Resampler
 
31
 
32
  from config import (
33
  device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
 
35
  )
36
 
37
 
38
+ def download_model_with_retry(repo_id, filename, max_retries=None, **kwargs):
39
  """Download model with retry logic and proper token handling."""
40
  if max_retries is None:
41
  max_retries = DOWNLOAD_CONFIG['max_retries']
42
 
43
+ # Ensure token is passed if available
44
+ if HUGGINGFACE_TOKEN and "token" not in kwargs:
45
+ kwargs["token"] = HUGGINGFACE_TOKEN
46
+
47
  for attempt in range(max_retries):
48
  try:
49
  print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
50
 
51
+ return hf_hub_download(
 
 
 
 
52
  repo_id=repo_id,
53
  filename=filename,
54
  **kwargs
55
  )
 
 
56
 
57
  except Exception as e:
58
  print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
 
68
 
69
 
70
  def load_face_analysis():
71
+ """
72
+ Load face analysis model with proper model downloading from HuggingFace.
73
+ Downloads from DIAMONIK7777/antelopev2 which has the correct model structure.
74
+ """
75
  print("Loading face analysis model...")
76
+
77
  try:
78
+ antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
79
+ # --- FIX: Load InsightFace on CPU to save VRAM ---
80
+ face_app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider'])
81
+ face_app.prepare(ctx_id=0, det_size=(640, 640))
82
+ print(" [OK] Face analysis loaded (on CPU)")
 
 
 
 
 
83
  return face_app, True
84
+
85
  except Exception as e:
86
+ print(f" [ERROR] Face detection not available: {e}")
87
+ import traceback
88
+ traceback.print_exc()
89
  return None, False
90
+
91
 
92
  def load_depth_detector():
93
+ """
94
+ Load depth detector with fallback hierarchy: Leres → Zoe Midas.
95
+ Returns (detector, detector_type, success).
96
+ """
97
+ print("Loading depth detector with fallback hierarchy...")
98
+
99
+ # Try LeresDetector first (best quality)
100
+ try:
101
+ print(" Attempting LeresDetector (highest quality)...")
102
+ # --- FIX: Load on CPU ---
103
+ leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators")
104
+ # leres_depth.to(device) # Removed
105
+ print(" [OK] LeresDetector loaded successfully (on CPU)")
106
+ return leres_depth, 'leres', True
107
+ except Exception as e:
108
+ print(f" [INFO] LeresDetector not available: {e}")
109
+
110
+ # Fallback to ZoeDetector
111
  try:
112
+ print(" Attempting ZoeDetector (fallback #1)...")
113
+ # --- FIX: Load on CPU ---
114
  zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
115
+ # zoe_depth.to(device) # Removed
116
+ print(" [OK] ZoeDetector loaded successfully (on CPU)")
117
+ return zoe_depth, 'zoe', True
118
+ except Exception as e:
119
+ print(f" [INFO] ZoeDetector not available: {e}")
120
+
121
+ # Final fallback to MidasDetector
122
+ try:
123
+ print(" Attempting MidasDetector (fallback #2)...")
124
+ # --- FIX: Load on CPU ---
125
+ midas_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
126
+ # midas_depth.to(device) # Removed
127
+ print(" [OK] MidasDetector loaded successfully (on CPU)")
128
+ return midas_depth, 'midas', True
129
  except Exception as e:
130
+ print(f" [WARNING] MidasDetector not available: {e}")
131
+
132
+ print(" [ERROR] No depth detector available")
133
+ return None, None, False
134
+
135
+ # --- NEW FUNCTION ---
136
+ def load_openpose_detector():
137
+ """Load OpenPose detector."""
138
+ print("Loading OpenPose detector...")
139
+ try:
140
+ # --- FIX: Load on CPU ---
141
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
142
+ # openpose.to(device) # Removed
143
+ print(" [OK] OpenPose loaded successfully (on CPU)")
144
+ return openpose, True
145
+ except Exception as e:
146
+ print(f" [WARNING] OpenPose not available: {e}")
147
  return None, False
148
+ # --- END NEW FUNCTION ---
149
 
150
+ # --- NEW FUNCTION ---
151
+ def load_mediapipe_face_detector():
152
+ """Load MediapipeFaceDetector for advanced face detection."""
153
+ print("Loading MediapipeFaceDetector...")
154
+ try:
155
+ face_detector = MediapipeFaceDetector()
156
+ print(" [OK] MediapipeFaceDetector loaded successfully")
157
+ return face_detector, True
158
+ except Exception as e:
159
+ print(f" [WARNING] MediapipeFaceDetector not available: {e}")
160
+ return None, False
161
+ # --- END NEW FUNCTION ---
162
 
163
  def load_controlnets():
164
+ """Load ControlNet models."""
165
+ print("Loading ControlNet Zoe Depth model...")
166
+ # --- FIX: Load core models on GPU ---
 
 
 
 
 
 
 
 
 
 
167
  controlnet_depth = ControlNetModel.from_pretrained(
168
+ "xinsir/controlnet-depth-sdxl-1.0",
169
  torch_dtype=dtype
170
  ).to(device)
171
+ print(" [OK] ControlNet Depth loaded (on GPU)")
172
+
173
+ # --- NEW: Load OpenPose ControlNet ---
174
+ print("Loading ControlNet OpenPose model...")
175
+ try:
176
+ # --- FIX: Load core models on GPU ---
177
+ controlnet_openpose = ControlNetModel.from_pretrained(
178
+ "xinsir/controlnet-openpose-sdxl-1.0",
179
+ torch_dtype=dtype
180
+ ).to(device)
181
+ print(" [OK] ControlNet OpenPose loaded (on GPU)")
182
+ except Exception as e:
183
+ print(f" [WARNING] ControlNet OpenPose not available: {e}")
184
+ controlnet_openpose = None
185
+ # --- END NEW ---
186
 
187
+ print("Loading InstantID ControlNet...")
188
+ try:
189
+ # --- FIX: Load core models on GPU ---
190
+ controlnet_instantid = ControlNetModel.from_pretrained(
191
+ "InstantX/InstantID",
192
+ subfolder="ControlNetModel",
193
+ torch_dtype=dtype
194
+ ).to(device)
195
+ print(" [OK] InstantID ControlNet loaded successfully (on GPU)")
196
+ # Return all three models
197
+ return controlnet_depth, controlnet_instantid, controlnet_openpose, True
198
+ except Exception as e:
199
+ print(f" [WARNING] InstantID ControlNet not available: {e}")
200
+ # Return models, indicating InstantID failure
201
+ return controlnet_depth, None, controlnet_openpose, False
202
+
203
+
204
+ def load_image_encoder():
205
+ """Load CLIP Image Encoder for IP-Adapter."""
206
+ print("Loading CLIP Image Encoder for IP-Adapter...")
207
+ try:
208
+ # --- FIX: Load core models on GPU ---
209
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
210
+ "h94/IP-Adapter",
211
+ subfolder="models/image_encoder",
212
+ torch_dtype=dtype
213
+ ).to(device)
214
+ print(" [OK] CLIP Image Encoder loaded successfully (on GPU)")
215
+ return image_encoder
216
+ except Exception as e:
217
+ print(f" [ERROR] Could not load image encoder: {e}")
218
+ return None
219
 
220
 
221
  def load_sdxl_pipeline(controlnets):
222
+ """Load SDXL checkpoint from HuggingFace Hub."""
223
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
224
+
225
+ # --- START FIX ---
226
+ # Load tokenizers and text encoders from the base model first
227
+ # This guarantees they exist, even if the single file doesn't have them
228
+ print(" Loading base tokenizers and text encoders...")
229
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
230
+
231
  try:
232
+ tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
233
+ tokenizer_2 = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer_2")
234
 
235
+ text_encoder = CLIPTextModel.from_pretrained(
236
+ BASE_MODEL, subfolder="text_encoder", torch_dtype=dtype
 
 
 
 
237
  ).to(device)
238
 
239
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
240
+ BASE_MODEL, subfolder="text_encoder_2", torch_dtype=dtype
241
+ ).to(device)
242
+ print(" [OK] Base text/token models loaded")
243
+
244
+ except Exception as e:
245
+ print(f" [ERROR] Could not load base text models: {e}")
246
+ print(" Pipeline will likely fail. Check HF connection/model access.")
247
+ # Allow it to continue, but it will likely fail below
248
+ tokenizer = None
249
+ tokenizer_2 = None
250
+ text_encoder = None
251
+ text_encoder_2 = None
252
+ # --- END FIX ---
253
+
254
+ try:
255
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'], repo_type="model")
256
 
257
+ # --- START FIX ---
258
+ # Pass the pre-loaded models to from_single_file
259
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
260
+ model_path,
261
+ controlnet=controlnets,
262
+ torch_dtype=dtype,
263
+ use_safetensors=True,
264
+
265
+ # Explicitly provide the models
266
+ tokenizer=tokenizer,
267
+ tokenizer_2=tokenizer_2,
268
+ text_encoder=text_encoder,
269
+ text_encoder_2=text_encoder_2,
270
+
271
+ ).to(device) # This main pipe MUST be on device
272
+ # --- END FIX ---
273
+
274
+ print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
275
  return pipe, True
276
 
277
  except Exception as e:
278
+ print(f" [WARNING] Could not load custom checkpoint: {e}")
279
+ print(" Using default SDXL base model")
 
280
 
281
+ # The fallback logic is already correct
 
 
282
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
283
  "stabilityai/stable-diffusion-xl-base-1.0",
284
  controlnet=controlnets,
285
  torch_dtype=dtype,
286
  use_safetensors=True
287
+ ).to(device) # This main pipe MUST be on device
288
  return pipe, False
289
 
290
+ def load_loras(pipe):
291
+ """Load all LORAs from HuggingFace Hub."""
292
+ print("Loading all LORAs from HuggingFace Hub...")
293
+ loaded_loras = {}
294
+
295
+ lora_files = {
296
+ "retroart": MODEL_FILES.get("lora_retroart"),
297
+ "vga": MODEL_FILES.get("lora_vga"),
298
+ "lucasart": MODEL_FILES.get("lora_lucasart")
299
+ }
300
 
301
+ for adapter_name, filename in lora_files.items():
302
+ if not filename:
303
+ print(f" [INFO] No file specified for LORA '{adapter_name}', skipping.")
304
+ loaded_loras[adapter_name] = False
305
+ continue
306
+
307
+ try:
308
+ lora_path = download_model_with_retry(MODEL_REPO, filename, repo_type="model")
309
+ pipe.load_lora_weights(lora_path, adapter_name=adapter_name)
310
+ print(f" [OK] LORA loaded successfully: {filename} as '{adapter_name}'")
311
+ loaded_loras[adapter_name] = True
312
+ except Exception as e:
313
+ print(f" [WARNING] Could not load LORA {filename}: {e}")
314
+ loaded_loras[adapter_name] = False
315
+
316
+ success = any(loaded_loras.values())
317
+ if not success:
318
+ print(" [WARNING] No LORAs were loaded successfully.")
319
+
320
+ return loaded_loras, success
321
+
322
+
323
+ def setup_ip_adapter(pipe, image_encoder):
324
+ """
325
+ Setup IP-Adapter for InstantID face embeddings.
326
+ This is CRITICAL for face preservation.
327
+ """
328
+ if image_encoder is None:
329
+ return None, False
330
+
331
+ print("Setting up IP-Adapter for InstantID face embeddings...")
332
  try:
333
+ # Download InstantID weights
334
+ ip_adapter_path = download_model_with_retry(
335
+ "InstantX/InstantID",
336
+ "ip-adapter.bin",
337
+ repo_type="model"
338
+ )
339
+
340
+ # Load full state dict
341
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
342
+
343
+ # Extract image_proj and ip_adapter weights
344
+ image_proj_state_dict = {}
345
+ ip_adapter_state_dict = {}
346
+
347
+ for key, value in state_dict.items():
348
+ if key.startswith("image_proj."):
349
+ image_proj_state_dict[key.replace("image_proj.", "")] = value
350
+ elif key.startswith("ip_adapter."):
351
+ ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
352
+
353
+ # Create Resampler with CORRECT parameters
354
+ print("Creating Resampler (Perceiver architecture)...")
355
+ image_proj_model = Resampler(
356
+ dim=1280,
357
+ depth=4,
358
+ dim_head=64,
359
+ heads=20,
360
+ num_queries=16,
361
+ embedding_dim=512, # CRITICAL: Must match InsightFace embedding size
362
+ output_dim=pipe.unet.config.cross_attention_dim,
363
+ ff_mult=4
364
+ )
365
+
366
+ image_proj_model.eval()
367
+ image_proj_model = image_proj_model.to(device, dtype=dtype)
368
+
369
+ # Load image_proj weights
370
+ if image_proj_state_dict:
371
+ try:
372
+ image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
373
+ print(" [OK] Resampler loaded with pretrained weights")
374
+ except Exception as e:
375
+ print(f" [WARNING] Could not load Resampler weights: {e}")
376
+
377
+ # Setup IP-Adapter attention processors
378
+ print("Setting up IP-Adapter attention processors...")
379
+ attn_procs = {}
380
+ num_tokens = 16
381
+
382
+ for name in pipe.unet.attn_processors.keys():
383
+ cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
384
+
385
+ if name.startswith("mid_block"):
386
+ hidden_size = pipe.unet.config.block_out_channels[-1]
387
+ elif name.startswith("up_blocks"):
388
+ block_id = int(name[len("up_blocks.")])
389
+ hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
390
+ elif name.startswith("down_blocks"):
391
+ block_id = int(name[len("down_blocks.")])
392
+ hidden_size = pipe.unet.config.block_out_channels[block_id]
393
+ else:
394
+ hidden_size = pipe.unet.config.block_out_channels[-1]
395
+
396
+ if cross_attention_dim is None:
397
+ attn_procs[name] = AttnProcessor2_0()
398
+ else:
399
+ attn_procs[name] = IPAttnProcessor2_0(
400
+ hidden_size=hidden_size,
401
+ cross_attention_dim=cross_attention_dim,
402
+ scale=1.0,
403
+ num_tokens=num_tokens
404
+ ).to(device, dtype=dtype)
405
+
406
+ # Set attention processors
407
+ pipe.unet.set_attn_processor(attn_procs)
408
+
409
+ # Load IP-Adapter weights
410
+ if ip_adapter_state_dict:
411
+ try:
412
+ ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
413
+ ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
414
+ print(" [OK] IP-Adapter attention weights loaded")
415
+ except Exception as e:
416
+ print(f" [WARNING] Could not load IP-Adapter weights: {e}")
417
+
418
+ # Store image encoder
419
+ pipe.image_encoder = image_encoder
420
+
421
+ print(" [OK] IP-Adapter fully loaded with InstantID architecture")
422
+ print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
423
+ print(f" - Face embeddings: 512D -> 16x{pipe.unet.config.cross_attention_dim}D")
424
+
425
+ return image_proj_model, True
426
+
427
  except Exception as e:
428
+ print(f" [ERROR] Could not setup IP-Adapter: {e}")
429
+ import traceback
430
+ traceback.print_exc()
431
+ return None, False
432
 
433
 
434
+ # --- START FIX: Replace setup_cappella with setup_compel ---
435
  def setup_compel(pipe):
436
+ """Setup Compel for prompt encoding."""
437
+ print("Setting up Compel (prompt encoder)...")
438
  try:
439
  compel = Compel(
440
  tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
 
442
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
443
  requires_pooled=[False, True]
444
  )
445
+ print(" [OK] Compel loaded successfully.")
446
  return compel, True
447
  except Exception as e:
448
  print(f" [WARNING] Compel not available: {e}")
449
+ import traceback
450
+ traceback.print_exc()
451
  return None, False
452
+ # --- END FIX ---
453
 
454
 
455
  def setup_scheduler(pipe):
 
461
 
462
  def optimize_pipeline(pipe):
463
  """Apply optimizations to pipeline."""
464
+
465
+ # --- FIX: Removed enable_model_cpu_offload() ---
466
+
467
+ # Try to enable xformers
468
  if device == "cuda":
469
  try:
470
  pipe.enable_xformers_memory_efficient_attention()
 
480
  """
481
  print("Loading caption model...")
482
 
483
+ # Try GIT-Large first (good balance of quality and compatibility)
484
  try:
485
  from transformers import AutoProcessor, AutoModelForCausalLM
486
 
487
  print(" Attempting GIT-Large (recommended)...")
488
  caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
489
+ # --- FIX: Load on CPU ---
490
  caption_model = AutoModelForCausalLM.from_pretrained(
491
  "microsoft/git-large-coco",
492
  torch_dtype=dtype
493
+ ) # .to(device) removed
494
+ print(" [OK] GIT-Large model loaded (produces detailed captions, on CPU)")
495
  return caption_processor, caption_model, True, 'git'
496
  except Exception as e1:
497
  print(f" [INFO] GIT-Large not available: {e1}")
 
502
 
503
  print(" Attempting BLIP base (fallback)...")
504
  caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
505
+ # --- FIX: Load on CPU ---
506
  caption_model = BlipForConditionalGeneration.from_pretrained(
507
  "Salesforce/blip-image-captioning-base",
508
  torch_dtype=dtype
509
+ ) # .to(device) removed
510
+ print(" [OK] BLIP base model loaded (standard captions, on CPU)")
511
  return caption_processor, caption_model, True, 'blip'
512
  except Exception as e2:
513
  print(f" [WARNING] Caption models not available: {e2}")
514
+ print(" Caption generation will be disabled")
515
  return None, None, False, 'none'
516
 
517
 
 
521
  print(f" [OK] CLIP skip set to {CLIP_SKIP}")
522
 
523
 
524
+ print("[OK] Model loading functions ready")