stephenebert commited on
Commit
9a249bb
·
verified ·
1 Parent(s): 16ab2a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -32
app.py CHANGED
@@ -1,71 +1,74 @@
1
  import gradio as gr
2
  import torch
3
  import functools
4
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
 
6
- # Model options for switching between Stable Diffusion variants
 
 
 
 
 
7
  MODEL_OPTS = {
8
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
9
  "SDXL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
10
  "SD-Turbo (ultra-fast)": "stabilityai/sd-turbo"
11
  }
12
 
13
- # Auto-detect compute device (GPU/CPU)
14
  DEVICE = (
15
  "mps" if torch.backends.mps.is_available() else
16
  "cuda" if torch.cuda.is_available() else
17
  "cpu"
18
  )
19
- # Use float16 on GPU/MPS for performance, float32 on CPU
20
  DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
21
 
22
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
23
  def get_pipeline(model_id: str):
24
- """
25
- Lazily load and cache the Stable Diffusion pipeline for a given model.
26
- """
27
- pipe = StableDiffusionPipeline.from_pretrained(
28
- model_id,
29
- torch_dtype=DTYPE,
30
- safety_checker=None
31
- ).to(DEVICE)
32
- # Switch to a faster sampler
 
 
 
 
 
 
 
33
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
34
  return pipe
35
 
36
-
37
  def generate(prompt, steps, guidance, seed, model_name):
38
- """
39
- Generate images using the specified Stable Diffusion model.
40
- """
41
  model_id = MODEL_OPTS[model_name]
42
- # Limit steps for turbo model
43
  if "Turbo" in model_name:
44
  steps = min(int(steps), 4)
45
  pipe = get_pipeline(model_id)
46
- generator = None if int(seed) == 0 else torch.manual_seed(int(seed))
47
 
48
- # Ensure added_cond_kwargs is always a dict to avoid NoneType errors
49
- output = pipe(
50
- prompt=prompt,
51
  num_inference_steps=int(steps),
52
  guidance_scale=float(guidance),
53
  generator=generator,
54
- cross_attention_kwargs={},
55
- added_cond_kwargs={},
56
  )
57
- return output.images
58
 
59
- # Build the Gradio UI
60
  with gr.Blocks() as demo:
61
- gr.Markdown("## ✨ Model-Switcher Stable Diffusion Demo")
62
- prompt = gr.Textbox(label="Prompt", value="Retro robot in neon city")
63
- checkpoint = gr.Dropdown(choices=list(MODEL_OPTS.keys()),
64
- value="SD v1.5 (base)",
65
- label="Checkpoint")
66
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
67
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
68
- seed = gr.Number(value=0, label="Seed (0 = random)")
69
  btn = gr.Button("Generate")
70
  gallery = gr.Gallery(label="Gallery", columns=2, height="auto")
71
 
@@ -77,3 +80,4 @@ with gr.Blocks() as demo:
77
 
78
  if __name__ == "__main__":
79
  demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
  import functools
 
4
 
5
+ from diffusers import (
6
+ StableDiffusionPipeline,
7
+ StableDiffusionXLPipeline,
8
+ DPMSolverMultistepScheduler,
9
+ )
10
+
11
  MODEL_OPTS = {
12
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
13
  "SDXL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
14
  "SD-Turbo (ultra-fast)": "stabilityai/sd-turbo"
15
  }
16
 
 
17
  DEVICE = (
18
  "mps" if torch.backends.mps.is_available() else
19
  "cuda" if torch.cuda.is_available() else
20
  "cpu"
21
  )
 
22
  DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
23
 
24
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
25
  def get_pipeline(model_id: str):
26
+ # Choose the correct pipeline class for SDXL
27
+ if "sdxl-base" in model_id:
28
+ pipe = StableDiffusionXLPipeline.from_pretrained(
29
+ model_id,
30
+ torch_dtype=DTYPE,
31
+ safety_checker=None,
32
+ )
33
+ else:
34
+ pipe = StableDiffusionPipeline.from_pretrained(
35
+ model_id,
36
+ torch_dtype=DTYPE,
37
+ safety_checker=None,
38
+ )
39
+
40
+ pipe = pipe.to(DEVICE)
41
+ # swap in the faster DPMSolver scheduler everywhere
42
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
43
  return pipe
44
 
 
45
  def generate(prompt, steps, guidance, seed, model_name):
 
 
 
46
  model_id = MODEL_OPTS[model_name]
 
47
  if "Turbo" in model_name:
48
  steps = min(int(steps), 4)
49
  pipe = get_pipeline(model_id)
50
+ generator = None if seed == 0 else torch.manual_seed(int(seed))
51
 
52
+ # explicitly pass empty dicts so SDXL's UNet never sees None
53
+ out = pipe(
54
+ prompt,
55
  num_inference_steps=int(steps),
56
  guidance_scale=float(guidance),
57
  generator=generator,
58
+ cross_attention_kwargs={}, # always a dict
59
+ added_cond_kwargs={}, # always a dict (for SDXL)
60
  )
61
+ return out.images
62
 
 
63
  with gr.Blocks() as demo:
64
+ gr.Markdown("Model-Switcher Stable Diffusion Demo")
65
+ prompt = gr.Textbox("Retro robot in neon city", label="Prompt")
66
+ checkpoint = gr.Dropdown(list(MODEL_OPTS.keys()),
67
+ value="SD v1.5 (base)",
68
+ label="Checkpoint")
69
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
70
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
71
+ seed = gr.Number(0, label="Seed (0=random)")
72
  btn = gr.Button("Generate")
73
  gallery = gr.Gallery(label="Gallery", columns=2, height="auto")
74
 
 
80
 
81
  if __name__ == "__main__":
82
  demo.launch()
83
+