oliveryanzuolu commited on
Commit
6580922
·
verified ·
1 Parent(s): 65663ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -89
app.py CHANGED
@@ -4,13 +4,11 @@ import spaces
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
7
- import os
8
 
9
- # Diffusers and ControlNet imports
10
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
11
 
12
  # -----------------------------------------------------------------------------
13
- # 1. Configuration & Registry (SDXL Version)
14
  # -----------------------------------------------------------------------------
15
  LORA_REGISTRY = {
16
  "None (Base SDXL)": {
@@ -39,50 +37,35 @@ LORA_REGISTRY = {
39
  }
40
 
41
  # -----------------------------------------------------------------------------
42
- # 2. Model Initialization
43
  # -----------------------------------------------------------------------------
44
- print("Initializing SDXL Inference Pipeline...")
45
 
46
- # On ZeroGPU, we initialize standard variables, but we rely on the decorator for device placement
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
- dtype = torch.float16
49
-
50
- # 1. Load VAE (Critical for SDXL fp16 stability to avoid NaNs)
51
  vae = AutoencoderKL.from_pretrained(
52
  "madebyollin/sdxl-vae-fp16-fix",
53
- torch_dtype=dtype
54
  )
55
 
56
- # 2. Load ControlNet (Must be SDXL version)
57
  controlnet = ControlNetModel.from_pretrained(
58
  "diffusers/controlnet-canny-sdxl-1.0",
59
- torch_dtype=dtype,
60
  use_safetensors=True
61
  )
62
 
63
- # 3. Load Base SDXL
64
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
65
  "stabilityai/stable-diffusion-xl-base-1.0",
66
  controlnet=controlnet,
67
  vae=vae,
68
- torch_dtype=dtype,
69
  use_safetensors=True
70
  )
71
 
72
- # Optimization
73
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
74
 
75
- # For ZeroGPU/Spaces, enable_model_cpu_offload is the standard way to handle SDXL
76
- # This registers hooks that automatically move layers to GPU when the @spaces.GPU function is called
77
- try:
78
- pipe.enable_model_cpu_offload()
79
- except Exception as e:
80
- print(f"Offload warning: {e}")
81
-
82
- print("SDXL Pipeline Loaded Successfully.")
83
 
84
  # -----------------------------------------------------------------------------
85
- # 3. Computer Vision Helper Functions
86
  # -----------------------------------------------------------------------------
87
 
88
  def get_canny_image(image, low_threshold=100, high_threshold=200):
@@ -93,7 +76,7 @@ def get_canny_image(image, low_threshold=100, high_threshold=200):
93
  return Image.fromarray(canny_edges)
94
 
95
  # -----------------------------------------------------------------------------
96
- # 4. Inference Logic
97
  # -----------------------------------------------------------------------------
98
 
99
  @spaces.GPU(duration=120)
@@ -107,14 +90,14 @@ def generate_controlled_image(
107
  seed
108
  ):
109
  if input_image is None:
110
- raise gr.Error("Validation Error: Please upload an image first!")
111
 
112
- # 1. Preprocess Image (SDXL works best at 1024x1024)
113
  width, height = 1024, 1024
114
  input_image = input_image.resize((width, height))
115
  canny_image = get_canny_image(input_image)
116
 
117
- # 2. Configuration
 
118
  style_config = LORA_REGISTRY[lora_selection]
119
  repo_id = style_config["repo"]
120
  trigger_text = style_config["trigger"]
@@ -123,29 +106,24 @@ def generate_controlled_image(
123
 
124
  final_prompt = f"{trigger_text}{prompt}"
125
 
126
- # 3. LoRA & Generation Block
127
- # We use a try/finally block to ensure LoRA is ALWAYS unloaded,
128
- # preventing state corruption on the shared GPU.
129
- try:
130
- # A. Load LoRA
131
- if repo_id:
132
  print(f"Loading LoRA: {repo_id}")
133
- # Ensure we are in a clean state before loading
134
- pipe.unload_lora_weights()
135
-
136
  if lora_file:
137
  pipe.load_lora_weights(repo_id, weight_name=lora_file)
138
  else:
139
  pipe.load_lora_weights(repo_id)
140
-
141
- pipe.fuse_lora(lora_scale=lora_weight)
142
- print("LoRA fused successfully.")
 
 
 
143
 
144
- # B. Generate
145
- generator = torch.Generator("cuda").manual_seed(int(seed))
146
- print(f"Generating with Prompt: {final_prompt}")
147
-
148
- output_image = pipe(
149
  prompt=final_prompt,
150
  negative_prompt=negative_prompt,
151
  image=canny_image,
@@ -153,43 +131,33 @@ def generate_controlled_image(
153
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
154
  guidance_scale=7.0,
155
  generator=generator,
156
- ).images
157
-
158
  except Exception as e:
 
159
  raise e
160
-
161
- finally:
162
- # C. Cleanup (Always run this)
163
- if repo_id:
164
- print("Cleaning up LoRA weights...")
165
- try:
166
- pipe.unfuse_lora()
167
- pipe.unload_lora_weights()
168
- except Exception as cleanup_error:
169
- print(f"Cleanup warning: {cleanup_error}")
170
-
171
- # Explicit cache clearing for ZeroGPU shared environment
172
- torch.cuda.empty_cache()
173
 
174
  return canny_image, output_image
175
 
176
  # -----------------------------------------------------------------------------
177
- # 5. Gradio UI Architecture
178
  # -----------------------------------------------------------------------------
179
 
180
  css = """
181
- #col-container {max_width: 1200px; margin-left: auto; margin-right: auto;}
182
  .guide-text {font-size: 1.1em; color: #4a5568;}
183
  """
184
 
185
- # Example Data
186
  examples = [
187
  [
188
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
189
  "a colorful exotic bird sitting on a branch, detailed feathers, masterpiece, 8k",
190
  "blurry, low quality, deformed, illustration",
191
  "None (Base SDXL)",
192
- 1.0, 30, 42
193
  ],
194
  [
195
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_depth.png",
@@ -203,14 +171,14 @@ examples = [
203
  "pixel art, a cute bird, isometric view, retro game asset, 8-bit graphics",
204
  "photorealistic, vector, high resolution, smooth, 3d render",
205
  "Pixel Art XL",
206
- 1.0, 30, 202
207
  ],
208
  [
209
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_mlsd.png",
210
- "made-of-clay, claymation style, interior of a modern living room, stop motion animation, plasticine texture, fingerprint textures",
211
  "cgi, 3d render, glossy, architectural visualization",
212
  "Claymation Style XL",
213
- 1.0, 30, 303
214
  ],
215
  ]
216
 
@@ -222,65 +190,56 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
222
  """
223
  <p class='guide-text'>
224
  <b>SDXL Edition.</b><br>
225
- Higher resolution, better prompt adherence, and native LoRA support.
226
- Uses <b>ControlNet Canny (SDXL)</b> for structure.
227
  </p>
228
  """
229
  )
230
 
231
  with gr.Row():
232
- # Left Column: Inputs
233
  with gr.Column(scale=1):
234
- input_image = gr.Image(label="Input Image (Structure)", type="pil", sources=["upload", "clipboard"])
235
 
236
  prompt = gr.Textbox(
237
  label="Prompt",
238
  value="A house on a hill, sunny day, masterpiece",
239
- placeholder="Describe the content...",
240
  lines=2
241
  )
242
 
243
  negative_prompt = gr.Textbox(
244
  label="Negative Prompt",
245
- value="blurry, low quality, distorted, ugly, bad anatomy, watermark, text",
246
  lines=1
247
  )
248
 
249
  lora_selection = gr.Dropdown(
250
- label="Select LoRA Style",
251
  choices=list(LORA_REGISTRY.keys()),
252
- value="None (Base SDXL)",
253
- info="Automatically injects trigger words and loads weights."
254
  )
255
 
256
- with gr.Accordion("⚙️ Advanced Settings", open=False):
257
  controlnet_conditioning_scale = gr.Slider(
258
  label="ControlNet Strength",
259
- minimum=0.0, maximum=1.5, value=0.8, step=0.1,
260
- info="SDXL ControlNet is strong. 0.8 is usually a good sweet spot."
261
  )
262
- steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, value=30, step=1)
263
  seed = gr.Number(label="Seed", value=42, precision=0)
264
 
265
- submit_btn = gr.Button("Generate Art", variant="primary", size="lg")
266
 
267
- # Right Column: Outputs
268
  with gr.Column(scale=1):
269
  with gr.Row():
270
- output_canny = gr.Image(label="Detected Edges", type="pil")
271
- output_result = gr.Image(label="Final Stylized Image", type="pil")
272
 
273
- # Examples Section
274
- gr.Markdown("### 🔍 Try These Examples")
275
  gr.Examples(
276
  examples=examples,
277
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
278
  outputs=[output_canny, output_result],
279
  fn=generate_controlled_image,
280
- cache_examples=False # Must be False for ZeroGPU async compatibility
281
  )
282
 
283
- # Event Wiring
284
  submit_btn.click(
285
  fn=generate_controlled_image,
286
  inputs=[
 
4
  import cv2
5
  import numpy as np
6
  from PIL import Image
 
7
 
 
8
  from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
9
 
10
  # -----------------------------------------------------------------------------
11
+ # Configuration & Registry
12
  # -----------------------------------------------------------------------------
13
  LORA_REGISTRY = {
14
  "None (Base SDXL)": {
 
37
  }
38
 
39
  # -----------------------------------------------------------------------------
40
+ # Model Initialization (CPU only, ZeroGPU handles device transfer)
41
  # -----------------------------------------------------------------------------
42
+ print("Initializing SDXL Pipeline on CPU...")
43
 
 
 
 
 
 
44
  vae = AutoencoderKL.from_pretrained(
45
  "madebyollin/sdxl-vae-fp16-fix",
46
+ torch_dtype=torch.float16
47
  )
48
 
 
49
  controlnet = ControlNetModel.from_pretrained(
50
  "diffusers/controlnet-canny-sdxl-1.0",
51
+ torch_dtype=torch.float16,
52
  use_safetensors=True
53
  )
54
 
 
55
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
56
  "stabilityai/stable-diffusion-xl-base-1.0",
57
  controlnet=controlnet,
58
  vae=vae,
59
+ torch_dtype=torch.float16,
60
  use_safetensors=True
61
  )
62
 
 
63
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
64
 
65
+ print("Pipeline loaded. ZeroGPU will handle device management.")
 
 
 
 
 
 
 
66
 
67
  # -----------------------------------------------------------------------------
68
+ # Helper Functions
69
  # -----------------------------------------------------------------------------
70
 
71
  def get_canny_image(image, low_threshold=100, high_threshold=200):
 
76
  return Image.fromarray(canny_edges)
77
 
78
  # -----------------------------------------------------------------------------
79
+ # Inference Logic
80
  # -----------------------------------------------------------------------------
81
 
82
  @spaces.GPU(duration=120)
 
90
  seed
91
  ):
92
  if input_image is None:
93
+ raise gr.Error("Please upload an image first!")
94
 
 
95
  width, height = 1024, 1024
96
  input_image = input_image.resize((width, height))
97
  canny_image = get_canny_image(input_image)
98
 
99
+ pipe.unload_lora_weights()
100
+
101
  style_config = LORA_REGISTRY[lora_selection]
102
  repo_id = style_config["repo"]
103
  trigger_text = style_config["trigger"]
 
106
 
107
  final_prompt = f"{trigger_text}{prompt}"
108
 
109
+ if repo_id:
110
+ try:
 
 
 
 
111
  print(f"Loading LoRA: {repo_id}")
 
 
 
112
  if lora_file:
113
  pipe.load_lora_weights(repo_id, weight_name=lora_file)
114
  else:
115
  pipe.load_lora_weights(repo_id)
116
+ print("LoRA loaded successfully.")
117
+ except Exception as e:
118
+ print(f"LoRA Load Error: {e}")
119
+ gr.Warning(f"Failed to load LoRA. Using base model.")
120
+
121
+ generator = torch.Generator("cuda").manual_seed(int(seed))
122
 
123
+ print(f"Generating: {final_prompt[:100]}...")
124
+
125
+ try:
126
+ output = pipe(
 
127
  prompt=final_prompt,
128
  negative_prompt=negative_prompt,
129
  image=canny_image,
 
131
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
132
  guidance_scale=7.0,
133
  generator=generator,
134
+ )
135
+ output_image = output.images[0]
136
  except Exception as e:
137
+ pipe.unload_lora_weights()
138
  raise e
139
+
140
+ pipe.unload_lora_weights()
141
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
142
 
143
  return canny_image, output_image
144
 
145
  # -----------------------------------------------------------------------------
146
+ # Gradio UI
147
  # -----------------------------------------------------------------------------
148
 
149
  css = """
150
+ #col-container {max-width: 1200px; margin-left: auto; margin-right: auto;}
151
  .guide-text {font-size: 1.1em; color: #4a5568;}
152
  """
153
 
 
154
  examples = [
155
  [
156
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_canny.png",
157
  "a colorful exotic bird sitting on a branch, detailed feathers, masterpiece, 8k",
158
  "blurry, low quality, deformed, illustration",
159
  "None (Base SDXL)",
160
+ 0.8, 30, 42
161
  ],
162
  [
163
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_depth.png",
 
171
  "pixel art, a cute bird, isometric view, retro game asset, 8-bit graphics",
172
  "photorealistic, vector, high resolution, smooth, 3d render",
173
  "Pixel Art XL",
174
+ 0.8, 30, 202
175
  ],
176
  [
177
  "https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_mlsd.png",
178
+ "made-of-clay, claymation style, interior of a modern living room, stop motion animation, plasticine texture",
179
  "cgi, 3d render, glossy, architectural visualization",
180
  "Claymation Style XL",
181
+ 0.8, 30, 303
182
  ],
183
  ]
184
 
 
190
  """
191
  <p class='guide-text'>
192
  <b>SDXL Edition.</b><br>
193
+ Uses ControlNet Canny (SDXL) for structure preservation with LoRA styles.
 
194
  </p>
195
  """
196
  )
197
 
198
  with gr.Row():
 
199
  with gr.Column(scale=1):
200
+ input_image = gr.Image(label="Input Image", type="pil", sources=["upload", "clipboard"])
201
 
202
  prompt = gr.Textbox(
203
  label="Prompt",
204
  value="A house on a hill, sunny day, masterpiece",
 
205
  lines=2
206
  )
207
 
208
  negative_prompt = gr.Textbox(
209
  label="Negative Prompt",
210
+ value="blurry, low quality, distorted, ugly, watermark",
211
  lines=1
212
  )
213
 
214
  lora_selection = gr.Dropdown(
215
+ label="LoRA Style",
216
  choices=list(LORA_REGISTRY.keys()),
217
+ value="None (Base SDXL)"
 
218
  )
219
 
220
+ with gr.Accordion("Advanced Settings", open=False):
221
  controlnet_conditioning_scale = gr.Slider(
222
  label="ControlNet Strength",
223
+ minimum=0.0, maximum=1.5, value=0.8, step=0.1
 
224
  )
225
+ steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=30, step=1)
226
  seed = gr.Number(label="Seed", value=42, precision=0)
227
 
228
+ submit_btn = gr.Button("Generate", variant="primary", size="lg")
229
 
 
230
  with gr.Column(scale=1):
231
  with gr.Row():
232
+ output_canny = gr.Image(label="Canny Edges", type="pil")
233
+ output_result = gr.Image(label="Result", type="pil")
234
 
 
 
235
  gr.Examples(
236
  examples=examples,
237
  inputs=[input_image, prompt, negative_prompt, lora_selection, controlnet_conditioning_scale, steps, seed],
238
  outputs=[output_canny, output_result],
239
  fn=generate_controlled_image,
240
+ cache_examples=False
241
  )
242
 
 
243
  submit_btn.click(
244
  fn=generate_controlled_image,
245
  inputs=[