DiZH797 commited on
Commit
db16914
·
verified ·
1 Parent(s): c11d49f

Update app.py

Browse files

small fix from init app.py

Files changed (1) hide show
  1. app.py +111 -162
app.py CHANGED
@@ -1,161 +1,131 @@
1
- # app.py
2
  import gradio as gr
3
  import numpy as np
4
  import random
 
 
5
  from diffusers import DiffusionPipeline
6
- from diffusers import (
7
- DDIMScheduler,
8
- PNDMScheduler,
9
- LMSDiscreteScheduler,
10
- EulerDiscreteScheduler,
11
- DPMSolverMultistepScheduler,
12
- )
13
  import torch
 
14
 
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- MAX_SEED = np.iinfo(np.int32).max
17
- MAX_IMAGE_SIZE = 1024
18
-
19
  DEFAULT_MODEL = "CompVis/stable-diffusion-v1-4"
20
  MODEL_OPTIONS = [
21
  "CompVis/stable-diffusion-v1-4",
 
22
  "stabilityai/sdxl-turbo",
23
- # add other model ids you want to expose here
24
  ]
25
 
26
- SCHEDULER_MAP = {
27
- "default": None,
28
- "DDIM": DDIMScheduler,
29
- "PNDM": PNDMScheduler,
30
- "LMS": LMSDiscreteScheduler,
31
- "Euler": EulerDiscreteScheduler,
32
- "DPMSolver": DPMSolverMultistepScheduler,
33
- }
34
-
35
-
36
- def get_torch_dtype():
37
- return torch.float16 if torch.cuda.is_available() else torch.float32
38
 
 
 
39
 
40
- def load_pipeline(model_id: str, scheduler_name: str = "default"):
41
- """Load pipeline from pretrained model_id and optionally replace scheduler."""
42
- torch_dtype = get_torch_dtype()
43
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
44
- # try to replace scheduler if requested
45
- sched_cls = SCHEDULER_MAP.get(scheduler_name)
46
- if sched_cls is not None:
47
- try:
48
- pipe.scheduler = sched_cls.from_config(pipe.scheduler.config)
49
- except Exception:
50
- # fallback to default if replacement failed
51
- pass
52
- pipe = pipe.to(device)
53
- return pipe
54
-
55
-
56
- # preload default pipeline (may take time on startup)
57
- print(f"Loading default model {DEFAULT_MODEL} ...")
58
- try:
59
- default_pipe = load_pipeline(DEFAULT_MODEL, "default")
60
- print("Loaded default model.")
61
- except Exception as e:
62
- default_pipe = None
63
- print("Failed to preload default model:", e)
64
-
65
- css = """
66
- #col-container {
67
- margin: 0 auto;
68
- max-width: 880px;
69
- }
70
- """
71
 
72
- examples = [
73
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
74
- "An astronaut riding a green horse",
75
- "A delicious ceviche cheesecake slice",
76
- ]
 
 
 
77
 
78
- def load_model_and_update(model_id, scheduler_name):
79
- """Called when user selects a model or scheduler: load and return new pipeline + status."""
80
- try:
81
- pipe = load_pipeline(model_id, scheduler_name)
82
- return pipe, f"Loaded `{model_id}` (scheduler: {scheduler_name})"
83
- except Exception as e:
84
- return None, f"Error loading `{model_id}`: {e}"
85
 
 
 
86
 
87
- def size_to_dims(size_str):
88
- try:
89
- w, h = map(int, size_str.split("x"))
90
- # clamp to limits
91
- w = min(max(256, w), MAX_IMAGE_SIZE)
92
- h = min(max(256, h), MAX_IMAGE_SIZE)
93
- return gr.Slider.update(value=w), gr.Slider.update(value=h)
94
- except Exception:
95
- return gr.Slider.update(value=512), gr.Slider.update(value=512)
96
 
97
 
 
98
  def infer(
99
- prompt,
100
- negative_prompt,
101
- seed,
102
- randomize_seed,
103
- width,
104
- height,
105
- guidance_scale,
106
- num_inference_steps,
107
- pipe_state, # gr.State containing pipeline
 
108
  progress=gr.Progress(track_tqdm=True),
109
  ):
110
- if pipe_state is None:
111
- return None, seed, "Model not loaded."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  if randomize_seed:
114
  seed = random.randint(0, MAX_SEED)
115
 
116
- # create generator on proper device
117
- if device.startswith("cuda"):
118
- generator = torch.Generator(device=device).manual_seed(seed)
119
- else:
120
- generator = torch.Generator().manual_seed(seed)
121
-
122
- try:
123
- out = pipe_state(
124
- prompt=prompt,
125
- negative_prompt=negative_prompt if negative_prompt else None,
126
- guidance_scale=float(guidance_scale),
127
- num_inference_steps=int(num_inference_steps),
128
- width=int(width),
129
- height=int(height),
130
- generator=generator,
131
- )
132
- image = out.images[0]
133
- return image, seed, "OK"
134
- except Exception as e:
135
- return None, seed, f"Inference error: {e}"
136
 
 
 
 
 
 
 
 
 
 
137
 
