joeaa17 commited on
Commit
217e2c7
Β·
verified Β·
1 Parent(s): 6f5220b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -48
app.py CHANGED
@@ -1,24 +1,32 @@
1
  #!/usr/bin/env python
 
2
 
3
- import os, random, numpy as np, cv2, torch
4
- import gradio as gr
5
  from pathlib import Path
6
- from PIL import Image, ImageOps
7
- import PIL.Image
8
 
 
 
 
 
 
9
  import spaces
 
 
 
 
 
10
  from diffusers import (
11
  ControlNetModel,
12
  StableDiffusionXLControlNetPipeline,
13
  AutoencoderKL,
14
  EulerAncestralDiscreteScheduler,
15
  )
16
- from controlnet_aux import HEDdetector
17
- from gradio_imageslider import ImageSlider
18
 
19
  # ──────────────────────────────────────────────────────────────────────────────
20
- # Small JS helper to force dark theme (kept from your version)
21
  # ──────────────────────────────────────────────────────────────────────────────
 
22
  js_func = """
23
  function refresh() {
24
  const url = new URL(window.location);
@@ -29,20 +37,18 @@ function refresh() {
29
  }
30
  """
31
 
32
- # ──────────────────────────────────────────────────────────────────────────────
33
- # UI text
34
- # ──────────────────────────────────────────────────────────────────────────────
35
  DESCRIPTION = '''# Scribble SDXL πŸ–‹οΈπŸŒ„ β€” live updates
36
- Sketch β†’ image with SDXL ControlNet (scribble/canny). Now with **auto re-inference** when you draw or tweak settings (debounced).
37
- Models: [xinsir/controlnet-scribble-sdxl-1.0], [xinsir/controlnet-canny-sdxl-1.0], base [stabilityai/stable-diffusion-xl-base-1.0]
38
  '''
39
 
40
  if not torch.cuda.is_available():
41
  DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo is intended for GPU Spaces for good latency.</p>"
42
 
43
  # ──────────────────────────────────────────────────────────────────────────────
44
- # Styles (unchanged, but refactored into a compact mapping)
45
  # ──────────────────────────────────────────────────────────────────────────────
 
