primerz commited on
Commit
089fd21
·
verified ·
1 Parent(s): 44fc9c3

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +135 -105
generator.py CHANGED
@@ -20,7 +20,7 @@ from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
  load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
- load_openpose_detector # <-- NEW
24
  )
25
 
26
 
@@ -36,7 +36,7 @@ class RetroArtConverter:
36
  'instantid': False,
37
  'zoe_depth': False,
38
  'ip_adapter': False,
39
- 'openpose': False # <-- NEW
40
  }
41
 
42
  # Initialize face analysis
@@ -64,17 +64,44 @@ class RetroArtConverter:
64
  else:
65
  self.image_encoder = None
66
 
 
67
  # Determine which controlnets to use
68
- controlnets = [controlnet_depth, self.controlnet_openpose] # Start with depth and openpose
69
- if self.instantid_enabled and self.controlnet_instantid is not None:
70
- controlnets.insert(0, self.controlnet_instantid) # Add InstantID at the start if available
71
- print(f"Initializing with multiple ControlNets: InstantID + Depth + OpenPose")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
- print(f"Initializing with ControlNets: Depth + OpenPose (InstantID disabled)")
 
 
 
74
 
 
75
 
76
  # Load SDXL pipeline
77
- self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
 
 
 
78
  self.models_loaded['custom_checkpoint'] = checkpoint_success
79
 
80
  # Load LORA
@@ -82,11 +109,11 @@ class RetroArtConverter:
82
  self.models_loaded['lora'] = lora_success
83
 
84
  # Setup IP-Adapter
85
- if self.instantid_enabled and self.image_encoder is not None:
86
  self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
87
  self.models_loaded['ip_adapter'] = ip_adapter_success
88
  else:
89
- print("[INFO] Face preservation: InstantID ControlNet keypoints only")
90
  self.models_loaded['ip_adapter'] = False
91
  self.image_proj_model = None
92
 
@@ -283,7 +310,7 @@ class RetroArtConverter:
283
  def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
284
  identity_preservation, identity_control_scale,
285
  depth_control_scale, consistency_mode=True,
286
- expression_control_scale=0.6): # <-- NEW
287
  """
288
  Enhanced parameter validation with stricter rules for consistency.
289
  """
@@ -338,16 +365,29 @@ class RetroArtConverter:
338
  adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
339
 
340
  # Rule 5: ControlNet balance
341
- # <-- MODIFIED: Now balances 3 controlnets -->
342
- total_control = identity_control_scale + depth_control_scale + expression_control_scale
 
 
 
 
 
 
 
343
  if total_control > 2.0: # Increased max total from 1.7 to 2.0
344
  scale_factor = 2.0 / total_control
345
  original_id_ctrl = identity_control_scale
346
  original_depth_ctrl = depth_control_scale
347
  original_expr_ctrl = expression_control_scale
348
- identity_control_scale *= scale_factor
349
- depth_control_scale *= scale_factor
350
- expression_control_scale *= scale_factor
 
 
 
 
 
 
351
  adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}, Expr {original_expr_ctrl:.2f}->{expression_control_scale:.2f}")
352
 
353
  # Report adjustments
@@ -441,7 +481,7 @@ class RetroArtConverter:
441
  guidance_scale=1.0,
442
  depth_control_scale=0.8,
443
  identity_control_scale=0.85,
444
- expression_control_scale=0.6, # <-- NEW
445
  lora_scale=1.0,
446
  identity_preservation=0.8,
447
  strength=0.75,
@@ -462,10 +502,10 @@ class RetroArtConverter:
462
  if consistency_mode:
463
  print("\n[CONSISTENCY] Validating and adjusting parameters...")
464
  strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale = \
465
- self.validate_and_adjust_parameters( # <-- MODIFIED
466
  strength, guidance_scale, lora_scale, identity_preservation,
467
  identity_control_scale, depth_control_scale, consistency_mode,
468
- expression_control_scale # <-- NEW
469
  )
470
 
471
  # Add trigger word
@@ -482,35 +522,37 @@ class RetroArtConverter:
482
  # Resize with high quality
483
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
484
 
 
 
485
  # Generate depth map
486
- print("Generating Zoe depth map...")
487
- depth_image = self.get_depth_map(resized_image)
488
- if depth_image.size != (target_width, target_height):
489
- depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
490
-
491
- # --- NEW: Generate OpenPose map ---
 
 
492
  openpose_image = None
493
- if self.openpose_detector is not None:
494
  print("Generating OpenPose map...")
495
  try:
496
  openpose_image = self.openpose_detector(resized_image, face_only=True)
497
  except Exception as e:
498
  print(f"OpenPose failed, using blank map: {e}")
