primerz commited on
Commit
bbcd03c
·
verified ·
1 Parent(s): f587361

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +20 -112
models.py CHANGED
@@ -116,20 +116,15 @@ def load_sdxl_pipeline(controlnets):
116
  """
117
  print("Loading pipeline...")
118
 
119
- # Load VAE (line 128)
120
- vae = AutoencoderKL.from_pretrained(
121
- "madebyollin/sdxl-vae-fp16-fix",
122
- torch_dtype=dtype
123
- )
124
- print(" [OK] VAE loaded")
 
 
125
 
126
- # Create pipeline (line 134) - controlnets as LIST!
127
- pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
128
- "frankjoshua/albedobaseXL_v21",
129
- vae=vae,
130
- controlnet=controlnets, # ← LIST [identitynet, zoedepthnet] - NO WRAPPER!
131
- torch_dtype=dtype
132
- )
133
  print(" [OK] Pipeline created with direct controlnet list")
134
 
135
  # LCM scheduler
@@ -309,120 +304,33 @@ def load_image_encoder():
309
  print(f" [ERROR] Could not load image encoder: {e}")
310
  return None
311
 
312
- def setup_ip_adapter(pipe, image_encoder):
313
  """
314
- Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
315
- Based on the reference InstantID pipeline.
316
  """
317
- if image_encoder is None:
318
- return None, False
319
-
320
- print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
321
  try:
322
  # Download InstantID weights
323
- ip_adapter_path = download_model_with_retry(
324
  "InstantX/InstantID",
325
  "ip-adapter.bin"
326
  )
327
 
328
- # Load full state dict
329
- state_dict = torch.load(ip_adapter_path, map_location="cpu")
330
-
331
- # Extract image_proj and ip_adapter weights
332
- image_proj_state_dict = {}
333
- ip_adapter_state_dict = {}
334
-
335
- for key, value in state_dict.items():
336
- if key.startswith("image_proj."):
337
- image_proj_state_dict[key.replace("image_proj.", "")] = value
338
- elif key.startswith("ip_adapter."):
339
- ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
340
-
341
- # Create Resampler (image projection model) with CORRECT parameters from reference
342
- print("Creating Resampler (Perceiver architecture)...")
343
- image_proj_model = Resampler(
344
- dim=1280, # Hidden dimension
345
- depth=4, # IMPORTANT: 4 layers (not 8!)
346
- dim_head=64, # Dimension per head
347
- heads=20, # Number of heads
348
- num_queries=16, # Number of output tokens
349
- embedding_dim=512, # InsightFace embedding dim
350
- output_dim=pipe.unet.config.cross_attention_dim, # SDXL cross-attention dim (2048)
351
- ff_mult=4 # Feedforward multiplier
352
- )
353
-
354
- image_proj_model.eval()
355
- image_proj_model = image_proj_model.to(device, dtype=dtype)
356
-
357
- # Load image_proj weights
358
- if image_proj_state_dict:
359
- try:
360
- image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
361
- print(" [OK] Resampler loaded with pretrained weights")
362
- except Exception as e:
363
- print(f" [WARNING] Could not load Resampler weights: {e}")
364
- print(" Using randomly initialized Resampler")
365
- else:
366
- print(" [WARNING] No image_proj weights found, using random initialization")
367
-
368
- # Setup IP-Adapter attention processors
369
- print("Setting up IP-Adapter attention processors...")
370
- attn_procs = {}
371
- num_tokens = 16 # Match Resampler num_queries
372
-
373
- for name in pipe.unet.attn_processors.keys():
374
- cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
375
-
376
- if name.startswith("mid_block"):
377
- hidden_size = pipe.unet.config.block_out_channels[-1]
378
- elif name.startswith("up_blocks"):
379
- block_id = int(name[len("up_blocks.")])
380
- hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
381
- elif name.startswith("down_blocks"):
382
- block_id = int(name[len("down_blocks.")])
383
- hidden_size = pipe.unet.config.block_out_channels[block_id]
384
- else:
385
- hidden_size = pipe.unet.config.block_out_channels[-1]
386
-
387
- if cross_attention_dim is None:
388
- attn_procs[name] = AttnProcessor2_0()
389
- else:
390
- attn_procs[name] = IPAttnProcessor2_0(
391
- hidden_size=hidden_size,
392
- cross_attention_dim=cross_attention_dim,
393
- scale=1.0,
394
- num_tokens=num_tokens
395
- ).to(device, dtype=dtype)
396
-
397
- # Set attention processors
398
- pipe.unet.set_attn_processor(attn_procs)
399
 
400
- # Load IP-Adapter weights into attention processors
401
- if ip_adapter_state_dict:
402
- try:
403
- ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
404
- ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
405
- print(" [OK] IP-Adapter attention weights loaded")
406
- except Exception as e:
407
- print(f" [WARNING] Could not load IP-Adapter weights: {e}")
408
- else:
409
- print(" [WARNING] No ip_adapter weights found")
410
-
411
- # Store image encoder and projection model
412
- pipe.image_encoder = image_encoder
413
 
414
- print(" [OK] IP-Adapter fully loaded with InstantID architecture")
415
- print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
416
- print(f" - Face embeddings: 512D → 16x2048D")
417
-
418
- return image_proj_model, True
419
 
420
  except Exception as e:
421
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
422
  import traceback
423
  traceback.print_exc()
424
- return None, False
425
-
426
 
427
 
428
  __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter']
 
116
  """
117
  print("Loading pipeline...")
118
 
119
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
120
+
121
+ pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
122
+ model_path,
123
+ controlnet=controlnets,
124
+ torch_dtype=dtype,
125
+ use_safetensors=True
126
+ );
127
 
 
 
 
 
 
 
 
128
  print(" [OK] Pipeline created with direct controlnet list")
129
 
130
  # LCM scheduler
 
304
  print(f" [ERROR] Could not load image encoder: {e}")
305
  return None
306
 
307
+ def setup_ip_adapter(pipe):
308
  """
309
+ Setup IP-Adapter for InstantID - SIMPLIFIED VERSION.
310
+ Uses the pipeline's built-in method like exampleapp.py.
311
  """
312
+ print("Setting up IP-Adapter for InstantID face embeddings...")
 
 
 
313
  try:
314
  # Download InstantID weights
315
+ face_adapter_path = download_model_with_retry(
316
  "InstantX/InstantID",
317
  "ip-adapter.bin"
318
  )
319
 
320
+ # Use the pipeline's built-in method (like exampleapp.py line 139)
321
+ pipe.load_ip_adapter_instantid(face_adapter_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ # Set initial scale (like exampleapp.py line 140)
324
+ pipe.set_ip_adapter_scale(0.8)
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ print(" [OK] IP-Adapter loaded successfully with built-in method")
327
+ return True
 
 
 
328
 
329
  except Exception as e:
330
  print(f" [ERROR] Could not setup IP-Adapter: {e}")
331
  import traceback
332
  traceback.print_exc()
333
+ return False
 
334
 
335
 
336
  __all__ = ['draw_kps', 'fuse_lora_with_scale', 'load_image_encoder', 'setup_ip_adapter']