138
- with gr.Blocks(css=css, title="Text-to-Image") as demo:
139
- pipe_state = gr.State(value=default_pipe)
140
 
141
- with gr.Column(elem_id="col-container"):
142
- gr.Markdown("# Text-to-Image — demo")
143
 
144
- with gr.Row():
145
- model_selector = gr.Dropdown(
146
- label="Model ID",
147
- choices=MODEL_OPTIONS,
148
- value=DEFAULT_MODEL,
149
- interactive=True,
150
- )
151
- scheduler_selector = gr.Dropdown(
152
- label="Scheduler",
153
- choices=list(SCHEDULER_MAP.keys()),
154
- value="default",
155
- interactive=True,
156
- )
157
 
158
- status = gr.Markdown("Model status: ready" if default_pipe else "Model status: not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  with gr.Row():
161
  prompt = gr.Text(
@@ -165,7 +135,8 @@ with gr.Blocks(css=css, title="Text-to-Image") as demo:
165
  placeholder="Enter your prompt",
166
  container=False,
167
  )
168
- run_button = gr.Button("Run", variant="primary")
 
169
 
170
  result = gr.Image(label="Result", show_label=False)
171
 
@@ -185,21 +156,15 @@ with gr.Blocks(css=css, title="Text-to-Image") as demo:
185
  value=42,
186
  )
187
 
188
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
189
 
190
  with gr.Row():
191
- size_preset = gr.Dropdown(
192
- label="Size preset",
193
- choices=["512x512", "768x512", "1024x1024"],
194
- value="512x512",
195
- )
196
-
197
  width = gr.Slider(
198
  label="Width",
199
  minimum=256,
200
  maximum=MAX_IMAGE_SIZE,
201
  step=32,
202
- value=512,
203
  )
204
 
205
  height = gr.Slider(
@@ -207,47 +172,32 @@ with gr.Blocks(css=css, title="Text-to-Image") as demo:
207
  minimum=256,
208
  maximum=MAX_IMAGE_SIZE,
209
  step=32,
210
- value=512,
211
  )
212
 
213
  with gr.Row():
214
  guidance_scale = gr.Slider(
215
  label="Guidance scale",
216
  minimum=0.0,
217
- maximum=20.0,
218
  step=0.1,
219
- value=7.0,
220
  )
221
 
222
  num_inference_steps = gr.Slider(
223
  label="Number of inference steps",
224
  minimum=1,
225
- maximum=150,
226
  step=1,
227
- value=20,
228
  )
229
 
230
  gr.Examples(examples=examples, inputs=[prompt])
231
-
232
- # Events
233
- model_selector.change(
234
- fn=load_model_and_update,
235
- inputs=[model_selector, scheduler_selector],
236
- outputs=[pipe_state, status],
237
- queue=True,
238
- )
239
- scheduler_selector.change(
240
- fn=load_model_and_update,
241
- inputs=[model_selector, scheduler_selector],
242
- outputs=[pipe_state, status],
243
- queue=True,
244
- )
245
-
246
- size_preset.change(fn=size_to_dims, inputs=size_preset, outputs=[width, height])
247
-
248
- run_button.click(
249
  fn=infer,
250
  inputs=[
 
251
  prompt,
252
  negative_prompt,
253
  seed,
@@ -256,11 +206,10 @@ with gr.Blocks(css=css, title="Text-to-Image") as demo:
256
  height,
257
  guidance_scale,
258
  num_inference_steps,
259
- pipe_state,
260
  ],
261
- outputs=[result, seed, status],
262
- queue=True,
263
  )
264
 
265
  if __name__ == "__main__":
266
- demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+
5
+ # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
7
  import torch
8
+ from typing import Optional
9
 
10
+ # кэш для пайплайнов (чтобы не перезагружать модель при каждом запросе)
11
+ PIPE_CACHE: dict[str, DiffusionPipeline] = {}
 
 
12
  DEFAULT_MODEL = "CompVis/stable-diffusion-v1-4"
13
  MODEL_OPTIONS = [
14
  "CompVis/stable-diffusion-v1-4",
15
+ "stabilityai/stable-diffusion-2-1",
16
  "stabilityai/sdxl-turbo",
 
17
  ]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
22
 
23
+ if torch.cuda.is_available():
24
+ torch_dtype = torch.float16
25
+ else:
26
+ torch_dtype = torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def get_pipe(model_id: str):
29
+ if model_id in PIPE_CACHE:
30
+ return PIPE_CACHE[model_id]
31
+ # загружаем и кэшируем
32
+ p = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
33
+ p = p.to(device)
34
+ PIPE_CACHE[model_id] = p
35
+ return p
36
 
 
 
 
 
 
 
 
37
 
38
+ # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
39
+ # pipe = pipe.to(device)
40
 
41
+ MAX_SEED = np.iinfo(np.int32).max
42
+ MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
43
 
