primerz commited on
Commit
f4b692c
·
verified ·
1 Parent(s): ec1cd29

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +185 -92
generator.py CHANGED
@@ -18,9 +18,9 @@ from utils import (
18
  )
19
  from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
- load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
- load_openpose_detector
24
  )
25
 
26
 
@@ -34,17 +34,25 @@ class RetroArtConverter:
34
  'custom_checkpoint': False,
35
  'lora': False,
36
  'instantid': False,
37
- 'zoe_depth': False,
 
38
  'ip_adapter': False,
39
- 'openpose': False
 
40
  }
 
41
 
42
- # Initialize face analysis
43
  self.face_app, self.face_detection_enabled = load_face_analysis()
44
 
45
- # Load Zoe Depth detector
46
- self.zoe_depth, zoe_success = load_depth_detector()
47
- self.models_loaded['zoe_depth'] = zoe_success
 
 
 
 
 
48
 
49
  # --- NEW: Load OpenPose detector ---
50
  self.openpose_detector, openpose_success = load_openpose_detector()
@@ -104,8 +112,8 @@ class RetroArtConverter:
104
 
105
  self.models_loaded['custom_checkpoint'] = checkpoint_success
106
 
107
- # Load LORA
108
- lora_success = load_lora(self.pipe)
109
  self.models_loaded['lora'] = lora_success
110
 
111
  # Setup IP-Adapter
@@ -155,8 +163,15 @@ class RetroArtConverter:
155
  """Print model loading status"""
156
  print("\n=== MODEL STATUS ===")
157
  for model, loaded in self.models_loaded.items():
158
- status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
159
- print(f"{model}: {status}")
 
 
 
 
 
 
 
160
  print("===================\n")
161
 
162
  print("=== UPGRADE VERIFICATION ===")
@@ -182,8 +197,11 @@ class RetroArtConverter:
182
  print("============================\n")
183
 
184
  def get_depth_map(self, image):
185
- """Generate depth map using Zoe Depth"""
186
- if self.zoe_depth is not None:
 
 
 
187
  try:
188
  if image.mode != 'RGB':
189
  image = image.convert('RGB')
@@ -203,25 +221,27 @@ class RetroArtConverter:
203
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
204
 
205
  if target_width != orig_width or target_height != orig_height:
206
- print(f"[DEPTH] Resized for ZoeDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
207
 
208
  # FIXED: Add torch.no_grad() wrapper
209
  with torch.no_grad():
210
- depth_image = self.zoe_depth(image_for_depth)
211
 
212
  depth_width, depth_height = depth_image.size
213
  if depth_width != orig_width or depth_height != orig_height:
214
  depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
215
 
216
- print(f"[DEPTH] Zoe depth map generated: {orig_width}x{orig_height}")
217
  return depth_image
218
 
219
  except Exception as e:
220
- print(f"[DEPTH] ZoeDetector failed ({e}), falling back to grayscale depth")
221
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
222
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
223
  return Image.fromarray(depth_colored)
224
  else:
 
 
225
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
226
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
227
  return Image.fromarray(depth_colored)
@@ -482,6 +502,7 @@ class RetroArtConverter:
482
  depth_control_scale=0.8,
483
  identity_control_scale=0.85,
484
  expression_control_scale=0.6,
 
485
  lora_scale=1.0,
486
  identity_preservation=0.8,
487
  strength=0.75,
@@ -552,81 +573,153 @@ class RetroArtConverter:
552
  has_detected_faces = False
553
  face_bbox_original = None
554
 
555
- if self.instantid_active and self.face_app is not None: # <-- Check instantid_active
556
- print("Detecting faces and extracting keypoints...")
557
- img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
558
- faces = self.face_app.get(img_array)
559
 
560
- if len(faces) > 0:
561
- has_detected_faces = True
562
- print(f"Detected {len(faces)} face(s)")
563
-
564
- # Get largest face
565
- face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
566
-
567
- # ADAPTIVE PARAMETERS
568
- adaptive_params = self.detect_face_quality(face)
569
- if adaptive_params is not None:
570
- print(f"[ADAPTIVE] {adaptive_params['reason']}")
571
- identity_preservation = adaptive_params['identity_preservation']
572
- identity_control_scale = adaptive_params['identity_control_scale']
573
- guidance_scale = adaptive_params['guidance_scale']
574
- lora_scale = adaptive_params['lora_scale']
575
-
576
- # Extract face embeddings
577
- face_embeddings_base = face.normed_embedding
578
-
579
- # Extract face crop
580
- bbox = face.bbox.astype(int)
581
- x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
582
- face_bbox_original = [x1, y1, x2, y2]
583
-
584
- # Add padding
585
- face_width = x2 - x1
586
- face_height = y2 - y1
587
- padding_x = int(face_width * 0.3)
588
- padding_y = int(face_height * 0.3)
589
- x1 = max(0, x1 - padding_x)
590
- y1 = max(0, y1 - padding_y)
591
- x2 = min(resized_image.width, x2 + padding_x)
592
- y2 = min(resized_image.height, y2 + padding_y)
593
-
594
- # Crop face region
595
- face_crop = resized_image.crop((x1, y1, x2, y2))
596
-
597
- # MULTI-SCALE PROCESSING
598
- face_embeddings = self.extract_multi_scale_face(face_crop, face)
599
-
600
- # Enhance face crop
601
- face_crop_enhanced = enhance_face_crop(face_crop)
602
-
603
- # Draw keypoints
604
- face_kps = face.kps
605
- face_kps_image = draw_kps(resized_image, face_kps)
606
-
607
- # ENHANCED: Extract comprehensive facial attributes
608
- from utils import get_facial_attributes, build_enhanced_prompt
609
- facial_attrs = get_facial_attributes(face)
610
 
611
- # Update prompt with detected attributes
612
- prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD)
613
-
614
- # Legacy output for compatibility
615
- age = facial_attrs['age']
616
- gender_code = facial_attrs['gender']
617
- det_score = facial_attrs['quality']
618
-
619
- gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
620
- print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
621
- print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
- # Set LORA scale
624
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
625
- try:
626
- self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
627
- print(f"LORA scale: {lora_scale}")
628
- except Exception as e:
629
- print(f"Could not set LORA scale: {e}")
 
 
 
 
 
 
 
 
 
 
 
