lexi-starikova commited on
Commit
6947996
·
verified ·
1 Parent(s): 9c48429

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -13
app.py CHANGED
@@ -7,15 +7,31 @@ from diffusers import DiffusionPipeline
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
14
  else:
15
  torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
@@ -23,6 +39,7 @@ MAX_IMAGE_SIZE = 1024
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
 
26
  prompt,
27
  negative_prompt,
28
  seed,
@@ -33,6 +50,12 @@ def infer(
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
 
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
@@ -68,6 +91,13 @@ with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
  gr.Markdown(" # Text-to-Image Gradio Template")
70
 
 
 
 
 
 
 
 
71
  with gr.Row():
72
  prompt = gr.Text(
73
  label="Prompt",
@@ -86,7 +116,7 @@ with gr.Blocks(css=css) as demo:
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
 
92
  seed = gr.Slider(
@@ -94,7 +124,7 @@ with gr.Blocks(css=css) as demo:
94
  minimum=0,
95
  maximum=MAX_SEED,
96
  step=1,
97
- value=0,
98
  )
99
 
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
@@ -105,7 +135,7 @@ with gr.Blocks(css=css) as demo:
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
 
111
  height = gr.Slider(
@@ -113,24 +143,24 @@ with gr.Blocks(css=css) as demo:
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
120
  guidance_scale = gr.Slider(
121
  label="Guidance scale",
122
  minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
129
  label="Number of inference steps",
130
  minimum=1,
131
- maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
  gr.Examples(examples=examples, inputs=[prompt])
@@ -138,6 +168,7 @@ with gr.Blocks(css=css) as demo:
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
  inputs=[
 
141
  prompt,
142
  negative_prompt,
143
  seed,
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
10
  if torch.cuda.is_available():
11
  torch_dtype = torch.float16
12
  else:
13
  torch_dtype = torch.float32
14
 
15
+ # список моделей для выбора
16
+ MODEL_OPTIONS = [
17
+ "stabilityai/sdxl-turbo",
18
+ "CompVis/stable-diffusion-v1-4",
19
+ "runwayml/stable-diffusion-v1-5",
20
+ "stabilityai/stable-diffusion-2-1",
21
+ ]
22
+
23
+ # глобальная переменная пайплайна
24
+ pipe = None
25
+
26
+
27
+ def load_model(model_id):
28
+ global pipe
29
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
30
+ pipe = pipe.to(device)
31
+
32
+
33
+ # загрузим модель по умолчанию
34
+ load_model(MODEL_OPTIONS[0])
35
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
  MAX_IMAGE_SIZE = 1024
 
39
 
40
  # @spaces.GPU #[uncomment to use ZeroGPU]
41
  def infer(
42
+ model_id,
43
  prompt,
44
  negative_prompt,
45
  seed,
 
50
  num_inference_steps,
51
  progress=gr.Progress(track_tqdm=True),
52
  ):
53
+ global pipe
54
+
55
+ # если модель поменялась — перезагружаем
56
+ if pipe is None or pipe.config._name_or_path != model_id:
57
+ load_model(model_id)
58
+
59
  if randomize_seed:
60
  seed = random.randint(0, MAX_SEED)
61
 
 
91
  with gr.Column(elem_id="col-container"):
92
  gr.Markdown(" # Text-to-Image Gradio Template")
93
 
94
+ # выбор модели
95
+ model_id = gr.Dropdown(
96
+ choices=MODEL_OPTIONS,
97
+ value=MODEL_OPTIONS[0],
98
+ label="Model",
99
+ )
100
+
101
  with gr.Row():
102
  prompt = gr.Text(
103
  label="Prompt",
 
116
  label="Negative prompt",
117
  max_lines=1,
118
  placeholder="Enter a negative prompt",
119
+ visible=True,
120
  )
121
 
122
  seed = gr.Slider(
 
124
  minimum=0,
125
  maximum=MAX_SEED,
126
  step=1,
127
+ value=42,
128
  )
129
 
130
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
135
  minimum=256,
136
  maximum=MAX_IMAGE_SIZE,
137
  step=32,
138
+ value=512,
139
  )
140
 
141
  height = gr.Slider(
 
143
  minimum=256,
144
  maximum=MAX_IMAGE_SIZE,
145
  step=32,
146
+ value=512,
147
  )
148
 
149
  with gr.Row():
150
  guidance_scale = gr.Slider(
151
  label="Guidance scale",
152
  minimum=0.0,
153
+ maximum=20.0,
154
+ step=0.5,
155
+ value=7.0,
156
  )
157
 
158
  num_inference_steps = gr.Slider(
159
  label="Number of inference steps",
160
  minimum=1,
161
+ maximum=100,
162
  step=1,
163
+ value=20,
164
  )
165
 
166
  gr.Examples(examples=examples, inputs=[prompt])
 
168
  triggers=[run_button.click, prompt.submit],
169
  fn=infer,
170
  inputs=[
171
+ model_id,
172
  prompt,
173
  negative_prompt,
174
  seed,