RefDecoder / app.py
Arrokothwhi's picture
update gpu duration
eb10dec
import gc
import html
import random
import sys
import uuid
from pathlib import Path
from urllib.parse import quote
import gradio as gr
import imageio
import numpy as np
import ftfy
try:
import spaces
except ImportError:
class _SpacesShim:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim()
import torch
from diffusers.pipelines.wan import pipeline_wan_i2v
from diffusers import AutoencoderKLWan as DiffusersWanVAE
from diffusers import WanImageToVideoPipeline
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import CLIPVisionModel
from src.models.Wan.autoencoder_wanT import AutoencoderKLWan
from src.models.Wan.transformer_wan import WanDecoderTransformer
ROOT = Path(__file__).resolve().parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder"
REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt"
OUTPUT_ROOT = ROOT / "gradio_outputs"
NEGATIVE_PROMPT = (
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
"images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, "
"incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, "
"misshapen limbs, fused fingers, still picture, messy background, three legs, many people "
"in the background, walking backwards"
)
TARGET_AREA = 480 * 832
FPS = 16
NUM_FRAMES = 17
NUM_INFERENCE_STEPS = 50
GUIDANCE_SCALE = 5.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
# Some diffusers Wan builds reference a module-level `ftfy` during prompt cleaning.
# Make it explicit so Spaces don't fail if that global was not initialized.
pipeline_wan_i2v.ftfy = ftfy
def download_refdecoder_ckpt():
print("[init] Downloading RefDecoder checkpoint metadata/file if needed")
ckpt_path = hf_hub_download(
repo_id=REFDECODER_REPO_ID,
filename=REFDECODER_CKPT_PATH_IN_REPO,
)
print(f"[init] RefDecoder checkpoint ready at: {ckpt_path}")
return ckpt_path
def download_wan_weights():
print(f"[init] Downloading Wan I2V weights from {MODEL_ID}")
repo_dir = snapshot_download(repo_id=MODEL_ID)
print(f"[init] Wan I2V weights ready at: {repo_dir}")
return repo_dir
REFDECODER_CKPT_LOCAL_PATH = download_refdecoder_ckpt()
download_wan_weights()
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
def log_cuda_mem(tag):
if not torch.cuda.is_available():
print(f"[mem] {tag}: CUDA not available")
return
free_bytes, total_bytes = torch.cuda.mem_get_info()
allocated_bytes = torch.cuda.memory_allocated()
reserved_bytes = torch.cuda.memory_reserved()
print(
f"[mem] {tag}: "
f"free={free_bytes / 1024**3:.2f} GB, "
f"total={total_bytes / 1024**3:.2f} GB, "
f"allocated={allocated_bytes / 1024**3:.2f} GB, "
f"reserved={reserved_bytes / 1024**3:.2f} GB"
)
def get_module_dtype(module):
try:
return next(module.parameters()).dtype
except StopIteration:
return PIPE_DTYPE
def load_generation_pipe():
log_cuda_mem("before load_generation_pipe")
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID,
subfolder="image_encoder",
torch_dtype=PIPE_DTYPE,
)
vae = DiffusersWanVAE.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=PIPE_DTYPE,
)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=PIPE_DTYPE,
)
pipe = pipe.to(DEVICE)
log_cuda_mem("after load_generation_pipe")
return pipe
def load_wan_vae():
log_cuda_mem("before load_wan_vae")
vae = DiffusersWanVAE.from_pretrained(
MODEL_ID,
subfolder="vae",
torch_dtype=PIPE_DTYPE,
)
vae = vae.to(DEVICE)
vae.eval()
log_cuda_mem("after load_wan_vae")
return vae
def load_refdecoder_module():
log_cuda_mem("before load_refdecoder_module")
vae = AutoencoderKLWan(
dropout_p=0.0,
use_reference=True,
).eval()
transformer = WanDecoderTransformer(
chunk=5,
num_layers=10,
num_heads=12,
head_dim=128,
reusing=True,
pretrained=False,
).eval()
checkpoint = torch.load(REFDECODER_CKPT_LOCAL_PATH, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint))
vae_sd = {}
transformer_sd = {}
for key, value in state_dict.items():
if key.startswith("vae."):
vae_sd[key[len("vae.") :]] = value
elif key.startswith("transformer."):
transformer_sd[key[len("transformer.") :]] = value
vae.load_state_dict(vae_sd, strict=False)
transformer.load_state_dict(transformer_sd, strict=False)
vae = vae.to(DEVICE).eval()
transformer = transformer.to(DEVICE).eval()
log_cuda_mem("after load_refdecoder_module")
return vae, transformer
def resize_image_for_wan(image, pipe):
image = image.convert("RGB")
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value
resized = image.resize((width, height))
return resized, height, width
def build_reference_frame(image, device):
ref_array = np.asarray(image).astype(np.float32)
ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1)
ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0
return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32)
def normalize_latent_shape(latents):
if isinstance(latents, list):
latents = latents[0]
if latents.ndim == 4:
latents = latents.unsqueeze(0)
if latents.ndim != 5:
raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}")
return latents
def gradio_file_url(path):
return f"/gradio_api/file={quote(str(path), safe='/')}"
def build_compare_html(wan_video_path, ref_video_path):
compare_id = f"compare-{uuid.uuid4().hex}"
wan_url = gradio_file_url(wan_video_path) if wan_video_path else ""
ref_url = gradio_file_url(ref_video_path) if ref_video_path else ""
base_source = (
f'<video class="compare-video compare-base" src="{wan_url}" autoplay muted loop playsinline></video>'
if wan_url
else '<div class="compare-video compare-base compare-placeholder"></div>'
)
overlay_source = (
f'<video class="compare-video compare-overlay" src="{ref_url}" autoplay muted loop playsinline></video>'
if ref_url
else '<div class="compare-video compare-overlay compare-placeholder"></div>'
)
inner_doc = f"""
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style>
html, body {{
margin: 0;
padding: 0;
background: transparent;
font-family: Manrope, Inter, system-ui, sans-serif;
}}
.compare-shell {{
display: flex;
flex-direction: column;
gap: 12px;
}}
.compare-topbar {{
display: grid;
grid-template-columns: 1fr auto 1fr;
align-items: center;
gap: 12px;
}}
.compare-chip {{
padding: 12px 22px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.14);
color: #123a2d;
font-size: 22px;
font-weight: 800;
letter-spacing: 0.03em;
text-transform: uppercase;
box-shadow: inset 0 0 0 1px rgba(31, 106, 82, 0.12);
justify-self: start;
}}
.compare-chip-right {{
background: rgba(201, 111, 66, 0.16);
color: #6e3d23;
box-shadow: inset 0 0 0 1px rgba(201, 111, 66, 0.16);
justify-self: end;
}}
.compare-button {{
border: 0;
border-radius: 999px;
padding: 10px 22px;
background: #1f6a52;
color: white;
font-size: 16px;
font-weight: 700;
cursor: pointer;
justify-self: center;
}}
.compare-stage {{
position: relative;
width: 100%;
aspect-ratio: 16 / 9;
overflow: hidden;
border-radius: 22px;
background: #16120f;
border: 1px solid rgba(255,255,255,0.08);
}}
.compare-video {{
position: absolute;
inset: 0;
width: 100%;
height: 100%;
object-fit: contain;
background: #16120f;
}}
.compare-overlay {{
clip-path: inset(0 0 0 50%);
}}
.compare-placeholder {{
background:
linear-gradient(135deg, rgba(255,255,255,0.055), transparent 35%),
#16120f;
}}
.compare-divider {{
position: absolute;
top: 0;
bottom: 0;
left: 50%;
width: 2px;
background: rgba(255,255,255,0.96);
box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15);
transform: translateX(-1px);
pointer-events: none;
}}
.compare-divider::after {{
content: "";
position: absolute;
top: 50%;
left: 50%;
width: 18px;
height: 18px;
border-radius: 999px;
background: #fff;
border: 2px solid rgba(31, 26, 20, 0.18);
transform: translate(-50%, -50%);
}}
.compare-range {{
position: absolute;
inset: 0;
width: 100%;
height: 100%;
opacity: 0.01;
cursor: ew-resize;
margin: 0;
-webkit-appearance: none;
appearance: none;
}}
.compare-caption {{
color: #201a14;
font-size: 14px;
line-height: 1.5;
text-align: center;
}}
</style>
</head>
<body>
<div class="compare-shell" id="{compare_id}">
<div class="compare-topbar">
<div class="compare-chip">Wan Baseline</div>
<button class="compare-button" type="button">Pause</button>
<div class="compare-chip compare-chip-right">RefDecoder</div>
</div>
<div class="compare-stage">
{base_source}
{overlay_source}
<div class="compare-divider"></div>
<input class="compare-range" type="range" min="0" max="100" value="50" />
</div>
<div class="compare-caption">Drag the divider to compare the two decoders on the same latent video.</div>
</div>
<script>
(() => {{
const root = document.getElementById("{compare_id}");
const base = root.querySelector(".compare-base");
const overlay = root.querySelector(".compare-overlay");
const divider = root.querySelector(".compare-divider");
const slider = root.querySelector(".compare-range");
const button = root.querySelector(".compare-button");
const videos = Array.from(root.querySelectorAll("video"));
const applySplit = () => {{
const value = Number(slider.value);
overlay.style.clipPath = `inset(0 0 0 ${{value}}%)`;
divider.style.left = `${{value}}%`;
}};
const syncVideo = (source, target) => {{
if (Math.abs((target.currentTime || 0) - (source.currentTime || 0)) > 0.08) {{
try {{ target.currentTime = source.currentTime; }} catch (e) {{}}
}}
}};
const playBoth = () => {{
videos.forEach((video) => video.play().catch(() => {{}}));
button.textContent = "Pause";
}};
const pauseBoth = () => {{
videos.forEach((video) => video.pause());
button.textContent = "Play";
}};
const bindSync = (primary, secondary) => {{
primary.addEventListener("play", () => secondary.play().catch(() => {{}}));
primary.addEventListener("pause", () => secondary.pause());
primary.addEventListener("seeking", () => syncVideo(primary, secondary));
primary.addEventListener("timeupdate", () => syncVideo(primary, secondary));
primary.addEventListener("ratechange", () => {{ secondary.playbackRate = primary.playbackRate; }});
}};
if (base.tagName === "VIDEO" && overlay.tagName === "VIDEO") {{
bindSync(base, overlay);
bindSync(overlay, base);
}} else {{
button.disabled = true;
button.textContent = "Play";
button.style.opacity = "0.55";
}}
videos.forEach((video) => {{
video.addEventListener("loadeddata", playBoth, {{ once: true }});
}});
button.addEventListener("click", () => {{
if (!videos.length || videos[0].paused) {{
playBoth();
}} else {{
pauseBoth();
}}
}});
slider.addEventListener("input", applySplit);
applySplit();
}})();
</script>
</body>
</html>
"""
return (
'<iframe class="compare-frame" '
'sandbox="allow-scripts allow-same-origin" '
'scrolling="no" '
'srcdoc="' + html.escape(inner_doc, quote=True) + '"></iframe>'
)
def save_video_tensor(video_tensor, output_path):
video = (video_tensor / 2 + 0.5).clamp(0, 1)
video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy()
video = (video * 255).astype(np.uint8)
imageio.mimwrite(output_path, video, fps=FPS, quality=10)
return str(output_path)
def decode_with_wan_vae(latents, vae):
vae_dtype = get_module_dtype(vae)
latents = latents.to(device=DEVICE, dtype=vae_dtype)
latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1)
latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=vae_dtype).view(1, -1, 1, 1, 1)
latents = latents * latents_std + latents_mean
with torch.no_grad():
video = vae.decode(latents, return_dict=False)[0]
return video
def decode_with_refdecoder(latents, reference_frame, vae, transformer):
decode_dtype = get_module_dtype(vae)
latents = latents.to(device=DEVICE, dtype=decode_dtype)
latents_mean = torch.tensor(
vae.config.latents_mean,
device=DEVICE,
dtype=decode_dtype,
).view(1, -1, 1, 1, 1)
latents_std = torch.tensor(
vae.config.latents_std,
device=DEVICE,
dtype=decode_dtype,
).view(1, -1, 1, 1, 1)
latents = latents * latents_std + latents_mean
reference_frame = reference_frame.to(device=DEVICE, dtype=decode_dtype)
with torch.no_grad():
video = vae.decode(
latents,
transformer,
return_dict=True,
reference_frame=reference_frame,
skip=False,
window_size=-1,
).sample
if hasattr(vae, "clear_cache"):
vae.clear_cache()
return video
def button_state(label, interactive):
return gr.update(value=label, interactive=interactive)
@spaces.GPU(duration=80)
def generate_latents_on_gpu(image, prompt, seed):
log_cuda_mem("start generate_latents_on_gpu")
pipe = load_generation_pipe()
resized_image, height, width = resize_image_for_wan(image, pipe)
generator = torch.Generator(device=DEVICE).manual_seed(seed)
with torch.no_grad():
output = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
height=height,
width=width,
num_frames=NUM_FRAMES,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
generator=generator,
output_type="latent",
)
latents = normalize_latent_shape(output.frames).detach().cpu()
log_cuda_mem("after latent generation")
return latents, resized_image, height, width
@spaces.GPU(duration=20)
def decode_wan_on_gpu(latents):
log_cuda_mem("start decode_wan_on_gpu")
wan_vae = load_wan_vae()
video = decode_with_wan_vae(latents, wan_vae)
log_cuda_mem("after wan decode")
return video.detach().cpu()
@spaces.GPU(duration=25)
def decode_refdecoder_on_gpu(latents, reference_frame):
log_cuda_mem("start decode_refdecoder_on_gpu")
ref_vae, ref_transformer = load_refdecoder_module()
video = decode_with_refdecoder(latents, reference_frame, ref_vae, ref_transformer)
log_cuda_mem("after refdecoder decode")
return video.detach().cpu()
def generate_and_decode(image, prompt, seed):
if image is None:
raise gr.Error("Please upload an input image.")
if DEVICE != "cuda":
raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.")
yield gr.update(), gr.update(), gr.update(), button_state("Loading Wan I2V...", False)
prompt = prompt.strip() if prompt else ""
seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
run_dir = OUTPUT_ROOT / f"refdecoder_demo_{uuid.uuid4().hex}"
run_dir.mkdir(parents=True, exist_ok=True)
yield gr.update(), gr.update(), gr.update(), button_state("Generating Latents...", False)
latents, resized_image, height, width = generate_latents_on_gpu(image, prompt, seed)
reference_frame = build_reference_frame(resized_image, "cpu")
latent_path = run_dir / "wan_latents.pt"
torch.save(
{
"latents": latents,
"height": height,
"width": width,
"prompt": prompt,
"seed": seed,
},
latent_path,
)
yield gr.update(), gr.update(), gr.update(), button_state("Decoding Wan Baseline...", False)
wan_video = decode_wan_on_gpu(latents)
wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
del wan_video
gc.collect()
yield gr.update(), wan_video_path, gr.update(), button_state("Decoding RefDecoder...", False)
ref_video = decode_refdecoder_on_gpu(latents, reference_frame)
ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
del ref_video
gc.collect()
compare_html = build_compare_html(wan_video_path, ref_video_path)
yield (
gr.update(value=compare_html, visible=True),
wan_video_path,
ref_video_path,
button_state("Generate Comparison", True),
)
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap');
:root {
--page-bg: #f4f1e8;
--card-bg: rgba(255, 252, 246, 0.92);
--card-border: rgba(50, 43, 32, 0.12);
--accent: #1f6a52;
--accent-2: #c96f42;
--text-main: #201a14;
--text-soft: #201a14;
--ui-font: "Manrope", "Inter", "Segoe UI", sans-serif;
}
.gradio-container {
background:
radial-gradient(circle at top left, rgba(201, 111, 66, 0.18), transparent 26%),
radial-gradient(circle at top right, rgba(31, 106, 82, 0.16), transparent 28%),
linear-gradient(180deg, #f8f4ec 0%, var(--page-bg) 100%);
font-family: var(--ui-font);
}
.app-shell {
max-width: 1320px;
margin: 0 auto;
}
.hero-card,
.panel-card,
.output-card {
background: var(--card-bg);
border: 1px solid var(--card-border);
border-radius: 24px;
box-shadow: 0 18px 50px rgba(49, 39, 26, 0.08);
}
.hero-card {
padding: 28px 30px 20px 30px;
margin-bottom: 18px;
}
.hero-kicker {
display: inline-block;
padding: 6px 12px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.10);
color: var(--accent);
font-size: 12px;
font-weight: 700;
letter-spacing: 0.08em;
text-transform: uppercase;
}
.hero-title {
margin: 14px 0 8px 0;
font-size: 42px;
line-height: 1.05;
font-weight: 800;
color: var(--text-main);
}
.hero-copy {
margin: 0;
max-width: 840px;
color: var(--text-soft);
font-size: 17px;
line-height: 1.6;
font-family: var(--ui-font);
}
.panel-card,
.output-card {
padding: 18px;
}
.panel-card {
overflow: hidden;
}
.section-title {
margin: 0 0 6px 0;
color: var(--text-main);
font-size: 22px;
font-weight: 750;
}
.section-copy {
margin: 0 0 14px 0;
color: var(--text-soft);
font-size: 14px;
line-height: 1.55;
font-family: var(--ui-font);
}
.compare-note {
padding: 12px 14px;
border-radius: 16px;
background: rgba(201, 111, 66, 0.08);
color: #6a4128;
font-size: 14px;
line-height: 1.5;
margin-bottom: 14px;
}
#generate-btn {
min-height: 108px;
height: 100%;
width: 100%;
font-size: 16px;
font-weight: 700;
background: linear-gradient(135deg, var(--accent) 0%, #154f3d 100%);
border: none;
}
#generate-btn:hover {
filter: brightness(1.04);
}
.output-grid {
gap: 14px;
}
.compare-shell {
display: flex;
flex-direction: column;
gap: 12px;
}
.compare-frame {
width: 100%;
height: 860px;
border: 0;
background: transparent;
overflow: hidden;
}
@media (max-width: 900px) {
.compare-frame {
height: 720px;
}
}
.compare-topbar {
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
}
.compare-chip {
padding: 8px 12px;
border-radius: 999px;
background: rgba(31, 106, 82, 0.08);
color: var(--text-main);
font-size: 12px;
font-weight: 700;
letter-spacing: 0.04em;
text-transform: uppercase;
}
.compare-chip-right {
background: rgba(201, 111, 66, 0.10);
}
.compare-stage {
position: relative;
width: 100%;
aspect-ratio: 16 / 9;
overflow: hidden;
border-radius: 22px;
background: #16120f;
border: 1px solid rgba(255,255,255,0.08);
}
.compare-video {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
object-fit: contain;
background: #16120f;
}
.compare-overlay {
clip-path: inset(0 0 0 50%);
}
.compare-divider {
position: absolute;
top: 0;
bottom: 0;
left: 50%;
width: 2px;
background: rgba(255,255,255,0.96);
box-shadow: 0 0 0 1px rgba(31, 26, 20, 0.15);
transform: translateX(-1px);
pointer-events: none;
}
.compare-divider::after {
content: "";
position: absolute;
top: 50%;
left: 50%;
width: 18px;
height: 18px;
border-radius: 999px;
background: #fff;
border: 2px solid rgba(31, 26, 20, 0.18);
transform: translate(-50%, -50%);
}
.compare-range {
position: absolute;
inset: 0;
width: 100%;
height: 100%;
opacity: 0;
cursor: ew-resize;
}
.compare-caption {
color: var(--text-soft);
font-size: 14px;
line-height: 1.5;
font-family: var(--ui-font);
}
.compare-panel {
padding-bottom: 34px;
}
.seed-action-row {
align-items: stretch;
}
.seed-action-row > .gradio-column {
min-width: 0;
}
"""
with gr.Blocks(title="RefDecoder I2V Demo", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
with gr.Column(elem_classes="app-shell"):
gr.HTML(
"""
<div class="hero-card">
<div class="hero-title">RefDecoder I2V Demo</div>
<p class="hero-copy">
Upload one image, optionally add a prompt, and compare two decoders on the same Wan latent video.
The app generates latents once, then renders them with Wan's original VAE and with RefDecoder.
</p>
</div>
"""
)
with gr.Column(elem_classes=["panel-card", "compare-panel"]):
gr.HTML(
"""
<div class="section-title">Inputs</div>
<div class="section-copy">
Upload a reference image, optionally add a prompt, and compare the decoders below.
</div>
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=3):
image_input = gr.Image(
label="Input Image",
type="pil",
height=180,
)
with gr.Column(scale=5):
prompt_input = gr.Textbox(
label="Prompt",
lines=2,
placeholder="A woman turns toward the camera as her hair moves in the wind...",
)
with gr.Row(equal_height=True, elem_classes="seed-action-row"):
with gr.Column(scale=1):
seed_input = gr.Number(
label="Seed",
value=None,
precision=0,
info="Optional",
)
with gr.Column(scale=1):
run_button = gr.Button(
"Generate Comparison",
variant="primary",
elem_id="generate-btn",
)
with gr.Column(elem_classes="panel-card"):
gr.HTML(
"""
<div class="section-title">Decoder Comparison</div>
<div class="section-copy">
Left side shows Wan Baseline. Right side shows RefDecoder. Drag the divider across the frame to compare them.
</div>
"""
)
compare_output = gr.HTML(value=build_compare_html(None, None))
wan_video_hidden = gr.Video(visible=False)
ref_video_hidden = gr.Video(visible=False)
run_button.click(
fn=generate_and_decode,
inputs=[image_input, prompt_input, seed_input],
outputs=[compare_output, wan_video_hidden, ref_video_hidden, run_button],
)
if __name__ == "__main__":
demo.queue(max_size=2).launch(allowed_paths=[str(OUTPUT_ROOT)])