630
 
631
  # Prepare generation kwargs
632
  pipe_kwargs = {
@@ -715,11 +808,11 @@ class RetroArtConverter:
715
  print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
716
 
717
  else:
718
- # No face, must add a blank image to keep list order
719
- print("Using blank map for InstantID (no face/disabled)")
720
  control_images.append(Image.new("RGB", (target_width, target_height), (0,0,0)))
721
  conditioning_scales.append(0.0) # Set scale to 0
722
- scale_debug_str.append("Identity: 0.00")
723
 
724
  # 2. Depth
725
  if self.depth_active:
 
18
  )
19
  from models import (
20
  load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
21
+ load_sdxl_pipeline, load_loras, setup_ip_adapter, setup_compel,
22
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
+ load_openpose_detector, load_mediapipe_face_detector
24
  )
25
 
26
 
 
34
  'custom_checkpoint': False,
35
  'lora': False,
36
  'instantid': False,
37
+ 'depth_detector': False,
38
+ 'depth_type': None,
39
  'ip_adapter': False,
40
+ 'openpose': False,
41
+ 'mediapipe_face': False
42
  }
43
+ self.loaded_loras = {} # Store status of each LORA
44
 
45
+ # Initialize face analysis (InsightFace)
46
  self.face_app, self.face_detection_enabled = load_face_analysis()
47
 
48
+ # Load MediapipeFaceDetector (alternative face detection)
49
+ self.mediapipe_face, mediapipe_success = load_mediapipe_face_detector()
50
+ self.models_loaded['mediapipe_face'] = mediapipe_success
51
+
52
+ # Load Depth detector with fallback hierarchy (Leres → Zoe → Midas)
53
+ self.depth_detector, self.depth_type, depth_success = load_depth_detector()
54
+ self.models_loaded['depth_detector'] = depth_success
55
+ self.models_loaded['depth_type'] = self.depth_type
56
 
57
  # --- NEW: Load OpenPose detector ---
58
  self.openpose_detector, openpose_success = load_openpose_detector()
 
112
 
113
  self.models_loaded['custom_checkpoint'] = checkpoint_success
114
 
115
+ # Load LORAs
116
+ self.loaded_loras, lora_success = load_loras(self.pipe)
117
  self.models_loaded['lora'] = lora_success
118
 
119
  # Setup IP-Adapter
 
163
  """Print model loading status"""
164
  print("\n=== MODEL STATUS ===")
165
  for model, loaded in self.models_loaded.items():
