import gc import html import random import sys import uuid from pathlib import Path from urllib.parse import quote import gradio as gr import imageio import numpy as np import ftfy try: import spaces except ImportError: class _SpacesShim: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesShim() import torch from diffusers.pipelines.wan import pipeline_wan_i2v from diffusers import AutoencoderKLWan as DiffusersWanVAE from diffusers import WanImageToVideoPipeline from huggingface_hub import hf_hub_download, snapshot_download from transformers import CLIPVisionModel from src.models.Wan.autoencoder_wanT import AutoencoderKLWan from src.models.Wan.transformer_wan import WanDecoderTransformer ROOT = Path(__file__).resolve().parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder" REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt" OUTPUT_ROOT = ROOT / "gradio_outputs" NEGATIVE_PROMPT = ( "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " "in the background, walking backwards" ) TARGET_AREA = 480 * 832 FPS = 16 NUM_FRAMES = 17 NUM_INFERENCE_STEPS = 50 GUIDANCE_SCALE = 5.0 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # Some diffusers Wan builds reference a module-level `ftfy` during prompt cleaning. # Make it explicit so Spaces don't fail if that global was not initialized. pipeline_wan_i2v.ftfy = ftfy def download_refdecoder_ckpt(): print("[init] Downloading RefDecoder checkpoint metadata/file if needed") ckpt_path = hf_hub_download( repo_id=REFDECODER_REPO_ID, filename=REFDECODER_CKPT_PATH_IN_REPO, ) print(f"[init] RefDecoder checkpoint ready at: {ckpt_path}") return ckpt_path def download_wan_weights(): print(f"[init] Downloading Wan I2V weights from {MODEL_ID}") repo_dir = snapshot_download(repo_id=MODEL_ID) print(f"[init] Wan I2V weights ready at: {repo_dir}") return repo_dir REFDECODER_CKPT_LOCAL_PATH = download_refdecoder_ckpt() download_wan_weights() OUTPUT_ROOT.mkdir(parents=True, exist_ok=True) def log_cuda_mem(tag): if not torch.cuda.is_available(): print(f"[mem] {tag}: CUDA not available") return free_bytes, total_bytes = torch.cuda.mem_get_info() allocated_bytes = torch.cuda.memory_allocated() reserved_bytes = torch.cuda.memory_reserved() print( f"[mem] {tag}: " f"free={free_bytes / 1024**3:.2f} GB, " f"total={total_bytes / 1024**3:.2f} GB, " f"allocated={allocated_bytes / 1024**3:.2f} GB, " f"reserved={reserved_bytes / 1024**3:.2f} GB" ) def get_module_dtype(module): try: return next(module.parameters()).dtype except StopIteration: return PIPE_DTYPE def load_generation_pipe(): log_cuda_mem("before load_generation_pipe") image_encoder = CLIPVisionModel.from_pretrained( MODEL_ID, subfolder="image_encoder", torch_dtype=PIPE_DTYPE, ) vae = DiffusersWanVAE.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=PIPE_DTYPE, ) pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=PIPE_DTYPE, ) pipe = pipe.to(DEVICE) log_cuda_mem("after load_generation_pipe") return pipe def load_wan_vae(): log_cuda_mem("before load_wan_vae") vae = DiffusersWanVAE.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=PIPE_DTYPE, ) vae = vae.to(DEVICE) vae.eval() log_cuda_mem("after load_wan_vae") return vae def load_refdecoder_module(): log_cuda_mem("before load_refdecoder_module") vae = AutoencoderKLWan( dropout_p=0.0, use_reference=True, ).eval() transformer = WanDecoderTransformer( chunk=5, num_layers=10, num_heads=12, head_dim=128, reusing=True, pretrained=False, ).eval() checkpoint = torch.load(REFDECODER_CKPT_LOCAL_PATH, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint)) vae_sd = {} transformer_sd = {} for key, value in state_dict.items(): if key.startswith("vae."): vae_sd[key[len("vae.") :]] = value elif key.startswith("transformer."): transformer_sd[key[len("transformer.") :]] = value vae.load_state_dict(vae_sd, strict=False) transformer.load_state_dict(transformer_sd, strict=False) vae = vae.to(DEVICE).eval() transformer = transformer.to(DEVICE).eval() log_cuda_mem("after load_refdecoder_module") return vae, transformer def resize_image_for_wan(image, pipe): image = image.convert("RGB") aspect_ratio = image.height / image.width mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value resized = image.resize((width, height)) return resized, height, width def build_reference_frame(image, device): ref_array = np.asarray(image).astype(np.float32) ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1) ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0 return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32) def normalize_latent_shape(latents): if isinstance(latents, list): latents = latents[0] if latents.ndim == 4: latents = latents.unsqueeze(0) if latents.ndim != 5: raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}") return latents def gradio_file_url(path): return f"/gradio_api/file={quote(str(path), safe='/')}" def build_compare_html(wan_video_path, ref_video_path): compare_id = f"compare-{uuid.uuid4().hex}" wan_url = gradio_file_url(wan_video_path) if wan_video_path else "" ref_url = gradio_file_url(ref_video_path) if ref_video_path else "" base_source = ( f'' if wan_url else '
' ) overlay_source = ( f'' if ref_url else '
' ) inner_doc = f"""
Wan Baseline
RefDecoder
{base_source} {overlay_source}
Drag the divider to compare the two decoders on the same latent video.
""" return ( '' ) def save_video_tensor(video_tensor, output_path): video = (video_tensor / 2 + 0.5).clamp(0, 1) video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy() video = (video * 255).astype(np.uint8) imageio.mimwrite(output_path, video, fps=FPS, quality=10) return str(output_path) def decode_with_wan_vae(latents, vae): vae_dtype = get_module_dtype(vae) latents = latents.to(device=DEVICE, dtype=vae_dtype) latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1) latents = latents * latents_std + latents_mean with torch.no_grad(): video = vae.decode(latents, return_dict=False)[0] return video def decode_with_refdecoder(latents, reference_frame, vae, transformer): decode_dtype = get_module_dtype(vae) latents = latents.to(device=DEVICE, dtype=decode_dtype) latents_mean = torch.tensor( vae.config.latents_mean, device=DEVICE, dtype=decode_dtype, ).view(1, -1, 1, 1, 1) latents_std = torch.tensor( vae.config.latents_std, device=DEVICE, dtype=decode_dtype, ).view(1, -1, 1, 1, 1) latents = latents * latents_std + latents_mean reference_frame = reference_frame.to(device=DEVICE, dtype=decode_dtype) with torch.no_grad(): video = vae.decode( latents, transformer, return_dict=True, reference_frame=reference_frame, skip=False, window_size=-1, ).sample if hasattr(vae, "clear_cache"): vae.clear_cache() return video def button_state(label, interactive): return gr.update(value=label, interactive=interactive) @spaces.GPU(duration=80) def generate_latents_on_gpu(image, prompt, seed): log_cuda_mem("start generate_latents_on_gpu") pipe = load_generation_pipe() resized_image, height, width = resize_image_for_wan(image, pipe) generator = torch.Generator(device=DEVICE).manual_seed(seed) with torch.no_grad(): output = pipe( image=resized_image, prompt=prompt, negative_prompt=NEGATIVE_PROMPT, height=height, width=width, num_frames=NUM_FRAMES, num_inference_steps=NUM_INFERENCE_STEPS, guidance_scale=GUIDANCE_SCALE, generator=generator, output_type="latent", ) latents = normalize_latent_shape(output.frames).detach().cpu() log_cuda_mem("after latent generation") return latents, resized_image, height, width @spaces.GPU(duration=20) def decode_wan_on_gpu(latents): log_cuda_mem("start decode_wan_on_gpu") wan_vae = load_wan_vae() video = decode_with_wan_vae(latents, wan_vae) log_cuda_mem("after wan decode") return video.detach().cpu() @spaces.GPU(duration=25) def decode_refdecoder_on_gpu(latents, reference_frame): log_cuda_mem("start decode_refdecoder_on_gpu") ref_vae, ref_transformer = load_refdecoder_module() video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer) log_cuda_mem("after refdecoder decode") return video.detach().cpu() def generate_and_decode(image, prompt, seed): if image is None: raise gr.Error("Please upload an input image.") if DEVICE != "cuda": raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.") yield gr.update(), gr.update(), gr.update(), button_state("Loading Wan I2V...", False) prompt = prompt.strip() if prompt else "" seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1) run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}" run_dir.mkdir(parents=True, exist_ok=True) yield gr.update(), gr.update(), gr.update(), button_state("Generating Latents...", False) latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed) reference_frame = build_reference_frame(resized_image, "cpu") latent_path = run_dir / "wan_latents.pt" torch.save( { "latents": latents, "height": height, "width": width, "prompt": prompt, "seed": seed, }, latent_path, ) yield gr.update(), gr.update(), gr.update(), button_state("Decoding Wan Baseline...", False) wan_video = decode_wan_on_gpu(latents) wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4") del wan_video gc.collect() yield gr.update(), wan_video_path, gr.update(), button_state("Decoding RefDecoder...", False) ref_video = decode_refdecoder_on_gpu(latents, reference_frame) ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4") del ref_video gc.collect() compare_html = build_compare_html(wan_video_path, ref_video_path) yield ( gr.update(value=compare_html, visible=True), wan_video_path, ref_video_path, button_state("Generate Comparison", True), ) CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap'); :root { --page-bg: #f4f1e8; --card-bg: rgba(255, 252, 246, 0.92); --card-border: rgba(50, 43, 32, 0.12); --accent: #1f6a52; --accent-2: #c96f42; --text-main: #201a14; --text-soft: #201a14; --ui-font: "Manrope", "Inter", "Segoe UI", sans-serif; } .gradio-container { background: radial-gradient(circle at top left, rgba(201, 111, 66, 0.18), transparent 26%), radial-gradient(circle at top right, rgba(31, 106, 82, 0.16), transparent 28%), linear-gradient(180deg, #f8f4ec 0%, var(--page-bg) 100%); font-family: var(--ui-font); } .app-shell { max-width: 1320px; margin: 0 auto; } .hero-card, .panel-card, .output-card { background: var(--card-bg); border: 1px solid var(--card-border); border-radius: 24px; box-shadow: 0 18px 50px rgba(49, 39, 26, 0.08); } .hero-card { padding: 28px 30px 20px 30px; margin-bottom: 18px; } .hero-kicker { display: inline-block; padding: 6px 12px; border-radius: 999px; background: rgba(31, 106, 82, 0.10); color: var(--accent); font-size: 12px; font-weight: 700; letter-spacing: 0.08em; text-transform: uppercase; } .hero-title { margin: 14px 0 8px 0; font-size: 42px; line-height: 1.05; font-weight: 800; color: var(--text-main); } .hero-copy { margin: 0; max-width: 840px; color: var(--text-soft); font-size: 17px; line-height: 1.6; font-family: var(--ui-font); } .panel-card, .output-card { padding: 18px; } .panel-card { overflow: hidden; } .section-title { margin: 0 0 6px 0; color: var(--text-main); font-size: 22px; font-weight: 750; } .section-copy { margin: 0 0 14px 0; color: var(--text-soft); font-size: 14px; line-height: 1.55; font-family: var(--ui-font); } .compare-note { padding: 12px 14px; border-radius: 16px; background: rgba(201, 111, 66, 0.08); color: #6a4128; font-size: 14px; line-height: 1.5; margin-bottom: 14px; } #generate-btn { min-height: 108px; height: 100%; width: 100%; font-size: 16px; font-weight: 700; background: linear-gradient(135deg, var(--accent) 0%, #154f3d 100%); border: none; } #generate-btn:hover { filter: brightness(1.04); } .output-grid { gap: 14px; } .compare-shell { display: flex; flex-direction: column; gap: 12px; } .compare-frame { width: 100%; height: 860px; border: 0; background: transparent; overflow: hidden; } @media (max-width: 900px) { .compare-frame { height: 720px; } } .compare-topbar { display: flex; justify-content: space-between; align-items: center; gap: 12px; } .compare-chip { padding: 8px 12px; border-radius: 999px; background: rgba(31, 106, 82, 0.08); color: var(--text-main); font-size: 12px; font-weight: 700; letter-spacing: 0.04em; text-transform: uppercase; } .compare-chip-right { background: rgba(201, 111, 66, 0.10); } .compare-stage { position: relative; width: 100%; aspect-ratio: 16 / 9; overflow: hidden; border-radius: 22px; background: #16120f; border: 1px solid rgba(255,255,255,0.08); } .compare-video { position: absolute; inset: 0; width: 100%; height: 100%; object-fit: contain; background: #16120f; } .compare-overlay { clip-path: inset(0 0 0 50%); } .compare-divider { position: absolute; top: 0; bottom: 0; left: 50%; width: 2px; background: rgba(255,255,255,0.96); box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15); transform: translateX(-1px); pointer-events: none; } .compare-divider::after { content: ""; position: absolute; top: 50%; left: 50%; width: 18px; height: 18px; border-radius: 999px; background: #fff; border: 2px solid rgba(31, 26, 20, 0.18); transform: translate(-50%, -50%); } .compare-range { position: absolute; inset: 0; width: 100%; height: 100%; opacity: 0; cursor: ew-resize; } .compare-caption { color: var(--text-soft); font-size: 14px; line-height: 1.5; font-family: var(--ui-font); } .compare-panel { padding-bottom: 34px; } .seed-action-row { align-items: stretch; } .seed-action-row > .gradio-column { min-width: 0; } """ with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: with gr.Column(elem_classes="app-shell"): gr.HTML( """
RefDecoder I2V Demo

Upload one image, optionally add a prompt, and compare two decoders on the same Wan latent video. The app generates latents once, then renders them with Wan's original VAE and with RefDecoder.

""" ) with gr.Column(elem_classes=["panel-card", "compare-panel"]): gr.HTML( """
Inputs
Upload a reference image, optionally add a prompt, and compare the decoders below.
""" ) with gr.Row(equal_height=True): with gr.Column(scale=3): image_input = gr.Image( label="Input Image", type="pil", height=180, ) with gr.Column(scale=5): prompt_input = gr.Textbox( label="Prompt", lines=2, placeholder="A woman turns toward the camera as her hair moves in the wind...", ) with gr.Row(equal_height=True, elem_classes="seed-action-row"): with gr.Column(scale=1): seed_input = gr.Number( label="Seed", value=None, precision=0, info="Optional", ) with gr.Column(scale=1): run_button = gr.Button( "Generate Comparison", variant="primary", elem_id="generate-btn", ) with gr.Column(elem_classes="panel-card"): gr.HTML( """
Decoder Comparison
Left side shows Wan Baseline. Right side shows RefDecoder. Drag the divider across the frame to compare them.
""" ) compare_output = gr.HTML(value=build_compare_html(None, None)) wan_video_hidden = gr.Video(visible=False) ref_video_hidden = gr.Video(visible=False) run_button.click( fn=generate_and_decode, inputs=[image_input, prompt_input, seed_input], outputs=[compare_output, wan_video_hidden, ref_video_hidden, run_button], ) if __name__ == "__main__": demo.queue(max_size=2).launch(allowed_paths=[str(OUTPUT_ROOT)])