| | """Image-to-video generation using Wan 2.1 on-device via diffusers. |
| | |
| | Runs Wan 2.1 14B I2V locally on GPU (designed for HF Spaces ZeroGPU). |
| | Same public interface as video_generator_api.py so app.py can swap backends. |
| | """ |
| |
|
| | import json |
| | import time |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| |
|
| | |
| | |
| | |
| |
|
| | MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" |
| |
|
| | NUM_FRAMES = 81 |
| | FPS = 16 |
| | NUM_INFERENCE_STEPS = 25 |
| | GUIDANCE_SCALE = 5.0 |
| | SEED = 42 |
| |
|
| | |
| | MAX_AREA = 480 * 832 |
| |
|
| | |
| | _pipe = None |
| |
|
| |
|
| | def _get_pipe(): |
| | """Load Wan 2.1 I2V pipeline (lazy singleton).""" |
| | global _pipe |
| | if _pipe is not None: |
| | return _pipe |
| |
|
| | from diffusers import AutoencoderKLWan, WanImageToVideoPipeline |
| | from transformers import CLIPVisionModel |
| |
|
| | print(f"Loading Wan 2.1 I2V pipeline ({MODEL_ID})...") |
| |
|
| | |
| | image_encoder = CLIPVisionModel.from_pretrained( |
| | MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32, |
| | ) |
| | vae = AutoencoderKLWan.from_pretrained( |
| | MODEL_ID, subfolder="vae", torch_dtype=torch.float32, |
| | ) |
| |
|
| | _pipe = WanImageToVideoPipeline.from_pretrained( |
| | MODEL_ID, |
| | vae=vae, |
| | image_encoder=image_encoder, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| |
|
| | |
| | |
| | from torchao.quantization import quantize_, Float8WeightOnlyConfig |
| | quantize_(_pipe.transformer, Float8WeightOnlyConfig()) |
| |
|
| | _pipe.to("cuda") |
| |
|
| | print("Wan 2.1 I2V pipeline ready.") |
| | return _pipe |
| |
|
| |
|
| | def unload(): |
| | """Unload the pipeline to free GPU memory.""" |
| | global _pipe |
| | if _pipe is not None: |
| | _pipe.to("cpu") |
| | del _pipe |
| | _pipe = None |
| | torch.cuda.empty_cache() |
| | print("Wan 2.1 I2V pipeline unloaded.") |
| |
|
| |
|
| | def _resize_for_480p(image: Image.Image, pipe) -> tuple[Image.Image, int, int]: |
| | """Resize image to fit 480p area while respecting model patch constraints.""" |
| | aspect_ratio = image.height / image.width |
| | mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] |
| | height = round(np.sqrt(MAX_AREA * aspect_ratio)) // mod_value * mod_value |
| | width = round(np.sqrt(MAX_AREA / aspect_ratio)) // mod_value * mod_value |
| | return image.resize((width, height)), height, width |
| |
|
| |
|
| | def generate_clip( |
| | image_path: str | Path, |
| | prompt: str, |
| | output_path: str | Path, |
| | negative_prompt: str = "", |
| | seed: Optional[int] = None, |
| | ) -> Path: |
| | """Generate a video clip from an image using on-device Wan 2.1. |
| | |
| | Args: |
| | image_path: Path to the source image. |
| | prompt: Motion/scene description. |
| | output_path: Where to save the .mp4 clip. |
| | negative_prompt: What to avoid. |
| | seed: Random seed. |
| | |
| | Returns: |
| | Path to the saved video clip. |
| | """ |
| | from diffusers.utils import export_to_video |
| |
|
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | pipe = _get_pipe() |
| |
|
| | |
| | image = Image.open(image_path).convert("RGB") |
| | image, height, width = _resize_for_480p(image, pipe) |
| |
|
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator(device="cpu").manual_seed(seed) |
| |
|
| | output = pipe( |
| | image=image, |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | height=height, |
| | width=width, |
| | num_frames=NUM_FRAMES, |
| | num_inference_steps=NUM_INFERENCE_STEPS, |
| | guidance_scale=GUIDANCE_SCALE, |
| | generator=generator, |
| | ) |
| |
|
| | export_to_video(output.frames[0], str(output_path), fps=FPS) |
| | return output_path |
| |
|
| |
|
| | def generate_all( |
| | segments: list[dict], |
| | images_dir: str | Path, |
| | output_dir: str | Path, |
| | seed: int = SEED, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Generate video clips for all segments. |
| | |
| | Args: |
| | segments: List of segment dicts with 'segment', 'prompt' keys. |
| | images_dir: Directory containing generated images. |
| | output_dir: Directory to save video clips. |
| | seed: Base seed (incremented per segment). |
| | |
| | Returns: |
| | List of saved video clip paths. |
| | """ |
| | images_dir = Path(images_dir) |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | paths = [] |
| | for seg in segments: |
| | idx = seg["segment"] |
| | image_path = images_dir / f"segment_{idx:03d}.png" |
| | clip_path = output_dir / f"clip_{idx:03d}.mp4" |
| |
|
| | if clip_path.exists(): |
| | print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| | paths.append(clip_path) |
| | continue |
| |
|
| | if not image_path.exists(): |
| | print(f" Segment {idx}: image not found at {image_path}, skipping") |
| | continue |
| |
|
| | |
| | prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) |
| | neg = seg.get("negative_prompt", "") |
| |
|
| | print(f" Segment {idx}/{len(segments)}: generating video clip...") |
| | t0 = time.time() |
| | generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) |
| | elapsed = time.time() - t0 |
| | print(f" Saved {clip_path.name} ({elapsed:.1f}s)") |
| |
|
| | paths.append(clip_path) |
| | if progress_callback: |
| | progress_callback(idx, len(segments)) |
| |
|
| | return paths |
| |
|
| |
|
| | def run( |
| | data_dir: str | Path, |
| | seed: int = SEED, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Full video generation pipeline: read segments, generate clips, save. |
| | |
| | Args: |
| | data_dir: Run directory containing segments.json and images/. |
| | seed: Base random seed. |
| | |
| | Returns: |
| | List of saved video clip paths. |
| | """ |
| | data_dir = Path(data_dir) |
| |
|
| | with open(data_dir / "segments.json") as f: |
| | segments = json.load(f) |
| |
|
| | paths = generate_all( |
| | segments, |
| | images_dir=data_dir / "images", |
| | output_dir=data_dir / "clips", |
| | seed=seed, |
| | progress_callback=progress_callback, |
| | ) |
| |
|
| | print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") |
| | return paths |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.video_generator_hf <data_dir>") |
| | print(" e.g. python -m src.video_generator_hf data/Gone/run_001") |
| | sys.exit(1) |
| |
|
| | run(sys.argv[1]) |
| |
|