primerz commited on
Commit
079d679
·
verified ·
1 Parent(s): 171e0fc

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +90 -68
generator.py CHANGED
@@ -19,7 +19,8 @@ from utils import (
19
  from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
  load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
- setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip
 
23
  )
24
 
25
 
@@ -33,19 +34,26 @@ class RetroArtConverter:
33
  'custom_checkpoint': False,
34
  'lora': False,
35
  'instantid': False,
36
- 'midas_depth': False,
37
- 'ip_adapter': False
 
38
  }
39
 
40
  # Initialize face analysis
41
  self.face_app, self.face_detection_enabled = load_face_analysis()
42
 
43
- # Load Midas Depth detector
44
- self.midas_depth, midas_success = load_depth_detector()
45
- self.models_loaded['midas_depth'] = midas_success
 
 
 
 
 
46
 
47
  # Load ControlNets
48
- controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
 
49
  self.controlnet_depth = controlnet_depth
50
  self.instantid_enabled = instantid_success
51
  self.models_loaded['instantid'] = instantid_success
@@ -57,12 +65,13 @@ class RetroArtConverter:
57
  self.image_encoder = None
58
 
59
  # Determine which controlnets to use
 
60
  if self.instantid_enabled and self.controlnet_instantid is not None:
61
- controlnets = [self.controlnet_instantid, controlnet_depth]
62
- print(f"Initializing with multiple ControlNets: InstantID + Depth")
63
  else:
64
- controlnets = controlnet_depth
65
- print(f"Initializing with single ControlNet: Depth only")
66
 
67
  # Load SDXL pipeline
68
  self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
@@ -146,8 +155,8 @@ class RetroArtConverter:
146
  print("============================\n")
147
 
148
  def get_depth_map(self, image):
149
- """Generate depth map using Midas Depth"""
150
- if self.midas_depth is not None:
151
  try:
152
  if image.mode != 'RGB':
153
  image = image.convert('RGB')
@@ -163,29 +172,25 @@ class RetroArtConverter:
163
  target_width = int(max(64, target_width))
164
  target_height = int(max(64, target_height))
165
 
 
 
 
166
  if target_width != orig_width or target_height != orig_height:
167
- image = image.resize((int(target_width), int(target_height)), Image.LANCZOS)
168
- print(f"[DEPTH] Resized for MidasDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
169
 
170
  # FIXED: Add torch.no_grad() wrapper
171
  with torch.no_grad():
172
- depth_image = self.midas_depth(image)
173
 
174
  depth_width, depth_height = depth_image.size
175
- # Convert numpy int64 to Python int to avoid PIL errors
176
- depth_width = int(depth_width)
177
- depth_height = int(depth_height)
178
- orig_width_int = int(orig_width)
179
- orig_height_int = int(orig_height)
180
-
181
- if depth_width != orig_width_int or depth_height != orig_height_int:
182
- depth_image = depth_image.resize((orig_width_int, orig_height_int), Image.LANCZOS)
183
 
184
- print(f"[DEPTH] Midas depth map generated: {orig_width}x{orig_height}")
185
  return depth_image
186
 
187
  except Exception as e:
188
- print(f"[DEPTH] MidasDetector failed ({e}), falling back to grayscale depth")
189
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
190
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
191
  return Image.fromarray(depth_colored)
@@ -198,6 +203,8 @@ class RetroArtConverter:
198
  def add_trigger_word(self, prompt):
199
  """Add trigger word to prompt if not present"""
200
  if TRIGGER_WORD.lower() not in prompt.lower():
 
 
201
  return f"{TRIGGER_WORD}, {prompt}"
202
  return prompt
203
 
@@ -275,7 +282,8 @@ class RetroArtConverter:
275
 
276
  def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
277
  identity_preservation, identity_control_scale,
278
- depth_control_scale, consistency_mode=True):
 
279
  """
280
  Enhanced parameter validation with stricter rules for consistency.
281
  """
@@ -330,14 +338,17 @@ class RetroArtConverter:
330
  adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
331
 
