oliveryanzuolu commited on
Commit
65663ad
·
verified ·
1 Parent(s): 834f6ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -43,10 +43,11 @@ LORA_REGISTRY = {
43
  # -----------------------------------------------------------------------------
44
  print("Initializing SDXL Inference Pipeline...")
45
 
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
- dtype = torch.float16 if device == "cuda" else torch.float32
48
 
49
- # 1. Load VAE (Critical for SDXL fp16 stability)
50
  vae = AutoencoderKL.from_pretrained(
51
  "madebyollin/sdxl-vae-fp16-fix",
52
  torch_dtype=dtype
@@ -71,11 +72,12 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
71
  # Optimization
72
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
73
 
 
 
74
  try:
75
  pipe.enable_model_cpu_offload()
76
  except Exception as e:
77
- print(f"Warning: CPU offload failed, moving to device manually. {e}")
78
- pipe.to(device)
79
 
80
  print("SDXL Pipeline Loaded Successfully.")
81
 
@@ -112,9 +114,7 @@ def generate_controlled_image(
112
  input_image = input_image.resize((width, height))
113
  canny_image = get_canny_image(input_image)
114
 
115
- # 2. Manage LoRA State
116
- pipe.unload_lora_weights()
117
-
118
  style_config = LORA_REGISTRY[lora_selection]
119
  repo_id = style_config["repo"]
120
  trigger_text = style_config["trigger"]
@@ -123,9 +123,16 @@ def generate_controlled_image(
123
 
124
  final_prompt = f"{trigger_text}{prompt}"
125
 
 
 
 
126
  try:
 
127
  if repo_id:
128
  print(f"Loading LoRA: {repo_id}")
 
 
 
129
  if lora_file:
130
  pipe.load_lora_weights(repo_id, weight_name=lora_file)
131
  else:
@@ -133,38 +140,36 @@ def generate_controlled_image(
133
 
134
  pipe.fuse_lora(lora_scale=lora_weight)
135
  print("LoRA fused successfully.")
136
-
137
- except Exception as e:
138
- print(f"LoRA Load Error: {e}")
139
- gr.Warning(f"Failed to load LoRA {lora_selection}. Using base model. Error: {str(e)}")
140
-
141
- # 3. Generation
142
- generator = torch.Generator(device).manual_seed(int(seed))
143
 
144
- print(f"Generating with Prompt: {final_prompt}")
145
-
146
- try:
 
147
  output_image = pipe(
148
  prompt=final_prompt,
149
  negative_prompt=negative_prompt,
150
  image=canny_image,
151
  num_inference_steps=int(steps),
152
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
153
- guidance_scale=7.0, # SDXL usually prefers slightly lower CFG than SD1.5
154
  generator=generator,
155
  ).images
 
156
  except Exception as e:
157
- pipe.unfuse_lora()
158
- pipe.unload_lora_weights()
159
  raise e
160
-
161
- # 4. Cleanup
162
- if repo_id:
163
- print("Unfusing LoRA...")
164
- pipe.unfuse_lora()
165
- pipe.unload_lora_weights()
166
-
167
- torch.cuda.empty_cache()
 
 
 
 
 
168
 
169
  return canny_image, output_image
170
 
@@ -177,7 +182,7 @@ css = """
177
  .guide-text {font-size: 1.1em; color: #4a5568;}
178
  """
179
 
180
- # Example Data (Updated for SDXL context)
181
  examples = [
182
  [
183
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
@@ -272,7 +277,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
272
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
273
  outputs=[output_canny, output_result],
274
  fn=generate_controlled_image,
275
- cache_examples=False # Keep False for stability
276
  )
277
 
278
  # Event Wiring
 
43
  # -----------------------------------------------------------------------------
44
  print("Initializing SDXL Inference Pipeline...")
45
 
46
+ # On ZeroGPU, we initialize standard variables, but we rely on the decorator for device placement
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ dtype = torch.float16
49
 
50
+ # 1. Load VAE (Critical for SDXL fp16 stability to avoid NaNs)
51
  vae = AutoencoderKL.from_pretrained(
52
  "madebyollin/sdxl-vae-fp16-fix",
53
  torch_dtype=dtype
 
72
  # Optimization
73
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
74
 
75
+ # For ZeroGPU/Spaces, enable_model_cpu_offload is the standard way to handle SDXL
76
+ # This registers hooks that automatically move layers to GPU when the @spaces.GPU function is called
77
  try:
78
  pipe.enable_model_cpu_offload()
79
  except Exception as e:
80
+ print(f"Offload warning: {e}")
 
81
 
82
  print("SDXL Pipeline Loaded Successfully.")
83
 
 
114
  input_image = input_image.resize((width, height))
115
  canny_image = get_canny_image(input_image)
116
 
117
+ # 2. Configuration
 
 
118
  style_config = LORA_REGISTRY[lora_selection]
119
  repo_id = style_config["repo"]
120
  trigger_text = style_config["trigger"]
 
123
 
124
  final_prompt = f"{trigger_text}{prompt}"
125
 
126
+ # 3. LoRA & Generation Block
127
+ # We use a try/finally block to ensure LoRA is ALWAYS unloaded,
128
+ # preventing state corruption on the shared GPU.
129
  try:
130
+ # A. Load LoRA
131
  if repo_id:
132
  print(f"Loading LoRA: {repo_id}")
133
+ # Ensure we are in a clean state before loading
134
+ pipe.unload_lora_weights()
135
+
136
  if lora_file:
137
  pipe.load_lora_weights(repo_id, weight_name=lora_file)
138
  else:
 
140
 
141
  pipe.fuse_lora(lora_scale=lora_weight)
142
  print("LoRA fused successfully.")
 
 
 
 
 
 
 
143
 
144
+ # B. Generate
145
+ generator = torch.Generator("cuda").manual_seed(int(seed))
146
+ print(f"Generating with Prompt: {final_prompt}")
147
+
148
  output_image = pipe(
149
  prompt=final_prompt,
150
  negative_prompt=negative_prompt,
151
  image=canny_image,
152
  num_inference_steps=int(steps),
153
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
154
+ guidance_scale=7.0,
155
  generator=generator,
156
  ).images
157
+
158
  except Exception as e:
 
 
159
  raise e
160
+
161
+ finally:
162
+ # C. Cleanup (Always run this)
163
+ if repo_id:
164
+ print("Cleaning up LoRA weights...")
165
+ try:
166
+ pipe.unfuse_lora()
167
+ pipe.unload_lora_weights()
168
+ except Exception as cleanup_error:
169
+ print(f"Cleanup warning: {cleanup_error}")
170
+
171
+ # Explicit cache clearing for ZeroGPU shared environment
172
+ torch.cuda.empty_cache()
173
 
174
  return canny_image, output_image
175
 
 
182
  .guide-text {font-size: 1.1em; color: #4a5568;}
183
  """
184
 
185
+ # Example Data
186
  examples = [
187
  [
188
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
 
277
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
278
  outputs=[output_canny, output_result],
279
  fn=generate_controlled_image,
280
+ cache_examples=False # Must be False for ZeroGPU async compatibility
281
  )
282
 
283
  # Event Wiring