wan / app.py
kingkay000's picture
Update app.py
2d49997 verified
import os
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from diffusers.pipelines.wan import WanVACEPipeline
from diffusers.utils import load_image, export_to_video
from PIL import Image
REPO_ID = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
LOCAL_DIR = "/root/.cache/wan21"
OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)
pipe = None
def download_model():
os.makedirs(LOCAL_DIR, exist_ok=True)
snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR,
local_dir_use_symlinks=False
)
ok = os.path.exists(os.path.join(LOCAL_DIR, "model_index.json"))
return {
"downloaded": ok,
"local_dir": LOCAL_DIR,
"contains": sorted(os.listdir(LOCAL_DIR))[:30],
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
}
def init_pipe():
global pipe
if pipe is not None:
return "Pipeline already initialized."
if not os.path.exists(os.path.join(LOCAL_DIR, "model_index.json")):
raise RuntimeError("Model not downloaded yet. Click '1) Download Model' first.")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = WanVACEPipeline.from_pretrained(LOCAL_DIR, torch_dtype=dtype).to(device)
try:
pipe.enable_attention_slicing()
except Exception:
pass
try:
pipe.enable_vae_slicing()
except Exception:
pass
try:
pipe.enable_vae_tiling()
except Exception:
pass
return f"Initialized WanVACEPipeline on {device} ({dtype})."
def generate_demo():
msg = init_pipe()
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
init_image = load_image(image_url)
prompt = "A realistic video. Subtle natural motion, gentle camera movement, stable subject, cinematic lighting."
out_path = os.path.join(OUT_DIR, "test.mp4")
# L4-safe settings (divisible by 16, and (num_frames-1) divisible by 4)
height, width = 320, 576
num_frames = 13
# Resize the conditioning image exactly to the generation size
init_image = init_image.resize((width, height))
# Build the conditioning video:
# - first frame is the image
# - remaining frames are blank (white) placeholders
blank = Image.new("RGB", (width, height), (255, 255, 255))
video_in = [init_image] + [blank] * (num_frames - 1)
# Build the masks:
# - black on frame 0 => keep/condition
# - white on other frames => generate
black_mask = Image.new("RGB", (width, height), (0, 0, 0))
white_mask = Image.new("RGB", (width, height), (255, 255, 255))
mask_in = [black_mask] + [white_mask] * (num_frames - 1)
result = pipe(
prompt=prompt,
video=video_in,
mask=mask_in,
reference_images=[init_image],
conditioning_scale=2.0,
num_frames=num_frames,
height=height,
width=width,
guidance_scale=5.0,
num_inference_steps=20,
output_type="pil",
)
frames = result.frames[0] if hasattr(result, "frames") else result["frames"][0]
export_to_video(frames, out_path, fps=8)
return msg, out_path
with gr.Blocks(title="Wan2.1 VACE 1.3B — Stateless Server") as demo:
gr.Markdown("## Wan2.1 VACE 1.3B (Stateless)\nNo persistent storage: download the model after each restart.")
btn_dl = gr.Button("1) Download Model (one-time per restart)")
dl_out = gr.JSON(label="Download status")
btn_gen = gr.Button("2) Generate Test Video")
gen_status = gr.Textbox(label="Init status")
gen_vid = gr.Video(label="Generated video")
btn_dl.click(download_model, inputs=[], outputs=[dl_out])
btn_gen.click(generate_demo, inputs=[], outputs=[gen_status, gen_vid])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)