499
  openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
500
- else:
501
- openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
502
- # --- END NEW ---
503
 
504
 
505
  # Handle face detection
506
- using_multiple_controlnets = self.using_multiple_controlnets
507
  face_kps_image = None
508
  face_embeddings = None
509
  face_crop_enhanced = None
510
  has_detected_faces = False
511
  face_bbox_original = None
512
 
513
- if using_multiple_controlnets and self.face_app is not None and self.instantid_enabled: # <-- Check instantid_enabled
514
  print("Detecting faces and extracting keypoints...")
515
  img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
516
  faces = self.face_app.get(img_array)
@@ -631,90 +673,78 @@ class RetroArtConverter:
631
  if hasattr(self.pipe, 'text_encoder'):
632
  pipe_kwargs["clip_skip"] = 2
633
 
634
- # Configure ControlNet inputs
635
- # --- MODIFIED: Handle 3 ControlNets ---
636
- if using_multiple_controlnets and has_detected_faces and face_kps_image is not None and self.instantid_enabled:
637
- print("Using InstantID (keypoints) + Depth + OpenPose ControlNets")
638
- control_images = [face_kps_image, depth_image, openpose_image]
639
- conditioning_scales = [identity_control_scale, depth_control_scale, expression_control_scale]
640
-
641
- pipe_kwargs["control_image"] = control_images
642
- pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
643
-
644
- # Add face embeddings for IP-Adapter if available
645
- if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
646
- print(f"Processing InstantID face embeddings with Resampler...")
647
-
648
- with torch.no_grad():
649
- # Convert InsightFace embeddings to tensor
650
- face_emb_tensor = torch.from_numpy(face_embeddings).to(
651
- device=self.device,
652
- dtype=self.dtype
653
- )
654
-
655
- # Reshape for Resampler: [1, 1, 512]
656
- face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
657
-
658
- # Pass through Resampler: [1, 1, 512] -> 16x2048
659
- face_proj_embeds = self.image_proj_model(face_emb_tensor)
660
-
661
- # Scale with identity preservation
662
- boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
663
- face_proj_embeds = face_proj_embeds * boosted_scale
664
-
665
- print(f" - Face embedding: {face_emb_tensor.shape}")
666
- print(f" - Resampler output: {face_proj_embeds.shape}")
667
- print(f" - Scale: {boosted_scale:.2f}")
668
 
669
- # CRITICAL: Concatenate with text embeddings (not separate kwargs!)
670
- if 'prompt_embeds' in pipe_kwargs:
671
- # Compel encoded prompts
672
- original_embeds = pipe_kwargs['prompt_embeds']
673
-
674
- # Handle CFG (classifier-free guidance)
675
- if original_embeds.shape[0] > 1: # Has negative + positive
676
- # Duplicate for negative + positive
677
- face_proj_embeds = torch.cat([
678
- torch.zeros_like(face_proj_embeds), # Negative
679
- face_proj_embeds # Positive
680
- ], dim=0)
681
 
682
- # Concatenate: [batch, text_tokens, 2048] + [batch, 16, 2048]
683
- combined_embeds = torch.cat([original_embeds, face_proj_embeds], dim=1)
684
- pipe_kwargs['prompt_embeds'] = combined_embeds
685
 
686
- print(f" - Text embeds: {original_embeds.shape}")
687
- print(f" - Combined embeds: {combined_embeds.shape}")
688
- print(f" [OK] Face embeddings concatenated successfully!")
689
 
690
- else:
691
- print(f" [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
692
-
693
- elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
694
- # Face detected but embeddings unavailable
695
- print(" Face detected but embeddings unavailable, using keypoints only")
696
-
697
- elif using_multiple_controlnets: # No face, or InstantID disabled
698
- print("InstantID disabled or no faces detected, using depth + openpose only")
699
- # Use blank image for InstantID
700
- blank_kps = Image.new("RGB", (target_width, target_height), (0,0,0))
701
-
702
- control_images = [blank_kps, depth_image, openpose_image]
703
- conditioning_scales = [0.0, depth_control_scale, expression_control_scale]
704
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  pipe_kwargs["control_image"] = control_images
706
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
707
-
708
- else: # Fallback to just depth (shouldn't happen if setup is correct)
709
- print("Using Depth ControlNet only")
710
- pipe_kwargs["control_image"] = depth_image
711
- pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
712
- # --- END MODIFICATION ---
713
 
714
 
715
  # Generate
716
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
717
- print(f"Controlnet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}, Expression: {expression_control_scale}")
718
  result = self.pipe(**pipe_kwargs)
719
 
720
  generated_image = result.images[0]
 
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
  load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
+ load_openpose_detector
24
  )
