xiangfan00 commited on
Commit
efdefa8
·
1 Parent(s): eb10dec

Optimize and apply fixes

Browse files
Files changed (1) hide show
  1. app.py +303 -51
app.py CHANGED
@@ -2,6 +2,7 @@ import gc
2
  import html
3
  import random
4
  import sys
 
5
  import uuid
6
  from pathlib import Path
7
  from urllib.parse import quote
@@ -86,7 +87,11 @@ def log_cuda_mem(tag):
86
  print(f"[mem] {tag}: CUDA not available")
87
  return
88
 
89
- free_bytes, total_bytes = torch.cuda.mem_get_info()
 
 
 
 
90
  allocated_bytes = torch.cuda.memory_allocated()
91
  reserved_bytes = torch.cuda.memory_reserved()
92
  print(
@@ -123,7 +128,6 @@ def load_generation_pipe():
123
  image_encoder=image_encoder,
124
  torch_dtype=PIPE_DTYPE,
125
  )
126
- pipe = pipe.to(DEVICE)
127
  log_cuda_mem("after load_generation_pipe")
128
  return pipe
129
 
@@ -135,7 +139,6 @@ def load_wan_vae():
135
  subfolder="vae",
136
  torch_dtype=PIPE_DTYPE,
137
  )
138
- vae = vae.to(DEVICE)
139
  vae.eval()
140
  log_cuda_mem("after load_wan_vae")
141
  return vae
@@ -170,12 +173,17 @@ def load_refdecoder_module():
170
  vae.load_state_dict(vae_sd, strict=False)
171
  transformer.load_state_dict(transformer_sd, strict=False)
172
 
173
- vae = vae.to(DEVICE).eval()
174
- transformer = transformer.to(DEVICE).eval()
175
  log_cuda_mem("after load_refdecoder_module")
176
  return vae, transformer
177
 
178
 
 
 
 
 
 
 
 
179
  def resize_image_for_wan(image, pipe):
180
  image = image.convert("RGB")
181
  aspect_ratio = image.height / image.width
@@ -240,8 +248,8 @@ def build_compare_html(wan_video_path, ref_video_path):
240
  gap: 12px;
241
  }}
242
  .compare-topbar {{
243
- display: grid;
244
- grid-template-columns: 1fr auto 1fr;
245
  align-items: center;
246
  gap: 12px;
247
  }}
@@ -339,13 +347,33 @@ def build_compare_html(wan_video_path, ref_video_path):
339
  line-height: 1.5;
340
  text-align: center;
341
  }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  </style>
343
  </head>
344
  <body>
345
  <div class="compare-shell" id="{compare_id}">
346
  <div class="compare-topbar">
347
  <div class="compare-chip">Wan Baseline</div>
348
- <button class="compare-button" type="button">Pause</button>
349
  <div class="compare-chip compare-chip-right">RefDecoder</div>
350
  </div>
351
  <div class="compare-stage">
@@ -354,6 +382,12 @@ def build_compare_html(wan_video_path, ref_video_path):
354
  <div class="compare-divider"></div>
355
  <input class="compare-range" type="range" min="0" max="100" value="50" />
356
  </div>
 
 
 
 
 
 
357
  <div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div>
358
  </div>
359
  <script>
@@ -363,8 +397,13 @@ def build_compare_html(wan_video_path, ref_video_path):
363
  const overlay = root.querySelector(".compare-overlay");
364
  const divider = root.querySelector(".compare-divider");
365
  const slider = root.querySelector(".compare-range");
366
- const button = root.querySelector(".compare-button");
 
 
 
 
367
  const videos = Array.from(root.querySelectorAll("video"));
 
368
 
369
  const applySplit = () => {{
370
  const value = Number(slider.value);
@@ -396,6 +435,28 @@ def build_compare_html(wan_video_path, ref_video_path):
396
  primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }});
397
  }};
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{
400
  bindSync(base, overlay);
401
  bindSync(overlay, base);
@@ -403,6 +464,7 @@ def build_compare_html(wan_video_path, ref_video_path):
403
  button.disabled = true;
404
  button.textContent = "Play";
405
  button.style.opacity = "0.55";
 
406
  }}