166
+ if model == 'lora':
167
+ lora_status = 'DISABLED'
168
+ if loaded:
169
+ loaded_count = sum(1 for status in self.loaded_loras.values() if status)
170
+ lora_status = f"[OK] LOADED ({loaded_count}/3)"
171
+ print(f"loras: {lora_status}")
172
+ else:
173
+ status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
174
+ print(f"{model}: {status}")
175
  print("===================\n")
176
 
177
  print("=== UPGRADE VERIFICATION ===")
 
197
  print("============================\n")
198
 
199
  def get_depth_map(self, image):
200
+ """
201
+ Generate depth map using available depth detector.
202
+ Supports: LeresDetector, ZoeDetector, or MidasDetector.
203
+ """
204
+ if self.depth_detector is not None:
205
  try:
206
  if image.mode != 'RGB':
207
  image = image.convert('RGB')
 
221
  image_for_depth = image.resize(size_for_depth, Image.LANCZOS)
222
 
223
  if target_width != orig_width or target_height != orig_height:
224
+ print(f"[DEPTH] Resized for {self.depth_type.upper()}Detector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
225
 
226
  # FIXED: Add torch.no_grad() wrapper
227
  with torch.no_grad():
228
+ depth_image = self.depth_detector(image_for_depth)
229
 
230
  depth_width, depth_height = depth_image.size
231
  if depth_width != orig_width or depth_height != orig_height:
232
  depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
233
 
234
+ print(f"[DEPTH] {self.depth_type.upper()} depth map generated: {orig_width}x{orig_height}")
235
  return depth_image
236
 
237
  except Exception as e:
238
+ print(f"[DEPTH] {self.depth_type.upper()}Detector failed ({e}), falling back to grayscale depth")
239
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
240
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
241
  return Image.fromarray(depth_colored)
242
  else:
243
+ # No depth detector available, use grayscale fallback
244
+ print("[DEPTH] No depth detector available, using grayscale fallback")
245
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
246
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
247
  return Image.fromarray(depth_colored)
 
502
  depth_control_scale=0.8,
503
  identity_control_scale=0.85,
504
  expression_control_scale=0.6,
505
+ lora_choice="RetroArt",
506
  lora_scale=1.0,
507
  identity_preservation=0.8,
508
  strength=0.75,
 
573
  has_detected_faces = False
574
  face_bbox_original = None
575
 
576
+ if self.instantid_active:
577
+ # Try InsightFace first (if available)
578
+ insightface_tried = False
579
+ insightface_success = False
580
 
581
+ if self.face_app is not None:
582
+ print("Detecting faces with InsightFace...")
583
+ insightface_tried = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
+ try:
586
+ img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
587
+ faces = self.face_app.get(img_array)
588
+
589
+ if len(faces) > 0:
590
+ insightface_success = True
591
+ has_detected_faces = True
592
+ print(f"✓ InsightFace detected {len(faces)} face(s)")
593
+
594
+ # Get largest face
595
+ face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
596
+
597
+ # ADAPTIVE PARAMETERS
598
+ adaptive_params = self.detect_face_quality(face)
599
+ if adaptive_params is not None:
600
+ print(f"[ADAPTIVE] {adaptive_params['reason']}")
601
+ identity_preservation = adaptive_params['identity_preservation']
602
+ identity_control_scale = adaptive_params['identity_control_scale']
603
+ guidance_scale = adaptive_params['guidance_scale']
604
+ lora_scale = adaptive_params['lora_scale']
605
+
606
+ # Extract face embeddings
607
+ face_embeddings_base = face.normed_embedding
608
+
609
+ # Extract face crop
610
+ bbox = face.bbox.astype(int)
611
+ x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
612
+ face_bbox_original = [x1, y1, x2, y2]
613
+
614
+ # Add padding
615
+ face_width = x2 - x1
616
+ face_height = y2 - y1
617
+ padding_x = int(face_width * 0.3)
618
+ padding_y = int(face_height * 0.3)
619
+ x1 = max(0, x1 - padding_x)
620
+ y1 = max(0, y1 - padding_y)
621
+ x2 = min(resized_image.width, x2 + padding_x)
622
+ y2 = min(resized_image.height, y2 + padding_y)
623
+
624
+ # Crop face region
625
+ face_crop = resized_image.crop((x1, y1, x2, y2))
626
+
627
+ # MULTI-SCALE PROCESSING
628
+ face_embeddings = self.extract_multi_scale_face(face_crop, face)
629
+
630
+ # Enhance face crop
631
+ face_crop_enhanced = enhance_face_crop(face_crop)
632
+
633
+ # Draw keypoints
634
+ face_kps = face.kps
635
+ face_kps_image = draw_kps(resized_image, face_kps)
636
+
637
+ # ENHANCED: Extract comprehensive facial attributes
638
+ from utils import get_facial_attributes, build_enhanced_prompt
639
+ facial_attrs = get_facial_attributes(face)
640
+
641
+ # Update prompt with detected attributes
642
+ prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD)
643
+
644
+ # Legacy output for compatibility
645
+ age = facial_attrs['age']
646
+ gender_code = facial_attrs['gender']
647
+ det_score = facial_attrs['quality']
648
+
649
+ gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
650
+ print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
651
+ print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
652
+ else:
653
+ print("✗ InsightFace found no faces")
654
+
655
+ except Exception as e:
656
+ print(f"[ERROR] InsightFace detection failed: {e}")
657
+ import traceback
658
+ traceback.print_exc()
659
+ else:
660
+ print("[INFO] InsightFace not available (face_app is None)")
661
+
662
+ # If InsightFace didn't succeed, try MediapipeFace
663
+ if not insightface_success:
664
+ if self.mediapipe_face is not None:
665
+ print("Trying MediapipeFaceDetector as fallback...")
666
+
667
+ try:
668
+ # MediapipeFace returns an annotated image with keypoints
669
+ mediapipe_result = self.mediapipe_face(resized_image)
670
+
671
+ # Check if face was detected (result is not blank/black)
672
+ mediapipe_array = np.array(mediapipe_result)
673
+ if mediapipe_array.sum() > 1000: # If image has significant content
674
+ has_detected_faces = True
675
+ face_kps_image = mediapipe_result
676
+ print(f"✓ MediapipeFace detected face(s)")
677
+ print(f"[INFO] Using MediapipeFace keypoints (no embeddings available)")
678
+
679
+ # Note: MediapipeFace doesn't provide embeddings or detailed info
680
+ # So face_embeddings, face_crop_enhanced remain None
681
+ # InstantID will work with keypoints only (reduced quality)
682
+ else:
683
+ print("✗ MediapipeFace found no faces")
684
+ except Exception as e:
685
+ print(f"[ERROR] MediapipeFace detection failed: {e}")
686
+ import traceback
687
+ traceback.print_exc()
688
+ else:
689
+ print("[INFO] MediapipeFaceDetector not available")
690
+
691
+ # Final summary
692
+ if not has_detected_faces:
693
+ print("\n[SUMMARY] No faces detected by any detector")
694
+ if insightface_tried:
695
+ print(" - InsightFace: tried, found nothing")
696
+ else:
697
+ print(" - InsightFace: not available")
698
+
699
+ if self.mediapipe_face is not None:
700
+ print(" - MediapipeFace: tried, found nothing")
701
+ else:
702
+ print(" - MediapipeFace: not available")
703
+ print()
704
 