25
 
26
 
 
36
  'instantid': False,
37
  'zoe_depth': False,
38
  'ip_adapter': False,
39
+ 'openpose': False
40
  }
41
 
42
  # Initialize face analysis
 
64
  else:
65
  self.image_encoder = None
66
 
67
+ # --- FIX START: Robust ControlNet Loading ---
68
  # Determine which controlnets to use
69
+
70
+ # Store booleans for which models are active
71
+ self.instantid_active = self.instantid_enabled and self.controlnet_instantid is not None
72
+ self.depth_active = self.controlnet_depth is not None
73
+ self.openpose_active = self.controlnet_openpose is not None
74
+
75
+ # Build the list of *active* controlnet models
76
+ controlnets = []
77
+ if self.instantid_active:
78
+ controlnets.append(self.controlnet_instantid)
79
+ print(" [CN] InstantID (Identity) active")
80
+ else:
81
+ print(" [CN] InstantID (Identity) DISABLED")
82
+
83
+ if self.depth_active:
84
+ controlnets.append(self.controlnet_depth)
85
+ print(" [CN] Depth active")
86
+ else:
87
+ print(" [CN] Depth DISABLED")
88
+
89
+ if self.openpose_active:
90
+ controlnets.append(self.controlnet_openpose)
91
+ print(" [CN] OpenPose (Expression) active")
92
  else:
93
+ print(" [CN] OpenPose (Expression) DISABLED")
94
+
95
+ if not controlnets:
96
+ print("[WARNING] No ControlNets loaded!")
97
 
98
+ print(f"Initializing with {len(controlnets)} active ControlNet(s)")
99
 
100
  # Load SDXL pipeline
101
+ # Pass the filtered list (or None if empty)
102
+ self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets if controlnets else None)
103
+ # --- FIX END ---
104
+
105
  self.models_loaded['custom_checkpoint'] = checkpoint_success
106
 
107
  # Load LORA
 
109
  self.models_loaded['lora'] = lora_success
110
 
111
  # Setup IP-Adapter
112
+ if self.instantid_active and self.image_encoder is not None: # <-- Check instantid_active
113
  self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
114
  self.models_loaded['ip_adapter'] = ip_adapter_success
115
  else:
116
+ print("[INFO] Face preservation: IP-Adapter disabled (InstantID model failed or encoder failed)")
117
  self.models_loaded['ip_adapter'] = False
118
  self.image_proj_model = None
119
 
 
310
  def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
311
  identity_preservation, identity_control_scale,
312
  depth_control_scale, consistency_mode=True,
313
+ expression_control_scale=0.6):
314
  """
315
  Enhanced parameter validation with stricter rules for consistency.
316
  """
 
365
  adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
366
 
367
  # Rule 5: ControlNet balance
368
+ # MODIFIED: Only sum *active* controlnets
369
+ total_control = 0
370
+ if self.instantid_active:
371
+ total_control += identity_control_scale
372
+ if self.depth_active:
373
+ total_control += depth_control_scale
374
+ if self.openpose_active:
375
+ total_control += expression_control_scale
376
+
377
  if total_control > 2.0: # Increased max total from 1.7 to 2.0
378
  scale_factor = 2.0 / total_control
379
  original_id_ctrl = identity_control_scale
380
  original_depth_ctrl = depth_control_scale
381
  original_expr_ctrl = expression_control_scale
382
+
383
+ # Only scale active controlnets
384
+ if self.instantid_active:
385
+ identity_control_scale *= scale_factor
386
+ if self.depth_active:
387
+ depth_control_scale *= scale_factor
388
+ if self.openpose_active:
389
+ expression_control_scale *= scale_factor
390
+
391
  adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}, Expr {original_expr_ctrl:.2f}->{expression_control_scale:.2f}")
392
 
393
  # Report adjustments
 
481
  guidance_scale=1.0,
482
  depth_control_scale=0.8,
483
  identity_control_scale=0.85,
484
+ expression_control_scale=0.6,
485
  lora_scale=1.0,
486
  identity_preservation=0.8,
487
  strength=0.75,
 
502
  if consistency_mode:
503
  print("\n[CONSISTENCY] Validating and adjusting parameters...")
504
  strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale = \
505
+ self.validate_and_adjust_parameters(
506
  strength, guidance_scale, lora_scale, identity_preservation,
507
  identity_control_scale, depth_control_scale, consistency_mode,
508
+ expression_control_scale
509
  )
510
 
511
  # Add trigger word
 
522
  # Resize with high quality
523
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
524
 
525
+ # --- FIX START: Generate control images only if models are active ---
526
+
527
  # Generate depth map
