mcuo commited on
Commit
df528d5
·
verified ·
1 Parent(s): 44c56fd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -119
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import uuid
3
  import time
4
  import random
5
-
6
  import spaces
7
  import gradio as gr
8
  import numpy as np
@@ -22,7 +21,6 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
22
 
23
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
24
  pipe.to(device)
25
-
26
  pipe.text_encoder.to(torch.float16)
27
  pipe.text_encoder_2.to(torch.float16)
28
  pipe.vae.to(torch.float16)
@@ -38,20 +36,17 @@ compel = Compel(
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
40
  MAX_IMAGE_SIZE = 1216
41
-
42
  OUTPUT_DIR = "/tmp/generated_images"
43
  os.makedirs(OUTPUT_DIR, exist_ok=True)
44
 
45
-
46
  def save_image_jpg(pil_image: Image.Image) -> str:
47
  if pil_image.mode != "RGB":
48
  pil_image = pil_image.convert("RGB")
49
  path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.jpg")
50
  pil_image.save(path, "JPEG", quality=95)
51
- # return path
52
-
53
 
54
- @spaces.GPU(duration=10)
55
  def infer(
56
  prompt,
57
  negative_prompt,
@@ -64,20 +59,15 @@ def infer(
64
  ):
65
  if not prompt.strip():
66
  raise gr.Error("Prompt cannot be empty.")
67
-
68
  if randomize_seed:
69
  seed = random.randint(0, MAX_SEED)
70
-
71
  generator = torch.Generator(device=device).manual_seed(seed)
72
-
73
  try:
74
  conditioning, pooled = compel([prompt, negative_prompt])
75
-
76
  prompt_embeds = conditioning[0:1]
77
  pooled_prompt_embeds = pooled[0:1]
78
  negative_prompt_embeds = conditioning[1:2]
79
  negative_pooled_prompt_embeds = pooled[1:2]
80
-
81
  image = pipe(
82
  prompt_embeds=prompt_embeds,
83
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -89,16 +79,13 @@ def infer(
89
  height=height,
90
  generator=generator,
91
  ).images[0]
92
-
93
  image_path = save_image_jpg(image)
94
- return image_path, seed
95
-
96
  except RuntimeError as e:
97
  print(f"Error during generation: {e}")
98
  blank_image = Image.new("RGB", (width, height), color=(0, 0, 0))
99
  blank_path = save_image_jpg(blank_image)
100
- return blank_path, seed
101
-
102
 
103
  def generation_loop(
104
  prompt,
@@ -113,7 +100,6 @@ def generation_loop(
113
  ):
114
  if not prompt.strip():
115
  raise gr.Error("Prompt cannot be empty to start consecutive generation.")
116
-
117
  while True:
118
  try:
119
  image_path, new_seed = infer(
@@ -126,42 +112,40 @@ def generation_loop(
126
  guidance_scale,
127
  num_inference_steps,
128
  )
129
-
130
- yield {result: image_path, seed: new_seed}
131
  time.sleep(interval_sec)
132
-
133
  except gr.exceptions.CancelledError:
134
  print("Generation loop cancelled by user.")
135
  break
136
 
137
-
138
  css = """
139
  #col-container {
140
- margin: 0 auto;
141
- max-width: 1024px;
142
  }
143
 
144
  /* 完全透過(非表示だがクリック等は可能なまま) */
145
  .transparent-btn,
146
  .transparent-btn * {
147
- opacity: 0 !important;
148
  }
149
 
150
  .transparent-btn button {
151
- background: transparent !important;
152
- border: 0 !important;
153
- box-shadow: none !important;
154
  }
155
 
156
  .transparent-btn button:focus,
157
  .transparent-btn button:focus-visible {
158
- outline: none !important;
159
  }
160
  """
161
 
162
  with gr.Blocks(css=css) as demo:
163
  with gr.Column(elem_id="col-container"):
164
- gr.Markdown("<br>" * 1)
 
165
  # Prompt(右にGenerateは置かない)
166
  with gr.Row(equal_height=True):
167
  prompt = gr.Text(
@@ -183,7 +167,7 @@ with gr.Blocks(css=css) as demo:
183
  run_button = gr.Button("Generate", scale=0, interactive=False, elem_classes=["transparent-btn"])
184
  consecutive_button = gr.Button("Consecutive", scale=0, interactive=False, elem_classes=["transparent-btn"])
185
 
186
- gr.Markdown("<br>" * 20)
187
 
188
  # 停止/クリア
189
  with gr.Row():
@@ -212,18 +196,15 @@ with gr.Blocks(css=css) as demo:
212
 
213
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
214
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
215
-
216
  with gr.Row():
217
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
218
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
219
-
220
  with gr.Row():
221
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=8)
222
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=28, step=1, value=25)
223
-
224
  interval_seconds = gr.Slider(label="Interval (seconds)", minimum=1, maximum=60, step=1, value=1)
225
 
226
- gr.Markdown("<br>" * 20)
227
 
228
  gr.Examples(
229
  examples=[
@@ -233,95 +214,96 @@ with gr.Blocks(css=css) as demo:
233
  label="Examples (Click to copy to prompt)",
234
  )
235
 
236
- # Promptが空でなければボタンを押せるようにする
237
- prompt.input(
238
- fn=None,
239
- inputs=[prompt],
240
- outputs=[run_button, consecutive_button],
241
- js="(p) => { const interactive = p.trim().length > 0; return [{ interactive: interactive, '__type__': 'update' }, { interactive: interactive, '__type__': 'update' }]; }",
242
- )
243
-
244
- # クリア:promptを空にしてボタン無効、URL欄も空にする
245
- clear_button.click(
246
- fn=None,
247
- inputs=None,
248
- outputs=[prompt, run_button, consecutive_button, image_url],
249
- js="""
250
  function() {
251
- return [
252
- "",
253
- { "interactive": false, "__type__": "update" },
254
- { "interactive": false, "__type__": "update" },
255
- ""
256
- ];
257
- }
258
- """,
259
- )
260
-
261
- # 生成
262
- run_button.click(
263
- fn=infer,
264
- inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
265
- outputs=[result, seed],
266
- )
267
-
268
- # 連続生成
269
- gen_inputs = [
270
- prompt,
271
- negative_prompt,
272
- seed,
273
- randomize_seed,
274
- width,
275
- height,
276
- guidance_scale,
277
- num_inference_steps,
278
- interval_seconds,
279
- ]
280
-
281
- consecutive_event = consecutive_button.click(
282
- fn=generation_loop,
283
- inputs=gen_inputs,
284
- outputs=[result, seed],
285
- )
286
-
287
- # 停止
288
- stop_button.click(
289
- fn=None,
290
- inputs=None,
291
- outputs=None,
292
- cancels=[consecutive_event],
293
- )
294
-
295
- # resultが更新されたら、表示中のimg.src(= /file=...)を拾って表示
296
- result.change(
297
- fn=None,
298
- inputs=None,
299
- outputs=[image_url],
300
- js=r"""
301
- () => {
302
- const img = document.querySelector("#result_image img");
303
- if (!img || !img.src) return "";
304
- return new URL(img.src, window.location.href).href;
305
  }
306
  """,
307
- )
308
-
309
- # Copyボタン:URL文字列をコピー
310
- copy_button.click(
311
- fn=None,
312
- inputs=[image_url],
313
- outputs=None,
314
- js=r"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  async (url) => {
316
- if (!url) return;
317
- try {
318
- await navigator.clipboard.writeText(url);
319
- console.log("URL copied");
320
- } catch (e) {
321
- console.error("Copy failed", e);
322
- }
 
323
  }
324
  """,
325
- )
326
 
327
  demo.queue().launch()
 
2
  import uuid
3
  import time
4
  import random
 
5
  import spaces
6
  import gradio as gr
7
  import numpy as np
 
21
 
22
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
23
  pipe.to(device)
 
24
  pipe.text_encoder.to(torch.float16)
25
  pipe.text_encoder_2.to(torch.float16)
26
  pipe.vae.to(torch.float16)
 
36
 
37
  MAX_SEED = np.iinfo(np.int32).max
38
  MAX_IMAGE_SIZE = 1216
 
39
  OUTPUT_DIR = "/tmp/generated_images"
40
  os.makedirs(OUTPUT_DIR, exist_ok=True)
41
 
 
42
  def save_image_jpg(pil_image: Image.Image) -> str:
43
  if pil_image.mode != "RGB":
44
  pil_image = pil_image.convert("RGB")
45
  path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.jpg")
46
  pil_image.save(path, "JPEG", quality=95)
47
+ return path
 
48
 
49
+ @spaces.GPU(duration=15)
50
  def infer(
51
  prompt,
52
  negative_prompt,
 
59
  ):
60
  if not prompt.strip():
61
  raise gr.Error("Prompt cannot be empty.")
 
62
  if randomize_seed:
63
  seed = random.randint(0, MAX_SEED)
 
64
  generator = torch.Generator(device=device).manual_seed(seed)
 
65
  try:
66
  conditioning, pooled = compel([prompt, negative_prompt])
 
67
  prompt_embeds = conditioning[0:1]
68
  pooled_prompt_embeds = pooled[0:1]
69
  negative_prompt_embeds = conditioning[1:2]
70
  negative_pooled_prompt_embeds = pooled[1:2]
 
71
  image = pipe(
72
  prompt_embeds=prompt_embeds,
73
  pooled_prompt_embeds=pooled_prompt_embeds,
 
79
  height=height,
80
  generator=generator,
81
  ).images[0]
 
82
  image_path = save_image_jpg(image)
83
+ return f"/file={image_path}", seed # ← 変更: /file= プレフィックスを付けてURLとして返す
 
84
  except RuntimeError as e:
85
  print(f"Error during generation: {e}")
86
  blank_image = Image.new("RGB", (width, height), color=(0, 0, 0))
87
  blank_path = save_image_jpg(blank_image)
88
+ return f"/file={blank_path}", seed # ← 変更: 同上
 
89
 
90
  def generation_loop(
91
  prompt,
 
100
  ):
101
  if not prompt.strip():
102
  raise gr.Error("Prompt cannot be empty to start consecutive generation.")
 
103
  while True:
104
  try:
105
  image_path, new_seed = infer(
 
112
  guidance_scale,
113
  num_inference_steps,
114
  )
115
+ yield {image_url: image_path, seed: new_seed} # ← 変更: result → image_url
 
116
  time.sleep(interval_sec)
 
117
  except gr.exceptions.CancelledError:
118
  print("Generation loop cancelled by user.")
119
  break
120
 
 
121
  css = """
122
  #col-container {
123
+ margin: 0 auto;
124
+ max-width: 1024px;
125
  }
126
 
127
  /* 完全透過(非表示だがクリック等は可能なまま) */
128
  .transparent-btn,
129
  .transparent-btn * {
130
+ opacity: 0 !important;
131
  }
132
 
133
  .transparent-btn button {
134
+ background: transparent !important;
135
+ border: 0 !important;
136
+ box-shadow: none !important;
137
  }
138
 
139
  .transparent-btn button:focus,
140
  .transparent-btn button:focus-visible {
141
+ outline: none !important;
142
  }
143
  """
144
 
145
  with gr.Blocks(css=css) as demo:
146
  with gr.Column(elem_id="col-container"):
147
+ gr.Markdown("" * 1)
148
+
149
  # Prompt(右にGenerateは置かない)
150
  with gr.Row(equal_height=True):
151
  prompt = gr.Text(
 
167
  run_button = gr.Button("Generate", scale=0, interactive=False, elem_classes=["transparent-btn"])
168
  consecutive_button = gr.Button("Consecutive", scale=0, interactive=False, elem_classes=["transparent-btn"])
169
 
170
+ gr.Markdown("" * 20)
171
 
172
  # 停止/クリア
173
  with gr.Row():
 
196
 
197
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
198
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
199
  with gr.Row():
200
  width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
201
  height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
 
202
  with gr.Row():
203
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=20.0, step=0.1, value=8)
204
  num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=28, step=1, value=25)
 
205
  interval_seconds = gr.Slider(label="Interval (seconds)", minimum=1, maximum=60, step=1, value=1)
206
 
207
+ gr.Markdown("" * 20)
208
 
209
  gr.Examples(
210
  examples=[
 
214
  label="Examples (Click to copy to prompt)",
215
  )
216
 
217
+ # Promptが空でなければボタンを押せるようにする
218
+ prompt.input(
219
+ fn=None,
220
+ inputs=[prompt],
221
+ outputs=[run_button, consecutive_button],
222
+ js="(p) => { const interactive = p.trim().length > 0; return [{ interactive: interactive, '__type__': 'update' }, { interactive: interactive, '__type__': 'update' }]; }",
223
+ )
224
+
225
+ # クリア:promptを空にしてボタン無効、URL欄も空にする
226
+ clear_button.click(
227
+ fn=None,
228
+ inputs=None,
229
+ outputs=[prompt, run_button, consecutive_button, image_url],
230
+ js="""
231
  function() {
232
+ return [
233
+ "",
234
+ { "interactive": false, "__type__": "update" },
235
+ { "interactive": false, "__type__": "update" },
236
+ ""
237
+ ];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  }
239
  """,
240
+ )
241
+
242
+ # 生成
243
+ run_button.click(
244
+ fn=infer,
245
+ inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
246
+ outputs=[image_url, seed], # ← 変更: result → image_url
247
+ )
248
+
249
+ # 連続生成
250
+ gen_inputs = [
251
+ prompt,
252
+ negative_prompt,
253
+ seed,
254
+ randomize_seed,
255
+ width,
256
+ height,
257
+ guidance_scale,
258
+ num_inference_steps,
259
+ interval_seconds,
260
+ ]
261
+
262
+ consecutive_event = consecutive_button.click(
263
+ fn=generation_loop,
264
+ inputs=gen_inputs,
265
+ outputs=[image_url, seed], # ← 変更: result → image_url
266
+ )
267
+
268
+ # 停止
269
+ stop_button.click(
270
+ fn=None,
271
+ inputs=None,
272
+ outputs=None,
273
+ cancels=[consecutive_event],
274
+ )
275
+
276
+ # result.change は不要になったためコメントアウト
277
+ # result.change(
278
+ # fn=None,
279
+ # inputs=None,
280
+ # outputs=[image_url],
281
+ # js=r"""
282
+ # () => {
283
+ # const img = document.querySelector("#result_image img");
284
+ # if (!img || !img.src) return "";
285
+ # return new URL(img.src, window.location.href).href;
286
+ # }
287
+ # """,
288
+ # )
289
+
290
+ # Copyボタン:URL文字列をコピー(相対パスは絶対URLに変換)
291
+ copy_button.click(
292
+ fn=None,
293
+ inputs=[image_url],
294
+ outputs=None,
295
+ js=r"""
296
  async (url) => {
297
+ if (!url) return;
298
+ try {
299
+ const fullUrl = url.startsWith('/') ? window.location.origin + url : url;
300
+ await navigator.clipboard.writeText(fullUrl);
301
+ console.log("URL copied");
302
+ } catch (e) {
303
+ console.error("Copy failed", e);
304
+ }
305
  }
306
  """,
307
+ )
308
 
309
  demo.queue().launch()