apolinario commited on
Commit
e2f50b1
·
1 Parent(s): afb0b5a

Drop intermediate captures + gallery; use gr.ImageSlider for Z-Image vs PiD A/B; dynamic 'Generating Z-Image step X/N' / 'Upscaling' labels

Browse files
Files changed (1) hide show
  1. app.py +26 -36
app.py CHANGED
@@ -176,7 +176,6 @@ import queue as _queue
176
  def generate(
177
  prompt: str,
178
  num_inference_steps: int = 28,
179
- num_captures: int = 4,
180
  guidance_scale: float = 5.0,
181
  seed: int = 0,
182
  resolution: int = 512,
@@ -185,22 +184,16 @@ def generate(
185
  raise gr.Error("Please enter a prompt.")
186
 
187
  num_inference_steps = int(num_inference_steps)
188
- num_captures = int(num_captures)
189
  H = W = int(resolution)
190
 
191
- # initial: show the live-preview image, hide the final gallery
192
- yield gr.update(visible=True, value=None), gr.update(visible=False, value=None)
193
-
194
- capture_ks = set(_evenly_spaced_capture_steps(num_inference_steps, num_captures))
195
- xt_cb = XtCaptureCallback(capture_ks) if capture_ks else None
196
 
197
  # ---- Run Z-Image in a thread; stream taef1 previews via a queue ----
198
  preview_q: "_queue.Queue" = _queue.Queue()
199
  _DONE = object()
200
 
201
  def streaming_cb(pipe, step_index, timestep, callback_kwargs):
202
- if xt_cb is not None:
203
- xt_cb(pipe, step_index, timestep, callback_kwargs)
204
  try:
205
  preview = _taef1_preview(callback_kwargs["latents"], H, W)
206
  preview_q.put((step_index, preview))
@@ -241,33 +234,31 @@ def generate(
241
  raise payload
242
  raw_output = payload
243
  break
244
- yield gr.update(visible=True, value=payload), gr.update(visible=False)
 
245
 
246
  thread.join()
247
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
248
 
249
- # ---- PiD per-step decode (sequentially) ----
250
- steps_iter = []
251
- if xt_cb is not None:
252
- for K in sorted(xt_cb.captured.keys()):
253
- xt_packed_cpu, sigma = xt_cb.captured[K]
254
- xt_packed = xt_packed_cpu.to(device="cuda", dtype=DTYPE)
255
- xt_latent = extract_latent(pipeline, SimpleNamespace(images=xt_packed), pipe_cfg, H, W)
256
- steps_iter.append((f"step {K:02d}/{num_inference_steps}", xt_latent, sigma))
257
- final_sigma = float(pipeline.scheduler.sigmas[-1].item())
258
- steps_iter.append(("final x₀", final_latent, final_sigma))
259
 
260
- outputs: list[tuple[Image.Image, str]] = []
261
- for label, latent, sigma in steps_iter:
262
- with torch.no_grad():
263
- baseline_01 = decode_with_pipeline_vae(pipeline, latent, pipe_cfg)
264
- pid_img = _pid_decode(latent, baseline_01, sigma, prompt)
265
- outputs.append((pid_img, f"{label} (σ={sigma:.3f})"))
266
- # Flash the latest PiD output in the live-preview image during PiD decoding too
267
- yield gr.update(visible=True, value=pid_img), gr.update(visible=False)
268
 
269
- # ---- Done: hide live preview, show the final gallery ----
270
- yield gr.update(visible=False, value=None), gr.update(visible=True, value=outputs)
 
 
 
271
 
272
 
273
  DESCRIPTION = """
@@ -297,18 +288,17 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
297
  resolution = gr.Slider(label="Z-Image resolution", minimum=256, maximum=1024, step=128, value=512)
298
  num_inference_steps = gr.Slider(label="Z-Image steps", minimum=8, maximum=50, step=1, value=28)
299
  with gr.Row():
300
- num_captures = gr.Slider(label="Intermediate captures", minimum=1, maximum=8, step=1, value=4)
301
  guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
302
- seed = gr.Number(label="Seed", value=0, precision=0)
303
  run = gr.Button("Run", variant="primary")
304
  with gr.Column(scale=2):
305
- live_preview = gr.Image(label="Live preview", visible=True, show_label=True, type="pil")
306
- gallery = gr.Gallery(label="PiD-decoded denoising trajectory", visible=False, columns=2, object_fit="contain")
307
 
308
  run.click(
309
  fn=generate,
310
- inputs=[prompt, num_inference_steps, num_captures, guidance_scale, seed, resolution],
311
- outputs=[live_preview, gallery],
312
  )
313
 
314
  if __name__ == "__main__":
 
176
  def generate(
177
  prompt: str,
178
  num_inference_steps: int = 28,
 
179
  guidance_scale: float = 5.0,
180
  seed: int = 0,
181
  resolution: int = 512,
 
184
  raise gr.Error("Please enter a prompt.")
185
 
186
  num_inference_steps = int(num_inference_steps)
 
187
  H = W = int(resolution)
188
 
189
+ # initial: show the live preview, hide the final slider
190
+ yield gr.update(visible=True, value=None, label="Generating Z-Image…"), gr.update(visible=False, value=None)
 
 
 
191
 
192
  # ---- Run Z-Image in a thread; stream taef1 previews via a queue ----
193
  preview_q: "_queue.Queue" = _queue.Queue()
194
  _DONE = object()
195
 
196
  def streaming_cb(pipe, step_index, timestep, callback_kwargs):
 
 
197
  try:
198
  preview = _taef1_preview(callback_kwargs["latents"], H, W)
199
  preview_q.put((step_index, preview))
 
234
  raise payload
235
  raw_output = payload
236
  break
237
+ label = f"Generating Z-Image — step {step_index + 1}/{num_inference_steps}"
238
+ yield gr.update(visible=True, value=payload, label=label), gr.update(visible=False)
239
 
240
  thread.join()
241
  final_latent = extract_latent(pipeline, raw_output, pipe_cfg, H, W)
242
 
243
+ # ---- VAE decode of the final clean latent (Z-Image baseline) ----
244
+ yield gr.update(visible=True, label="Decoding final Z-Image…"), gr.update(visible=False)
245
+ with torch.no_grad():
246
+ baseline_01 = decode_with_pipeline_vae(pipeline, final_latent, pipe_cfg)
247
+ zimage_img = Image.fromarray(
248
+ (baseline_01[0].clamp(0, 1).permute(1, 2, 0).float().cpu().numpy() * 255).astype(np.uint8)
249
+ )
 
 
 
250
 
251
+ # ---- PiD upscaling on the final latent ----
252
+ yield gr.update(visible=True, value=zimage_img, label="Upscaling with PiD (4× super-resolution, 4 steps)…"), gr.update(visible=False)
253
+ final_sigma = float(pipeline.scheduler.sigmas[-1].item())
254
+ with torch.no_grad():
255
+ pid_img = _pid_decode(final_latent, baseline_01, final_sigma, prompt)
 
 
 
256
 
257
+ # ---- Done: hide live preview, show the A/B slider ----
258
+ yield (
259
+ gr.update(visible=False, value=None),
260
+ gr.update(visible=True, value=(zimage_img, pid_img)),
261
+ )
262
 
263
 
264
  DESCRIPTION = """
 
288
  resolution = gr.Slider(label="Z-Image resolution", minimum=256, maximum=1024, step=128, value=512)
289
  num_inference_steps = gr.Slider(label="Z-Image steps", minimum=8, maximum=50, step=1, value=28)
290
  with gr.Row():
 
291
  guidance_scale = gr.Slider(label="Guidance", minimum=1.0, maximum=10.0, step=0.5, value=5.0)
292
+ seed = gr.Number(label="Seed", value=0, precision=0)
293
  run = gr.Button("Run", variant="primary")
294
  with gr.Column(scale=2):
295
+ live_preview = gr.Image(label="Generating Z-Image…", visible=True, show_label=True, type="pil")
296
+ slider = gr.ImageSlider(label="Z-Image (left) ↔ PiD 4× upscale (right)", visible=False, type="pil")
297
 
298
  run.click(
299
  fn=generate,
300
+ inputs=[prompt, num_inference_steps, guidance_scale, seed, resolution],
301
+ outputs=[live_preview, slider],
302
  )
303
 
304
  if __name__ == "__main__":