stephenebert commited on
Commit
c2d7f0c
·
verified ·
1 Parent(s): e5a5c47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -3,50 +3,68 @@ import torch
3
  import functools
4
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
 
 
6
  MODEL_OPTS = {
7
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
8
  "SDXL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
9
  "SD-Turbo (ultra-fast)": "stabilityai/sd-turbo"
10
  }
11
 
 
12
  DEVICE = (
13
  "mps" if torch.backends.mps.is_available() else
14
  "cuda" if torch.cuda.is_available() else
15
  "cpu"
16
  )
 
17
  DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
18
 
19
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
20
  def get_pipeline(model_id: str):
 
 
 
21
  pipe = StableDiffusionPipeline.from_pretrained(
22
  model_id,
23
  torch_dtype=DTYPE,
24
  safety_checker=None
25
  ).to(DEVICE)
 
26
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
27
  return pipe
28
 
 
29
  def generate(prompt, steps, guidance, seed, model_name):
 
 
 
30
  model_id = MODEL_OPTS[model_name]
 
31
  if "Turbo" in model_name:
32
  steps = min(int(steps), 4)
33
  pipe = get_pipeline(model_id)
34
- generator = None if seed == 0 else torch.manual_seed(int(seed))
35
- imgs = pipe(
36
- prompt,
 
 
37
  num_inference_steps=int(steps),
38
  guidance_scale=float(guidance),
39
- generator=generator
40
- ).images
41
- return imgs
 
42
 
 
43
  with gr.Blocks() as demo:
44
  gr.Markdown("## Model-Switcher Stable Diffusion Demo")
45
- prompt = gr.Textbox("Retro robot in neon city", label="Prompt")
46
- checkpoint = gr.Dropdown(list(MODEL_OPTS.keys()), value="SD v1.5 (base)", label="Checkpoint")
 
 
47
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
48
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
49
- seed = gr.Number(0, label="Seed (0=random)")
50
  btn = gr.Button("Generate")
51
  gallery = gr.Gallery(label="Gallery", columns=2, height="auto")
52
 
 
3
  import functools
4
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
5
 
6
+ # Model options for switching between Stable Diffusion variants
7
  MODEL_OPTS = {
8
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
9
  "SDXL Base 1.0": "stabilityai/stable-diffusion-xl-base-1.0",
10
  "SD-Turbo (ultra-fast)": "stabilityai/sd-turbo"
11
  }
12
 
13
+ # Auto-detect compute device (GPU/CPU)
14
  DEVICE = (
15
  "mps" if torch.backends.mps.is_available() else
16
  "cuda" if torch.cuda.is_available() else
17
  "cpu"
18
  )
19
+ # Use float16 on GPU/MPS for performance, float32 on CPU
20
  DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
21
 
22
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
23
  def get_pipeline(model_id: str):
24
+ """
25
+ Lazily load and cache the Stable Diffusion pipeline for a given model.
26
+ """
27
  pipe = StableDiffusionPipeline.from_pretrained(
28
  model_id,
29
  torch_dtype=DTYPE,
30
  safety_checker=None
31
  ).to(DEVICE)
32
+ # Switch to a faster sampler
33
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
34
  return pipe
35
 
36
+
37
  def generate(prompt, steps, guidance, seed, model_name):
38
+ """
39
+ Generate images using the specified Stable Diffusion model.
40
+ """
41
  model_id = MODEL_OPTS[model_name]
42
+ # Limit steps for turbo model
43
  if "Turbo" in model_name:
44
  steps = min(int(steps), 4)
45
  pipe = get_pipeline(model_id)
46
+ generator = None if int(seed) == 0 else torch.manual_seed(int(seed))
47
+
48
+ # Ensure cross_attention_kwargs is always a dict (avoids NoneType errors)
49
+ output = pipe(
50
+ prompt=prompt,
51
  num_inference_steps=int(steps),
52
  guidance_scale=float(guidance),
53
+ generator=generator,
54
+ cross_attention_kwargs={},
55
+ )
56
+ return output.images
57
 
58
+ # Build the Gradio UI
59
  with gr.Blocks() as demo:
60
  gr.Markdown("## Model-Switcher Stable Diffusion Demo")
61
+ prompt = gr.Textbox(label="Prompt", value="Retro robot in neon city")
62
+ checkpoint = gr.Dropdown(choices=list(MODEL_OPTS.keys()),
63
+ value="SD v1.5 (base)",
64
+ label="Checkpoint")
65
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
66
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
67
+ seed = gr.Number(value=0, label="Seed (0 = random)")
68
  btn = gr.Button("Generate")
69
  gallery = gr.Gallery(label="Gallery", columns=2, height="auto")
70