44
 
45
+ # @spaces.GPU #[uncomment to use ZeroGPU]
46
  def infer(
47
+ model_id: Optional[str] = DEFAULT_MODEL,
48
+ prompt: str = "",
49
+ negative_prompt: str = "",
50
+ seed: int = 42,
51
+ randomize_seed: bool = False,
52
+ width: int = 512,
53
+ height: int = 512,
54
+ guidance_scale: float = 7.0,
55
+ num_inference_steps: int = 20,
56
+ scheduler_name: Optional[str] = None,
57
  progress=gr.Progress(track_tqdm=True),
58
  ):
59
+ # получаем/загружаем нужный pipe
60
+ pipe = get_pipe(model_id)
61
+
62
+ # при желании можно подменить scheduler по имени (опционально)
63
+ if scheduler_name:
64
+ # примерная схема: словарь name->класс scheduler
65
+ # при необходимости добавить другие scheduler'ы — импортируйте их сверху и добавьте сюда
66
+ try:
67
+ from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, PNDMScheduler
68
+ sched_map = {
69
+ "DDIM": DDIMScheduler,
70
+ "EulerAncestral": EulerAncestralDiscreteScheduler,
71
+ "PNDM": PNDMScheduler,
72
+ }
73
+ if scheduler_name in sched_map:
74
+ pipe.scheduler = sched_map[scheduler_name].from_config(pipe.scheduler.config)
75
+ except Exception:
76
+ # если что-то пошло не так — просто используем дефолтный scheduler
77
+ pass
78
 
79
  if randomize_seed:
80
  seed = random.randint(0, MAX_SEED)
81
 
82
+ generator = torch.Generator().manual_seed(int(seed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ image = pipe(
85
+ prompt=prompt,
86
+ negative_prompt=negative_prompt,
87
+ guidance_scale=guidance_scale,
88
+ num_inference_steps=num_inference_steps,
89
+ width=width,
90
+ height=height,
91
+ generator=generator,
92
+ ).images[0]
93
 
94
+ return image, seed
 
95
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ examples = [
99
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
100
+ "An astronaut riding a green horse",
101
+ "A delicious ceviche cheesecake slice",
102
+ ]
103
+
104
+ css = """
105
+ #col-container {
106
+ margin: 0 auto;
107
+ max-width: 640px;
108
+ }
109
+ """
110
+
111
+ with gr.Blocks(css=css) as demo:
112
+ with gr.Column(elem_id="col-container"):
113
+ gr.Markdown(" # Text-to-Image Gradio Template")
114
+
115
+ # Model selector (выпадающий список)
116
+ model_select = gr.Dropdown(
117
+ label="Model",
118
+ choices=MODEL_OPTIONS,
119
+ value=DEFAULT_MODEL,
120
+ interactive=True,
121
+ )
122
+
123
+ # опциональный селектор scheduler
124
+ scheduler_select = gr.Dropdown(
125
+ label="Scheduler (optional)",
126
+ choices=["", "DDIM", "EulerAncestral", "PNDM"],
127
+ value="",
128
+ )
129
 
130
  with gr.Row():
131
  prompt = gr.Text(
 
135
  placeholder="Enter your prompt",
136
  container=False,
137
  )
138
+
139
+ run_button = gr.Button("Run", scale=0, variant="primary")
140
 
141
  result = gr.Image(label="Result", show_label=False)
142
 
 
156
  value=42,
157
  )
158
 
159
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
160
 
161
  with gr.Row():
 
 
 
 
 
 
162
  width = gr.Slider(
163
  label="Width",
164
  minimum=256,
165
  maximum=MAX_IMAGE_SIZE,
166
  step=32,
167
+ value=1024, # Replace with defaults that work for your model
168
  )
169
 
170
  height = gr.Slider(
 
172
  minimum=256,
173
  maximum=MAX_IMAGE_SIZE,
174
  step=32,
175
+ value=1024, # Replace with defaults that work for your model
176
  )
177
 
178
  with gr.Row():
179
  guidance_scale = gr.Slider(
180
  label="Guidance scale",
181
  minimum=0.0,
182
+ maximum=10.0,
183
  step=0.1,
184
+ value=7.0, # Replace with defaults that work for your model
185
  )
186
 
187
  num_inference_steps = gr.Slider(
188
  label="Number of inference steps",
189
  minimum=1,
190
+ maximum=50,
191
  step=1,
192
+ value=20, # Replace with defaults that work for your model
193
  )
194
 
195
  gr.Examples(examples=examples, inputs=[prompt])
196
+ gr.on(
197
+ triggers=[run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  fn=infer,
199
  inputs=[
200
+ model_select,
201
  prompt,
202
  negative_prompt,
203
  seed,
 
206
  height,
207
  guidance_scale,
208
  num_inference_steps,
209
+ scheduler_select
210
  ],
211
+ outputs=[result, seed],
 
212
  )
213
 
214
  if __name__ == "__main__":
215
+ demo.launch()