primerz commited on
Commit
5740014
·
verified ·
1 Parent(s): d36a1ac

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +38 -43
generator.py CHANGED
@@ -762,36 +762,31 @@ class RetroArtConverter:
762
 
763
  pipe_kwargs["generator"] = generator
764
 
765
- # --- START FIX 1: Correct Compel batching and slicing ---
766
  if self.use_compel and self.compel is not None:
767
  try:
768
  print("Encoding prompts with Compel...")
769
 
770
- # Pass both prompts as a list to be batched
771
- conditioning_batch, pooled_batch = self.compel([prompt, negative_prompt])
 
772
 
773
- # Store positive and negative embeds separately for now
774
- positive_prompt_embeds = conditioning_batch[0:1]
775
- positive_pooled_embeds = pooled_batch[0:1]
776
- negative_prompt_embeds = conditioning_batch[1:2]
777
- negative_pooled_embeds = pooled_batch[1:2]
778
-
779
- print(f"[OK] Compel encoded - Pos: {positive_prompt_embeds.shape}, Neg: {negative_prompt_embeds.shape}")
780
-
781
- # Put the positive embeds in pipe_kwargs for the *next* step
782
- pipe_kwargs["prompt_embeds"] = positive_prompt_embeds
783
- pipe_kwargs["pooled_prompt_embeds"] = positive_pooled_embeds
784
 
 
785
  except Exception as e:
786
  print(f"Compel encoding failed, using standard prompts: {e}")
787
  traceback.print_exc()
788
  pipe_kwargs["prompt"] = prompt
789
  pipe_kwargs["negative_prompt"] = negative_prompt
790
- self.use_compel = False # Fallback to standard
791
  else:
792
  pipe_kwargs["prompt"] = prompt
793
  pipe_kwargs["negative_prompt"] = negative_prompt
794
- # --- END FIX 1 ---
795
 
796
  # Add CLIP skip
797
  if hasattr(self.pipe, 'text_encoder'):
@@ -835,38 +830,38 @@ class RetroArtConverter:
835
 
836
  print(f" - Face embedding: {face_proj_embeds.shape}, Scale: {boosted_scale:.2f}")
837
 
838
- # --- START FIX 2: Correct CFG and Negative Padding ---
839
- if self.use_compel and 'prompt_embeds' in pipe_kwargs:
840
- # 1. Get the Compel-generated embeds
841
- positive_embeds = pipe_kwargs['prompt_embeds']
842
-
843
- # 2. Concatenate face embeddings to POSITIVE prompt
844
- final_positive_embeds = torch.cat([positive_embeds, face_proj_embeds], dim=1)
845
-
846
- # 3. Create zero padding for NEGATIVE prompt (YOUR FIX)
847
- neg_padding = torch.zeros_like(face_proj_embeds)
848
-
849
- # 4. Concatenate zero padding to NEGATIVE prompt
850
- final_negative_embeds = torch.cat([negative_prompt_embeds, neg_padding], dim=1)
851
 
852
- # 5. Create the final CFG batch (shape [2, 109, 2048])
853
- pipe_kwargs['prompt_embeds'] = torch.cat([final_negative_embeds, final_positive_embeds], dim=0)
 
854
 
855
- # 6. Do the same for the pooled embeds (shape [2, 1280])
856
- pipe_kwargs['pooled_prompt_embeds'] = torch.cat([negative_pooled_embeds, positive_pooled_embeds], dim=0)
857
-
858
- # 7. CRITICAL: Remove the separate negative_prompt_embeds
 
859
  if 'negative_prompt_embeds' in pipe_kwargs:
860
- del pipe_kwargs['negative_prompt_embeds']
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
- print(f" [OK] CFG batch created! Embeds: {pipe_kwargs['prompt_embeds'].shape}, Pooled: {pipe_kwargs['pooled_prompt_embeds'].shape}")
863
-
864
  else:
865
- # Fallback if Compel failed
866
- print(f" [WARNING] Can't concatenate - Compel failed. Using standard prompt.")
867
- pipe_kwargs['prompt'] = prompt
868
- pipe_kwargs['negative_prompt'] = negative_prompt
869
-
870
  # --- END FIX 2 ---
871
 
872
  elif has_detected_faces:
 
762
 
763
  pipe_kwargs["generator"] = generator
764
 
765
+ # --- START FIX: Use Compel as per documentation ---
766
  if self.use_compel and self.compel is not None:
767
  try:
768
  print("Encoding prompts with Compel...")
769
 
770
+ # Call Compel with prompt and negative_prompt kwargs
771
+ # Compel will handle the padding internally
772
+ conditioning = self.compel(prompt, negative_prompt=negative_prompt)
773
 
774
+ # Unpack the results from the returned object
775
+ pipe_kwargs["prompt_embeds"] = conditioning.embeds
776
+ pipe_kwargs["pooled_prompt_embeds"] = conditioning.pooled_embeds
777
+ pipe_kwargs["negative_prompt_embeds"] = conditioning.negative_embeds
778
+ pipe_kwargs["negative_pooled_prompt_embeds"] = conditioning.negative_pooled_embeds
 
 
 
 
 
 
779
 
780
+ print(f"[OK] Compel encoded - Prompt: {pipe_kwargs['prompt_embeds'].shape}, Negative: {pipe_kwargs['negative_prompt_embeds'].shape}")
781
  except Exception as e:
782
  print(f"Compel encoding failed, using standard prompts: {e}")
783
  traceback.print_exc()
784
  pipe_kwargs["prompt"] = prompt
785
  pipe_kwargs["negative_prompt"] = negative_prompt
 
786
  else:
787
  pipe_kwargs["prompt"] = prompt
788
  pipe_kwargs["negative_prompt"] = negative_prompt
789
+ # --- END FIX ---
790
 
791
  # Add CLIP skip
792
  if hasattr(self.pipe, 'text_encoder'):
 
830
 
831
  print(f" - Face embedding: {face_proj_embeds.shape}, Scale: {boosted_scale:.2f}")
832
 
833
+ # --- START FIX 2: Your padding solution ---
834
+ if 'prompt_embeds' in pipe_kwargs:
835
+ original_embeds = pipe_kwargs['prompt_embeds']
 
 
 
 
 
 
 
 
 
 
836
 
837
+ # Handle CFG by creating a [2, 16, 2048] tensor
838
+ # [0] is zeros for negative, [1] is face embeds for positive
839
+ face_proj_embeds_cfg = torch.cat([torch.zeros_like(face_proj_embeds), face_proj_embeds], dim=0)
840
 
841
+ # Concatenate face embeddings to POSITIVE prompt
842
+ combined_embeds = torch.cat([original_embeds, face_proj_embeds_cfg[1:2]], dim=1) # [1, 93, 2048] -> [1, 109, 2048]
843
+ pipe_kwargs['prompt_embeds'] = combined_embeds
844
+
845
+ # CRITICAL: Pad negative_prompt_embeds by the same amount
846
  if 'negative_prompt_embeds' in pipe_kwargs:
847
+ negative_embeds = pipe_kwargs['negative_prompt_embeds']
848
+ # Create zero padding [1, 16, 2048]
849
+ neg_padding = torch.zeros(
850
+ (
851
+ negative_embeds.shape[0], # 1
852
+ face_proj_embeds.shape[1], # 16
853
+ negative_embeds.shape[2], # 2048
854
+ ),
855
+ device=negative_embeds.device,
856
+ dtype=negative_embeds.dtype
857
+ )
858
+ # Concatenate zero padding to NEGATIVE prompt
859
+ pipe_kwargs['negative_prompt_embeds'] = torch.cat([negative_embeds, neg_padding], dim=1)
860
+ print(f" [OK] Negative prompt padded to match: {pipe_kwargs['negative_prompt_embeds'].shape}")
861
 
862
+ print(f" [OK] Face embeddings concatenated successfully! Prompt: {combined_embeds.shape}")
 
863
  else:
864
+ print(f" [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
 
 
 
 
865
  # --- END FIX 2 ---
866
 
867
  elif has_detected_faces: