primerz commited on
Commit
8fe797f
·
verified ·
1 Parent(s): cec85e0

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +18 -21
generator.py CHANGED
@@ -228,7 +228,10 @@ class RetroArtConverter:
228
 
229
  # Use torch.no_grad() and clear cache
230
  with torch.no_grad():
 
 
231
  depth_image = self.depth_detector(image_for_depth)
 
232
 
233
  # ADDED: Clear GPU cache after depth detection
234
  if torch.cuda.is_available():
@@ -448,6 +451,9 @@ class RetroArtConverter:
448
  num_beams = CAPTION_CONFIG['num_beams']
449
 
450
  try:
 
 
 
451
  if self.caption_model_type == "blip2":
452
  # BLIP-2 specific processing
453
  inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
@@ -496,10 +502,12 @@ class RetroArtConverter:
496
 
497
  caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
498
 
 
499
  return caption.strip()
500
 
501
  except Exception as e:
502
  print(f"Caption generation failed: {e}")
 
503
  return None
504
 
505
  def generate_retro_art(
@@ -568,9 +576,13 @@ class RetroArtConverter:
568
  if self.openpose_active:
569
  print("Generating OpenPose map...")
570
  try:
 
 
571
  openpose_image = self.openpose_detector(resized_image, face_only=True)
 
572
  except Exception as e:
573
  print(f"OpenPose failed, using blank map: {e}")
 
574
  openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
575
 
576
  # --- FIX END ---
@@ -692,7 +704,7 @@ class RetroArtConverter:
692
  else:
693
  print("✗ MediapipeFace found no faces")
694
  except Exception as e:
695
- print(f"[ERROR] MediapipeFace detection failed: {e}")
696
  import traceback
697
  traceback.print_exc()
698
  else:
@@ -751,20 +763,14 @@ class RetroArtConverter:
751
 
752
  pipe_kwargs["generator"] = generator
753
 
754
- # Use Compel for prompt encoding if available
 
 
755
  if self.use_compel and self.compel is not None:
756
  try:
757
  print("Encoding prompts with Compel...")
758
-
759
- # --- FIX: Move text encoders to GPU for Compel ---
760
- self.pipe.text_encoder.to(self.device)
761
- self.pipe.text_encoder_2.to(self.device)
762
- # --- END FIX ---
763
-
764
- # --- FIX: Remove 'device=self.device' argument ---
765
  conditioning = self.compel(prompt)
766
  negative_conditioning = self.compel(negative_prompt)
767
- # --- END FIX ---
768
 
769
  pipe_kwargs["prompt_embeds"] = conditioning[0]
770
  pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
@@ -774,21 +780,12 @@ class RetroArtConverter:
774
  print("[OK] Using Compel-encoded prompts")
775
  except Exception as e:
776
  print(f"Compel encoding failed, using standard prompts: {e}")
777
- import traceback
778
- traceback.print_exc()
779
  pipe_kwargs["prompt"] = prompt
780
  pipe_kwargs["negative_prompt"] = negative_prompt
781
- finally:
782
- # --- FIX: Move text encoders back to CPU to save VRAM ---
783
- try:
784
- self.pipe.text_encoder.to("cpu")
785
- self.pipe.text_encoder_2.to("cpu")
786
- except Exception as e:
787
- print(f"Could not move text encoders back to CPU: {e}")
788
- # --- END FIX ---
789
  else:
790
  pipe_kwargs["prompt"] = prompt
791
  pipe_kwargs["negative_prompt"] = negative_prompt
 
792
 
793
  # Add CLIP skip
794
  if hasattr(self.pipe, 'text_encoder'):
@@ -882,7 +879,7 @@ class RetroArtConverter:
882
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
883
  print(f"Active ControlNets: {len(control_images)} (all {target_width}x{target_height})")
884
  else:
885
- print("No active ControlNfets, running standard Img2Img")
886
 
887
  # Generate
888
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
 
228
 
229
  # Use torch.no_grad() and clear cache
230
  with torch.no_grad():
231
+ # --- FIX: Move model to GPU for inference and back to CPU ---
232
+ self.depth_detector.to(self.device)
233
  depth_image = self.depth_detector(image_for_depth)
234
+ self.depth_detector.to("cpu")
235
 
236
  # ADDED: Clear GPU cache after depth detection
237
  if torch.cuda.is_available():
 
451
  num_beams = CAPTION_CONFIG['num_beams']
452
 
453
  try:
454
+ # --- FIX: Move model to GPU for inference and back to CPU ---
455
+ self.caption_model.to(self.device)
456
+
457
  if self.caption_model_type == "blip2":
458
  # BLIP-2 specific processing
459
  inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
 
502
 
503
  caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
504
 
505
+ self.caption_model.to("cpu")
506
  return caption.strip()
507
 
508
  except Exception as e:
509
  print(f"Caption generation failed: {e}")
510
+ self.caption_model.to("cpu")
511
  return None
512
 
513
  def generate_retro_art(
 
576
  if self.openpose_active:
577
  print("Generating OpenPose map...")
578
  try:
579
+ # --- FIX: Move model to GPU for inference and back to CPU ---
580
+ self.openpose_detector.to(self.device)
581
  openpose_image = self.openpose_detector(resized_image, face_only=True)
582
+ self.openpose_detector.to("cpu")
583
  except Exception as e:
584
  print(f"OpenPose failed, using blank map: {e}")
585
+ self.openpose_detector.to("cpu")
586
  openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
587
 
588
  # --- FIX END ---
 
704
  else:
705
  print("✗ MediapipeFace found no faces")
706
  except Exception as e:
707
+ print(f"ERROR] MediapipeFace detection failed: {e}")
708
  import traceback
709
  traceback.print_exc()
710
  else:
 
763
 
764
  pipe_kwargs["generator"] = generator
765
 
766
+ # --- FIX: Reverted Compel block ---
767
+ # No more try/finally, no more .to(device)
768
+ # This works because optimize_pipeline() no longer offloads the text encoders.
769
  if self.use_compel and self.compel is not None:
770
  try:
771
  print("Encoding prompts with Compel...")
 
 
 
 
 
 
 
772
  conditioning = self.compel(prompt)
773
  negative_conditioning = self.compel(negative_prompt)
 
774
 
775
  pipe_kwargs["prompt_embeds"] = conditioning[0]
776
  pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
 
780
  print("[OK] Using Compel-encoded prompts")
781
  except Exception as e:
782
  print(f"Compel encoding failed, using standard prompts: {e}")
 
 
783
  pipe_kwargs["prompt"] = prompt
784
  pipe_kwargs["negative_prompt"] = negative_prompt
 
 
 
 
 
 
 
 
785
  else:
786
  pipe_kwargs["prompt"] = prompt
787
  pipe_kwargs["negative_prompt"] = negative_prompt
788
+ # --- END FIX ---
789
 
790
  # Add CLIP skip
791
  if hasattr(self.pipe, 'text_encoder'):
 
879
  pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
880
  print(f"Active ControlNets: {len(control_images)} (all {target_width}x{target_height})")
881
  else:
882
+ print("No active ControlNets, running standard Img2Img")
883
 
884
  # Generate
885
  print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")