primerz commited on
Commit
8cf9ae9
·
verified ·
1 Parent(s): 06a5771

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +41 -80
generator.py CHANGED
@@ -149,42 +149,15 @@ class RetroArtConverter:
149
  """Generate depth map using Zoe Depth"""
150
  if self.zoe_depth is not None:
151
  try:
152
- # Ensure clean PIL Image with proper dimensions
153
  if image.mode != 'RGB':
154
  image = image.convert('RGB')
155
 
156
- # Get dimensions - ensure they're Python ints (not numpy)
157
- orig_width, orig_height = image.size
158
- # Force conversion to Python int to avoid numpy types
159
- orig_width = int(orig_width.item() if hasattr(orig_width, 'item') else orig_width)
160
- orig_height = int(orig_height.item() if hasattr(orig_height, 'item') else orig_height)
161
-
162
- # Resize to dimensions ZoeDetector expects (multiples of 32)
163
- # CRITICAL: Ensure Python int, not numpy types
164
- target_width = int((orig_width // 32) * 32)
165
- target_height = int((orig_height // 32) * 32)
166
-
167
- # Ensure at least 32x32
168
- target_width = int(max(32, target_width))
169
- target_height = int(max(32, target_height))
170
-
171
- if target_width != orig_width or target_height != orig_height:
172
- # CRITICAL: Pass explicit Python ints to resize
173
- image = image.resize((int(target_width), int(target_height)), Image.LANCZOS)
174
- print(f"[DEPTH] Resized for ZoeDetector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
175
-
176
- # Use Zoe detector - now with safe dimensions
177
  depth_image = self.zoe_depth(image)
178
 
179
- # Resize back to original if needed
180
- depth_width, depth_height = depth_image.size
181
- # Ensure Python ints (not numpy)
182
- depth_width = int(depth_width)
183
- depth_height = int(depth_height)
184
- if depth_width != orig_width or depth_height != orig_height:
185
- depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
186
-
187
- print(f"[DEPTH] Zoe depth map generated: {orig_width}x{orig_height}")
188
  return depth_image
189
 
190
  except Exception as e:
@@ -622,58 +595,45 @@ class RetroArtConverter:
622
  try:
623
  print("Encoding prompts with Compel...")
624
 
625
- # Try to encode both prompts
626
- try:
627
- conditioning = self.compel(prompt)
628
- negative_conditioning = self.compel(negative_prompt)
629
- except RuntimeError as e:
630
- # Token length mismatch during encoding - this is a known SDXL+Compel issue
631
- error_msg = str(e)
632
- if ("size of tensor" in error_msg and "must match" in error_msg) or "dimension" in error_msg:
633
- print(f"[COMPEL] Token length mismatch detected: {e}")
634
- print(f"[COMPEL] Falling back to standard prompt encoding")
635
- raise # Raise to outer except to use standard prompts
636
- else:
637
- raise # Re-raise if it's a different error
638
 
639
- # Extract embeddings
640
  prompt_embeds = conditioning[0]
641
  pooled_prompt_embeds = conditioning[1]
642
  negative_prompt_embeds = negative_conditioning[0]
643
  negative_pooled_prompt_embeds = negative_conditioning[1]
644
 
645
- # Handle token length mismatch by padding/truncating to 77 tokens (SDXL standard)
646
- target_length = 77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
 
648
- # Check and fix length mismatches
649
- if prompt_embeds.shape[1] != target_length or negative_prompt_embeds.shape[1] != target_length:
650
- print(f"[COMPEL] Adjusting token lengths: pos={prompt_embeds.shape[1]}, neg={negative_prompt_embeds.shape[1]} -> {target_length}")
651
-
652
- # Truncate or pad positive embeddings
653
- if prompt_embeds.shape[1] > target_length:
654
- prompt_embeds = prompt_embeds[:, :target_length, :]
655
- elif prompt_embeds.shape[1] < target_length:
656
- padding = torch.zeros(
657
- prompt_embeds.shape[0],
658
- target_length - prompt_embeds.shape[1],
659
- prompt_embeds.shape[2],
660
- dtype=prompt_embeds.dtype,
661
- device=prompt_embeds.device
662
- )
663
- prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
664
-
665
- # Truncate or pad negative embeddings
666
- if negative_prompt_embeds.shape[1] > target_length:
667
- negative_prompt_embeds = negative_prompt_embeds[:, :target_length, :]
668
- elif negative_prompt_embeds.shape[1] < target_length:
669
- padding = torch.zeros(
670
- negative_prompt_embeds.shape[0],
671
- target_length - negative_prompt_embeds.shape[1],
672
- negative_prompt_embeds.shape[2],
673
- dtype=negative_prompt_embeds.dtype,
674
- device=negative_prompt_embeds.device
675
- )
676
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, padding], dim=1)
677
 
678
  pipe_kwargs["prompt_embeds"] = prompt_embeds
679
  pipe_kwargs["pooled_prompt_embeds"] = pooled_prompt_embeds
@@ -681,10 +641,11 @@ class RetroArtConverter:
681
  pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
682
 
683
  compel_success = True
684
- print("[OK] Using Compel-encoded prompts")
 
685
  except Exception as e:
686
- print(f"[COMPEL] Encoding failed: {e}")
687
- print(f"[COMPEL] Using standard prompt encoding instead")
688
  compel_success = False
689
 
690
  # Use standard prompts if Compel failed or not available
@@ -719,7 +680,7 @@ class RetroArtConverter:
719
  # Reshape for Resampler: [1, 1, 512]
720
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
721
 
722
- # Pass through Resampler: [1, 1, 512] [1, 16, 2048]
723
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
724
 
725
  # Scale with identity preservation
 
149
  """Generate depth map using Zoe Depth"""
150
  if self.zoe_depth is not None:
151
  try:
152
+ # Ensure RGB mode
153
  if image.mode != 'RGB':
154
  image = image.convert('RGB')
155
 
156
+ # ZoeDetector handles resizing internally - just call it
157
+ # It returns PIL Image matching input size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  depth_image = self.zoe_depth(image)
159
 
160
+ print(f"[DEPTH] Zoe depth map generated: {image.size[0]}x{image.size[1]}")
 
 
 
 
 
 
 
 
161
  return depth_image
162
 
163
  except Exception as e:
 
595
  try:
596
  print("Encoding prompts with Compel...")
597
 
598
+ # Encode prompts
599
+ conditioning = self.compel(prompt)
600
+ negative_conditioning = self.compel(negative_prompt)
 
 
 
 
 
 
 
 
 
 
601
 
602
+ # Extract embeddings - Compel returns (prompt_embeds, pooled_embeds)
603
  prompt_embeds = conditioning[0]
604
  pooled_prompt_embeds = conditioning[1]
605
  negative_prompt_embeds = negative_conditioning[0]
606
  negative_pooled_prompt_embeds = negative_conditioning[1]
607
 
608
+ # Ensure consistent shapes (SDXL uses 77 tokens max)
609
+ max_length = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
610
+
611
+ # Pad if needed
612
+ if prompt_embeds.shape[1] < max_length:
613
+ padding = torch.zeros(
614
+ prompt_embeds.shape[0],
615
+ max_length - prompt_embeds.shape[1],
616
+ prompt_embeds.shape[2],
617
+ dtype=prompt_embeds.dtype,
618
+ device=prompt_embeds.device
619
+ )
620
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
621
+
622
+ if negative_prompt_embeds.shape[1] < max_length:
623
+ padding = torch.zeros(
624
+ negative_prompt_embeds.shape[0],
625
+ max_length - negative_prompt_embeds.shape[1],
626
+ negative_prompt_embeds.shape[2],
627
+ dtype=negative_prompt_embeds.dtype,
628
+ device=negative_prompt_embeds.device
629
+ )
630
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, padding], dim=1)
631
 
632
+ # Truncate if needed
633
+ if prompt_embeds.shape[1] > 77:
634
+ prompt_embeds = prompt_embeds[:, :77, :]
635
+ if negative_prompt_embeds.shape[1] > 77:
636
+ negative_prompt_embeds = negative_prompt_embeds[:, :77, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
 
638
  pipe_kwargs["prompt_embeds"] = prompt_embeds
639
  pipe_kwargs["pooled_prompt_embeds"] = pooled_prompt_embeds
 
641
  pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
642
 
643
  compel_success = True
644
+ print(f"[OK] Compel encoded: pos={prompt_embeds.shape}, neg={negative_prompt_embeds.shape}")
645
+
646
  except Exception as e:
647
+ print(f"[COMPEL] Failed: {e}")
648
+ print("[COMPEL] Falling back to standard encoding")
649
  compel_success = False
650
 
651
  # Use standard prompts if Compel failed or not available
 
680
  # Reshape for Resampler: [1, 1, 512]
681
  face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
682
 
683
+ # Pass through Resampler: [1, 1, 512] → [1, 16, 2048]
684
  face_proj_embeds = self.image_proj_model(face_emb_tensor)
685
 
686
  # Scale with identity preservation