stephenebert commited on
Commit
d18d089
·
verified ·
1 Parent(s): b1c1eb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -37
app.py CHANGED
@@ -1,12 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import functools
4
-
5
- from diffusers import (
6
- StableDiffusionPipeline,
7
- StableDiffusionXLPipeline,
8
- DPMSolverMultistepScheduler,
9
- )
10
 
11
  MODEL_OPTS = {
12
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
@@ -22,50 +17,32 @@ DTYPE = torch.float16 if DEVICE != "cpu" else torch.float32
22
 
23
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
24
  def get_pipeline(model_id: str):
25
- # Use the SDXL pipeline class whenever the model_id is the XL Base
26
- if "stable-diffusion-xl-base" in model_id:
27
- pipe = StableDiffusionXLPipeline.from_pretrained(
28
- model_id,
29
- torch_dtype=DTYPE,
30
- safety_checker=None,
31
- )
32
- else:
33
- pipe = StableDiffusionPipeline.from_pretrained(
34
- model_id,
35
- torch_dtype=DTYPE,
36
- safety_checker=None,
37
- )
38
-
39
- pipe = pipe.to(DEVICE)
40
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
41
  return pipe
42
 
43
  def generate(prompt, steps, guidance, seed, model_name):
44
  model_id = MODEL_OPTS[model_name]
45
- # ultra-fast “Turbo” runs only 4 steps
46
  if "Turbo" in model_name:
47
  steps = min(int(steps), 4)
48
-
49
  pipe = get_pipeline(model_id)
50
  generator = None if seed == 0 else torch.manual_seed(int(seed))
51
-
52
- # ALWAYS pass dicts so SDXL's UNet doesn't see None
53
- output = pipe(
54
  prompt,
55
  num_inference_steps=int(steps),
56
  guidance_scale=float(guidance),
57
- generator=generator,
58
- cross_attention_kwargs={}, # never None
59
- added_cond_kwargs={}, # never None for SDXL
60
- )
61
- return output.images
62
 
63
  with gr.Blocks() as demo:
64
- gr.Markdown("Model-Switcher Stable Diffusion Demo")
65
  prompt = gr.Textbox("Retro robot in neon city", label="Prompt")
66
- checkpoint = gr.Dropdown(list(MODEL_OPTS.keys()),
67
- value="SD v1.5 (base)",
68
- label="Checkpoint")
69
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
70
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
71
  seed = gr.Number(0, label="Seed (0=random)")
@@ -75,9 +52,8 @@ with gr.Blocks() as demo:
75
  btn.click(
76
  fn=generate,
77
  inputs=[prompt, steps, guidance, seed, checkpoint],
78
- outputs=gallery,
79
  )
80
 
81
  if __name__ == "__main__":
82
  demo.launch()
83
-
 
1
  import gradio as gr
2
  import torch
3
  import functools
4
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
 
 
 
 
 
5
 
6
  MODEL_OPTS = {
7
  "SD v1.5 (base)": "runwayml/stable-diffusion-v1-5",
 
17
 
18
  @functools.lru_cache(maxsize=len(MODEL_OPTS))
19
  def get_pipeline(model_id: str):
20
+ pipe = StableDiffusionPipeline.from_pretrained(
21
+ model_id,
22
+ torch_dtype=DTYPE,
23
+ safety_checker=None
24
+ ).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
25
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
26
  return pipe
27
 
28
  def generate(prompt, steps, guidance, seed, model_name):
29
  model_id = MODEL_OPTS[model_name]
 
30
  if "Turbo" in model_name:
31
  steps = min(int(steps), 4)
 
32
  pipe = get_pipeline(model_id)
33
  generator = None if seed == 0 else torch.manual_seed(int(seed))
34
+ imgs = pipe(
 
 
35
  prompt,
36
  num_inference_steps=int(steps),
37
  guidance_scale=float(guidance),
38
+ generator=generator
39
+ ).images
40
+ return imgs
 
 
41
 
42
  with gr.Blocks() as demo:
43
+ gr.Markdown("## Model-Switcher Stable Diffusion Demo")
44
  prompt = gr.Textbox("Retro robot in neon city", label="Prompt")
45
+ checkpoint = gr.Dropdown(list(MODEL_OPTS.keys()), value="SD v1.5 (base)", label="Checkpoint")
 
 
46
  steps = gr.Slider(1, 50, value=30, label="Inference Steps")
47
  guidance = gr.Slider(1, 15, value=7.5, label="Guidance Scale")
48
  seed = gr.Number(0, label="Seed (0=random)")
 
52
  btn.click(
53
  fn=generate,
54
  inputs=[prompt, steps, guidance, seed, checkpoint],
55
+ outputs=gallery
56
  )
57
 
58
  if __name__ == "__main__":
59
  demo.launch()