407
 
408
  videos.forEach((video) => {{
@@ -417,8 +479,25 @@ def build_compare_html(wan_video_path, ref_video_path):
417
  }}
418
  }});
419
 
 
 
 
 
420
  slider.addEventListener("input", applySplit);
421
  applySplit();
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  }})();
423
  </script>
424
  </body>
@@ -480,30 +559,30 @@ def decode_with_refdecoder(latents, reference_frame, vae, transformer):
480
  return video
481
 
482
 
483
- def button_state(label, interactive):
484
- return gr.update(value=label, interactive=interactive)
485
-
486
-
487
- @spaces.GPU(duration=80)
488
  def generate_latents_on_gpu(image, prompt, seed):
489
  log_cuda_mem("start generate_latents_on_gpu")
490
- pipe = load_generation_pipe()
491
- resized_image, height, width = resize_image_for_wan(image, pipe)
 
492
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
493
- with torch.no_grad():
494
- output = pipe(
495
- image=resized_image,
496
- prompt=prompt,
497
- negative_prompt=NEGATIVE_PROMPT,
498
- height=height,
499
- width=width,
500
- num_frames=NUM_FRAMES,
501
- num_inference_steps=NUM_INFERENCE_STEPS,
502
- guidance_scale=GUIDANCE_SCALE,
503
- generator=generator,
504
- output_type="latent",
505
- )
506
- latents = normalize_latent_shape(output.frames).detach().cpu()
 
 
 
507
  log_cuda_mem("after latent generation")
508
  return latents, resized_image, height, width
509
 
@@ -511,37 +590,52 @@ def generate_latents_on_gpu(image, prompt, seed):
511
  @spaces.GPU(duration=20)
512
  def decode_wan_on_gpu(latents):
513
  log_cuda_mem("start decode_wan_on_gpu")
514
- wan_vae = load_wan_vae()
515
- video = decode_with_wan_vae(latents, wan_vae)
 
 
 
 
516
  log_cuda_mem("after wan decode")
517
- return video.detach().cpu()
518
 
519
 
520
  @spaces.GPU(duration=25)
521
  def decode_refdecoder_on_gpu(latents, reference_frame):
522
  log_cuda_mem("start decode_refdecoder_on_gpu")
523
- ref_vae, ref_transformer = load_refdecoder_module()
524
- video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
 
 
 
 
 
 
 
 
525
  log_cuda_mem("after refdecoder decode")
526
- return video.detach().cpu()
527
 
528
 
529
- def generate_and_decode(image, prompt, seed):
530
  if image is None:
531
  raise gr.Error("Please upload an input image.")
532
  if DEVICE != "cuda":
533
  raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.")
534
 
535
- yield gr.update(), gr.update(), gr.update(), button_state("Loading Wan I2V...", False)
536
 
537
  prompt = prompt.strip() if prompt else ""
538
  seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
539
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
540
  run_dir.mkdir(parents=True, exist_ok=True)
541
 
542
- yield gr.update(), gr.update(), gr.update(), button_state("Generating Latents...", False)
543
 
 
544
  latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed)
 
 
545
  reference_frame = build_reference_frame(resized_image, "cpu")
546
 
547
  latent_path = run_dir / "wan_latents.pt"
@@ -556,26 +650,41 @@ def generate_and_decode(image, prompt, seed):
556
  latent_path,
557
  )
558
 
559
- yield gr.update(), gr.update(), gr.update(), button_state("Decoding Wan Baseline...", False)
560
 
 
561
  wan_video = decode_wan_on_gpu(latents)
 
 
562
  wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
563
  del wan_video
564
  gc.collect()
565
 
566
- yield gr.update(), wan_video_path, gr.update(), button_state("Decoding RefDecoder...", False)
567
 
 
568
  ref_video = decode_refdecoder_on_gpu(latents, reference_frame)
 
 
569
  ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
570
  del ref_video
571
  gc.collect()
572
 
573
  compare_html = build_compare_html(wan_video_path, ref_video_path)
