Spaces:
Running on Zero
Running on Zero
| import gc | |
| import html | |
| import random | |
| import sys | |
| import uuid | |
| from pathlib import Path | |
| from urllib.parse import quote | |
| import gradio as gr | |
| import imageio | |
| import numpy as np | |
| import ftfy | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesShim: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesShim() | |
| import torch | |
| from diffusers.pipelines.wan import pipeline_wan_i2v | |
| from diffusers import AutoencoderKLWan as DiffusersWanVAE | |
| from diffusers import WanImageToVideoPipeline | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from transformers import CLIPVisionModel | |
| from src.models.Wan.autoencoder_wanT import AutoencoderKLWan | |
| from src.models.Wan.transformer_wan import WanDecoderTransformer | |
| ROOT = Path(__file__).resolve().parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
| REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder" | |
| REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt" | |
| OUTPUT_ROOT = ROOT / "gradio_outputs" | |
| NEGATIVE_PROMPT = ( | |
| "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " | |
| "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " | |
| "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " | |
| "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " | |
| "in the background, walking backwards" | |
| ) | |
| TARGET_AREA = 480 * 832 | |
| FPS = 16 | |
| NUM_FRAMES = 17 | |
| NUM_INFERENCE_STEPS = 50 | |
| GUIDANCE_SCALE = 5.0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| # Some diffusers Wan builds reference a module-level `ftfy` during prompt cleaning. | |
| # Make it explicit so Spaces don't fail if that global was not initialized. | |
| pipeline_wan_i2v.ftfy = ftfy | |
| def download_refdecoder_ckpt(): | |
| print("[init] Downloading RefDecoder checkpoint metadata/file if needed") | |
| ckpt_path = hf_hub_download( | |
| repo_id=REFDECODER_REPO_ID, | |
| filename=REFDECODER_CKPT_PATH_IN_REPO, | |
| ) | |
| print(f"[init] RefDecoder checkpoint ready at: {ckpt_path}") | |
| return ckpt_path | |
| def download_wan_weights(): | |
| print(f"[init] Downloading Wan I2V weights from {MODEL_ID}") | |
| repo_dir = snapshot_download(repo_id=MODEL_ID) | |
| print(f"[init] Wan I2V weights ready at: {repo_dir}") | |
| return repo_dir | |
| REFDECODER_CKPT_LOCAL_PATH = download_refdecoder_ckpt() | |
| download_wan_weights() | |
| OUTPUT_ROOT.mkdir(parents=True, exist_ok=True) | |
| def log_cuda_mem(tag): | |
| if not torch.cuda.is_available(): | |
| print(f"[mem] {tag}: CUDA not available") | |
| return | |
| free_bytes, total_bytes = torch.cuda.mem_get_info() | |
| allocated_bytes = torch.cuda.memory_allocated() | |
| reserved_bytes = torch.cuda.memory_reserved() | |
| print( | |
| f"[mem] {tag}: " | |
| f"free={free_bytes / 1024**3:.2f} GB, " | |
| f"total={total_bytes / 1024**3:.2f} GB, " | |
| f"allocated={allocated_bytes / 1024**3:.2f} GB, " | |
| f"reserved={reserved_bytes / 1024**3:.2f} GB" | |
| ) | |
| def get_module_dtype(module): | |
| try: | |
| return next(module.parameters()).dtype | |
| except StopIteration: | |
| return PIPE_DTYPE | |
| def load_generation_pipe(): | |
| log_cuda_mem("before load_generation_pipe") | |
| image_encoder = CLIPVisionModel.from_pretrained( | |
| MODEL_ID, | |
| subfolder="image_encoder", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| vae = DiffusersWanVAE.from_pretrained( | |
| MODEL_ID, | |
| subfolder="vae", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, | |
| vae=vae, | |
| image_encoder=image_encoder, | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| pipe = pipe.to(DEVICE) | |
| log_cuda_mem("after load_generation_pipe") | |
| return pipe | |
| def load_wan_vae(): | |
| log_cuda_mem("before load_wan_vae") | |
| vae = DiffusersWanVAE.from_pretrained( | |
| MODEL_ID, | |
| subfolder="vae", | |
| torch_dtype=PIPE_DTYPE, | |
| ) | |
| vae = vae.to(DEVICE) | |
| vae.eval() | |
| log_cuda_mem("after load_wan_vae") | |
| return vae | |
| def load_refdecoder_module(): | |
| log_cuda_mem("before load_refdecoder_module") | |
| vae = AutoencoderKLWan( | |
| dropout_p=0.0, | |
| use_reference=True, | |
| ).eval() | |
| transformer = WanDecoderTransformer( | |
| chunk=5, | |
| num_layers=10, | |
| num_heads=12, | |
| head_dim=128, | |
| reusing=True, | |
| pretrained=False, | |
| ).eval() | |
| checkpoint = torch.load(REFDECODER_CKPT_LOCAL_PATH, map_location="cpu") | |
| state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint)) | |
| vae_sd = {} | |
| transformer_sd = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("vae."): | |
| vae_sd[key[len("vae.") :]] = value | |
| elif key.startswith("transformer."): | |
| transformer_sd[key[len("transformer.") :]] = value | |
| vae.load_state_dict(vae_sd, strict=False) | |
| transformer.load_state_dict(transformer_sd, strict=False) | |
| vae = vae.to(DEVICE).eval() | |
| transformer = transformer.to(DEVICE).eval() | |
| log_cuda_mem("after load_refdecoder_module") | |
| return vae, transformer | |
| def resize_image_for_wan(image, pipe): | |
| image = image.convert("RGB") | |
| aspect_ratio = image.height / image.width | |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] | |
| height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value | |
| width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value | |
| resized = image.resize((width, height)) | |
| return resized, height, width | |
| def build_reference_frame(image, device): | |
| ref_array = np.asarray(image).astype(np.float32) | |
| ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1) | |
| ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0 | |
| return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32) | |
| def normalize_latent_shape(latents): | |
| if isinstance(latents, list): | |
| latents = latents[0] | |
| if latents.ndim == 4: | |
| latents = latents.unsqueeze(0) | |
| if latents.ndim != 5: | |
| raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}") | |
| return latents | |
| def gradio_file_url(path): | |
| return f"/gradio_api/file={quote(str(path), safe='/')}" | |
| def build_compare_html(wan_video_path, ref_video_path): | |
| compare_id = f"compare-{uuid.uuid4().hex}" | |
| wan_url = gradio_file_url(wan_video_path) if wan_video_path else "" | |
| ref_url = gradio_file_url(ref_video_path) if ref_video_path else "" | |
| base_source = ( | |
| f'<video class="compare-video compare-base" src="{wan_url}" autoplay muted loop playsinline></video>' | |
| if wan_url | |
| else '<div class="compare-video compare-base compare-placeholder"></div>' | |
| ) | |
| overlay_source = ( | |
| f'<video class="compare-video compare-overlay" src="{ref_url}" autoplay muted loop playsinline></video>' | |
| if ref_url | |
| else '<div class="compare-video compare-overlay compare-placeholder"></div>' | |
| ) | |
| inner_doc = f""" | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <style> | |
| html, body {{ | |
| margin: 0; | |
| padding: 0; | |
| background: transparent; | |
| font-family: Manrope, Inter, system-ui, sans-serif; | |
| }} | |
| .compare-shell {{ | |
| display: flex; | |
| flex-direction: column; | |
| gap: 12px; | |
| }} | |
| .compare-topbar {{ | |
| display: grid; | |
| grid-template-columns: 1fr auto 1fr; | |
| align-items: center; | |
| gap: 12px; | |
| }} | |
| .compare-chip {{ | |
| padding: 12px 22px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.14); | |
| color: #123a2d; | |
| font-size: 22px; | |
| font-weight: 800; | |
| letter-spacing: 0.03em; | |
| text-transform: uppercase; | |
| box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12); | |
| justify-self: start; | |
| }} | |
| .compare-chip-right {{ | |
| background: rgba(201, 111, 66, 0.16); | |
| color: #6e3d23; | |
| box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16); | |
| justify-self: end; | |
| }} | |
| .compare-button {{ | |
| border: 0; | |
| border-radius: 999px; | |
| padding: 10px 22px; | |
| background: #1f6a52; | |
| color: white; | |
| font-size: 16px; | |
| font-weight: 700; | |
| cursor: pointer; | |
| justify-self: center; | |
| }} | |
| .compare-stage {{ | |
| position: relative; | |
| width: 100%; | |
| aspect-ratio: 16 / 9; | |
| overflow: hidden; | |
| border-radius: 22px; | |
| background: #16120f; | |
| border: 1px solid rgba(255,255,255,0.08); | |
| }} | |
| .compare-video {{ | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| object-fit: contain; | |
| background: #16120f; | |
| }} | |
| .compare-overlay {{ | |
| clip-path: inset(0 0 0 50%); | |
| }} | |
| .compare-placeholder {{ | |
| background: | |
| linear-gradient(135deg, rgba(255,255,255,0.055), transparent 35%), | |
| #16120f; | |
| }} | |
| .compare-divider {{ | |
| position: absolute; | |
| top: 0; | |
| bottom: 0; | |
| left: 50%; | |
| width: 2px; | |
| background: rgba(255,255,255,0.96); | |
| box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15); | |
| transform: translateX(-1px); | |
| pointer-events: none; | |
| }} | |
| .compare-divider::after {{ | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| width: 18px; | |
| height: 18px; | |
| border-radius: 999px; | |
| background: #fff; | |
| border: 2px solid rgba(31, 26, 20, 0.18); | |
| transform: translate(-50%, -50%); | |
| }} | |
| .compare-range {{ | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| opacity: 0.01; | |
| cursor: ew-resize; | |
| margin: 0; | |
| -webkit-appearance: none; | |
| appearance: none; | |
| }} | |
| .compare-caption {{ | |
| color: #201a14; | |
| font-size: 14px; | |
| line-height: 1.5; | |
| text-align: center; | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="compare-shell" id="{compare_id}"> | |
| <div class="compare-topbar"> | |
| <div class="compare-chip">Wan Baseline</div> | |
| <button class="compare-button" type="button">Pause</button> | |
| <div class="compare-chip compare-chip-right">RefDecoder</div> | |
| </div> | |
| <div class="compare-stage"> | |
| {base_source} | |
| {overlay_source} | |
| <div class="compare-divider"></div> | |
| <input class="compare-range" type="range" min="0" max="100" value="50" /> | |
| </div> | |
| <div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div> | |
| </div> | |
| <script> | |
| (() => {{ | |
| const root = document.getElementById("{compare_id}"); | |
| const base = root.querySelector(".compare-base"); | |
| const overlay = root.querySelector(".compare-overlay"); | |
| const divider = root.querySelector(".compare-divider"); | |
| const slider = root.querySelector(".compare-range"); | |
| const button = root.querySelector(".compare-button"); | |
| const videos = Array.from(root.querySelectorAll("video")); | |
| const applySplit = () => {{ | |
| const value = Number(slider.value); | |
| overlay.style.clipPath = `inset(0 0 0 ${{value}}%)`; | |
| divider.style.left = `${{value}}%`; | |
| }}; | |
| const syncVideo = (source, target) => {{ | |
| if (Math.abs((target.currentTime || 0) - (source.currentTime || 0)) > 0.08) {{ | |
| try {{ target.currentTime = source.currentTime; }} catch (e) {{}} | |
| }} | |
| }}; | |
| const playBoth = () => {{ | |
| videos.forEach((video) => video.play().catch(() => {{}})); | |
| button.textContent = "Pause"; | |
| }}; | |
| const pauseBoth = () => {{ | |
| videos.forEach((video) => video.pause()); | |
| button.textContent = "Play"; | |
| }}; | |
| const bindSync = (primary, secondary) => {{ | |
| primary.addEventListener("play", () => secondary.play().catch(() => {{}})); | |
| primary.addEventListener("pause", () => secondary.pause()); | |
| primary.addEventListener("seeking", () => syncVideo(primary, secondary)); | |
| primary.addEventListener("timeupdate", () => syncVideo(primary, secondary)); | |
| primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }}); | |
| }}; | |
| if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{ | |
| bindSync(base, overlay); | |
| bindSync(overlay, base); | |
| }} else {{ | |
| button.disabled = true; | |
| button.textContent = "Play"; | |
| button.style.opacity = "0.55"; | |
| }} | |
| videos.forEach((video) => {{ | |
| video.addEventListener("loadeddata", playBoth, {{ once: true }}); | |
| }}); | |
| button.addEventListener("click", () => {{ | |
| if (!videos.length || videos[0].paused) {{ | |
| playBoth(); | |
| }} else {{ | |
| pauseBoth(); | |
| }} | |
| }}); | |
| slider.addEventListener("input", applySplit); | |
| applySplit(); | |
| }})(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return ( | |
| '<iframe class="compare-frame" ' | |
| 'sandbox="allow-scripts allow-same-origin" ' | |
| 'scrolling="no" ' | |
| 'srcdoc="' + html.escape(inner_doc, quote=True) + '"></iframe>' | |
| ) | |
| def save_video_tensor(video_tensor, output_path): | |
| video = (video_tensor / 2 + 0.5).clamp(0, 1) | |
| video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy() | |
| video = (video * 255).astype(np.uint8) | |
| imageio.mimwrite(output_path, video, fps=FPS, quality=10) | |
| return str(output_path) | |
| def decode_with_wan_vae(latents, vae): | |
| vae_dtype = get_module_dtype(vae) | |
| latents = latents.to(device=DEVICE, dtype=vae_dtype) | |
| latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) | |
| latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) | |
| latents = latents * latents_std + latents_mean | |
| with torch.no_grad(): | |
| video = vae.decode(latents, return_dict=False)[0] | |
| return video | |
| def decode_with_refdecoder(latents, reference_frame, vae, transformer): | |
| decode_dtype = get_module_dtype(vae) | |
| latents = latents.to(device=DEVICE, dtype=decode_dtype) | |
| latents_mean = torch.tensor( | |
| vae.config.latents_mean, | |
| device=DEVICE, | |
| dtype=decode_dtype, | |
| ).view(1, -1, 1, 1, 1) | |
| latents_std = torch.tensor( | |
| vae.config.latents_std, | |
| device=DEVICE, | |
| dtype=decode_dtype, | |
| ).view(1, -1, 1, 1, 1) | |
| latents = latents * latents_std + latents_mean | |
| reference_frame = reference_frame.to(device=DEVICE, dtype=decode_dtype) | |
| with torch.no_grad(): | |
| video = vae.decode( | |
| latents, | |
| transformer, | |
| return_dict=True, | |
| reference_frame=reference_frame, | |
| skip=False, | |
| window_size=-1, | |
| ).sample | |
| if hasattr(vae, "clear_cache"): | |
| vae.clear_cache() | |
| return video | |
| def button_state(label, interactive): | |
| return gr.update(value=label, interactive=interactive) | |
| def generate_latents_on_gpu(image, prompt, seed): | |
| log_cuda_mem("start generate_latents_on_gpu") | |
| pipe = load_generation_pipe() | |
| resized_image, height, width = resize_image_for_wan(image, pipe) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| with torch.no_grad(): | |
| output = pipe( | |
| image=resized_image, | |
| prompt=prompt, | |
| negative_prompt=NEGATIVE_PROMPT, | |
| height=height, | |
| width=width, | |
| num_frames=NUM_FRAMES, | |
| num_inference_steps=NUM_INFERENCE_STEPS, | |
| guidance_scale=GUIDANCE_SCALE, | |
| generator=generator, | |
| output_type="latent", | |
| ) | |
| latents = normalize_latent_shape(output.frames).detach().cpu() | |
| log_cuda_mem("after latent generation") | |
| return latents, resized_image, height, width | |
| def decode_wan_on_gpu(latents): | |
| log_cuda_mem("start decode_wan_on_gpu") | |
| wan_vae = load_wan_vae() | |
| video = decode_with_wan_vae(latents, wan_vae) | |
| log_cuda_mem("after wan decode") | |
| return video.detach().cpu() | |
| def decode_refdecoder_on_gpu(latents, reference_frame): | |
| log_cuda_mem("start decode_refdecoder_on_gpu") | |
| ref_vae, ref_transformer = load_refdecoder_module() | |
| video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer) | |
| log_cuda_mem("after refdecoder decode") | |
| return video.detach().cpu() | |
| def generate_and_decode(image, prompt, seed): | |
| if image is None: | |
| raise gr.Error("Please upload an input image.") | |
| if DEVICE != "cuda": | |
| raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.") | |
| yield gr.update(), gr.update(), gr.update(), button_state("Loading Wan I2V...", False) | |
| prompt = prompt.strip() if prompt else "" | |
| seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1) | |
| run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}" | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| yield gr.update(), gr.update(), gr.update(), button_state("Generating Latents...", False) | |
| latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed) | |
| reference_frame = build_reference_frame(resized_image, "cpu") | |
| latent_path = run_dir / "wan_latents.pt" | |
| torch.save( | |
| { | |
| "latents": latents, | |
| "height": height, | |
| "width": width, | |
| "prompt": prompt, | |
| "seed": seed, | |
| }, | |
| latent_path, | |
| ) | |
| yield gr.update(), gr.update(), gr.update(), button_state("Decoding Wan Baseline...", False) | |
| wan_video = decode_wan_on_gpu(latents) | |
| wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4") | |
| del wan_video | |
| gc.collect() | |
| yield gr.update(), wan_video_path, gr.update(), button_state("Decoding RefDecoder...", False) | |
| ref_video = decode_refdecoder_on_gpu(latents, reference_frame) | |
| ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4") | |
| del ref_video | |
| gc.collect() | |
| compare_html = build_compare_html(wan_video_path, ref_video_path) | |
| yield ( | |
| gr.update(value=compare_html, visible=True), | |
| wan_video_path, | |
| ref_video_path, | |
| button_state("Generate Comparison", True), | |
| ) | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap'); | |
| :root { | |
| --page-bg: #f4f1e8; | |
| --card-bg: rgba(255, 252, 246, 0.92); | |
| --card-border: rgba(50, 43, 32, 0.12); | |
| --accent: #1f6a52; | |
| --accent-2: #c96f42; | |
| --text-main: #201a14; | |
| --text-soft: #201a14; | |
| --ui-font: "Manrope", "Inter", "Segoe UI", sans-serif; | |
| } | |
| .gradio-container { | |
| background: | |
| radial-gradient(circle at top left, rgba(201, 111, 66, 0.18), transparent 26%), | |
| radial-gradient(circle at top right, rgba(31, 106, 82, 0.16), transparent 28%), | |
| linear-gradient(180deg, #f8f4ec 0%, var(--page-bg) 100%); | |
| font-family: var(--ui-font); | |
| } | |
| .app-shell { | |
| max-width: 1320px; | |
| margin: 0 auto; | |
| } | |
| .hero-card, | |
| .panel-card, | |
| .output-card { | |
| background: var(--card-bg); | |
| border: 1px solid var(--card-border); | |
| border-radius: 24px; | |
| box-shadow: 0 18px 50px rgba(49, 39, 26, 0.08); | |
| } | |
| .hero-card { | |
| padding: 28px 30px 20px 30px; | |
| margin-bottom: 18px; | |
| } | |
| .hero-kicker { | |
| display: inline-block; | |
| padding: 6px 12px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.10); | |
| color: var(--accent); | |
| font-size: 12px; | |
| font-weight: 700; | |
| letter-spacing: 0.08em; | |
| text-transform: uppercase; | |
| } | |
| .hero-title { | |
| margin: 14px 0 8px 0; | |
| font-size: 42px; | |
| line-height: 1.05; | |
| font-weight: 800; | |
| color: var(--text-main); | |
| } | |
| .hero-copy { | |
| margin: 0; | |
| max-width: 840px; | |
| color: var(--text-soft); | |
| font-size: 17px; | |
| line-height: 1.6; | |
| font-family: var(--ui-font); | |
| } | |
| .panel-card, | |
| .output-card { | |
| padding: 18px; | |
| } | |
| .panel-card { | |
| overflow: hidden; | |
| } | |
| .section-title { | |
| margin: 0 0 6px 0; | |
| color: var(--text-main); | |
| font-size: 22px; | |
| font-weight: 750; | |
| } | |
| .section-copy { | |
| margin: 0 0 14px 0; | |
| color: var(--text-soft); | |
| font-size: 14px; | |
| line-height: 1.55; | |
| font-family: var(--ui-font); | |
| } | |
| .compare-note { | |
| padding: 12px 14px; | |
| border-radius: 16px; | |
| background: rgba(201, 111, 66, 0.08); | |
| color: #6a4128; | |
| font-size: 14px; | |
| line-height: 1.5; | |
| margin-bottom: 14px; | |
| } | |
| #generate-btn { | |
| min-height: 108px; | |
| height: 100%; | |
| width: 100%; | |
| font-size: 16px; | |
| font-weight: 700; | |
| background: linear-gradient(135deg, var(--accent) 0%, #154f3d 100%); | |
| border: none; | |
| } | |
| #generate-btn:hover { | |
| filter: brightness(1.04); | |
| } | |
| .output-grid { | |
| gap: 14px; | |
| } | |
| .compare-shell { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 12px; | |
| } | |
| .compare-frame { | |
| width: 100%; | |
| height: 860px; | |
| border: 0; | |
| background: transparent; | |
| overflow: hidden; | |
| } | |
| @media (max-width: 900px) { | |
| .compare-frame { | |
| height: 720px; | |
| } | |
| } | |
| .compare-topbar { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| gap: 12px; | |
| } | |
| .compare-chip { | |
| padding: 8px 12px; | |
| border-radius: 999px; | |
| background: rgba(31, 106, 82, 0.08); | |
| color: var(--text-main); | |
| font-size: 12px; | |
| font-weight: 700; | |
| letter-spacing: 0.04em; | |
| text-transform: uppercase; | |
| } | |
| .compare-chip-right { | |
| background: rgba(201, 111, 66, 0.10); | |
| } | |
| .compare-stage { | |
| position: relative; | |
| width: 100%; | |
| aspect-ratio: 16 / 9; | |
| overflow: hidden; | |
| border-radius: 22px; | |
| background: #16120f; | |
| border: 1px solid rgba(255,255,255,0.08); | |
| } | |
| .compare-video { | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| object-fit: contain; | |
| background: #16120f; | |
| } | |
| .compare-overlay { | |
| clip-path: inset(0 0 0 50%); | |
| } | |
| .compare-divider { | |
| position: absolute; | |
| top: 0; | |
| bottom: 0; | |
| left: 50%; | |
| width: 2px; | |
| background: rgba(255,255,255,0.96); | |
| box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15); | |
| transform: translateX(-1px); | |
| pointer-events: none; | |
| } | |
| .compare-divider::after { | |
| content: ""; | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| width: 18px; | |
| height: 18px; | |
| border-radius: 999px; | |
| background: #fff; | |
| border: 2px solid rgba(31, 26, 20, 0.18); | |
| transform: translate(-50%, -50%); | |
| } | |
| .compare-range { | |
| position: absolute; | |
| inset: 0; | |
| width: 100%; | |
| height: 100%; | |
| opacity: 0; | |
| cursor: ew-resize; | |
| } | |
| .compare-caption { | |
| color: var(--text-soft); | |
| font-size: 14px; | |
| line-height: 1.5; | |
| font-family: var(--ui-font); | |
| } | |
| .compare-panel { | |
| padding-bottom: 34px; | |
| } | |
| .seed-action-row { | |
| align-items: stretch; | |
| } | |
| .seed-action-row > .gradio-column { | |
| min-width: 0; | |
| } | |
| """ | |
| with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| with gr.Column(elem_classes="app-shell"): | |
| gr.HTML( | |
| """ | |
| <div class="hero-card"> | |
| <div class="hero-title">RefDecoder I2V Demo</div> | |
| <p class="hero-copy"> | |
| Upload one image, optionally add a prompt, and compare two decoders on the same Wan latent video. | |
| The app generates latents once, then renders them with Wan's original VAE and with RefDecoder. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(elem_classes=["panel-card", "compare-panel"]): | |
| gr.HTML( | |
| """ | |
| <div class="section-title">Inputs</div> | |
| <div class="section-copy"> | |
| Upload a reference image, optionally add a prompt, and compare the decoders below. | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| image_input = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=180, | |
| ) | |
| with gr.Column(scale=5): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| lines=2, | |
| placeholder="A woman turns toward the camera as her hair moves in the wind...", | |
| ) | |
| with gr.Row(equal_height=True, elem_classes="seed-action-row"): | |
| with gr.Column(scale=1): | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=None, | |
| precision=0, | |
| info="Optional", | |
| ) | |
| with gr.Column(scale=1): | |
| run_button = gr.Button( | |
| "Generate Comparison", | |
| variant="primary", | |
| elem_id="generate-btn", | |
| ) | |
| with gr.Column(elem_classes="panel-card"): | |
| gr.HTML( | |
| """ | |
| <div class="section-title">Decoder Comparison</div> | |
| <div class="section-copy"> | |
| Left side shows Wan Baseline. Right side shows RefDecoder. Drag the divider across the frame to compare them. | |
| </div> | |
| """ | |
| ) | |
| compare_output = gr.HTML(value=build_compare_html(None, None)) | |
| wan_video_hidden = gr.Video(visible=False) | |
| ref_video_hidden = gr.Video(visible=False) | |
| run_button.click( | |
| fn=generate_and_decode, | |
| inputs=[image_input, prompt_input, seed_input], | |
| outputs=[compare_output, wan_video_hidden, ref_video_hidden, run_button], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=2).launch(allowed_paths=[str(OUTPUT_ROOT)]) | |