dolbohren commited on
Commit
ebbfbc7
·
verified ·
1 Parent(s): 1ee3035

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -21,13 +21,18 @@ pipe = pipe.to(device)
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
23
 
 
 
 
 
 
24
 
25
  @spaces.GPU #[uncomment to use ZeroGPU]
26
  def infer(
27
- # model_id: Optional[str] = "CompVis/stable-diffusion-v1-4",
28
  negative_prompt = 'cat, dog',
29
- prompt: str = 'cute_animal',
30
- seed: int = 42,
31
  randomize_seed = False,
32
  width = 128,
33
  height = 128,
@@ -35,6 +40,8 @@ def infer(
35
  num_inference_steps = 20,
36
  progress=gr.Progress(track_tqdm=True),
37
  ):
 
 
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
40
 
@@ -70,6 +77,13 @@ with gr.Blocks(css=css) as demo:
70
  with gr.Column(elem_id="col-container"):
71
  gr.Markdown(" # Text-to-Image Gradio Template")
72
 
 
 
 
 
 
 
 
73
  with gr.Row():
74
  prompt = gr.Text(
75
  label="Prompt",
@@ -140,6 +154,7 @@ with gr.Blocks(css=css) as demo:
140
  triggers=[run_button.click, prompt.submit],
141
  fn=infer,
142
  inputs=[
 
143
  prompt,
144
  negative_prompt,
145
  seed,
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 1024
23
 
24
+ def load_model(model_repo_id):
25
+ global pipe
26
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_cache=False)
27
+ pipe = pipe.to(device)
28
+
29
 
30
  @spaces.GPU #[uncomment to use ZeroGPU]
31
  def infer(
32
+ model_id = "CompVis/stable-diffusion-v1-4",
33
  negative_prompt = 'cat, dog',
34
+ prompt = 'cute_animal',
35
+ seed = 42,
36
  randomize_seed = False,
37
  width = 128,
38
  height = 128,
 
40
  num_inference_steps = 20,
41
  progress=gr.Progress(track_tqdm=True),
42
  ):
43
+ load_model(model_id)
44
+
45
  if randomize_seed:
46
  seed = random.randint(0, MAX_SEED)
47
 
 
77
  with gr.Column(elem_id="col-container"):
78
  gr.Markdown(" # Text-to-Image Gradio Template")
79
 
80
+ with gr.Row():
81
+ model_id = gr.Dropdown(
82
+ label="Model",
83
+ choices=["CompVis/stable-diffusion-v1-4", "stabilityai/sdxl-turbo"],
84
+ value="CompVis/stable-diffusion-v1-4",
85
+ )
86
+
87
  with gr.Row():
88
  prompt = gr.Text(
89
  label="Prompt",
 
154
  triggers=[run_button.click, prompt.submit],
155
  fn=infer,
156
  inputs=[
157
+ model_id,
158
  prompt,
159
  negative_prompt,
160
  seed,