osmr commited on
Commit
128bee1
·
verified ·
1 Parent(s): 75afb10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -7,37 +7,40 @@ from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
 
 
 
 
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
@@ -66,8 +69,17 @@ css = """
66
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
 
 
 
 
 
 
 
 
 
 
71
  with gr.Row():
72
  prompt = gr.Text(
73
  label="Prompt",
@@ -138,6 +150,7 @@ with gr.Blocks(css=css) as demo:
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
  inputs=[
 
141
  prompt,
142
  negative_prompt,
143
  seed,
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
 
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.float16
13
  else:
14
  torch_dtype = torch.float32
15
 
 
 
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
  MAX_IMAGE_SIZE = 1024
18
 
19
 
20
  # @spaces.GPU #[uncomment to use ZeroGPU]
21
+ def infer(model_id: Optional[str] = "CompVis/stable-diffusion-v1-4",
22
+ prompt: str = "",
23
+ negative_prompt: str = "",
24
+ seed: Optional[int] = 42,
25
+ randomize_seed: bool = True,
26
+ width: int = 1024,
27
+ height: int = 1024,
28
+ guidance_scale: Optional[float] = 7,
29
+ num_inference_steps: Optional[int] = 20,
30
+ progress = gr.Progress(track_tqdm=True)):
31
+ if model_id:
32
+ model_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
33
+
34
  if randomize_seed:
35
  seed = random.randint(0, MAX_SEED)
36
 
37
  generator = torch.Generator().manual_seed(seed)
38
 
39
+ pipe = DiffusionPipeline.from_pretrained(
40
+ pretrained_model_name_or_path=model_id,
41
+ torch_dtype=torch_dtype)
42
+ pipe = pipe.to(device)
43
+
44
  image = pipe(
45
  prompt=prompt,
46
  negative_prompt=negative_prompt,
 
69
 
70
  with gr.Blocks(css=css) as demo:
71
  with gr.Column(elem_id="col-container"):
72
+ gr.Markdown(" # Text-to-Image Gradio Form")
73
 
74
+ with gr.Row():
75
+ model_id = gr.Dropdown(
76
+ choices=["stabilityai/sdxl-turbo", "CompVis/stable-diffusion-v1-4"],
77
+ multiselect=False,
78
+ allow_custom_value=True,
79
+ label="Model",
80
+ info="Choose model ID",
81
+ )
82
+
83
  with gr.Row():
84
  prompt = gr.Text(
85
  label="Prompt",
 
150
  triggers=[run_button.click, prompt.submit],
151
  fn=infer,
152
  inputs=[
153
+ model_id,
154
  prompt,
155
  negative_prompt,
156
  seed,