primerz commited on
Commit
25765c5
·
verified ·
1 Parent(s): e09a80f

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +8 -13
models.py CHANGED
@@ -300,13 +300,13 @@ def load_loras(pipe):
300
 
301
  def setup_ip_adapter(pipe, image_encoder):
302
  """
303
- Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
304
- Based on the reference InstantID pipeline.
305
  """
306
  if image_encoder is None:
307
  return None, False
308
 
309
- print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
310
  try:
311
  # Download InstantID weights
312
  ip_adapter_path = download_model_with_retry(
@@ -328,7 +328,7 @@ def setup_ip_adapter(pipe, image_encoder):
328
  elif key.startswith("ip_adapter."):
329
  ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
330
 
331
- # Create Resampler (image projection model) with CORRECT parameters from reference
332
  print("Creating Resampler (Perceiver architecture)...")
333
  image_proj_model = Resampler(
334
  dim=1280,
@@ -336,7 +336,7 @@ def setup_ip_adapter(pipe, image_encoder):
336
  dim_head=64,
337
  heads=20,
338
  num_queries=16,
339
- embedding_dim=512,
340
  output_dim=pipe.unet.config.cross_attention_dim,
341
  ff_mult=4
342
  )
@@ -351,9 +351,6 @@ def setup_ip_adapter(pipe, image_encoder):
351
  print(" [OK] Resampler loaded with pretrained weights")
352
  except Exception as e:
353
  print(f" [WARNING] Could not load Resampler weights: {e}")
354
- print(" Using randomly initialized Resampler")
355
- else:
356
- print(" [WARNING] No image_proj weights found, using random initialization")
357
 
358
  # Setup IP-Adapter attention processors
359
  print("Setting up IP-Adapter attention processors...")
@@ -387,7 +384,7 @@ def setup_ip_adapter(pipe, image_encoder):
387
  # Set attention processors
388
  pipe.unet.set_attn_processor(attn_procs)
389
 
390
- # Load IP-Adapter weights into attention processors
391
  if ip_adapter_state_dict:
392
  try:
393
  ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
@@ -395,15 +392,13 @@ def setup_ip_adapter(pipe, image_encoder):
395
  print(" [OK] IP-Adapter attention weights loaded")
396
  except Exception as e:
397
  print(f" [WARNING] Could not load IP-Adapter weights: {e}")
398
- else:
399
- print(" [WARNING] No ip_adapter weights found")
400
 
401
- # Store image encoder and projection model
402
  pipe.image_encoder = image_encoder
403
 
404
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
405
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
406
- print(f" - Face embeddings: 512D -> 16x2048D")
407
 
408
  return image_proj_model, True
409
 
 
300
 
301
  def setup_ip_adapter(pipe, image_encoder):
302
  """
303
+ Setup IP-Adapter for InstantID face embeddings.
304
+ This is CRITICAL for face preservation.
305
  """
306
  if image_encoder is None:
307
  return None, False
308
 
309
+ print("Setting up IP-Adapter for InstantID face embeddings...")
310
  try:
311
  # Download InstantID weights
312
  ip_adapter_path = download_model_with_retry(
 
328
  elif key.startswith("ip_adapter."):
329
  ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
330
 
331
+ # Create Resampler with CORRECT parameters
332
  print("Creating Resampler (Perceiver architecture)...")
333
  image_proj_model = Resampler(
334
  dim=1280,
 
336
  dim_head=64,
337
  heads=20,
338
  num_queries=16,
339
+ embedding_dim=512, # CRITICAL: Must match InsightFace embedding size
340
  output_dim=pipe.unet.config.cross_attention_dim,
341
  ff_mult=4
342
  )
 
351
  print(" [OK] Resampler loaded with pretrained weights")
352
  except Exception as e:
353
  print(f" [WARNING] Could not load Resampler weights: {e}")
 
 
 
354
 
355
  # Setup IP-Adapter attention processors
356
  print("Setting up IP-Adapter attention processors...")
 
384
  # Set attention processors
385
  pipe.unet.set_attn_processor(attn_procs)
386
 
387
+ # Load IP-Adapter weights
388
  if ip_adapter_state_dict:
389
  try:
390
  ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
 
392
  print(" [OK] IP-Adapter attention weights loaded")
393
  except Exception as e:
394
  print(f" [WARNING] Could not load IP-Adapter weights: {e}")
 
 
395
 
396
+ # Store image encoder
397
  pipe.image_encoder = image_encoder
398
 
399
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
400
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
401
+ print(f" - Face embeddings: 512D -> 16x{pipe.unet.config.cross_attention_dim}D")
402
 
403
  return image_proj_model, True
404