705
+ # Set LORA
706
  if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
707
+ adapter_name = lora_choice.lower() # "retroart", "vga", "lucasart", or "none"
708
+
709
+ if adapter_name != "none" and self.loaded_loras.get(adapter_name, False):
710
+ try:
711
+ self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
712
+ print(f"LORA: Set adapter '{adapter_name}' with scale: {lora_scale}")
713
+ except Exception as e:
714
+ print(f"Could not set LORA adapter '{adapter_name}': {e}")
715
+ self.pipe.set_adapters([]) # Disable LORAs if setting failed
716
+ else:
717
+ if adapter_name == "none":
718
+ print("LORAs disabled by user choice.")
719
+ else:
720
+ print(f"LORA '{adapter_name}' not loaded or available, disabling LORAs.")
721
+ self.pipe.set_adapters([]) # Disable all LORAs
722
+
723
 
724
  # Prepare generation kwargs
725
  pipe_kwargs = {
 
808
  print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
809
 
810
  else:
811
+ # No face detected - blank map needed to maintain ControlNet list order
812
+ print("[INSTANTID] Using blank map (scale=0, no effect on generation)")
813
  control_images.append(Image.new("RGB", (target_width, target_height), (0,0,0)))
814
  conditioning_scales.append(0.0) # Set scale to 0
815
+ scale_debug_str.append("Identity: 0.00 (no face)")
816
 
817
  # 2. Depth
818
  if self.depth_active: