"""Image-to-video generation using Wan 2.1 via fal.ai API. Reads generated images and their prompts, produces a short video clip per segment. Each clip is ~5s at 16fps; the assembler later trims to the exact beat interval duration. Two backends: - "api" : fal.ai hosted Wan 2.1 (for development / local runs) - "hf" : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment) Set FAL_KEY env var for API mode. """ import base64 import json import os import time from pathlib import Path from typing import Optional import requests from dotenv import load_dotenv load_dotenv() # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- FAL_MODEL_ID = "fal-ai/wan-i2v" # Vertical 9:16 to match our SDXL images ASPECT_RATIO = "9:16" RESOLUTION = "480p" # cheaper/faster for dev; bump to 720p for final NUM_FRAMES = 81 # ~5s at 16fps FPS = 16 NUM_INFERENCE_STEPS = 30 GUIDANCE_SCALE = 5.0 SEED = 42 def _image_to_data_uri(image_path: str | Path) -> str: """Convert a local image file to a base64 data URI for the API.""" path = Path(image_path) suffix = path.suffix.lower() mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"} content_type = mime.get(suffix, "image/png") with open(path, "rb") as f: encoded = base64.b64encode(f.read()).decode() return f"data:{content_type};base64,{encoded}" def _download_video(url: str, output_path: Path) -> Path: """Download a video from URL to a local file.""" resp = requests.get(url, timeout=300) resp.raise_for_status() output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "wb") as f: f.write(resp.content) return output_path # --------------------------------------------------------------------------- # API backend (fal.ai) # --------------------------------------------------------------------------- def generate_clip_api( image_path: str | Path, prompt: str, negative_prompt: str = "", seed: Optional[int] = None, ) -> dict: """Generate a video clip from an image using fal.ai Wan 2.1 API. Args: image_path: Path to the source image. prompt: Motion/scene description for the video. negative_prompt: What to avoid. seed: Random seed for reproducibility. Returns: API response dict with 'video' (url, content_type, file_size) and 'seed'. """ import fal_client image_uri = _image_to_data_uri(image_path) args = { "image_url": image_uri, "prompt": prompt, "aspect_ratio": ASPECT_RATIO, "resolution": RESOLUTION, "num_frames": NUM_FRAMES, "frames_per_second": FPS, "num_inference_steps": NUM_INFERENCE_STEPS, "guide_scale": GUIDANCE_SCALE, "negative_prompt": negative_prompt, "enable_safety_checker": False, "enable_prompt_expansion": False, } if seed is not None: args["seed"] = seed result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) return result # --------------------------------------------------------------------------- # Public interface # --------------------------------------------------------------------------- def generate_clip( image_path: str | Path, prompt: str, output_path: str | Path, negative_prompt: str = "", seed: Optional[int] = None, ) -> Path: """Generate a video clip from an image and save it locally. Args: image_path: Path to the source image. prompt: Motion/scene description. output_path: Where to save the .mp4 clip. negative_prompt: What to avoid. seed: Random seed. Returns: Path to the saved video clip. """ output_path = Path(output_path) result = generate_clip_api(image_path, prompt, negative_prompt, seed) video_url = result["video"]["url"] return _download_video(video_url, output_path) def generate_all( segments: list[dict], images_dir: str | Path, output_dir: str | Path, seed: int = SEED, progress_callback=None, ) -> list[Path]: """Generate video clips for all segments. Expects images at images_dir/segment_001.png, segment_002.png, etc. Segments should have 'prompt' and optionally 'negative_prompt' keys (from prompt_generator). Args: segments: List of segment dicts with 'segment', 'prompt' keys. images_dir: Directory containing generated images. output_dir: Directory to save video clips. seed: Base seed (incremented per segment). Returns: List of saved video clip paths. """ images_dir = Path(images_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) paths = [] for seg in segments: idx = seg["segment"] image_path = images_dir / f"segment_{idx:03d}.png" clip_path = output_dir / f"clip_{idx:03d}.mp4" if clip_path.exists(): print(f" Segment {idx}/{len(segments)}: already exists, skipping") paths.append(clip_path) continue if not image_path.exists(): print(f" Segment {idx}: image not found at {image_path}, skipping") continue # Use dedicated video_prompt (detailed motion), fall back to scene prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) neg = seg.get("negative_prompt", "") print(f" Segment {idx}/{len(segments)}: generating video clip...") t0 = time.time() generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) elapsed = time.time() - t0 print(f" Saved {clip_path.name} ({elapsed:.1f}s)") paths.append(clip_path) if progress_callback: progress_callback(idx, len(segments)) return paths def run( data_dir: str | Path, seed: int = SEED, progress_callback=None, ) -> list[Path]: """Full video generation pipeline: read segments, generate clips, save. Args: data_dir: Song data directory containing segments.json and images/. seed: Base random seed. Returns: List of saved video clip paths. """ data_dir = Path(data_dir) with open(data_dir / "segments.json") as f: segments = json.load(f) paths = generate_all( segments, images_dir=data_dir / "images", output_dir=data_dir / "clips", seed=seed, progress_callback=progress_callback, ) print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") return paths if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python -m src.video_generator ") print(" e.g. python -m src.video_generator data/Gone") print("\nRequires FAL_KEY environment variable.") sys.exit(1) if not os.getenv("FAL_KEY"): print("Error: FAL_KEY environment variable not set.") print("Get your key at https://fal.ai/dashboard/keys") sys.exit(1) run(sys.argv[1])