| | """Image-to-video generation using Wan 2.1 via fal.ai API. |
| | |
| | Reads generated images and their prompts, produces a short video clip |
| | per segment. Each clip is ~5s at 16fps; the assembler later trims to |
| | the exact beat interval duration. |
| | |
| | Two backends: |
| | - "api" : fal.ai hosted Wan 2.1 (for development / local runs) |
| | - "hf" : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment) |
| | |
| | Set FAL_KEY env var for API mode. |
| | """ |
| |
|
| | import base64 |
| | import json |
| | import os |
| | import time |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import requests |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| |
|
| | FAL_MODEL_ID = "fal-ai/wan-i2v" |
| |
|
| | |
| | ASPECT_RATIO = "9:16" |
| | RESOLUTION = "480p" |
| | NUM_FRAMES = 81 |
| | FPS = 16 |
| | NUM_INFERENCE_STEPS = 30 |
| | GUIDANCE_SCALE = 5.0 |
| | SEED = 42 |
| |
|
| |
|
| | def _image_to_data_uri(image_path: str | Path) -> str: |
| | """Convert a local image file to a base64 data URI for the API.""" |
| | path = Path(image_path) |
| | suffix = path.suffix.lower() |
| | mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"} |
| | content_type = mime.get(suffix, "image/png") |
| |
|
| | with open(path, "rb") as f: |
| | encoded = base64.b64encode(f.read()).decode() |
| |
|
| | return f"data:{content_type};base64,{encoded}" |
| |
|
| |
|
| | def _download_video(url: str, output_path: Path) -> Path: |
| | """Download a video from URL to a local file.""" |
| | resp = requests.get(url, timeout=300) |
| | resp.raise_for_status() |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | with open(output_path, "wb") as f: |
| | f.write(resp.content) |
| | return output_path |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def generate_clip_api( |
| | image_path: str | Path, |
| | prompt: str, |
| | negative_prompt: str = "", |
| | seed: Optional[int] = None, |
| | ) -> dict: |
| | """Generate a video clip from an image using fal.ai Wan 2.1 API. |
| | |
| | Args: |
| | image_path: Path to the source image. |
| | prompt: Motion/scene description for the video. |
| | negative_prompt: What to avoid. |
| | seed: Random seed for reproducibility. |
| | |
| | Returns: |
| | API response dict with 'video' (url, content_type, file_size) and 'seed'. |
| | """ |
| | import fal_client |
| |
|
| | image_uri = _image_to_data_uri(image_path) |
| |
|
| | args = { |
| | "image_url": image_uri, |
| | "prompt": prompt, |
| | "aspect_ratio": ASPECT_RATIO, |
| | "resolution": RESOLUTION, |
| | "num_frames": NUM_FRAMES, |
| | "frames_per_second": FPS, |
| | "num_inference_steps": NUM_INFERENCE_STEPS, |
| | "guide_scale": GUIDANCE_SCALE, |
| | "negative_prompt": negative_prompt, |
| | "enable_safety_checker": False, |
| | "enable_prompt_expansion": False, |
| | } |
| | if seed is not None: |
| | args["seed"] = seed |
| |
|
| | result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def generate_clip( |
| | image_path: str | Path, |
| | prompt: str, |
| | output_path: str | Path, |
| | negative_prompt: str = "", |
| | seed: Optional[int] = None, |
| | ) -> Path: |
| | """Generate a video clip from an image and save it locally. |
| | |
| | Args: |
| | image_path: Path to the source image. |
| | prompt: Motion/scene description. |
| | output_path: Where to save the .mp4 clip. |
| | negative_prompt: What to avoid. |
| | seed: Random seed. |
| | |
| | Returns: |
| | Path to the saved video clip. |
| | """ |
| | output_path = Path(output_path) |
| |
|
| | result = generate_clip_api(image_path, prompt, negative_prompt, seed) |
| |
|
| | video_url = result["video"]["url"] |
| | return _download_video(video_url, output_path) |
| |
|
| |
|
| | def generate_all( |
| | segments: list[dict], |
| | images_dir: str | Path, |
| | output_dir: str | Path, |
| | seed: int = SEED, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Generate video clips for all segments. |
| | |
| | Expects images at images_dir/segment_001.png, segment_002.png, etc. |
| | Segments should have 'prompt' and optionally 'negative_prompt' keys |
| | (from prompt_generator). |
| | |
| | Args: |
| | segments: List of segment dicts with 'segment', 'prompt' keys. |
| | images_dir: Directory containing generated images. |
| | output_dir: Directory to save video clips. |
| | seed: Base seed (incremented per segment). |
| | |
| | Returns: |
| | List of saved video clip paths. |
| | """ |
| | images_dir = Path(images_dir) |
| | output_dir = Path(output_dir) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | paths = [] |
| | for seg in segments: |
| | idx = seg["segment"] |
| | image_path = images_dir / f"segment_{idx:03d}.png" |
| | clip_path = output_dir / f"clip_{idx:03d}.mp4" |
| |
|
| | if clip_path.exists(): |
| | print(f" Segment {idx}/{len(segments)}: already exists, skipping") |
| | paths.append(clip_path) |
| | continue |
| |
|
| | if not image_path.exists(): |
| | print(f" Segment {idx}: image not found at {image_path}, skipping") |
| | continue |
| |
|
| | |
| | prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) |
| | neg = seg.get("negative_prompt", "") |
| |
|
| | print(f" Segment {idx}/{len(segments)}: generating video clip...") |
| | t0 = time.time() |
| | generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) |
| | elapsed = time.time() - t0 |
| | print(f" Saved {clip_path.name} ({elapsed:.1f}s)") |
| |
|
| | paths.append(clip_path) |
| | if progress_callback: |
| | progress_callback(idx, len(segments)) |
| |
|
| | return paths |
| |
|
| |
|
| | def run( |
| | data_dir: str | Path, |
| | seed: int = SEED, |
| | progress_callback=None, |
| | ) -> list[Path]: |
| | """Full video generation pipeline: read segments, generate clips, save. |
| | |
| | Args: |
| | data_dir: Song data directory containing segments.json and images/. |
| | seed: Base random seed. |
| | |
| | Returns: |
| | List of saved video clip paths. |
| | """ |
| | data_dir = Path(data_dir) |
| |
|
| | with open(data_dir / "segments.json") as f: |
| | segments = json.load(f) |
| |
|
| | paths = generate_all( |
| | segments, |
| | images_dir=data_dir / "images", |
| | output_dir=data_dir / "clips", |
| | seed=seed, |
| | progress_callback=progress_callback, |
| | ) |
| |
|
| | print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") |
| | return paths |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.video_generator <data_dir>") |
| | print(" e.g. python -m src.video_generator data/Gone") |
| | print("\nRequires FAL_KEY environment variable.") |
| | sys.exit(1) |
| |
|
| | if not os.getenv("FAL_KEY"): |
| | print("Error: FAL_KEY environment variable not set.") |
| | print("Get your key at https://fal.ai/dashboard/keys") |
| | sys.exit(1) |
| |
|
| | run(sys.argv[1]) |
| |
|