primerz commited on
Commit
69e6233
·
verified ·
1 Parent(s): b76c724

Upload 10 files

Browse files
Files changed (5) hide show
  1. app.py +2 -2
  2. config.py +7 -7
  3. generator.py +118 -83
  4. models.py +27 -67
  5. utils.py +9 -9
app.py CHANGED
@@ -106,7 +106,7 @@ def get_model_status():
106
  status_text += f"- Custom Checkpoint (Horizon): {'[OK] Loaded' if converter.models_loaded['custom_checkpoint'] else '[OK] Using SDXL base'}\n"
107
  status_text += f"- LORA (RetroArt): {'[OK] Loaded' if converter.models_loaded['lora'] else ' Disabled'}\n"
108
  status_text += f"- InstantID: {'[OK] Loaded' if converter.models_loaded['instantid'] else ' Disabled'}\n"
109
- status_text += f"- Depth: Grayscale (simple & reliable)\n"
110
  status_text += f"- IP-Adapter (Face Embeddings): {'[OK] Loaded' if converter.models_loaded.get('ip_adapter', False) else ' Keypoints only'}\n"
111
  return status_text
112
  return "**Model status unavailable**"
@@ -351,7 +351,7 @@ with gr.Blocks(title="Pixagram - AI Pixel Art Generator", theme=gr.themes.Soft()
351
  **[ADAPTIVE] Automatic Adjustments:**
352
  - Small faces (< 50K px): Boosts identity preservation to 1.8
353
  - Low confidence (< 80%): Increases identity control to 0.9
354
- - Profile views (> 20° yaw): Enhances preservation to 1.7
355
  - Good quality faces: Uses your selected parameters
356
 
357
  **[PARAMETERS] Parameter Relationships:**
 
106
  status_text += f"- Custom Checkpoint (Horizon): {'[OK] Loaded' if converter.models_loaded['custom_checkpoint'] else '[OK] Using SDXL base'}\n"
107
  status_text += f"- LORA (RetroArt): {'[OK] Loaded' if converter.models_loaded['lora'] else ' Disabled'}\n"
108
  status_text += f"- InstantID: {'[OK] Loaded' if converter.models_loaded['instantid'] else ' Disabled'}\n"
109
+ status_text += f"- Zoe Depth: {'[OK] Loaded' if converter.models_loaded['zoe_depth'] else ' Fallback'}\n"
110
  status_text += f"- IP-Adapter (Face Embeddings): {'[OK] Loaded' if converter.models_loaded.get('ip_adapter', False) else ' Keypoints only'}\n"
111
  return status_text
112
  return "**Model status unavailable**"
 
351
  **[ADAPTIVE] Automatic Adjustments:**
352
  - Small faces (< 50K px): Boosts identity preservation to 1.8
353
  - Low confidence (< 80%): Increases identity control to 0.9
354
+ - Profile views (> 20° yaw): Enhances preservation to 1.7
355
  - Good quality faces: Uses your selected parameters
356
 
357
  **[PARAMETERS] Parameter Relationships:**
config.py CHANGED
@@ -24,18 +24,18 @@ TRIGGER_WORD = "p1x3l4rt, pixel art"
24
 
25
  # Face detection configuration
26
  FACE_DETECTION_CONFIG = {
27
- "model_name": "buffalo_l",
28
  "det_size": (640, 640),
29
  "ctx_id": 0
30
  }
31
 
32
- # Recommended resolutions (multiples of 64 for stable diffusion)
33
  RECOMMENDED_SIZES = [
34
- (896, 1152), # Portrait (14:18 ratio)
35
- (1152, 896), # Landscape (18:14 ratio)
36
- (832, 1216), # Tall portrait (13:19 ratio)
37
- (1216, 832), # Wide landscape (19:13 ratio)
38
- (1024, 1024) # Square (1:1 ratio)
39
  ]
40
 
41
  # Default generation parameters
 
24
 
25
  # Face detection configuration
26
  FACE_DETECTION_CONFIG = {
27
+ "model_name": "antelopev2",
28
  "det_size": (640, 640),
29
  "ctx_id": 0
30
  }
31
 
32
+ # Recommended resolutions
33
  RECOMMENDED_SIZES = [
34
+ (896, 1152), # Portrait
35
+ (1152, 896), # Landscape
36
+ (832, 1216), # Tall portrait
37
+ (1216, 832), # Wide landscape
38
+ (1024, 1024) # Square
39
  ]
40
 
41
  # Default generation parameters
generator.py CHANGED
@@ -33,16 +33,16 @@ class RetroArtConverter:
33
  'custom_checkpoint': False,
34
  'lora': False,
35
  'instantid': False,
36
- 'leres_depth': False,
37
  'ip_adapter': False
38
  }
