File size: 7,189 Bytes
72f552e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""Image-to-video generation using Wan 2.1 via fal.ai API.

Reads generated images and their prompts, produces a short video clip
per segment. Each clip is ~5s at 16fps; the assembler later trims to
the exact beat interval duration.

Two backends:
  - "api"  : fal.ai hosted Wan 2.1 (for development / local runs)
  - "hf"   : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment)

Set FAL_KEY env var for API mode.
"""

import base64
import json
import os
import time
from pathlib import Path
from typing import Optional

import requests
from dotenv import load_dotenv

load_dotenv()

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------

FAL_MODEL_ID = "fal-ai/wan-i2v"

# Vertical 9:16 to match our SDXL images
ASPECT_RATIO = "9:16"
RESOLUTION = "480p"  # cheaper/faster for dev; bump to 720p for final
NUM_FRAMES = 81  # ~5s at 16fps
FPS = 16
NUM_INFERENCE_STEPS = 30
GUIDANCE_SCALE = 5.0
SEED = 42


def _image_to_data_uri(image_path: str | Path) -> str:
    """Convert a local image file to a base64 data URI for the API."""
    path = Path(image_path)
    suffix = path.suffix.lower()
    mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"}
    content_type = mime.get(suffix, "image/png")

    with open(path, "rb") as f:
        encoded = base64.b64encode(f.read()).decode()

    return f"data:{content_type};base64,{encoded}"


def _download_video(url: str, output_path: Path) -> Path:
    """Download a video from URL to a local file."""
    resp = requests.get(url, timeout=300)
    resp.raise_for_status()
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "wb") as f:
        f.write(resp.content)
    return output_path


# ---------------------------------------------------------------------------
# API backend (fal.ai)
# ---------------------------------------------------------------------------

def generate_clip_api(
    image_path: str | Path,
    prompt: str,
    negative_prompt: str = "",
    seed: Optional[int] = None,
) -> dict:
    """Generate a video clip from an image using fal.ai Wan 2.1 API.

    Args:
        image_path: Path to the source image.
        prompt: Motion/scene description for the video.
        negative_prompt: What to avoid.
        seed: Random seed for reproducibility.

    Returns:
        API response dict with 'video' (url, content_type, file_size) and 'seed'.
    """
    import fal_client

    image_uri = _image_to_data_uri(image_path)

    args = {
        "image_url": image_uri,
        "prompt": prompt,
        "aspect_ratio": ASPECT_RATIO,
        "resolution": RESOLUTION,
        "num_frames": NUM_FRAMES,
        "frames_per_second": FPS,
        "num_inference_steps": NUM_INFERENCE_STEPS,
        "guide_scale": GUIDANCE_SCALE,
        "negative_prompt": negative_prompt,
        "enable_safety_checker": False,
        "enable_prompt_expansion": False,
    }
    if seed is not None:
        args["seed"] = seed

    result = fal_client.subscribe(FAL_MODEL_ID, arguments=args)
    return result


# ---------------------------------------------------------------------------
# Public interface
# ---------------------------------------------------------------------------

def generate_clip(
    image_path: str | Path,
    prompt: str,
    output_path: str | Path,
    negative_prompt: str = "",
    seed: Optional[int] = None,
) -> Path:
    """Generate a video clip from an image and save it locally.

    Args:
        image_path: Path to the source image.
        prompt: Motion/scene description.
        output_path: Where to save the .mp4 clip.
        negative_prompt: What to avoid.
        seed: Random seed.

    Returns:
        Path to the saved video clip.
    """
    output_path = Path(output_path)

    result = generate_clip_api(image_path, prompt, negative_prompt, seed)

    video_url = result["video"]["url"]
    return _download_video(video_url, output_path)


def generate_all(
    segments: list[dict],
    images_dir: str | Path,
    output_dir: str | Path,
    seed: int = SEED,
    progress_callback=None,
) -> list[Path]:
    """Generate video clips for all segments.

    Expects images at images_dir/segment_001.png, segment_002.png, etc.
    Segments should have 'prompt' and optionally 'negative_prompt' keys
    (from prompt_generator).

    Args:
        segments: List of segment dicts with 'segment', 'prompt' keys.
        images_dir: Directory containing generated images.
        output_dir: Directory to save video clips.
        seed: Base seed (incremented per segment).

    Returns:
        List of saved video clip paths.
    """
    images_dir = Path(images_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    paths = []
    for seg in segments:
        idx = seg["segment"]
        image_path = images_dir / f"segment_{idx:03d}.png"
        clip_path = output_dir / f"clip_{idx:03d}.mp4"

        if clip_path.exists():
            print(f"  Segment {idx}/{len(segments)}: already exists, skipping")
            paths.append(clip_path)
            continue

        if not image_path.exists():
            print(f"  Segment {idx}: image not found at {image_path}, skipping")
            continue

        # Use dedicated video_prompt (detailed motion), fall back to scene
        prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", "")))
        neg = seg.get("negative_prompt", "")

        print(f"  Segment {idx}/{len(segments)}: generating video clip...")
        t0 = time.time()
        generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx)
        elapsed = time.time() - t0
        print(f"    Saved {clip_path.name} ({elapsed:.1f}s)")

        paths.append(clip_path)
        if progress_callback:
            progress_callback(idx, len(segments))

    return paths


def run(
    data_dir: str | Path,
    seed: int = SEED,
    progress_callback=None,
) -> list[Path]:
    """Full video generation pipeline: read segments, generate clips, save.

    Args:
        data_dir: Song data directory containing segments.json and images/.
        seed: Base random seed.

    Returns:
        List of saved video clip paths.
    """
    data_dir = Path(data_dir)

    with open(data_dir / "segments.json") as f:
        segments = json.load(f)

    paths = generate_all(
        segments,
        images_dir=data_dir / "images",
        output_dir=data_dir / "clips",
        seed=seed,
        progress_callback=progress_callback,
    )

    print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}")
    return paths


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python -m src.video_generator <data_dir>")
        print("  e.g. python -m src.video_generator data/Gone")
        print("\nRequires FAL_KEY environment variable.")
        sys.exit(1)

    if not os.getenv("FAL_KEY"):
        print("Error: FAL_KEY environment variable not set.")
        print("Get your key at https://fal.ai/dashboard/keys")
        sys.exit(1)

    run(sys.argv[1])