46
  style_list = [
47
  {
48
  "name": "(No style)",
@@ -61,7 +67,7 @@ style_list = [
61
  },
62
  {
63
  "name": "Anime",
64
- "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
65
  "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
66
  },
67
  {
@@ -106,6 +112,7 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str
106
  # ──────────────────────────────────────────────────────────────────────────────
107
  # Utilities
108
  # ──────────────────────────────────────────────────────────────────────────────
 
109
  def HWC3(x: np.ndarray) -> np.ndarray:
110
  assert x.dtype == np.uint8
111
  if x.ndim == 2:
@@ -137,46 +144,58 @@ def nms(x, t, s):
137
  return z
138
 
139
  def clamp_size_to_megapixels(w: int, h: int, max_mpx: float = 1.0) -> tuple[int, int]:
140
- """Scale so that w*h β‰ˆ max_mpx*1e6 (default ~1024x1024 area)."""
141
  area = w * h
142
  target = max_mpx * 1_000_000.0
143
  if area <= target:
144
- return w, h
145
  r = (target / area) ** 0.5
146
- return max(64, int(w * r)) // 8 * 8, max(64, int(h * r)) // 8 * 8 # SDXL likes multiples of 8
147
 
148
  # ─────────────────────────────────────────────────��────────────────────────────
149
- # Load models once
150
  # ──────────────────────────────────────────────────────────────────────────────
 
151
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
152
 
153
  scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
154
- "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
 
 
155
  )
156
 
157
  controlnet_scribble = ControlNetModel.from_pretrained(
158
- "xinsir/controlnet-scribble-sdxl-1.0", torch_dtype=torch.float16 if device.type=="cuda" else torch.float32
 
 
159
  )
160
  controlnet_canny = ControlNetModel.from_pretrained(
161
- "xinsir/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16 if device.type=="cuda" else torch.float32
 
 
162
  )
163
  vae = AutoencoderKL.from_pretrained(
164
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 if device.type=="cuda" else torch.float32
 
 
165
  )
166
 
167
  pipe_scribble = StableDiffusionXLControlNetPipeline.from_pretrained(
168
  "stabilityai/stable-diffusion-xl-base-1.0",
169
  controlnet=controlnet_scribble,
170
  vae=vae,
171
- torch_dtype=torch.float16 if device.type=="cuda" else torch.float32,
172
  scheduler=scheduler,
 
 
173
  )
174
  pipe_canny = StableDiffusionXLControlNetPipeline.from_pretrained(
175
  "stabilityai/stable-diffusion-xl-base-1.0",
176
  controlnet=controlnet_canny,
177
  vae=vae,
178
- torch_dtype=torch.float16 if device.type=="cuda" else torch.float32,
179
  scheduler=scheduler,
 
 
180
  )
181
 
182
  for p in (pipe_scribble, pipe_canny):
@@ -192,12 +211,12 @@ MAX_SEED = np.iinfo(np.int32).max
192
  hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
193
 
194
  # ──────────────────────────────────────────────────────────────────────────────
195
- # Core inference
196
  # ──────────────────────────────────────────────────────────────────────────────
197
- def _prepare_control_image(image_editor_value, use_hed: bool, use_canny: bool) -> Image.Image:
 
198
  """
199
- Accepts the dict from gr.ImageEditor (contains 'composite'), or a PIL.Image.
200
- Returns a PIL.Image with control map (scribble/canny/hed result).
201
  """
202
  if image_editor_value is None:
203
  return None
@@ -209,7 +228,6 @@ def _prepare_control_image(image_editor_value, use_hed: bool, use_canny: bool) -
209
  else:
210
  return None
211
 
212
- # Convert to RGB for detectors
213
  if img.mode != "RGB":
214
  img = img.convert("RGB")
215
 
@@ -224,14 +242,12 @@ def _prepare_control_image(image_editor_value, use_hed: bool, use_canny: bool) -
224
  control = np.array(control)
225
  control = nms(control, 127, 3)
226
  control = cv2.GaussianBlur(control, (0, 0), 3)
227
-
228
- # Simulate human sketch width with a soft random threshold
229
- thr = int(round(random.uniform(0.01, 0.10), 2) * 255)
230
  control[control > thr] = 255
231
  control[control < 255] = 0
232
  return Image.fromarray(control)
233
 
234
- # Default: use the editor composite as "scribble"
235
  return img
236
 
237
  def _image_size_from_editor(image_editor_value, target_mpx=1.0) -> tuple[int, int]:
@@ -256,6 +272,10 @@ def _maybe_seed(seed: int):
256
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
257
  return random.randint(0, MAX_SEED) if randomize_seed else int(seed)
258
 
 
 
 
 
259
  @spaces.GPU
260
  def run(
261
  image, # dict from ImageEditor or PIL.Image
@@ -265,7 +285,7 @@ def run(
265
  num_steps: int = 12,
266
  guidance_scale: float = 5.0,
267
  controlnet_conditioning_scale: float = 1.0,
268
- seed: int = 0,
269
  use_hed: bool = False,
270
  use_canny: bool = False,
271
  progress=gr.Progress(track_tqdm=True),
@@ -273,13 +293,10 @@ def run(
273
  if image is None or (isinstance(prompt, str) and prompt.strip() == ""):
274
  return (None, None)
275
 
276
- # Prepare control image + target size (β‰ˆ1MP for speed)
277
  ctrl_img = _prepare_control_image(image, use_hed=use_hed, use_canny=use_canny)
278
- w, h = _image_size_from_editor(image, target_mpx=1.0)
279
 
280
- # Style injection
281
  prompt_styled, neg_styled = apply_style(style_name, prompt, negative_prompt or "")
282
-
283
  g = _maybe_seed(seed)
284
  pipe = _pick_pipe(use_canny)
285
 
@@ -294,16 +311,12 @@ def run(
294
  width=w, height=h,
295
  ).images[0]
296
 
297
- # Return (control, output) for ImageSlider
298
- if isinstance(ctrl_img, Image.Image):
299
- ci = ctrl_img
300
- else:
301
- ci = Image.fromarray(ctrl_img) if ctrl_img is not None else None
302
- return (ci, out)
303
 
304
  # ──────────────────────────────────────────────────────────────────────────────
305
- # UI (with live updates wired via .change on inputs)
306
  # ──────────────────────────────────────────────────────────────────────────────
 
307
  with gr.Blocks(css="style.css", js=js_func, title="Scribble SDXL β€” Live") as demo:
308
  gr.Markdown(DESCRIPTION, elem_id="description")
309
 
@@ -347,13 +360,12 @@ with gr.Blocks(css="style.css", js=js_func, title="Scribble SDXL β€” Live") as d
347
  ]
348
  outputs = [image_slider]
349
 
350
- # Manual "Run" flow (seed randomization, clear slider, then infer)
351
  run_button.click(
352
  fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False
353
  ).then(lambda: None, inputs=None, outputs=image_slider).then(fn=run, inputs=inputs, outputs=outputs)
354
 
355
- # ── Live re-inference hooks (debounced) ───────────────────────────────────
356
- # Fire when drawing or tweaking settings. 'every' = debounce seconds.
357
  for comp in [image, prompt, negative_prompt, style, num_steps, guidance_scale,
358
  controlnet_conditioning_scale, seed, use_hed, use_canny]:
359
  comp.change(fn=run, inputs=inputs, outputs=outputs, every=0.5, queue=True)
 
1
  #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
 
4
+ import os
5
+ import random
6
  from pathlib import Path
 
 
7
 
8
+ import cv2
9
+ import numpy as np
10
+ import PIL.Image
11
+ import torch
12
+ import gradio as gr
13
  import spaces
14
+
15
+ from PIL import Image
16
+ from gradio_imageslider import ImageSlider
17
+ from controlnet_aux import HEDdetector
18
+
19
  from diffusers import (
20
  ControlNetModel,
21
  StableDiffusionXLControlNetPipeline,
22
  AutoencoderKL,
23
  EulerAncestralDiscreteScheduler,
24
  )
 
 
25
 
26
  # ──────────────────────────────────────────────────────────────────────────────
27
+ # UI text / theme helper
28
  # ──────────────────────────────────────────────────────────────────────────────
29
+
30
  js_func = """
31
  function refresh() {
32
  const url = new URL(window.location);
 
37
  }
38
  """
39
 
 
 
 
40
  DESCRIPTION = '''# Scribble SDXL πŸ–‹οΈπŸŒ„ β€” live updates
41
+ Sketch β†’ image with SDXL ControlNet (scribble/canny). Auto re-infers when you draw or tweak settings (debounced).
42
+ Models: **xinsir/controlnet-scribble-sdxl-1.0**, **xinsir/controlnet-canny-sdxl-1.0**, base **stabilityai/stable-diffusion-xl-base-1.0**.
43
  '''
44
 
45
  if not torch.cuda.is_available():
46
  DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo is intended for GPU Spaces for good latency.</p>"
47
 
48
  # ──────────────────────────────────────────────────────────────────────────────
49
+ # Styles
50
  # ──────────────────────────────────────────────────────────────────────────────
51
+
52
  style_list = [
53
  {
54
  "name": "(No style)",
 
67
  },
68
  {
69
  "name": "Anime",
70
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
71
  "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
72
  },
73
  {
 
112
  # ──────────────────────────────────────────────────────────────────────────────
113
  # Utilities
114
  # ──────────────────────────────────────────────────────────────────────────────
115
+
116
  def HWC3(x: np.ndarray) -> np.ndarray:
117
  assert x.dtype == np.uint8
118
  if x.ndim == 2:
 
144
  return z
145
 
146
  def clamp_size_to_megapixels(w: int, h: int, max_mpx: float = 1.0) -> tuple[int, int]:
147
+ """Scale so that w*h β‰ˆ max_mpx*1e6 (default ~1024x1024 area). SDXL prefers multiples of 8."""
148
  area = w * h
149
  target = max_mpx * 1_000_000.0
150
  if area <= target:
151
+ return (w // 8) * 8, (h // 8) * 8
152
  r = (target / area) ** 0.5
153
+ return max(64, int(w * r)) // 8 * 8, max(64, int(h * r)) // 8 * 8
154
 
155
  # ─────────────────────────────────────────────────��────────────────────────────
156
+ # Models (use dtype= and use_safetensors=True to avoid offload_state_dict issue)
157
  # ──────────────────────────────────────────────────────────────────────────────
158
+
159
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
160
+ DTYPE = torch.float16 if device.type == "cuda" else torch.float32
161
 
162
  scheduler = EulerAncestralDiscreteScheduler.from_pretrained(
163
+ "stabilityai/stable-diffusion-xl-base-1.0",
164
+ subfolder="scheduler",
165
+ use_safetensors=True,
166
  )
167
 
168
  controlnet_scribble = ControlNetModel.from_pretrained(
169
+ "xinsir/controlnet-scribble-sdxl-1.0",
170
+ use_safetensors=True,
171
+ dtype=DTYPE,
172
  )
173
  controlnet_canny = ControlNetModel.from_pretrained(
174
+ "xinsir/controlnet-canny-sdxl-1.0",
175
+ use_safetensors=True,
176
+ dtype=DTYPE,
177
  )
178
  vae = AutoencoderKL.from_pretrained(
179
+ "madebyollin/sdxl-vae-fp16-fix",
180
+ use_safetensors=True,
181
+ dtype=DTYPE,
182
  )
183
 
184
  pipe_scribble = StableDiffusionXLControlNetPipeline.from_pretrained(
185
  "stabilityai/stable-diffusion-xl-base-1.0",
186
  controlnet=controlnet_scribble,
187
  vae=vae,
 
188
  scheduler=scheduler,
189
+ use_safetensors=True,
190
+ dtype=DTYPE,
191
  )
192
  pipe_canny = StableDiffusionXLControlNetPipeline.from_pretrained(
193
  "stabilityai/stable-diffusion-xl-base-1.0",
194
  controlnet=controlnet_canny,
195
  vae=vae,
 
196
  scheduler=scheduler,
197
+ use_safetensors=True,
198
+ dtype=DTYPE,
199
  )
200
 
201
  for p in (pipe_scribble, pipe_canny):
 
211
  hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
212
 
213
  # ──────────────────────────────────────────────────────────────────────────────
214
+ # Pre / Post processing
215
  # ──────────────────────────────────────────────────────────────────────────────
216
+
217
+ def _prepare_control_image(image_editor_value, use_hed: bool, use_canny: bool) -> Image.Image | None:
218
  """
219
+ Accepts gr.ImageEditor dict (with 'composite') or a PIL.Image and returns a PIL.Image control map.
 
220
  """
221
  if image_editor_value is None:
222
  return None
 
228
  else:
229
  return None
230
 
 
231
  if img.mode != "RGB":
232
  img = img.convert("RGB")
233
 
 
242
  control = np.array(control)
243
  control = nms(control, 127, 3)
244
  control = cv2.GaussianBlur(control, (0, 0), 3)
245
+ thr = int(round(random.uniform(0.01, 0.10), 2) * 255) # simulate human sketch thickness
 
 
246
  control[control > thr] = 255
247
  control[control < 255] = 0
248
  return Image.fromarray(control)
249
 
250
+ # default: treat the editor composite as the scribble itself
251
  return img
252
 
253
  def _image_size_from_editor(image_editor_value, target_mpx=1.0) -> tuple[int, int]:
 
272
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
273
  return random.randint(0, MAX_SEED) if randomize_seed else int(seed)
274
 
275
+ # ──────────────────────────────────────────────────────────────────────────────
276
+ # Inference
277
+ # ──────────────────────────────────────────────────────────────────────────────
278
+
279
  @spaces.GPU
280
  def run(
281
  image, # dict from ImageEditor or PIL.Image
 
285
  num_steps: int = 12,
286
  guidance_scale: float = 5.0,
287
  controlnet_conditioning_scale: float = 1.0,
288
+ seed: int = -1,
289
  use_hed: bool = False,
290
  use_canny: bool = False,
291
  progress=gr.Progress(track_tqdm=True),
 
293
  if image is None or (isinstance(prompt, str) and prompt.strip() == ""):
294
  return (None, None)
295
 
 
296
  ctrl_img = _prepare_control_image(image, use_hed=use_hed, use_canny=use_canny)
297
+ w, h = _image_size_from_editor(image, target_mpx=1.0) # ~1MP for speed
298
 
 
299
  prompt_styled, neg_styled = apply_style(style_name, prompt, negative_prompt or "")
 
300
  g = _maybe_seed(seed)
301
  pipe = _pick_pipe(use_canny)
302
 
 
311
  width=w, height=h,
312
  ).images[0]
313
 
314
+ return (ctrl_img if isinstance(ctrl_img, Image.Image) else Image.fromarray(ctrl_img), out)
 
 
 
 
 
315
 
316
  # ──────────────────────────────────────────────────────────────────────────────
317
+ # UI
318
  # ──────────────────────────────────────────────────────────────────────────────
319
+
320
  with gr.Blocks(css="style.css", js=js_func, title="Scribble SDXL β€” Live") as demo:
321
  gr.Markdown(DESCRIPTION, elem_id="description")
322
 
 
360
  ]
361
  outputs = [image_slider]
362
 
363
+ # Manual run
364
  run_button.click(
365
  fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False
366
  ).then(lambda: None, inputs=None, outputs=image_slider).then(fn=run, inputs=inputs, outputs=outputs)
367
 
368
+ # Live re-inference (debounced)
 
369
  for comp in [image, prompt, negative_prompt, style, num_steps, guidance_scale,
370
  controlnet_conditioning_scale, seed, use_hed, use_canny]:
371
  comp.change(fn=run, inputs=inputs, outputs=outputs, every=0.5, queue=True)