SyncAI / src /video_generator_api.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""Image-to-video generation using Wan 2.1 via fal.ai API.
Reads generated images and their prompts, produces a short video clip
per segment. Each clip is ~5s at 16fps; the assembler later trims to
the exact beat interval duration.
Two backends:
- "api" : fal.ai hosted Wan 2.1 (for development / local runs)
- "hf" : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment)
Set FAL_KEY env var for API mode.
"""
import base64
import json
import os
import time
from pathlib import Path
from typing import Optional
import requests
from dotenv import load_dotenv
load_dotenv()
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
FAL_MODEL_ID = "fal-ai/wan-i2v"
# Vertical 9:16 to match our SDXL images
ASPECT_RATIO = "9:16"
RESOLUTION = "480p" # cheaper/faster for dev; bump to 720p for final
NUM_FRAMES = 81 # ~5s at 16fps
FPS = 16
NUM_INFERENCE_STEPS = 30
GUIDANCE_SCALE = 5.0
SEED = 42
def _image_to_data_uri(image_path: str | Path) -> str:
"""Convert a local image file to a base64 data URI for the API."""
path = Path(image_path)
suffix = path.suffix.lower()
mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"}
content_type = mime.get(suffix, "image/png")
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode()
return f"data:{content_type};base64,{encoded}"
def _download_video(url: str, output_path: Path) -> Path:
"""Download a video from URL to a local file."""
resp = requests.get(url, timeout=300)
resp.raise_for_status()
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "wb") as f:
f.write(resp.content)
return output_path
# ---------------------------------------------------------------------------
# API backend (fal.ai)
# ---------------------------------------------------------------------------
def generate_clip_api(
image_path: str | Path,
prompt: str,
negative_prompt: str = "",
seed: Optional[int] = None,
) -> dict:
"""Generate a video clip from an image using fal.ai Wan 2.1 API.
Args:
image_path: Path to the source image.
prompt: Motion/scene description for the video.
negative_prompt: What to avoid.
seed: Random seed for reproducibility.
Returns:
API response dict with 'video' (url, content_type, file_size) and 'seed'.
"""
import fal_client
image_uri = _image_to_data_uri(image_path)
args = {
"image_url": image_uri,
"prompt": prompt,
"aspect_ratio": ASPECT_RATIO,
"resolution": RESOLUTION,
"num_frames": NUM_FRAMES,
"frames_per_second": FPS,
"num_inference_steps": NUM_INFERENCE_STEPS,
"guide_scale": GUIDANCE_SCALE,
"negative_prompt": negative_prompt,
"enable_safety_checker": False,
"enable_prompt_expansion": False,
}
if seed is not None:
args["seed"] = seed
result = fal_client.subscribe(FAL_MODEL_ID, arguments=args)
return result
# ---------------------------------------------------------------------------
# Public interface
# ---------------------------------------------------------------------------
def generate_clip(
image_path: str | Path,
prompt: str,
output_path: str | Path,
negative_prompt: str = "",
seed: Optional[int] = None,
) -> Path:
"""Generate a video clip from an image and save it locally.
Args:
image_path: Path to the source image.
prompt: Motion/scene description.
output_path: Where to save the .mp4 clip.
negative_prompt: What to avoid.
seed: Random seed.
Returns:
Path to the saved video clip.
"""
output_path = Path(output_path)
result = generate_clip_api(image_path, prompt, negative_prompt, seed)
video_url = result["video"]["url"]
return _download_video(video_url, output_path)
def generate_all(
segments: list[dict],
images_dir: str | Path,
output_dir: str | Path,
seed: int = SEED,
progress_callback=None,
) -> list[Path]:
"""Generate video clips for all segments.
Expects images at images_dir/segment_001.png, segment_002.png, etc.
Segments should have 'prompt' and optionally 'negative_prompt' keys
(from prompt_generator).
Args:
segments: List of segment dicts with 'segment', 'prompt' keys.
images_dir: Directory containing generated images.
output_dir: Directory to save video clips.
seed: Base seed (incremented per segment).
Returns:
List of saved video clip paths.
"""
images_dir = Path(images_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
paths = []
for seg in segments:
idx = seg["segment"]
image_path = images_dir / f"segment_{idx:03d}.png"
clip_path = output_dir / f"clip_{idx:03d}.mp4"
if clip_path.exists():
print(f" Segment {idx}/{len(segments)}: already exists, skipping")
paths.append(clip_path)
continue
if not image_path.exists():
print(f" Segment {idx}: image not found at {image_path}, skipping")
continue
# Use dedicated video_prompt (detailed motion), fall back to scene
prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", "")))
neg = seg.get("negative_prompt", "")
print(f" Segment {idx}/{len(segments)}: generating video clip...")
t0 = time.time()
generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx)
elapsed = time.time() - t0
print(f" Saved {clip_path.name} ({elapsed:.1f}s)")
paths.append(clip_path)
if progress_callback:
progress_callback(idx, len(segments))
return paths
def run(
data_dir: str | Path,
seed: int = SEED,
progress_callback=None,
) -> list[Path]:
"""Full video generation pipeline: read segments, generate clips, save.
Args:
data_dir: Song data directory containing segments.json and images/.
seed: Base random seed.
Returns:
List of saved video clip paths.
"""
data_dir = Path(data_dir)
with open(data_dir / "segments.json") as f:
segments = json.load(f)
paths = generate_all(
segments,
images_dir=data_dir / "images",
output_dir=data_dir / "clips",
seed=seed,
progress_callback=progress_callback,
)
print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}")
return paths
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.video_generator <data_dir>")
print(" e.g. python -m src.video_generator data/Gone")
print("\nRequires FAL_KEY environment variable.")
sys.exit(1)
if not os.getenv("FAL_KEY"):
print("Error: FAL_KEY environment variable not set.")
print("Get your key at https://fal.ai/dashboard/keys")
sys.exit(1)
run(sys.argv[1])