Spaces:
Paused
Paused
| #!/usr/bin/env python | |
| """ | |
| Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video | |
| • Single global load (no repeated downloads) | |
| • Balanced device_map to avoid OOM on 24 GB A10 | |
| • Fast CLIP processor via use_fast=True | |
| • High-level streaming progress | |
| • Auto-download via gr.File | |
| """ | |
| import os | |
| # persist Hugging Face cache so safetensors only download once | |
| os.environ["HF_HOME"] = "/mnt/data/huggingface" | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from diffusers import WanImageToVideoPipeline, AutoencoderKLWan | |
| from diffusers.utils import export_to_video | |
| from transformers import CLIPVisionModel | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| # ----------------------------------------------------------------------------- | |
| # CONFIG | |
| # ----------------------------------------------------------------------------- | |
| MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" | |
| DTYPE = torch.float16 | |
| MAX_AREA = 1280 * 720 | |
| DEFAULT_FRAMES = 81 | |
| # ----------------------------------------------------------------------------- | |
| # LOAD PIPELINE ONCE | |
| # ----------------------------------------------------------------------------- | |
| def load_pipeline(): | |
| # 1) CLIP image encoder (fp32) | |
| image_encoder = CLIPVisionModel.from_pretrained( | |
| MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32 | |
| ) | |
| # 2) VAE (fp16) | |
| vae = AutoencoderKLWan.from_pretrained( | |
| MODEL_ID, subfolder="vae", torch_dtype=DTYPE | |
| ) | |
| # 3) Balanced device placement + fast processor | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, | |
| image_encoder=image_encoder, | |
| vae=vae, | |
| torch_dtype=DTYPE, | |
| device_map="balanced", # spread weights CPU↔GPU | |
| use_fast=True, # internal fast CLIPImageProcessor | |
| ) | |
| return pipe | |
| PIPE = load_pipeline() | |
| # ----------------------------------------------------------------------------- | |
| # HELPERS | |
| # ----------------------------------------------------------------------------- | |
| def aspect_resize(img: Image.Image, max_area=MAX_AREA): | |
| ar = img.height / img.width | |
| mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1] | |
| h = int(np.sqrt(max_area * ar)) // mod * mod | |
| w = int(np.sqrt(max_area / ar)) // mod * mod | |
| return img.resize((w, h), Image.LANCZOS), h, w | |
| def center_crop_resize(img: Image.Image, h, w): | |
| ratio = max(w / img.width, h / img.height) | |
| img2 = img.resize( | |
| (round(img.width * ratio), round(img.height * ratio)), | |
| Image.LANCZOS | |
| ) | |
| return TF.center_crop(img2, [h, w]) | |
| # ----------------------------------------------------------------------------- | |
| # GENERATION + STREAMING | |
| # ----------------------------------------------------------------------------- | |
| def generate( | |
| first_frame: Image.Image, | |
| last_frame: Image.Image, | |
| prompt: str, | |
| negative: str, | |
| steps: int, | |
| guidance: float, | |
| num_frames: int, | |
| seed: int, | |
| fps: int, | |
| progress= gr.Progress(), | |
| ): | |
| # choose seed | |
| if seed == -1: | |
| seed = torch.seed() | |
| gen = torch.Generator(device=PIPE.device).manual_seed(seed) | |
| # 0–15%: resize | |
| progress(0.0, desc="Resizing first frame…") | |
| f_resized, h, w = aspect_resize(first_frame) | |
| if last_frame.size != f_resized.size: | |
| progress(0.15, desc="Resizing last frame…") | |
| l_resized = center_crop_resize(last_frame, h, w) | |
| else: | |
| l_resized = f_resized | |
| # 15–25%: spin up pipeline | |
| progress(0.25, desc="Launching inference…") | |
| out = PIPE( | |
| image=f_resized, | |
| last_image=l_resized, | |
| prompt=prompt, | |
| negative_prompt=negative or None, | |
| height=h, | |
| width=w, | |
| num_frames=num_frames, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| generator=gen, | |
| ) | |
| # 90–100%: export | |
| progress(0.90, desc="Building video file…") | |
| video_path = export_to_video(out.frames[0], fps=fps) | |
| progress(1.0, desc="Done!") | |
| return video_path, seed | |
| # ----------------------------------------------------------------------------- | |
| # GRADIO UI | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video") | |
| with gr.Row(): | |
| first_img = gr.Image(label="First frame", type="pil") | |
| last_img = gr.Image(label="Last frame", type="pil") | |
| prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…") | |
| negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres") | |
| with gr.Accordion("Advanced parameters", open=False): | |
| steps = gr.Slider(10, 50, value=30, step=1, label="Steps") | |
| guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance") | |
| num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames") | |
| fps = gr.Slider(4, 30, value=16, step=1, label="FPS") | |
| seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)") | |
| run_btn = gr.Button("Generate") | |
| download = gr.File(label="Download .mp4", interactive=False) | |
| seed_used = gr.Number(label="Seed used", interactive=False) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[ first_img, last_img, prompt, negative, | |
| steps, guidance, num_frames, seed_input, fps ], | |
| outputs=[ download, seed_used ], | |
| ) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |