LPX55 commited on
Commit
1063084
·
verified ·
1 Parent(s): 5dd4601

Update optimized.py

Browse files
Files changed (1) hide show
  1. optimized.py +15 -28
optimized.py CHANGED
@@ -83,38 +83,25 @@ pipe.enable_attention_slicing(1)
83
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
84
  @spaces.GPU
85
  def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
86
- # Clean up input handling
87
  control_image = load_image(control_image)
88
  w, h = control_image.size
89
- scale = min(scale, 2.0) # Cap scale factor
90
-
91
- # Size calculation with safety limits
92
- max_dim = 1536 # Set based on your VRAM
93
- target_w = min(int(w * scale), max_dim)
94
- target_h = min(int(h * scale), max_dim)
95
-
96
- control_image = control_image.resize(
97
- (target_w, target_h),
98
- PIL.Image.BICUBIC
99
- )
100
-
101
- # Generation with memory-friendly parameters
102
- with torch.autocast("cuda"): # Mixed precision
103
- image = pipe(
104
- prompt=prompt,
105
- control_image=control_image,
106
- controlnet_conditioning_scale=controlnet_conditioning_scale,
107
- num_inference_steps=steps,
108
- guidance_scale=guidance_scale,
109
- height=target_h,
110
- width=target_w,
111
- output_type="pil", # Avoid extra latent decoding steps
112
- generator=torch.Generator(device="cuda").manual_seed(0)
113
- ).images[0]
114
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
115
  # Aggressive memory cleanup
116
- torch.cuda.empty_cache()
117
- torch.cuda.ipc_collect()
118
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
119
  return image
120
  # Create Gradio interface
 
83
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
84
  @spaces.GPU
85
  def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale):
86
+ # Load control image
87
  control_image = load_image(control_image)
88
  w, h = control_image.size
89
+ # Upscale x1
90
+ control_image = control_image.resize((int(w * scale), int(h * scale)))
91
+ print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
92
+ image = pipe(
93
+ prompt=prompt,
94
+ control_image=control_image,
95
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
96
+ num_inference_steps=steps,
97
+ guidance_scale=guidance_scale,
98
+ height=control_image.size[1],
99
+ width=control_image.size[0]
100
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
102
  # Aggressive memory cleanup
103
+ # torch.cuda.empty_cache()
104
+ # torch.cuda.ipc_collect()
105
  print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f}GB")
106
  return image
107
  # Create Gradio interface