SyncAI / src /video_generator_hf.py
ICGenAIShare04's picture
Update src/video_generator_hf.py
a10589c verified
"""Image-to-video generation using Wan 2.1 on-device via diffusers.
Runs Wan 2.1 14B I2V locally on GPU (designed for HF Spaces ZeroGPU).
Same public interface as video_generator_api.py so app.py can swap backends.
"""
import json
import time
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from PIL import Image
# ---------------------------------------------------------------------------
# Config — matches video_generator_api.py settings
# ---------------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
NUM_FRAMES = 81 # ~5s at 16fps
FPS = 16
NUM_INFERENCE_STEPS = 25
GUIDANCE_SCALE = 5.0
SEED = 42
# 480p max pixel area (480 * 832 = 399360)
MAX_AREA = 480 * 832
# Singleton pipeline — loaded once, reused across calls
_pipe = None
def _get_pipe():
"""Load Wan 2.1 I2V pipeline (lazy singleton)."""
global _pipe
if _pipe is not None:
return _pipe
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from transformers import CLIPVisionModel
print(f"Loading Wan 2.1 I2V pipeline ({MODEL_ID})...")
# VAE and image encoder must be float32 for stability
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32,
)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=torch.float32,
)
_pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16,
)
# Quantize transformer to FP8 to fit in 24GB ZeroGPU VRAM
# (~28GB bf16 → ~14GB fp8). VAE + image encoder stay float32.
from torchao.quantization import quantize_, Float8WeightOnlyConfig
quantize_(_pipe.transformer, Float8WeightOnlyConfig())
_pipe.to("cuda")
print("Wan 2.1 I2V pipeline ready.")
return _pipe
def unload():
"""Unload the pipeline to free GPU memory."""
global _pipe
if _pipe is not None:
_pipe.to("cpu")
del _pipe
_pipe = None
torch.cuda.empty_cache()
print("Wan 2.1 I2V pipeline unloaded.")
def _resize_for_480p(image: Image.Image, pipe) -> tuple[Image.Image, int, int]:
"""Resize image to fit 480p area while respecting model patch constraints."""
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(MAX_AREA * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(MAX_AREA / aspect_ratio)) // mod_value * mod_value
return image.resize((width, height)), height, width
def generate_clip(
image_path: str | Path,
prompt: str,
output_path: str | Path,
negative_prompt: str = "",
seed: Optional[int] = None,
) -> Path:
"""Generate a video clip from an image using on-device Wan 2.1.
Args:
image_path: Path to the source image.
prompt: Motion/scene description.
output_path: Where to save the .mp4 clip.
negative_prompt: What to avoid.
seed: Random seed.
Returns:
Path to the saved video clip.
"""
from diffusers.utils import export_to_video
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
pipe = _get_pipe()
# Load and resize input image
image = Image.open(image_path).convert("RGB")
image, height, width = _resize_for_480p(image, pipe)
generator = None
if seed is not None:
generator = torch.Generator(device="cpu").manual_seed(seed)
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=NUM_FRAMES,
num_inference_steps=NUM_INFERENCE_STEPS,
guidance_scale=GUIDANCE_SCALE,
generator=generator,
)
export_to_video(output.frames[0], str(output_path), fps=FPS)
return output_path
def generate_all(
segments: list[dict],
images_dir: str | Path,
output_dir: str | Path,
seed: int = SEED,
progress_callback=None,
) -> list[Path]:
"""Generate video clips for all segments.
Args:
segments: List of segment dicts with 'segment', 'prompt' keys.
images_dir: Directory containing generated images.
output_dir: Directory to save video clips.
seed: Base seed (incremented per segment).
Returns:
List of saved video clip paths.
"""
images_dir = Path(images_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for seg in segments:
idx = seg["segment"]
image_path = images_dir / f"segment_{idx:03d}.png"
clip_path = output_dir / f"clip_{idx:03d}.mp4"
if clip_path.exists():
print(f" Segment {idx}/{len(segments)}: already exists, skipping")
paths.append(clip_path)
continue
if not image_path.exists():
print(f" Segment {idx}: image not found at {image_path}, skipping")
continue
# Use dedicated video_prompt (detailed motion), fall back to scene
prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", "")))
neg = seg.get("negative_prompt", "")
print(f" Segment {idx}/{len(segments)}: generating video clip...")
t0 = time.time()
generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx)
elapsed = time.time() - t0
print(f" Saved {clip_path.name} ({elapsed:.1f}s)")
paths.append(clip_path)
if progress_callback:
progress_callback(idx, len(segments))
return paths
def run(
data_dir: str | Path,
seed: int = SEED,
progress_callback=None,
) -> list[Path]:
"""Full video generation pipeline: read segments, generate clips, save.
Args:
data_dir: Run directory containing segments.json and images/.
seed: Base random seed.
Returns:
List of saved video clip paths.
"""
data_dir = Path(data_dir)
with open(data_dir / "segments.json") as f:
segments = json.load(f)
paths = generate_all(
segments,
images_dir=data_dir / "images",
output_dir=data_dir / "clips",
seed=seed,
progress_callback=progress_callback,
)
print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}")
return paths
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.video_generator_hf <data_dir>")
print(" e.g. python -m src.video_generator_hf data/Gone/run_001")
sys.exit(1)
run(sys.argv[1])