528
+ depth_image = None
529
+ if self.depth_active:
530
+ print("Generating Zoe depth map...")
531
+ depth_image = self.get_depth_map(resized_image)
532
+ if depth_image.size != (target_width, target_height):
533
+ depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
534
+
535
+ # Generate OpenPose map
536
  openpose_image = None
537
+ if self.openpose_active:
538
  print("Generating OpenPose map...")
539
  try:
540
  openpose_image = self.openpose_detector(resized_image, face_only=True)
541
  except Exception as e:
542
  print(f"OpenPose failed, using blank map: {e}")
543
  openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
544
+
545
+ # --- FIX END ---
 
546
 
547
 
548
  # Handle face detection
 
549
  face_kps_image = None
550
  face_embeddings = None
551
  face_crop_enhanced = None
552
  has_detected_faces = False
553
  face_bbox_original = None
554
 
555
+ if self.instantid_active and self.face_app is not None: # <-- Check instantid_active
556
  print("Detecting faces and extracting keypoints...")
557
  img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
558
  faces = self.face_app.get(img_array)
 
673
  if hasattr(self.pipe, 'text_encoder'):
674
  pipe_kwargs["clip_skip"] = 2
675
 
676
+ # --- FIX START: Configure ControlNet inputs dynamically ---
677
+ control_images = []
678
+ conditioning_scales = []
679
+ scale_debug_str = []
680
+
681
+ # 1. InstantID (Identity)
682
+ if self.instantid_active:
683
+ if has_detected_faces and face_kps_image is not None:
684
+ control_images.append(face_kps_image)
685
+ conditioning_scales.append(identity_control_scale)
686
+ scale_debug_str.append(f"Identity: {identity_control_scale:.2f}")
687
+
688
+ # Add face embeddings for IP-Adapter if available
689
+ if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
690
+ print(f"Processing InstantID face embeddings with Resampler...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
+ with torch.no_grad():
693
+ face_emb_tensor = torch.from_numpy(face_embeddings).to(device=self.device, dtype=self.dtype)
694
+ face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
695
+ face_proj_embeds = self.image_proj_model(face_emb_tensor)
 
 
 
 
 
 
 
 
696
 
697
+ boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
698
+ face_proj_embeds = face_proj_embeds * boosted_scale
 
699
 
700
+ print(f" - Face embedding: {face_emb_tensor.shape} -> {face_proj_embeds.shape}, Scale: {boosted_scale:.2f}")
 
 
701
 
702
+ if 'prompt_embeds' in pipe_kwargs:
703
+ original_embeds = pipe_kwargs['prompt_embeds']
704
+
705
+ if original_embeds.shape[0] > 1: # Handle CFG
706
+ face_proj_embeds = torch.cat([torch.zeros_like(face_proj_embeds), face_proj_embeds], dim=0)
707
+
708
+ combined_embeds = torch.cat([original_embeds, face_proj_embeds], dim=1)
709
+ pipe_kwargs['prompt_embeds'] = combined_embeds
710
+ print(f" [OK] Face embeddings concatenated successfully! New shape: {combined_embeds.shape}")
711
+ else:
712
+ print(f" [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
713
+
714
+ elif has_detected_faces:
715
+ print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
716
+
717
+ else:
718
+ # No face, must add a blank image to keep list order
719
+ print("Using blank map for InstantID (no face/disabled)")
720
+ control_images.append(Image.new("RGB", (target_width, target_height), (0,0,0)))
721
+ conditioning_scales.append(0.0) # Set scale to 0
722
+ scale_debug_str.append("Identity: 0.00")
723
+
724
+ # 2. Depth
725
+ if self.depth_active:
726
+ control_images.append(depth_image)
727
+ conditioning_scales.append(depth_control_scale)
728
+ scale_debug_str.append(f"Depth: {depth_control_scale:.2f}")
729
+
730
+ # 3. OpenPose (Expression)
731
+ if self.openpose_active:
732
+ control_images.append(openpose_image) # This is already a blank map if it failed
733
+ conditioning_scales.append(expression_control_scale)
734
+ scale_debug_str.append(f"Expression: {expression_control_scale:.2f}")
735
+
736
+ if control_images:
737
  pipe_kwargs["control_image"] = control_images
738
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
739
+ print(f"Active ControlNets: {len(control_images)}")
740
+ else:
741
+ print("No active ControlNets, running standard Img2Img")
742
+ # --- FIX END ---
 
 
743
 
744
 
745
  # Generate
746
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
747
+ print(f"Controlnet scales - {' | '.join(scale_debug_str)}")
748
  result = self.pipe(**pipe_kwargs)
749
 
750
  generated_image = result.images[0]