39
 
40
  # Initialize face analysis
41
  self.face_app, self.face_detection_enabled = load_face_analysis()
42
 
43
- # Skip depth detector - using grayscale conversion instead
44
- self.leres_depth = None
45
- self.models_loaded['leres_depth'] = False
46
 
47
  # Load ControlNets
48
  controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
@@ -81,7 +81,6 @@ class RetroArtConverter:
81
  self.models_loaded['ip_adapter'] = False
82
  self.image_proj_model = None
83
 
84
- # Setup Compel
85
  # Setup Compel
86
  self.compel, self.use_compel = setup_compel(self.pipe)
87
 
@@ -147,29 +146,48 @@ class RetroArtConverter:
147
  print("============================\n")
148
 
149
  def get_depth_map(self, image):
150
- """Generate depth map using grayscale conversion for reliability"""
151
- try:
152
- # Ensure RGB mode
153
- if image.mode != 'RGB':
154
- image = image.convert('RGB')
155
-
156
- # Convert to grayscale for depth
157
- gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
158
-
159
- # Apply some enhancement to make depth more pronounced
160
- gray = cv2.equalizeHist(gray)
161
-
162
- # Convert back to RGB format (ControlNet expects RGB)
163
- depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
164
- depth_image = Image.fromarray(depth_colored)
165
-
166
- print(f"[DEPTH] Grayscale depth map generated: {image.size}")
167
- return depth_image
168
- except Exception as e:
169
- print(f"[DEPTH] Depth generation failed ({e}), using basic grayscale")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
171
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
172
  return Image.fromarray(depth_colored)
 
173
 
174
  def add_trigger_word(self, prompt):
175
  """Add trigger word to prompt if not present"""
@@ -443,7 +461,7 @@ class RetroArtConverter:
443
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
444
 
445
  # Generate depth map
446
- print("Generating grayscale depth map...")
447
  depth_image = self.get_depth_map(resized_image)
448
  if depth_image.size != (target_width, target_height):
449
  depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
@@ -524,34 +542,13 @@ class RetroArtConverter:
524
  print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
525
  print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
526
 
527
- # Set LORA scale - use fuse_lora for immediate effect
528
- if hasattr(self.pipe, 'fuse_lora') and self.models_loaded['lora']:
529
  try:
530
- self.pipe.fuse_lora(lora_scale=lora_scale)
531
- print(f"[LORA] Fused with scale: {lora_scale}")
532
  except Exception as e:
533
- print(f"[WARNING] LORA fuse failed: {e}")
534
- # Try set_adapters as fallback
535
- try:
536
- for adapter_name in ["retroart", "default_0"]:
537
- try:
538
- self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
539
- print(f"[LORA] Set adapter '{adapter_name}' with scale: {lora_scale}")
540
- break
541
- except:
542
- continue
543
- except Exception as e2:
544
- print(f"[WARNING] LORA set_adapters also failed: {e2}")
545
-
546
- except Exception as e:
547
- print(f"[WARNING] LORA set_adapters failed: {e}")
548
- # Try fuse_lora as fallback
549
- try:
550
- if hasattr(self.pipe, 'fuse_lora'):
551
- self.pipe.fuse_lora(lora_scale=lora_scale)
552
- print(f"[LORA] Fused with scale: {lora_scale}")
553
- except Exception as e2:
554
- print(f"[INFO] LORA using default scale")
555
 
556
  # Prepare generation kwargs