332
  # Rule 5: ControlNet balance
333
- total_control = identity_control_scale + depth_control_scale
334
- if total_control > 1.7:
335
- scale_factor = 1.7 / total_control
 
336
  original_id_ctrl = identity_control_scale
337
  original_depth_ctrl = depth_control_scale
 
338
  identity_control_scale *= scale_factor
339
  depth_control_scale *= scale_factor
340
- adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}")
 
341
 
342
  # Report adjustments
343
  if adjustments:
@@ -347,7 +358,7 @@ class RetroArtConverter:
347
  else:
348
  print(" [OK] Parameters already optimal")
349
 
350
- return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale
351
 
352
  def generate_caption(self, image, max_length=None, num_beams=None):
353
  """Generate a descriptive caption for the image (supports BLIP-2, GIT, BLIP)."""
@@ -430,6 +441,7 @@ class RetroArtConverter:
430
  guidance_scale=1.0,
431
  depth_control_scale=0.8,
432
  identity_control_scale=0.85,
 
433
  lora_scale=1.0,
434
  identity_preservation=0.8,
435
  strength=0.75,
@@ -443,13 +455,17 @@ class RetroArtConverter:
443
  prompt = sanitize_text(prompt)
444
  negative_prompt = sanitize_text(negative_prompt)
445
 
 
 
 
446
  # Apply parameter validation
447
  if consistency_mode:
448
  print("\n[CONSISTENCY] Validating and adjusting parameters...")
449
- strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale = \
450
- self.validate_and_adjust_parameters(
451
  strength, guidance_scale, lora_scale, identity_preservation,
452
- identity_control_scale, depth_control_scale, consistency_mode
 
453
  )
454
 
455
  # Add trigger word
@@ -467,10 +483,24 @@ class RetroArtConverter:
467
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
468
 
469
  # Generate depth map
470
- print("Generating Midas depth map...")
471
  depth_image = self.get_depth_map(resized_image)
472
  if depth_image.size != (target_width, target_height):
473
  depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
  # Handle face detection
476
  using_multiple_controlnets = self.using_multiple_controlnets
@@ -480,14 +510,10 @@ class RetroArtConverter:
480
  has_detected_faces = False
481
  face_bbox_original = None
482
 
483
- if using_multiple_controlnets and self.face_app is not None:
484
  print("Detecting faces and extracting keypoints...")
485
  img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
486
- try:
487
- faces = self.face_app.get(img_array)
488
- except Exception as e:
489
- print(f"[WARNING] Face detection failed: {e}")
490
- faces = []
491
 
492
  if len(faces) > 0:
493
  has_detected_faces = True
@@ -555,8 +581,7 @@ class RetroArtConverter:
555
  # Set LORA scale
556
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
557
  try:
558
- # Use correct adapter name - peft uses 'default_0' for single adapters
559
- self.pipe.set_adapters(["default_0"], adapter_weights=[lora_scale])
560
  print(f"LORA scale: {lora_scale}")
561
  except Exception as e:
562
  print(f"Could not set LORA scale: {e}")
@@ -588,21 +613,14 @@ class RetroArtConverter:
588
  conditioning = self.compel(prompt)
589
  negative_conditioning = self.compel(negative_prompt)
590
 
591
- # Handle potential token length mismatches
592
- prompt_embeds_0 = conditioning[0]
593
- prompt_embeds_1 = conditioning[1]
594
- neg_embeds_0 = negative_conditioning[0]
595
- neg_embeds_1 = negative_conditioning[1]
596
-
597
- # Ensure consistent shapes if needed
598
- pipe_kwargs["prompt_embeds"] = prompt_embeds_0
599
- pipe_kwargs["pooled_prompt_embeds"] = prompt_embeds_1
600
- pipe_kwargs["negative_prompt_embeds"] = neg_embeds_0
601
- pipe_kwargs["negative_pooled_prompt_embeds"] = neg_embeds_1
602
 
603
  print("[OK] Using Compel-encoded prompts")
604
  except Exception as e:
605
- print(f"Compel encoding failed ({e}), falling back to standard prompts")
606
  pipe_kwargs["prompt"] = prompt
607
  pipe_kwargs["negative_prompt"] = negative_prompt
608
  else:
@@ -614,10 +632,11 @@ class RetroArtConverter:
614
  pipe_kwargs["clip_skip"] = 2
615
 
616
  # Configure ControlNet inputs
617
- if using_multiple_controlnets and has_detected_faces and face_kps_image is not None:
618
- print("Using InstantID (keypoints) + Depth ControlNets")
619
- control_images = [face_kps_image, depth_image]
620
- conditioning_scales = [identity_control_scale, depth_control_scale]
 
621
 
622
  pipe_kwargs["control_image"] = control_images
623
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
@@ -636,7 +655,7 @@ class RetroArtConverter:
636
  # Reshape for Resampler: [1, 1, 512]
637
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
638
 
639
- # Pass through Resampler: [1, 1, 512] → [1, 16, 2048]
640
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
641
 
642
  # Scale with identity preservation
@@ -674,25 +693,28 @@ class RetroArtConverter:
674
  elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
675
  # Face detected but embeddings unavailable
676
  print(" Face detected but embeddings unavailable, using keypoints only")
677
- # No need for dummy embeddings with concatenation approach
678
 
679
- elif using_multiple_controlnets and not has_detected_faces:
680
- print("Multiple ControlNets available but no faces detected, using depth only")
681
- control_images = [depth_image, depth_image]
682
- conditioning_scales = [0.0, depth_control_scale]
 
 
 
683
 
684
  pipe_kwargs["control_image"] = control_images
685
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
686
 
687
- else:
688
  print("Using Depth ControlNet only")
689
  pipe_kwargs["control_image"] = depth_image
690
  pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
 
691
 
692
 
693
  # Generate
694
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
695
- print(f"Controlnet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}")
696
  result = self.pipe(**pipe_kwargs)
697
 
698
  generated_image = result.images[0]
 
19
  from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
  load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
+ setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
+ load_openpose_detector # <-- NEW
24
  )
25
 
26
 
 
34
  'custom_checkpoint': False,
35
  'lora': False,
36
  'instantid': False,
37
+ 'zoe_depth': False,
38
+ 'ip_adapter': False,
39
+ 'openpose': False # <-- NEW
40
  }
41
 
42
  # Initialize face analysis
43
  self.face_app, self.face_detection_enabled = load_face_analysis()
44
 
45
+ # Load Zoe Depth detector
46
+ self.zoe_depth, zoe_success = load_depth_detector()
47
+ self.models_loaded['zoe_depth'] = zoe_success
48
+
49
+ # --- NEW: Load OpenPose detector ---
50
+ self.openpose_detector, openpose_success = load_openpose_detector()
51
+ self.models_loaded['openpose'] = openpose_success
52
+ # --- END NEW ---
53
 
54
  # Load ControlNets
55
+ # Now unpacks 3 models + success boolean
56
+ controlnet_depth, self.controlnet_instantid, self.controlnet_openpose, instantid_success = load_controlnets()
57
  self.controlnet_depth = controlnet_depth
58
  self.instantid_enabled = instantid_success
59
  self.models_loaded['instantid'] = instantid_success
 
65
  self.image_encoder = None
66
 
67
  # Determine which controlnets to use
68
+ controlnets = [controlnet_depth, self.controlnet_openpose] # Start with depth and openpose
69
  if self.instantid_enabled and self.controlnet_instantid is not None:
70
+ controlnets.insert(0, self.controlnet_instantid) # Add InstantID at the start if available
71
+ print(f"Initializing with multiple ControlNets: InstantID + Depth + OpenPose")
72
  else:
73
+ print(f"Initializing with ControlNets: Depth + OpenPose (InstantID disabled)")
74
+
75
 
76
  # Load SDXL pipeline
77
  self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets)
 
155
  print("============================\n")
156
 
157
  def get_depth_map(self, image):
158
+ """Generate depth map using Zoe Depth"""
159
+ if self.zoe_depth is not None:
160
  try:
161
  if image.mode != 'RGB':
162
  image = image.convert('RGB')
 
172
  target_width = int(max(64, target_width))
173
  target_height = int(max(64, target_height))
174
 
175
+ size_for_depth = (int(target_width), int(target_height))
176
+ image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
177
+
178
  if target_width != orig_width or target_height != orig_height:
179
+ print(f"[DEPTH] Resized for ZoeDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
 
180
 
181
  # FIXED: Add torch.no_grad() wrapper
182
  with torch.no_grad():
183
+ depth_image = self.zoe_depth(image_for_depth)
184
 
185
  depth_width, depth_height = depth_image.size
186
+ if depth_width != orig_width or depth_height != orig_height:
187
+ depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
 
 
 
 
 
 
188
 
189
+ print(f"[DEPTH] Zoe depth map generated: {orig_width}x{orig_height}")
190
  return depth_image
191
 
192
  except Exception as e:
193
+ print(f"[DEPTH] ZoeDetector failed ({e}), falling back to grayscale depth")
194
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
195
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
196
  return Image.fromarray(depth_colored)
 
203
  def add_trigger_word(self, prompt):
204
  """Add trigger word to prompt if not present"""
205
  if TRIGGER_WORD.lower() not in prompt.lower():
206
+ if not prompt or not prompt.strip():
207
+ return TRIGGER_WORD
208
  return f"{TRIGGER_WORD}, {prompt}"
209
  return prompt
210
 
 
282
 
283
  def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
284
  identity_preservation, identity_control_scale,
285
+ depth_control_scale, consistency_mode=True,
286
+ expression_control_scale=0.6): # <-- NEW
287
  """
288
  Enhanced parameter validation with stricter rules for consistency.
289
  """
 
338
  adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
339
 
340
  # Rule 5: ControlNet balance
341
+ # <-- MODIFIED: Now balances 3 controlnets -->
342
+ total_control = identity_control_scale + depth_control_scale + expression_control_scale
343
+ if total_control > 2.0: # Increased max total from 1.7 to 2.0
344
+ scale_factor = 2.0 / total_control
345
  original_id_ctrl = identity_control_scale
346
  original_depth_ctrl = depth_control_scale
347
+ original_expr_ctrl = expression_control_scale
348
  identity_control_scale *= scale_factor
349
  depth_control_scale *= scale_factor
350
+ expression_control_scale *= scale_factor
351
+ adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}, Expr {original_expr_ctrl:.2f}->{expression_control_scale:.2f}")
352
 
353
  # Report adjustments
354
  if adjustments:
 
358
  else:
359
  print(" [OK] Parameters already optimal")
360
 
361
+ return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale
362
 
363
  def generate_caption(self, image, max_length=None, num_beams=None):
364
  """Generate a descriptive caption for the image (supports BLIP-2, GIT, BLIP)."""
 
441
  guidance_scale=1.0,
442
  depth_control_scale=0.8,
443
  identity_control_scale=0.85,
444
+ expression_control_scale=0.6, # <-- NEW
445
  lora_scale=1.0,
446
  identity_preservation=0.8,
447
  strength=0.75,
 
455
  prompt = sanitize_text(prompt)
456
  negative_prompt = sanitize_text(negative_prompt)
457
 
458
+ if not negative_prompt or not negative_prompt.strip():
459
+ negative_prompt = ""
460
+
461
  # Apply parameter validation
462
  if consistency_mode:
463
  print("\n[CONSISTENCY] Validating and adjusting parameters...")
464
+ strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale = \
465
+ self.validate_and_adjust_parameters( # <-- MODIFIED
466
  strength, guidance_scale, lora_scale, identity_preservation,
467
+ identity_control_scale, depth_control_scale, consistency_mode,
468
+ expression_control_scale # <-- NEW
469
  )
470
 
471
  # Add trigger word
 
483
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
484
 
485
  # Generate depth map
486
+ print("Generating Zoe depth map...")
487
  depth_image = self.get_depth_map(resized_image)