574
- yield (
 
 
 
 
 
 
 
575
  gr.update(value=compare_html, visible=True),
576
  wan_video_path,
577
  ref_video_path,
578
- button_state("Generate Comparison", True),
 
 
579
  )
580
 
581
 
@@ -709,16 +818,14 @@ CUSTOM_CSS = """
709
 
710
  .compare-frame {
711
  width: 100%;
712
- height: 860px;
 
 
713
  border: 0;
714
  background: transparent;
715
  overflow: hidden;
716
- }
717
-
718
- @media (max-width: 900px) {
719
- .compare-frame {
720
- height: 720px;
721
- }
722
  }
723
 
724
  .compare-topbar {
@@ -818,6 +925,67 @@ CUSTOM_CSS = """
818
  .seed-action-row > .gradio-column {
819
  min-width: 0;
820
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821
  """
822
 
823
 
@@ -825,6 +993,55 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
825
  with gr.Column(elem_classes="app-shell"):
826
  gr.HTML(
827
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
  <div class="hero-card">
829
  <div class="hero-title">RefDecoder I2V Demo</div>
830
  <p class="hero-copy">
@@ -871,6 +1088,7 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
871
  variant="primary",
872
  elem_id="generate-btn",
873
  )
 
874
 
875
  with gr.Column(elem_classes="panel-card"):
876
  gr.HTML(
@@ -883,13 +1101,47 @@ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_C
883
  )
884
  compare_output = gr.HTML(value=build_compare_html(None, None))
885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
886
  wan_video_hidden = gr.Video(visible=False)
887
  ref_video_hidden = gr.Video(visible=False)
888
 
 
 
 
 
 
 
 
889
  run_button.click(
 
 
 
 
 
 
890
  fn=generate_and_decode,
891
  inputs=[image_input, prompt_input, seed_input],
892
- outputs=[compare_output, wan_video_hidden, ref_video_hidden, run_button],
 
 
 
 
 
 
 
893
  )
894
 
895
 
 
2
  import html
3
  import random
4
  import sys
5
+ import time
6
  import uuid
7
  from pathlib import Path
8
  from urllib.parse import quote
 
87
  print(f"[mem] {tag}: CUDA not available")
88
  return
89
 
90
+ try:
91
+ free_bytes, total_bytes = torch.cuda.mem_get_info()
92
+ except RuntimeError as exc:
93
+ print(f"[mem] {tag}: CUDA not currently leased ({exc})")
94
+ return
95
  allocated_bytes = torch.cuda.memory_allocated()
96
  reserved_bytes = torch.cuda.memory_reserved()
97
  print(
 
128
  image_encoder=image_encoder,
129
  torch_dtype=PIPE_DTYPE,
130
  )
 
131
  log_cuda_mem("after load_generation_pipe")
132
  return pipe
133
 
 
139
  subfolder="vae",
140
  torch_dtype=PIPE_DTYPE,
141
  )
 
142
  vae.eval()
143
  log_cuda_mem("after load_wan_vae")
144
  return vae
 
173
  vae.load_state_dict(vae_sd, strict=False)
174
  transformer.load_state_dict(transformer_sd, strict=False)
175
 
 
 
176
  log_cuda_mem("after load_refdecoder_module")
177
  return vae, transformer
178
 
179
 
180
+ # Preload all models on CPU at init so each @spaces.GPU lease only pays for the
181
+ # CPU -> GPU transfer, not the full from_pretrained / checkpoint read.
182
+ GENERATION_PIPE = load_generation_pipe()
183
+ WAN_VAE = load_wan_vae()
184
+ REFDECODER_VAE, REFDECODER_TRANSFORMER = load_refdecoder_module()
185
+
186
+
187
  def resize_image_for_wan(image, pipe):
188
  image = image.convert("RGB")
189
  aspect_ratio = image.height / image.width
 
248
  gap: 12px;
249
  }}
250
  .compare-topbar {{
251
+ display: flex;
252
+ justify-content: space-between;
253
  align-items: center;
254
  gap: 12px;
255
  }}
 
347
  line-height: 1.5;
348
  text-align: center;
349
  }}
350
+ .compare-controls {{
351
+ display: flex;
352
+ justify-content: center;
353
+ align-items: center;
354
+ gap: 10px;
355
+ flex-wrap: wrap;
356
+ }}
357
+ .compare-controls .compare-button {{
358
+ padding: 9px 16px;
359
+ font-size: 14px;
360
+ }}
361
+ .compare-button-step {{
362
+ background: #2f5746;
363
+ }}
364
+ .compare-button-reset {{
365
+ background: #c96f42;
366
+ }}
367
+ .compare-button[disabled] {{
368
+ opacity: 0.55;
369
+ cursor: not-allowed;
370
+ }}
371
  </style>
372
  </head>
373
  <body>
374
  <div class="compare-shell" id="{compare_id}">
375
  <div class="compare-topbar">
376
  <div class="compare-chip">Wan Baseline</div>
 
377
  <div class="compare-chip compare-chip-right">RefDecoder</div>
378
  </div>
379
  <div class="compare-stage">
 
382
  <div class="compare-divider"></div>
383
  <input class="compare-range" type="range" min="0" max="100" value="50" />
384
  </div>
385
+ <div class="compare-controls">
386
+ <button class="compare-button compare-button-step" type="button" data-action="prev">− 1 Frame</button>
387
+ <button class="compare-button" type="button" data-action="toggle">Pause</button>
388
+ <button class="compare-button compare-button-step" type="button" data-action="next">+ 1 Frame</button>
389
+ <button class="compare-button compare-button-reset" type="button" data-action="reset">Reset Playback</button>
390
+ </div>
391
  <div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div>
392
  </div>
393
  <script>
 
397
  const overlay = root.querySelector(".compare-overlay");
398
  const divider = root.querySelector(".compare-divider");
399
  const slider = root.querySelector(".compare-range");
400
+ const button = root.querySelector('[data-action="toggle"]');
401
+ const prevBtn = root.querySelector('[data-action="prev"]');
402
+ const nextBtn = root.querySelector('[data-action="next"]');
403
+ const resetBtn = root.querySelector('[data-action="reset"]');
404
+ const stepButtons = [prevBtn, nextBtn, resetBtn];
405
  const videos = Array.from(root.querySelectorAll("video"));
406
+ const FRAME_DELTA = 1 / {FPS};
407
 
408
  const applySplit = () => {{
409
  const value = Number(slider.value);
 
435
  primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }});
436
  }};
437
 
438
+ const stepFrame = (delta) => {{
439
+ if (!videos.length) return;
440
+ pauseBoth();
441
+ videos.forEach((video) => {{
442
+ const duration = isFinite(video.duration) ? video.duration : 0;
443
+ let nextTime = (video.currentTime || 0) + delta;
444
+ if (duration > 0) {{
445
+ nextTime = ((nextTime % duration) + duration) % duration;
446
+ }} else {{
447
+ nextTime = Math.max(0, nextTime);
448
+ }}
449
+ try {{ video.currentTime = nextTime; }} catch (e) {{}}
450
+ }});
451
+ }};
452
+
453
+ const resetPlayback = () => {{
454
+ pauseBoth();
455
+ videos.forEach((video) => {{
456
+ try {{ video.currentTime = 0; }} catch (e) {{}}
457
+ }});
458
+ }};
459
+
460
  if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{
461
  bindSync(base, overlay);
462
  bindSync(overlay, base);
 
464
  button.disabled = true;
465
  button.textContent = "Play";
466
  button.style.opacity = "0.55";
467
+ stepButtons.forEach((btn) => {{ if (btn) btn.disabled = true; }});
468
  }}
469
 
470
  videos.forEach((video) => {{
 
479
  }}
480
  }});
481
 
482
+ if (prevBtn) prevBtn.addEventListener("click", () => stepFrame(-FRAME_DELTA));
483
+ if (nextBtn) nextBtn.addEventListener("click", () => stepFrame(FRAME_DELTA));
484
+ if (resetBtn) resetBtn.addEventListener("click", resetPlayback);
485
+
486
  slider.addEventListener("input", applySplit);
487
  applySplit();
488
+
489
+ const reportHeight = () => {{
490
+ const h = Math.ceil(root.getBoundingClientRect().height + 2);
491
+ parent.postMessage({{ type: "compare-iframe-height", id: "{compare_id}", height: h }}, "*");
492
+ }};
493
+ reportHeight();
494
+ window.addEventListener("load", reportHeight);
495
+ if (typeof ResizeObserver !== "undefined") {{
496
+ new ResizeObserver(reportHeight).observe(root);
497
+ }}
498
+ videos.forEach((video) => {{
499
+ video.addEventListener("loadedmetadata", reportHeight);
500
+ }});
501
  }})();
502
  </script>
503
  </body>
 
559
  return video
560
 
561
 
562
+ @spaces.GPU(duration=95)
 
 
 
 
563
  def generate_latents_on_gpu(image, prompt, seed):
564
  log_cuda_mem("start generate_latents_on_gpu")
565
+ GENERATION_PIPE.to(DEVICE)
566
+ log_cuda_mem("after pipe -> cuda")
567
+ resized_image, height, width = resize_image_for_wan(image, GENERATION_PIPE)
568
  generator = torch.Generator(device=DEVICE).manual_seed(seed)
569
+ try:
570
+ with torch.no_grad():
571
+ output = GENERATION_PIPE(
572
+ image=resized_image,
573
+ prompt=prompt,
574
+ negative_prompt=NEGATIVE_PROMPT,
575
+ height=height,
576
+ width=width,
577
+ num_frames=NUM_FRAMES,
578
+ num_inference_steps=NUM_INFERENCE_STEPS,
579
+ guidance_scale=GUIDANCE_SCALE,
580
+ generator=generator,
581
+ output_type="latent",
582
+ )
583
+ latents = normalize_latent_shape(output.frames).detach().cpu()
584
+ finally:
585
+ GENERATION_PIPE.to("cpu")
586
  log_cuda_mem("after latent generation")
587
  return latents, resized_image, height, width
588
 
 
590
  @spaces.GPU(duration=20)
591
  def decode_wan_on_gpu(latents):
592
  log_cuda_mem("start decode_wan_on_gpu")
593
+ WAN_VAE.to(DEVICE)
594
+ try:
595
+ video = decode_with_wan_vae(latents, WAN_VAE)
596
+ video = video.detach().cpu()
597
+ finally:
598
+ WAN_VAE.to("cpu")
599
  log_cuda_mem("after wan decode")
600
+ return video
601
 
602
 
603
  @spaces.GPU(duration=25)
604
  def decode_refdecoder_on_gpu(latents, reference_frame):
605
  log_cuda_mem("start decode_refdecoder_on_gpu")
606
+ REFDECODER_VAE.to(DEVICE)
607
+ REFDECODER_TRANSFORMER.to(DEVICE)
608
+ try:
609
+ video = decode_with_refdecoder(
610
+ latents, reference_frame, REFDECODER_VAE, REFDECODER_TRANSFORMER,
611
+ )
612
+ video = video.detach().cpu()
613
+ finally:
614
+ REFDECODER_VAE.to("cpu")
615
+ REFDECODER_TRANSFORMER.to("cpu")
616
  log_cuda_mem("after refdecoder decode")
617
+ return video
618
 
619
 
620
+ def generate_and_decode(image, prompt, seed, progress=gr.Progress()):
621
  if image is None:
622
  raise gr.Error("Please upload an input image.")
623
  if DEVICE != "cuda":
624
  raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.")
625
 
626
+ request_start = time.perf_counter()
627
 
628
  prompt = prompt.strip() if prompt else ""
629
  seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
630
  run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
631
  run_dir.mkdir(parents=True, exist_ok=True)
632
 
633
+ progress(0.0, desc="Generating latents")
634
 
635
+ t0 = time.perf_counter()
636
  latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed)
637
+ latent_secs = time.perf_counter() - t0
638
+ print(f"[timing] latent generation: {latent_secs:.2f}s")
639
  reference_frame = build_reference_frame(resized_image, "cpu")
640
 
641
  latent_path = run_dir / "wan_latents.pt"
 
650
  latent_path,
651
  )
652
 
653
+ progress(0.8, desc="Decoding Wan baseline")
654
 
655
+ t0 = time.perf_counter()
656
  wan_video = decode_wan_on_gpu(latents)
657
+ wan_secs = time.perf_counter() - t0
658
+ print(f"[timing] wan decode: {wan_secs:.2f}s")
659
  wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
660
  del wan_video
661
  gc.collect()
662
 
663
+ progress(0.9, desc="Decoding RefDecoder")
664
 
665
+ t0 = time.perf_counter()
666
  ref_video = decode_refdecoder_on_gpu(latents, reference_frame)
667
+ ref_secs = time.perf_counter() - t0
668
+ print(f"[timing] refdecoder decode: {ref_secs:.2f}s")
669
  ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
670
  del ref_video
671
  gc.collect()
672
 
673
  compare_html = build_compare_html(wan_video_path, ref_video_path)
674
+
675
+ total_secs = time.perf_counter() - request_start
676
+ print(
677
+ f"[timing] request total: {total_secs:.2f}s "
678
+ f"(latents={latent_secs:.2f}s, wan={wan_secs:.2f}s, ref={ref_secs:.2f}s)"
679
+ )
680
+
681
+ return (
682
  gr.update(value=compare_html, visible=True),
683
  wan_video_path,
684
  ref_video_path,
685
+ "",
686
+ gr.update(value=wan_video_path, interactive=True),
687
+ gr.update(value=ref_video_path, interactive=True),
688
  )
689
 
690
 
 
818
 
819
  .compare-frame {
820
  width: 100%;
821
+ /* aspect-ratio is a tight fallback for the brief moment before the parent
822
+ JS estimator (and then the iframe's own postMessage) sets the height. */
823
+ aspect-ratio: 16 / 11;
824
  border: 0;
825
  background: transparent;
826
  overflow: hidden;
827
+ display: block;
828
+ transition: height 120ms ease;
 
 
 
 
829
  }
830
 
831
  .compare-topbar {
 
925
  .seed-action-row > .gradio-column {
926
  min-width: 0;
927
  }
928
+
929
+ .run-status {
930
+ margin-top: 8px;
931
+ color: var(--text-soft);
932
+ font-size: 13px;
933
+ line-height: 1.4;
934
+ min-height: 1.4em;
935
+ }
936
+
937
+ .run-status p {
938
+ margin: 0;
939
+ }
940
+
941
+ .download-row {
942
+ margin-top: 12px;
943
+ gap: 12px;
944
+ justify-content: center;
945
+ flex-wrap: wrap;
946
+ }
947
+
948
+ .download-row button {
949
+ border: 0 !important;
950
+ border-radius: 999px !important;
951
+ padding: 10px 22px !important;
952
+ font-size: 14px !important;
953
+ font-weight: 700 !important;
954
+ box-shadow: none !important;
955
+ min-height: 0 !important;
956
+ }
957
+
958
+ button.download-baseline {
959
+ background: var(--accent) !important;
960
+ color: #fff !important;
961
+ }
962
+
963
+ button.download-ref {
964
+ background: var(--accent-2) !important;
965
+ color: #fff !important;
966
+ }
967
+
968
+ .download-row button:hover:not([disabled]):not(:disabled) {
969
+ filter: brightness(1.05);
970
+ }
971
+
972
+ button.download-baseline[disabled],
973
+ button.download-baseline:disabled {
974
+ background: rgba(31, 106, 82, 0.14) !important;
975
+ color: #123a2d !important;
976
+ box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12) !important;
977
+ opacity: 1 !important;
978
+ cursor: not-allowed;
979
+ }
980
+
981
+ button.download-ref[disabled],
982
+ button.download-ref:disabled {
983
+ background: rgba(201, 111, 66, 0.16) !important;
984
+ color: #6e3d23 !important;
985
+ box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16) !important;
986
+ opacity: 1 !important;
987
+ cursor: not-allowed;
988
+ }
989
  """
990
 
991
 
 
993
  with gr.Column(elem_classes="app-shell"):
994
  gr.HTML(
995
  """
996
+ <script>
997
+ (() => {
998
+ if (window.__refdecoderResizeBound) return;
999
+ window.__refdecoderResizeBound = true;
1000
+
1001
+ const STAGE_RATIO = 9 / 16;
1002
+ const CHROME = 160;
1003
+ const observed = new WeakSet();
1004
+
1005
+ const estimateHeight = (iframe) => {
1006
+ if (iframe.dataset.exactSized === "1") return;
1007
+ const w = iframe.getBoundingClientRect().width;
1008
+ if (w > 0) {
1009
+ iframe.style.height = Math.round(w * STAGE_RATIO + CHROME) + "px";
1010
+ }
1011
+ };
1012
+
1013
+ const trackIframe = (iframe) => {
1014
+ if (observed.has(iframe)) return;
1015
+ observed.add(iframe);
1016
+ estimateHeight(iframe);
1017
+ new ResizeObserver(() => estimateHeight(iframe)).observe(iframe);
1018
+ };
1019
+
1020
+ document.querySelectorAll("iframe.compare-frame").forEach(trackIframe);
1021
+
1022
+ new MutationObserver((mutations) => {
1023
+ for (const m of mutations) {
1024
+ for (const n of m.addedNodes) {
1025
+ if (n.nodeType !== 1) continue;
1026
+ if (n.matches && n.matches("iframe.compare-frame")) trackIframe(n);
1027
+ const inner = n.querySelectorAll && n.querySelectorAll("iframe.compare-frame");
1028
+ if (inner) inner.forEach(trackIframe);
1029
+ }
1030
+ }
1031
+ }).observe(document.body, { childList: true, subtree: true });
1032
+
1033
+ window.addEventListener("message", (e) => {
1034
+ if (!e.data || e.data.type !== "compare-iframe-height") return;
1035
+ const h = Math.max(200, Number(e.data.height) || 0);
1036
+ document.querySelectorAll("iframe.compare-frame").forEach((f) => {
1037
+ if (f.contentWindow === e.source) {
1038
+ f.style.height = h + "px";
1039
+ f.dataset.exactSized = "1";
1040
+ }
1041
+ });
1042
+ });
1043
+ })();
1044
+ </script>
1045
  <div class="hero-card">
1046
  <div class="hero-title">RefDecoder I2V Demo</div>
1047
  <p class="hero-copy">
 
1088
  variant="primary",
1089
  elem_id="generate-btn",
1090
  )
1091
+ status_md = gr.Markdown(value="", elem_classes="run-status")
1092
 
1093
  with gr.Column(elem_classes="panel-card"):
1094
  gr.HTML(
 
1101
  )
1102
  compare_output = gr.HTML(value=build_compare_html(None, None))
1103
 
1104
+ with gr.Row(elem_classes="download-row"):
1105
+ wan_download_btn = gr.DownloadButton(
1106
+ label="Download Baseline",
1107
+ value=None,
1108
+ interactive=False,
1109
+ elem_classes="download-baseline",
1110
+ )
1111
+ ref_download_btn = gr.DownloadButton(
1112
+ label="Download RefDecoder",
1113
+ value=None,
1114
+ interactive=False,
1115
+ elem_classes="download-ref",
1116
+ )
1117
+
1118
  wan_video_hidden = gr.Video(visible=False)
1119
  ref_video_hidden = gr.Video(visible=False)
1120
 
1121
+ def reset_for_new_run():
1122
+ return (
1123
+ "",
1124
+ gr.update(value=None, interactive=False),
1125
+ gr.update(value=None, interactive=False),
1126
+ )
1127
+
1128
  run_button.click(
1129
+ fn=reset_for_new_run,
1130
+ inputs=None,
1131
+ outputs=[status_md, wan_download_btn, ref_download_btn],
1132
+ queue=False,
1133
+ show_progress="hidden",
1134
+ ).then(
1135
  fn=generate_and_decode,
1136
  inputs=[image_input, prompt_input, seed_input],
1137
+ outputs=[
1138
+ compare_output,
1139
+ wan_video_hidden,
1140
+ ref_video_hidden,
1141
+ status_md,
1142
+ wan_download_btn,
1143
+ ref_download_btn,
1144
+ ],
1145
  )
1146
 
1147