primerz commited on
Commit
6971978
·
verified ·
1 Parent(s): 570d128

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +116 -1
models.py CHANGED
@@ -260,6 +260,121 @@ def load_image_encoder():
260
  return None
261
 
262
 
263
- __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")
 
260
  return None
261
 
262
 
263
+ def setup_ip_adapter(pipe, image_encoder):
264
+ """
265
+ Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
266
+ Based on the reference InstantID pipeline.
267
+ """
268
+ if image_encoder is None:
269
+ return None, False
270
+
271
+ print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
272
+ try:
273
+ # Download InstantID weights
274
+ ip_adapter_path = download_model_with_retry(
275
+ "InstantX/InstantID",
276
+ "ip-adapter.bin"
277
+ )
278
+
279
+ # Load full state dict
280
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
281
+
282
+ # Extract image_proj and ip_adapter weights
283
+ image_proj_state_dict = {}
284
+ ip_adapter_state_dict = {}
285
+
286
+ for key, value in state_dict.items():
287
+ if key.startswith("image_proj."):
288
+ image_proj_state_dict[key.replace("image_proj.", "")] = value
289
+ elif key.startswith("ip_adapter."):
290
+ ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
291
+
292
+ # Create Resampler (image projection model) with CORRECT parameters from reference
293
+ print("Creating Resampler (Perceiver architecture)...")
294
+ image_proj_model = Resampler(
295
+ dim=1280, # Hidden dimension
296
+ depth=4, # IMPORTANT: 4 layers (not 8!)
297
+ dim_head=64, # Dimension per head
298
+ heads=20, # Number of heads
299
+ num_queries=16, # Number of output tokens
300
+ embedding_dim=512, # InsightFace embedding dim
301
+ output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
302
+ ff_mult=4 # Feedforward multiplier
303
+ )
304
+
305
+ image_proj_model.eval()
306
+ image_proj_model = image_proj_model.to(device, dtype=dtype)
307
+
308
+ # Load image_proj weights
309
+ if image_proj_state_dict:
310
+ try:
311
+ image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
312
+ print(" [OK] Resampler loaded with pretrained weights")
313
+ except Exception as e:
314
+ print(f" [WARNING] Could not load Resampler weights: {e}")
315
+ print(" Using randomly initialized Resampler")
316
+ else:
317
+ print(" [WARNING] No image_proj weights found, using random initialization")
318
+
319
+ # Setup IP-Adapter attention processors
320
+ print("Setting up IP-Adapter attention processors...")
321
+ attn_procs = {}
322
+ num_tokens = 16 # Match Resampler num_queries
323
+
324
+ for name in pipe.unet.attn_processors.keys():
325
+ cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
326
+
327
+ if name.startswith("mid_block"):
328
+ hidden_size = pipe.unet.config.block_out_channels[-1]
329
+ elif name.startswith("up_blocks"):
330
+ block_id = int(name[len("up_blocks.")])
331
+ hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
332
+ elif name.startswith("down_blocks"):
333
+ block_id = int(name[len("down_blocks.")])
334
+ hidden_size = pipe.unet.config.block_out_channels[block_id]
335
+ else:
336
+ hidden_size = pipe.unet.config.block_out_channels[-1]
337
+
338
+ if cross_attention_dim is None:
339
+ attn_procs[name] = AttnProcessor2_0()
340
+ else:
341
+ attn_procs[name] = IPAttnProcessor2_0(
342
+ hidden_size=hidden_size,
343
+ cross_attention_dim=cross_attention_dim,
344
+ scale=1.0,
345
+ num_tokens=num_tokens
346
+ ).to(device, dtype=dtype)
347
+
348
+ # Set attention processors
349
+ pipe.unet.set_attn_processor(attn_procs)
350
+
351
+ # Load IP-Adapter weights into attention processors
352
+ if ip_adapter_state_dict:
353
+ try:
354
+ ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
355
+ ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
356
+ print(" [OK] IP-Adapter attention weights loaded")
357
+ except Exception as e:
358
+ print(f" [WARNING] Could not load IP-Adapter weights: {e}")
359
+ else:
360
+ print(" [WARNING] No ip_adapter weights found")
361
+
362
+ # Store image encoder and projection model
363
+ pipe.image_encoder = image_encoder
364
+
365
+ print(" [OK] IP-Adapter fully loaded with InstantID architecture")
366
+ print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
367
+ print(f" - Face embeddings: 512D → 16x2048D")
368
+
369
+ return image_proj_model, True
370
+
371
+ except Exception as e:
372
+ print(f" [ERROR] Could not setup IP-Adapter: {e}")
373
+ import traceback
374
+ traceback.print_exc()
375
+ return None, False
376
+
377
+
378
+ __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter']
379
 
380
  print("[OK] models.py ready - NO MultiControlNetModel, following examplewithface.py")