488
  if depth_image.size != (target_width, target_height):
489
  depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
490
+
491
+ # --- NEW: Generate OpenPose map ---
492
+ openpose_image = None
493
+ if self.openpose_detector is not None:
494
+ print("Generating OpenPose map...")
495
+ try:
496
+ openpose_image = self.openpose_detector(resized_image, face_only=True)
497
+ except Exception as e:
498
+ print(f"OpenPose failed, using blank map: {e}")
499
+ openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
500
+ else:
501
+ openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
502
+ # --- END NEW ---
503
+
504
 
505
  # Handle face detection
506
  using_multiple_controlnets = self.using_multiple_controlnets
 
510
  has_detected_faces = False
511
  face_bbox_original = None
512
 
513
+ if using_multiple_controlnets and self.face_app is not None and self.instantid_enabled: # <-- Check instantid_enabled
514
  print("Detecting faces and extracting keypoints...")
515
  img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
516
+ faces = self.face_app.get(img_array)
 
 
 
 
517
 
518
  if len(faces) > 0:
519
  has_detected_faces = True
 
581
  # Set LORA scale
582
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
583
  try:
584
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
 
585
  print(f"LORA scale: {lora_scale}")
586
  except Exception as e:
587
  print(f"Could not set LORA scale: {e}")
 
613
  conditioning = self.compel(prompt)
614
  negative_conditioning = self.compel(negative_prompt)
615
 
616
+ pipe_kwargs["prompt_embeds"] = conditioning[0]
617
+ pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
618
+ pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
619
+ pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
 
 
 
 
 
 
 
620
 
621
  print("[OK] Using Compel-encoded prompts")
622
  except Exception as e:
623
+ print(f"Compel encoding failed, using standard prompts: {e}")
624
  pipe_kwargs["prompt"] = prompt
625
  pipe_kwargs["negative_prompt"] = negative_prompt
626
  else:
 
632
  pipe_kwargs["clip_skip"] = 2
633
 
634
  # Configure ControlNet inputs
635
+ # --- MODIFIED: Handle 3 ControlNets ---
636
+ if using_multiple_controlnets and has_detected_faces and face_kps_image is not None and self.instantid_enabled:
637
+ print("Using InstantID (keypoints) + Depth + OpenPose ControlNets")
638
+ control_images = [face_kps_image, depth_image, openpose_image]
639
+ conditioning_scales = [identity_control_scale, depth_control_scale, expression_control_scale]
640
 
641
  pipe_kwargs["control_image"] = control_images
642
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
 
655
  # Reshape for Resampler: [1, 1, 512]
656
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
657
 
658
+ # Pass through Resampler: [1, 1, 512] -> 16x2048
659
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
660
 
661
  # Scale with identity preservation
 
693
  elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
694
  # Face detected but embeddings unavailable
695
  print(" Face detected but embeddings unavailable, using keypoints only")
 
696
 
697
+ elif using_multiple_controlnets: # No face, or InstantID disabled
698
+ print("InstantID disabled or no faces detected, using depth + openpose only")
699
+ # Use blank image for InstantID
700
+ blank_kps = Image.new("RGB", (target_width, target_height), (0,0,0))
701
+
702
+ control_images = [blank_kps, depth_image, openpose_image]
703
+ conditioning_scales = [0.0, depth_control_scale, expression_control_scale]
704
 
705
  pipe_kwargs["control_image"] = control_images
706
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
707
 
708
+ else: # Fallback to just depth (shouldn't happen if setup is correct)
709
  print("Using Depth ControlNet only")
710
  pipe_kwargs["control_image"] = depth_image
711
  pipe_kwargs["controlnet_conditioning_scale"] = depth_control_scale
712
+ # --- END MODIFICATION ---
713
 
714
 
715
  # Generate
716
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
717
+ print(f"Controlnet scales - Identity: {identity_control_scale}, Depth: {depth_control_scale}, Expression: {expression_control_scale}")
718
  result = self.pipe(**pipe_kwargs)
719
 
720
  generated_image = result.images[0]