557
  pipe_kwargs = {
@@ -573,37 +570,76 @@ class RetroArtConverter:
573
 
574
  pipe_kwargs["generator"] = generator
575
 
576
- # Use Compel for prompt encoding (critical for quality)
577
- negative_conditioning = None # Initialize for later use
578
  if self.use_compel and self.compel is not None:
579
  try:
580
  print("Encoding prompts with Compel...")
581
 
582
- # Direct tuple unpacking as in working example
583
- conditioning, pooled = self.compel(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
 
585
- # Handle negative prompt conditionally
586
- if negative_prompt and negative_prompt.strip():
587
- negative_conditioning, negative_pooled = self.compel(negative_prompt)
588
- else:
589
- negative_conditioning, negative_pooled = None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
- # Set embeddings for pipeline
592
- pipe_kwargs["prompt_embeds"] = conditioning
593
- pipe_kwargs["pooled_prompt_embeds"] = pooled
594
- pipe_kwargs["negative_prompt_embeds"] = negative_conditioning
595
- pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled
596
 
 
597
  print("[OK] Using Compel-encoded prompts")
598
  except Exception as e:
599
- print(f"[FALLBACK] Compel failed ({e}), using standard encoding")
600
- pipe_kwargs["prompt"] = prompt
601
- pipe_kwargs["negative_prompt"] = negative_prompt if negative_prompt and negative_prompt.strip() else None
602
- else:
603
- # Fallback to native SDXL encoding
604
- print("Using standard SDXL prompt encoding...")
605
- pipe_kwargs["prompt"] = prompt
606
- pipe_kwargs["negative_prompt"] = negative_prompt if negative_prompt and negative_prompt.strip() else None
607
 
608
  # Add CLIP skip
609
  if hasattr(self.pipe, 'text_encoder'):
@@ -632,7 +668,7 @@ class RetroArtConverter:
632
  # Reshape for Resampler: [1, 1, 512]
633
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
634
 
635
- # Pass through Resampler: [1, 1, 512] → [1, 16, 2048]
636
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
637
 
638
  # Scale with identity preservation
@@ -643,13 +679,13 @@ class RetroArtConverter:
643
  print(f" - Resampler output: {face_proj_embeds.shape}")
644
  print(f" - Scale: {boosted_scale:.2f}")
645
 
646
- # Handle face embeddings with or without Compel
647
  if 'prompt_embeds' in pipe_kwargs:
648
- # Compel is being used - concatenate embeddings
649
  original_embeds = pipe_kwargs['prompt_embeds']
650
 
651
  # Handle CFG (classifier-free guidance)
652
- if negative_conditioning is not None:
653
  # Duplicate for negative + positive
654
  face_proj_embeds = torch.cat([
655
  torch.zeros_like(face_proj_embeds), # Negative
@@ -662,11 +698,10 @@ class RetroArtConverter:
662
 
663
  print(f" - Text embeds: {original_embeds.shape}")
664
  print(f" - Combined embeds: {combined_embeds.shape}")
665
- print(f" [OK] Face embeddings concatenated with text embeddings!")
 
666
  else:
667
- # Native encoding - use image_embeds parameter
668
- pipe_kwargs['image_embeds'] = face_proj_embeds
669
- print(f" [OK] Face embeddings set via image_embeds!")
670
 
671
  elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
672
  # Face detected but embeddings unavailable
@@ -721,4 +756,4 @@ class RetroArtConverter:
721
  return generated_image
722
 
723
 
724
- print("[OK] Generator class ready")
 
33
  'custom_checkpoint': False,
34
  'lora': False,
35
  'instantid': False,
36
+ 'zoe_depth': False,
37
  'ip_adapter': False
38
  }
39
 
40
  # Initialize face analysis
41
  self.face_app, self.face_detection_enabled = load_face_analysis()
42
 
43
+ # Load Zoe Depth detector
44
+ self.zoe_depth, zoe_success = load_depth_detector()
45
+ self.models_loaded['zoe_depth'] = zoe_success
46
 
47
  # Load ControlNets
48
  controlnet_depth, self.controlnet_instantid, instantid_success = load_controlnets()
 
81
  self.models_loaded['ip_adapter'] = False
82
  self.image_proj_model = None
83
 
 
84
  # Setup Compel
85
  self.compel, self.use_compel = setup_compel(self.pipe)
86
 
 
146
  print("============================\n")
147
 
148
  def get_depth_map(self, image):
149
+ """Generate depth map using Zoe Depth"""
150
+ if self.zoe_depth is not None:
151
+ try:
152
+ if image.mode != 'RGB':
153
+ image = image.convert('RGB')
154
+
155
+ orig_width, orig_height = image.size
156
+ orig_width = int(orig_width)
157
+ orig_height = int(orig_height)
158
+
159
+ # FIXED: Use multiples of 64 (not 32)
160
+ target_width = int((orig_width // 64) * 64)
161
+ target_height = int((orig_height // 64) * 64)
162
+
163
+ target_width = int(max(64, target_width))
164
+ target_height = int(max(64, target_height))
165
+
166
+ if target_width != orig_width or target_height != orig_height:
167
+ image = image.resize((int(target_width), int(target_height)), Image.LANCZOS)
168
+ print(f"[DEPTH] Resized for ZoeDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
169
+
170
+ # FIXED: Add torch.no_grad() wrapper
171
+ with torch.no_grad():
172
+ depth_image = self.zoe_depth(image)
173
+
174
+ depth_width, depth_height = depth_image.size
175
+ if depth_width != orig_width or depth_height != orig_height:
176
+ depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
177
+
178
+ print(f"[DEPTH] Zoe depth map generated: {orig_width}x{orig_height}")
179
+ return depth_image
180
+
181
+ except Exception as e:
182
+ print(f"[DEPTH] ZoeDetector failed ({e}), falling back to grayscale depth")
183
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
184
+ depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
185
+ return Image.fromarray(depth_colored)
186
+ else:
187
  gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
188
  depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
189
  return Image.fromarray(depth_colored)
190
+
191
 
192
  def add_trigger_word(self, prompt):
193
  """Add trigger word to prompt if not present"""
 
461
  resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
462
 
463
  # Generate depth map
464
+ print("Generating Zoe depth map...")
465
  depth_image = self.get_depth_map(resized_image)
466
  if depth_image.size != (target_width, target_height):
467
  depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
 
542
  print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
543
  print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
544
 
545
+ # Set LORA scale
546
+ if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
547
  try:
548
+ self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
549
+ print(f"LORA scale: {lora_scale}")
550
  except Exception as e:
551
+ print(f"Could not set LORA scale: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  # Prepare generation kwargs
554
  pipe_kwargs = {
 
570
 
571
  pipe_kwargs["generator"] = generator
572
 
 
 
573
  if self.use_compel and self.compel is not None:
574
  try:
575
  print("Encoding prompts with Compel...")
576
 
577
+ try:
578
+ # Tuple unpacking: (prompt_embeds, pooled_prompt_embeds)
579
+ conditioning = self.compel(prompt)
580
+ prompt_embeds, pooled_prompt_embeds = conditioning
581
+
582
+ # Handle negative prompt conditionally
583
+ if negative_prompt and negative_prompt.strip():
584
+ negative_conditioning = self.compel(negative_prompt)
585
+ negative_prompt_embeds, negative_pooled_prompt_embeds = negative_conditioning
586
+ else:
587
+ # Use zeros for negative
588
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
589
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
590
+
591
+ except RuntimeError as e:
592
+ error_msg = str(e)
593
+ if ("size of tensor" in error_msg and "must match" in error_msg) or "dimension" in error_msg:
594
+ print(f"[COMPEL] Token length mismatch detected: {e}")
595
+ print(f"[COMPEL] Falling back to standard prompt encoding")
596
+ raise
597
+ else:
598
+ raise
599
+
600
+ # Handle token length mismatch by padding/truncating to 77 tokens
601
+ target_length = 77
602
 
603
+ if prompt_embeds.shape[1] != target_length or negative_prompt_embeds.shape[1] != target_length:
604
+ print(f"[COMPEL] Adjusting token lengths: pos={prompt_embeds.shape[1]}, neg={negative_prompt_embeds.shape[1]} -> {target_length}")
605
+
606
+ # Truncate or pad positive embeddings
607
+ if prompt_embeds.shape[1] > target_length:
608
+ prompt_embeds = prompt_embeds[:, :target_length, :]
609
+ elif prompt_embeds.shape[1] < target_length:
610
+ padding = torch.zeros(
611
+ prompt_embeds.shape[0],
612
+ target_length - prompt_embeds.shape[1],
613
+ prompt_embeds.shape[2],
614
+ dtype=prompt_embeds.dtype,
615
+ device=prompt_embeds.device
616
+ )
617
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
618
+
619
+ # Truncate or pad negative embeddings
620
+ if negative_prompt_embeds.shape[1] > target_length:
621
+ negative_prompt_embeds = negative_prompt_embeds[:, :target_length, :]
622
+ elif negative_prompt_embeds.shape[1] < target_length:
623
+ padding = torch.zeros(
624
+ negative_prompt_embeds.shape[0],
625
+ target_length - negative_prompt_embeds.shape[1],
626
+ negative_prompt_embeds.shape[2],
627
+ dtype=negative_prompt_embeds.dtype,
628
+ device=negative_prompt_embeds.device
629
+ )
630
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, padding], dim=1)
631
 
632
+ pipe_kwargs["prompt_embeds"] = prompt_embeds
633
+ pipe_kwargs["pooled_prompt_embeds"] = pooled_prompt_embeds
634
+ pipe_kwargs["negative_prompt_embeds"] = negative_prompt_embeds
635
+ pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
 
636
 
637
+ compel_success = True
638
  print("[OK] Using Compel-encoded prompts")
639
  except Exception as e:
640
+ print(f"[COMPEL] Encoding failed: {e}")
641
+ print(f"[COMPEL] Using standard prompt encoding instead")
642
+ compel_success = False
 
 
 
 
 
643
 
644
  # Add CLIP skip
645
  if hasattr(self.pipe, 'text_encoder'):
 
668
  # Reshape for Resampler: [1, 1, 512]
669
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
670
 
671
+ # Pass through Resampler: [1, 1, 512] [1, 16, 2048]
672
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
673
 
674
  # Scale with identity preservation
 
679
  print(f" - Resampler output: {face_proj_embeds.shape}")
680
  print(f" - Scale: {boosted_scale:.2f}")
681
 
682
+ # CRITICAL: Concatenate with text embeddings (not separate kwargs!)
683
  if 'prompt_embeds' in pipe_kwargs:
684
+ # Compel encoded prompts
685
  original_embeds = pipe_kwargs['prompt_embeds']
686
 
687
  # Handle CFG (classifier-free guidance)
688
+ if original_embeds.shape[0] > 1: # Has negative + positive
689
  # Duplicate for negative + positive
690
  face_proj_embeds = torch.cat([
691
  torch.zeros_like(face_proj_embeds), # Negative
 
698
 
699
  print(f" - Text embeds: {original_embeds.shape}")
700
  print(f" - Combined embeds: {combined_embeds.shape}")
701
+ print(f" [OK] Face embeddings concatenated successfully!")
702
+
703
  else:
704
+ print(f" [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
 
 
705
 
706
  elif has_detected_faces and self.models_loaded.get('ip_adapter', False):
707
  # Face detected but embeddings unavailable
 
756
  return generated_image
757
 
758
 
759
+ print("[OK] Generator class ready")
models.py CHANGED
@@ -13,7 +13,7 @@ from diffusers import (
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
  from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
- from controlnet_aux import LeresDetector
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
@@ -82,15 +82,15 @@ def load_face_analysis():
82
 
83
 
84
  def load_depth_detector():
85
- """Load Leres Depth detector for better quality."""
86
- print("Loading Leres Depth detector...")
87
  try:
88
- leres_depth = LeresDetector.from_pretrained("lllyasviel/Annotators")
89
- leres_depth.to(device)
90
- print(" [OK] Leres Depth loaded successfully")
91
- return leres_depth, True
92
  except Exception as e:
93
- print(f" [WARNING] Leres Depth not available: {e}")
94
  return None, False
95
 
96
 
@@ -164,19 +164,12 @@ def load_lora(pipe):
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
- # Load with explicit adapter name to avoid default_0
168
- pipe.load_lora_weights(lora_path, adapter_name="retroart")
169
- print(f" [OK] LORA loaded successfully as 'retroart' adapter")
170
  return True
171
  except Exception as e:
172
- # Fallback to default loading
173
- try:
174
- pipe.load_lora_weights(lora_path)
175
- print(f" [OK] LORA loaded successfully (default adapter)")
176
- return True
177
- except Exception as e2:
178
- print(f" [WARNING] Could not load LORA: {e2}")
179
- return False
180
 
181
 
182
  def setup_ip_adapter(pipe, image_encoder):
@@ -198,29 +191,15 @@ def setup_ip_adapter(pipe, image_encoder):
198
  # Load full state dict
199
  state_dict = torch.load(ip_adapter_path, map_location="cpu")
200
 
201
- # Debug: Print available keys
202
- print(f"[DEBUG] State dict keys sample: {list(state_dict.keys())[:5]}")
203
-
204
- # Extract image_proj and ip_adapter weights with flexible key matching
205
  image_proj_state_dict = {}
206
  ip_adapter_state_dict = {}
207
 
208
  for key, value in state_dict.items():
209
- # Handle different possible key formats
210
- if "image_proj" in key:
211
- # Remove any prefix before image_proj
212
- clean_key = key.split("image_proj.")[-1] if "image_proj." in key else key
213
- image_proj_state_dict[clean_key] = value
214
- elif "ip_adapter" in key or "to_k_ip" in key or "to_v_ip" in key:
215
- # IP adapter weights might not have prefix
216
- if "ip_adapter." in key:
217
- clean_key = key.replace("ip_adapter.", "")
218
- else:
219
- clean_key = key
220
- ip_adapter_state_dict[clean_key] = value
221
-
222
- print(f"[DEBUG] Found {len(image_proj_state_dict)} image_proj weights")
223
- print(f"[DEBUG] Found {len(ip_adapter_state_dict)} ip_adapter weights")
224
 
225
  # Create Resampler (image projection model) with CORRECT parameters from reference
226
  print("Creating Resampler (Perceiver architecture)...")
@@ -241,25 +220,13 @@ def setup_ip_adapter(pipe, image_encoder):
241
  # Load image_proj weights
242
  if image_proj_state_dict:
243
  try:
244
- # Check if weights are nested under 'image_proj' key
245
- if 'image_proj' in image_proj_state_dict and isinstance(image_proj_state_dict['image_proj'], dict):
246
- actual_weights = image_proj_state_dict['image_proj']
247
- else:
248
- actual_weights = image_proj_state_dict
249
-
250
- # Try loading the weights
251
- missing, unexpected = image_proj_model.load_state_dict(actual_weights, strict=False)
252
  print(" [OK] Resampler loaded with pretrained weights")
253
- if missing:
254
- print(f" Missing keys: {missing[:5]}...")
255
- if unexpected:
256
- print(f" Unexpected keys: {unexpected[:5]}...")
257
  except Exception as e:
258
  print(f" [WARNING] Could not load Resampler weights: {e}")
259
  print(" Using randomly initialized Resampler")
260
  else:
261
- print(" [WARNING] No image_proj weights found in state dict")
262
- print(" Using randomly initialized Resampler")
263
 
264
  # Setup IP-Adapter attention processors
265
  print("Setting up IP-Adapter attention processors...")
@@ -293,30 +260,23 @@ def setup_ip_adapter(pipe, image_encoder):
293
  # Set attention processors
294
  pipe.unet.set_attn_processor(attn_procs)
295
 
296
- # Load IP-Adapter weights into attention processors (optional - face preservation works without it)
297
  if ip_adapter_state_dict:
298
  try:
299
- # Count successfully loaded processors
300
- loaded_count = 0
301
- for name, processor in pipe.unet.attn_processors.items():
302
- if hasattr(processor, 'to_k_ip') and hasattr(processor, 'to_v_ip'):
303
- loaded_count += 1
304
-
305
- if loaded_count > 0:
306
- print(f" [OK] Found {loaded_count} IP-Adapter processors ready")
307
- print(" [INFO] IP-Adapter weights available but skipping complex loading")
308
- print(" Face preservation will use ControlNet + Resampler embeddings")
309
  except Exception as e:
310
- pass
311
  else:
312
- print(" [INFO] No IP-Adapter weights found")
313
 
314
  # Store image encoder and projection model
315
  pipe.image_encoder = image_encoder
316
 
317
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
318
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
319
- print(f" - Face embeddings: 512D → 16x2048D")
320
 
321
  return image_proj_model, True
322
 
@@ -328,7 +288,7 @@ def setup_ip_adapter(pipe, image_encoder):
328
 
329
 
330
  def setup_compel(pipe):
331
- """Setup Compel for SDXL prompt handling - based on working example."""
332
  print("Setting up Compel for enhanced prompt processing...")
333
  try:
334
  compel = Compel(
 
13
  from diffusers.models.attention_processor import AttnProcessor2_0
14
  from transformers import CLIPVisionModelWithProjection
15
  from insightface.app import FaceAnalysis
16
+ from controlnet_aux import ZoeDetector
17
  from huggingface_hub import hf_hub_download
18
  from compel import Compel, ReturnedEmbeddingsType
19
 
 
82
 
83
 
84
  def load_depth_detector():
85
+ """Load Zoe Depth detector."""
86
+ print("Loading Zoe Depth detector...")
87
  try:
88
+ zoe_depth = ZoeDetector.from_pretrained("lllyasviel/Annotators")
89
+ zoe_depth.to(device)
90
+ print(" [OK] Zoe Depth loaded successfully")
91
+ return zoe_depth, True
92
  except Exception as e:
93
+ print(f" [WARNING] Zoe Depth not available: {e}")
94
  return None, False
95
 
96
 
 
164
  print("Loading LORA (retroart) from HuggingFace Hub...")
165
  try:
166
  lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
167
+ pipe.load_lora_weights(lora_path)
168
+ print(f" [OK] LORA loaded successfully")
 
169
  return True
170
  except Exception as e:
171
+ print(f" [WARNING] Could not load LORA: {e}")
172
+ return False
 
 
 
 
 
 
173
 
174
 
175
  def setup_ip_adapter(pipe, image_encoder):
 
191
  # Load full state dict
192
  state_dict = torch.load(ip_adapter_path, map_location="cpu")
193
 
194
+ # Extract image_proj and ip_adapter weights
 
 
 
195
  image_proj_state_dict = {}
196
  ip_adapter_state_dict = {}
197
 
198
  for key, value in state_dict.items():
199
+ if key.startswith("image_proj."):
200
+ image_proj_state_dict[key.replace("image_proj.", "")] = value
201
+ elif key.startswith("ip_adapter."):
202
+ ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  # Create Resampler (image projection model) with CORRECT parameters from reference
205
  print("Creating Resampler (Perceiver architecture)...")
 
220
  # Load image_proj weights
221
  if image_proj_state_dict:
222
  try:
223
+ image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
 
 
 
 
 
 
 
224
  print(" [OK] Resampler loaded with pretrained weights")
 
 
 
 
225
  except Exception as e:
226
  print(f" [WARNING] Could not load Resampler weights: {e}")
227
  print(" Using randomly initialized Resampler")
228
  else:
229
+ print(" [WARNING] No image_proj weights found, using random initialization")
 
230
 
231
  # Setup IP-Adapter attention processors
232
  print("Setting up IP-Adapter attention processors...")
 
260
  # Set attention processors
261
  pipe.unet.set_attn_processor(attn_procs)
262
 
263
+ # Load IP-Adapter weights into attention processors
264
  if ip_adapter_state_dict:
265
  try:
266
+ ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
267
+ ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
268
+ print(" [OK] IP-Adapter attention weights loaded")
 
 
 
 
 
 
 
269
  except Exception as e:
270
+ print(f" [WARNING] Could not load IP-Adapter weights: {e}")
271
  else:
272
+ print(" [WARNING] No ip_adapter weights found")
273
 
274
  # Store image encoder and projection model
275
  pipe.image_encoder = image_encoder
276
 
277
  print(" [OK] IP-Adapter fully loaded with InstantID architecture")
278
  print(f" - Resampler: 4 layers, 20 heads, 16 output tokens")
279
+ print(f" - Face embeddings: 512D → 16x2048D")
280
 
281
  return image_proj_model, True
282
 
 
288
 
289
 
290
  def setup_compel(pipe):
291
+ """Setup Compel for better SDXL prompt handling."""
292
  print("Setting up Compel for enhanced prompt processing...")
293
  try:
294
  compel = Compel(
utils.py CHANGED
@@ -395,10 +395,10 @@ def get_demographic_description(age, gender_code):
395
 
396
  def calculate_optimal_size(original_width, original_height, recommended_sizes=None, max_dimension=1536):
397
  """
398
- Calculate optimal size maintaining aspect ratio with dimensions as multiples of 64.
399
 
400
  This updated version supports ANY aspect ratio (not just predefined ones),
401
- while ensuring dimensions are multiples of 64 and keeping total pixels reasonable.
402
 
403
  Args:
404
  original_width: Original image width
@@ -407,7 +407,7 @@ def calculate_optimal_size(original_width, original_height, recommended_sizes=No
407
  max_dimension: Maximum allowed dimension (default 1536)
408
 
409
  Returns:
410
- Tuple of (optimal_width, optimal_height) as multiples of 64
411
  """
412
  aspect_ratio = original_width / original_height
413
 
@@ -423,7 +423,7 @@ def calculate_optimal_size(original_width, original_height, recommended_sizes=No
423
  best_diff = diff
424
  best_match = (width, height)
425
 
426
- # Ensure dimensions are multiples of 64
427
  width, height = best_match
428
  width = int((width // 64) * 64)
429
  height = int((height // 64) * 64)
@@ -431,7 +431,7 @@ def calculate_optimal_size(original_width, original_height, recommended_sizes=No
431
  return width, height
432
 
433
  # NEW: Support any aspect ratio
434
- # Strategy: Keep aspect ratio, scale to reasonable total pixels, round to multiples of 64
435
 
436
  # Target total pixels (around 1 megapixel for SDXL, adjustable)
437
  target_pixels = 1024 * 1024 # ~1MP, good balance for SDXL
@@ -455,7 +455,7 @@ def calculate_optimal_size(original_width, original_height, recommended_sizes=No
455
  optimal_height = max_dimension
456
  optimal_width = optimal_height * aspect_ratio
457
 
458
- # Round to nearest multiple of 64
459
  width = int(round(optimal_width / 64) * 64)
460
  height = int(round(optimal_height / 64) * 64)
461
 
@@ -469,9 +469,9 @@ def calculate_optimal_size(original_width, original_height, recommended_sizes=No
469
  height = min_dimension
470
  width = int(round((height * aspect_ratio) / 64) * 64)
471
 
472
- # Final safety check: ensure multiples of 64
473
- width = max(64, int((width // 64) * 64))
474
- height = max(64, int((height // 64) * 64))
475
 
476
  print(f"[SIZING] Aspect ratio: {aspect_ratio:.3f}, Output: {width}x{height} ({width*height/1e6:.2f}MP)")
477
 
 
395
 
396
  def calculate_optimal_size(original_width, original_height, recommended_sizes=None, max_dimension=1536):
397
  """
398
+ Calculate optimal size maintaining aspect ratio with dimensions as multiples of 8.
399
 
400
  This updated version supports ANY aspect ratio (not just predefined ones),
401
+ while ensuring dimensions are multiples of 8 and keeping total pixels reasonable.
402
 
403
  Args:
404
  original_width: Original image width
 
407
  max_dimension: Maximum allowed dimension (default 1536)
408
 
409
  Returns:
410
+ Tuple of (optimal_width, optimal_height) as multiples of 8
411
  """
412
  aspect_ratio = original_width / original_height
413
 
 
423
  best_diff = diff
424
  best_match = (width, height)
425
 
426
+ # Ensure dimensions are multiples of 8
427
  width, height = best_match
428
  width = int((width // 64) * 64)
429
  height = int((height // 64) * 64)
 
431
  return width, height
432
 
433
  # NEW: Support any aspect ratio
434
+ # Strategy: Keep aspect ratio, scale to reasonable total pixels, round to multiples of 8
435
 
436
  # Target total pixels (around 1 megapixel for SDXL, adjustable)
437
  target_pixels = 1024 * 1024 # ~1MP, good balance for SDXL
 
455
  optimal_height = max_dimension
456
  optimal_width = optimal_height * aspect_ratio
457
 
458
+ # Round to nearest multiple of 8
459
  width = int(round(optimal_width / 64) * 64)
460
  height = int(round(optimal_height / 64) * 64)
461
 
 
469
  height = min_dimension
470
  width = int(round((height * aspect_ratio) / 64) * 64)
471
 
472
+ # Final safety check: ensure multiples of 8
473
+ width = max(8, int((width // 64) * 64))
474
+ height = max(8, int((height // 64) * 64))
475
 
476
  print(f"[SIZING] Aspect ratio: {aspect_ratio:.3f}, Output: {width}x{height} ({width*height/1e6:.2f}MP)")
477