diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..fbadb973501bb55ffbd21f0eb966218801649213 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/Cant[[:space:]]find[[:space:]]myself.png filter=lfs diff=lfs merge=lfs -text +examples/Cant[[:space:]]find[[:space:]]myself.wav filter=lfs diff=lfs merge=lfs -text +examples/Gone.jpg filter=lfs diff=lfs merge=lfs -text +examples/Gone.wav filter=lfs diff=lfs merge=lfs -text +examples/House[[:space:]]of[[:space:]]House.png filter=lfs diff=lfs merge=lfs -text +examples/House[[:space:]]of[[:space:]]House.wav filter=lfs diff=lfs merge=lfs -text +examples/The[[:space:]]more[[:space:]]I[[:space:]]do.png filter=lfs diff=lfs merge=lfs -text +examples/The[[:space:]]more[[:space:]]I[[:space:]]do.wav filter=lfs diff=lfs merge=lfs -text +fonts/Anton-Regular.ttf filter=lfs diff=lfs merge=lfs -text +fonts/Montserrat-Bold.ttf filter=lfs diff=lfs merge=lfs -text +fonts/Oswald-Regular.ttf filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-anytiffng-2121455.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-artemmeletov-9201316.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-ekrulila-6536235.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-helenalopes-1959053.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-jerusaemm-2905514.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-kovyrina-1600139.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-kyle-karbowski-109303118-9968067.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-lokmansevim-13627402.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-matthew-jesus-468170389-30227212.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-omer-hakki-49913894-7820946.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-perspectivo-2048722386-29185675.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-pixabay-417059.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-pixabay-67566.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-seyma-alkas-178198724-12858917.jpg filter=lfs diff=lfs merge=lfs -text +lora_training_data/pexels-todd-trapani-488382-1535162.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ba37a060ee9dda1e121ec12dc30c1811d26d2023 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +.env +venv/ +__pycache__/ +*.pyc +.DS_Store +*.egg-info/ +dist/ +build/ +input/ +data/ +70113_1_spec.pdf +(0) 70113_Generative_AI_README_for_Coursework.ipynb +styles/*.safetensors +lora_training_data/ +fonts/*.zip diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000000000000000000000000000000000000..9d13f7d071e0f6802ad194ffca8b8b97e06ca99d --- /dev/null +++ b/PLAN.md @@ -0,0 +1,277 @@ +# SyncAI — AI Music Video Generator + +## Overview + +An end-to-end pipeline that takes a song as input and produces a beat-synced AI-generated video suitable for music ads. The system splits audio stems, extracts vocals and beat timing, generates images from lyrics, animates them into short video clips, and stitches everything together on beat. + +--- + +## Pipeline + +``` +Song (audio file) + │ + ├─► Stem Separation (LALAL.AI API) + │ ├─► Vocals + │ └─► Drums + │ + ├─► Lyrics Extraction (Whisper) ─► Timestamped lyrics + │ + ├─► Beat Detection (onset/kick detection) ─► Beat timestamps + │ + ├─► Segment Lyrics by Beat ─► 1 lyric snippet per beat interval + │ + ├─► Prompt Generation (Claude Sonnet 4.6, two LLM calls) ─► Image prompts + video motion prompts + │ + ├─► Image Generation (SDXL + Hyper-SD LoRA + custom style LoRA) ─► 1 styled image per segment (768x1344, 9:16 vertical) + │ + ├─► Image-to-Video (Wan 2.1 14B) ─► 1 clip per segment + │ + └─► Stitch on Beat (FFmpeg) ─► Final video (~15s) with original audio +``` + +--- + +## Modules + +### 1. Stem Separation +- **Tool:** [LALAL.AI API](https://www.lalal.ai/api/) — Andromeda model (commercial API, minute-based quota) +- **Input:** Song audio file (mp3/wav) +- **Output:** Separated stems — vocals, drums (only what the pipeline needs) +- **Install:** `pip install requests` (HTTP client only, no heavy ML dependencies) +- **Why LALAL.AI over Demucs:** Demucs (`htdemucs_ft`) produced unusable vocal isolation — Whisper extracted only 1 word ("um") from a Demucs vocal stem where LALAL.AI's Andromeda model correctly extracted all 15 words across 3 repetitions. The quality gap directly impacts downstream lyrics → prompts → images. +- **API flow:** Upload → split (vocals + drums as separate tasks) → poll status → download stems → cleanup remote files +- **Auth:** `X-License-Key` header via `LALAL_KEY` env var (HF Space secret) + +### 2. Lyrics Extraction +- **Tool:** [WhisperX](https://github.com/m-bain/whisperX) with `large-v2` model +- **Input:** Isolated vocal stem (from LALAL.AI) +- **Output:** Word-level timestamped transcript with forced alignment (via wav2vec2) +- **Why WhisperX over vanilla Whisper:** Forced alignment gives precise word-level timestamps needed for beat-syncing. `large-v2` outperforms `v3-turbo` on lyrics. +- **Notes:** Whisper is trained on speech, not singing — expect imperfect transcription on melodic/fast vocals. Good enough for generating image prompts (exact lyrics not critical). Isolated vocals from Demucs are the biggest accuracy boost. + +### 3. Beat / Kick Detection +- **Tool:** [madmom](https://github.com/CPJKU/madmom) (RNN-based beat tracker) +- **Approach:** + 1. `RNNBeatProcessor` — ensemble of bidirectional LSTMs produces beat activation function (probability per frame at 100fps) + 2. `DBNBeatTrackingProcessor` — Dynamic Bayesian Network decodes activations into precise beat timestamps + 3. `select_beats()` — trim to target duration + enforce minimum interval between beats +- **Why madmom over librosa:** librosa onset detection has a known 20-60ms latency bias ([github issue #1052](https://github.com/librosa/librosa/issues/1052)). For beat-synced video cuts, that lag is perceptible. Madmom's RNN+DBN approach has no such bias and benchmarks as the most accurate open-source beat tracker. +- **Input:** Drum stem (from LALAL.AI) +- **Output:** List of beat timestamps (seconds) +- **Install:** `pip install git+https://github.com/CPJKU/madmom.git` (PyPI release doesn't support Python 3.10+) +- **Target:** Select enough beats to produce ~15 seconds of video + +### 4. Lyric-to-Beat Segmentation +- **Logic:** Map timestamped lyrics onto beat intervals + - For each beat interval `[beat_i, beat_{i+1}]`, collect all words/phrases that fall within +- **Output:** List of `(start_time, end_time, lyrics)` segment dicts + +### 4b. Prompt Generation (Two LLM Calls) +- **Model:** Claude Sonnet 4.6 (`claude-sonnet-4-6`) via Anthropic API +- **Architecture:** Two separate LLM calls per run: + 1. **Image prompts** — short SDXL-optimized scene descriptions (25-35 words, under 77 CLIP tokens) + 2. **Video prompts** — detailed motion/action descriptions for I2V (no token limit) +- **Style-specific guidance:** Each style in `styles.py` provides an `image_prompt_guidance` field with a concrete SETTING (e.g. "Coastal sunset drive", "Rainy city at night"). The LLM places all scenes within this setting. +- **Style-specific quality suffix:** Each style provides a `quality_suffix` (e.g. "8K, cinematic, golden hour glow, warm volumetric light") appended to every prompt. The LLM is told NOT to include style/quality tags in scenes — those come from the suffix. +- **Lyrics integration:** Concrete lyrics that fit the setting are interpreted literally. Abstract/metaphorical lyrics are translated into physical actions within the setting. +- **Prompt rules:** Literal language only (no metaphors — SDXL interprets words literally). Focus on concrete objects and actions. Physically plausible scenes only. No periods (waste tokens). No style boilerplate in scenes. +- **Output:** Enriched segments with `prompt`, `video_prompt`, `scene`, `camera_angle`, `negative_prompt` + +### 5. Image Generation (Core GenAI Requirement) +- **Base model:** SDXL (Stable Diffusion XL) via `diffusers` library +- **Acceleration:** Hyper-SD 8-step LoRA (ByteDance) — distilled SDXL that produces near-full-quality images in 8 steps (~2s/image instead of ~13s) +- **Style adaptation:** Multiple style LoRAs on SDXL, selectable via Gradio UI dropdown: + - **Sunset Coastal Drive** — custom-trained LoRA (`samuelsattler/warm-sunset-lora`), weight 1.0, trigger "sks" + - **Rainy City Night** — community film grain LoRA (`artificialguybr/filmgrain-redmond`), weight 0.8, trigger "FilmGrainAF" + - **Cyberpunk** — community cyberpunk LoRA (`jbilcke-hf/sdxl-cyberpunk-2077`), weight 0.9, trigger "cyberpunk-2077" + - **Watercolour Harbour** — community watercolor LoRA (`ostris/watercolor_style_lora_sdxl`), weight 1.4 + - Each style also provides `image_prompt_guidance` (setting) and `quality_suffix` (style-specific quality tags) + - LoRAs loaded from HuggingFace Hub at runtime — no local `.safetensors` files needed + - **Custom LoRA training:** 15-20 curated images, Google Colab T4, ~1-2 hours. Captions describe content only, not style. + - **Usage:** Stack Hyper-SD 8-step LoRA (speed) + style LoRA (aesthetics) at inference +- **Output resolution:** 768 x 1344 (9:16 vertical) — stays within SDXL's ~1MP training budget, ideal for mobile/social media video +- **Input:** Text prompt per segment (from prompt generator) +- **Output:** One vertical image per segment in the selected style +- **Why SDXL over alternatives:** + - vs SD 1.5: SDXL is equally fast with Hyper-SD but far better quality, especially at non-square ratios (SD 1.5 trained at 512x512, poor at vertical) + - vs Flux: 4x slower, 2x VRAM, much smaller LoRA ecosystem, overkill for fleeting music video frames + - vs SD 3.5: Immature LoRA ecosystem compared to SDXL's thousands of community models +- **This satisfies the coursework requirement:** "Take a pre-trained model and adapt it for a niche, creative application" + +### 6. Image-to-Video (Wan 2.1) +- **Model:** Wan 2.1 I2V 14B (Alibaba, open-weights) — best open-source I2V model +- **Two backends, same model:** + - **Local dev:** fal.ai API (`video_generator_api.py`) — ~$0.20/clip at 480p, instant setup + - **HF Spaces:** Wan 2.1 14B on ZeroGPU with FP8 quantization (`video_generator_hf.py`) — free for users, no API key needed +- **Input:** Generated image (from our LoRA-styled SDXL) + motion prompt +- **Output:** Short video clip (~5s at 16fps, 9:16 vertical) — assembler trims to beat interval +- **Why Wan 2.1:** Best quality open-weights I2V, fits 24GB VRAM with FP8, natively in `diffusers`, existing ZeroGPU reference Spaces to build from +- **Prompt strategy:** Use strong kinetic verbs and mid-action descriptions to get immediate full motion from frame 1 (critical since clips are only ~2s after trimming) +- **Why image→video, not direct text→video?** The custom style LoRA only works on SDXL (image gen) — video models like Wan 2.1 don't support SDXL LoRAs. The two-step pipeline lets us apply our trained style in the image step, then animate it. This also gives precise control over the first frame's composition, whereas T2V is unpredictable. The image step is where the coursework GenAI requirement lives. + +### 7. Final Assembly +- **Tool:** FFmpeg (via `subprocess` or `ffmpeg-python`) +- **Steps:** + 1. Trim/stretch each video clip to its exact beat interval duration + 2. Concatenate clips in order + 3. Overlay the original audio (or a mixed version) + 4. Export final video (mp4, H.264) +- **Output:** ~15-second beat-synced music video + +--- + +## Tech Stack + +| Component | Library / Model | +|----------------------|------------------------------------| +| Stem separation | LALAL.AI API (Andromeda model) | +| Lyrics (ASR) | WhisperX (large-v2 + wav2vec2 alignment) | +| Beat detection | madmom (RNN + DBN beat tracker) | +| Prompt generation | Claude Sonnet 4.6 (Anthropic API, two-call architecture) | +| Image generation | SDXL + Hyper-SD 8-step + style LoRA (4 styles, diffusers) | +| Image-to-video | Wan 2.1 14B (fal.ai API for dev, ZeroGPU with FP8 for HF Spaces) | +| Video assembly | FFmpeg | +| Demo UI | Gradio (for Hugging Face Spaces) | +| Orchestration | Python | + +--- + +## Development & Deployment Strategy + +### Development (Local — MacBook Pro M1 Pro, 16GB) + +- **CPU tasks (fast locally):** Whisper transcription, madmom beat detection, FFmpeg stitching, Gradio UI +- **GPU tasks (via MPS):** SDXL + Hyper-SD image generation (~2-4 sec/image on M1 Pro with 8-step inference) +- **Video generation:** fal.ai API (Wan 2.1 can't run on 16GB) — ~$0.20/clip at 480p +- **LoRA training:** Google Colab T4 (one-time, download weights when done) + +### MVP (get end-to-end working first) + +| Step | MVP Approach | Why | +|------|-------------|-----| +| Stem separation | LALAL.AI API (Andromeda) | Best quality, API-based | +| Lyrics | Whisper (pre-trained) | Works out of the box | +| Beat detection | madmom `RNNBeatProcessor` + `DBNBeatTrackingProcessor` on drum stem | Most accurate, no latency bias | +| Segmentation | Map lyrics to beat intervals | Pure Python logic | +| Image gen | SDXL + Hyper-SD 8-step LoRA locally (MPS), 768x1344 vertical | ~2s/image, runs on M1 Pro | +| Image-to-video | fal.ai API → Wan 2.1 14B | Same model as HF deployment, can't run locally | +| Assembly | FFmpeg concat + overlay audio | Reliable, no ML | + +### Deployment (HF Spaces) + +When deploying to HF Spaces, swap to on-device inference: + +1. **Image gen:** `image_generator.py` — same code, add `@spaces.GPU` decorator, switch dtype to bf16 +2. **Video gen:** `video_generator_hf.py` replaces `video_generator_api.py` — Wan 2.1 14B with FP8 quantization on ZeroGPU (no API key, no credits) +3. **`requirements.txt`** replaces local pip installs +4. **`packages.txt`** for system dependencies (e.g. `ffmpeg`) +5. **Owner:** `mvp-lab` account, **Hardware:** ZeroGPU, **SDK:** Gradio + +The pipeline orchestration stays identical — only the import path for the video generator changes. + +--- + +## GenAI Requirement Fulfillment + +**Option chosen:** Take a pre-trained model and adapt it for a niche, creative application. + +- Train a custom style LoRA on SDXL (warm sunset aesthetic) + curate 3 community style LoRAs (film grain, cyberpunk, watercolour) +- Stack with Hyper-SD 8-step LoRA for fast inference at 768x1344 (9:16 vertical) +- Style-specific prompt engineering: each style has a concrete setting, quality suffix, and LLM guidance +- The full pipeline (lyrics → images → video, synced to beat) is a novel creative application + +--- + +## Hugging Face Demo Plan + +- **Framework:** Gradio +- **Interface:** + 1. User uploads a song (mp3/wav) + 2. User selects visual style (dropdown) + 3. User clicks "Generate" + 4. Progress bar showing pipeline stages + 5. Output: playable video + download link +- **Compute:** Hugging Face Spaces with ZeroGPU (H200 MIG slice, ~24GB) +- **Constraints:** Keep total inference under ~5 minutes for demo usability +- **Lyrics overlay:** Fixed to Bebas Neue font + warm white (#FFF7D4) colour. This combination looks best across all visual styles and gives the best out-of-the-box results. Font/colour selection UI code is kept commented out in `app.py` for future re-enablement. + +--- + +## API Usage & Spend Controls + +### Why APIs instead of on-device models for non-visual components + +The pipeline uses external APIs for three non-visual tasks: stem separation (LALAL.AI), lyrics-to-prompt expansion (Anthropic Claude), and optionally image/video generation during development (fal.ai). + +Running these models on-device was considered but rejected for two reasons: + +1. **Quality**: On-device alternatives produce substantially worse results. For example, Demucs (`htdemucs_ft`) extracted only 1 word ("um") from a vocal stem where LALAL.AI's Andromeda model correctly extracted all 15 words across 3 repetitions. The quality gap directly impacts downstream steps — bad vocals mean bad lyrics mean bad prompts mean bad images. + +2. **Scope**: The coursework focus is on **visual generative AI** — specifically training a custom style LoRA on SDXL and building a multi-modal composition pipeline. Implementing production-quality ASR, source separation, or LLM inference on-device would explode the project scope without contributing to the core visual AI objective. + +### Spend limits + +All API keys are stored as HF Space secrets on the supervisor's shared account. To prevent runaway costs: + +- **Anthropic**: Spend limit configured in the Anthropic Console (usage dashboard → limits) +- **LALAL.AI**: Minute-based quota tied to the license tier — processing stops when minutes are exhausted +- **fal.ai**: Only used during local development, not on the deployed Space (Spaces use on-device Wan 2.1 + SDXL) + +### Deployment secret summary + +| Secret | Used for | Where | +|--------|----------|-------| +| `ANTHROPIC_API_KEY` | Prompt generation (Claude) | Both local + Spaces | +| `LALAL_KEY` | Vocal separation (LALAL.AI) | Both local + Spaces | +| `FAL_KEY` | Image + video gen (fal.ai) | Local dev only | + +--- + +## Post-MVP Ideas + +- **Smart clip selection:** Auto-detect the best ~15s of a full song (e.g. 5-10s build-up before the drop + 5-10s of the drop). Use energy analysis, onset density, or structural segmentation to find the drop. For MVP, the user pre-trims the input to the desired 15s. +- Beat-synced crossfades instead of hard cuts +- More community style LoRAs (e.g. retro anime, oil painting) +- Direct text-to-video generation (skip image step) if video models improve enough + +--- + +## File Structure + +``` +CW/ +├── PLAN.md +├── app.py # Gradio demo entry point +├── requirements.txt +├── packages.txt # System deps for HF Spaces (ffmpeg) +├── src/ +│ ├── __init__.py +│ ├── stem_separator.py # LALAL.AI API wrapper +│ ├── lyrics_extractor.py # WhisperX wrapper +│ ├── beat_detector.py # madmom RNN+DBN beat detection +│ ├── segmenter.py # lyrics-to-beat mapping +│ ├── prompt_generator.py # Two-call LLM prompt generation (Claude Sonnet 4.6) +│ ├── styles.py # Style registry (LoRA sources, settings, quality suffixes) +│ ├── image_generator.py # SDXL + Hyper-SD + style LoRA (768x1344 vertical) +│ ├── video_generator_api.py # Wan 2.1 I2V via fal.ai (local dev) +│ ├── video_generator_hf.py # Wan 2.1 I2V on ZeroGPU with FP8 (HF Spaces deployment) +│ └── assembler.py # FFmpeg stitching + lyrics overlay +├── train_lora.py # LoRA training script (run on Colab T4, ~1-2 hours) +├── fonts/ # Fonts for lyrics overlay +├── lora_training_data/ # Curated style images for LoRA training +├── data/ # All pipeline output (one folder per song) +│ └── / +│ ├── run_001/ # Each pipeline run gets its own directory +│ │ ├── stems/ +│ │ │ ├── drums.wav +│ │ │ └── vocals.wav +│ │ ├── lyrics.json +│ │ ├── beats.json +│ │ ├── segments.json # Enriched with prompt, video_prompt, scene, etc. +│ │ ├── images/ +│ │ ├── clips/ +│ │ └── output/ +│ └── run_002/ # Re-running creates a new run, no overwrites +└── poster/ # Poster assets & PDF +``` diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..af0126dc93a932397a324548199d1d5c931951be --- /dev/null +++ b/README.md @@ -0,0 +1,62 @@ +--- +title: SyncAI +emoji: 🎵 +colorFrom: indigo +colorTo: pink +sdk: gradio +sdk_version: "6.8.0" +app_file: app.py +pinned: false +short_description: AI Music Ads Generator +--- + +# SyncAI — AI Music Video Generator + +Generate beat-synced music video ads from a song clip. Upload ~15 seconds of audio, pick a visual style, and SyncAI produces a fully assembled vertical video with AI-generated visuals cut to the beat. + +## How It Works + +``` +Song (audio file) + ├─► Stem Separation (LALAL.AI) → Vocals + Drums + ├─► Lyrics Extraction (WhisperX) → Word-level timestamps + ├─► Beat Detection (madmom RNN + DBN) → Beat timestamps + drop detection + ├─► Segmentation → Lyrics mapped to beat intervals + ├─► Prompt Generation (Claude Sonnet 4.6) → Image + video motion prompts + ├─► Image Generation (SDXL + Hyper-SD + style LoRA) → 768x1344 images + ├─► Image-to-Video (Wan 2.1 14B) → Animated clips + └─► Assembly (FFmpeg) → Beat-synced video with lyrics overlay +``` + +## Visual Styles + +Each style applies a different LoRA to SDXL and sets a unique scene world for the LLM prompt generator. The Sunset Coastal Drive LoRA was custom-trained for this project; the others are community LoRAs from HuggingFace Hub: + +| Style | LoRA | Setting | +|-------|------|---------| +| **Sunset Coastal Drive** | Custom-trained (`samuelsattler/warm-sunset-lora`) | Car cruising a coastal highway at golden hour | +| **Rainy City Night** | Film grain (`artificialguybr/filmgrain-redmond`) | Walking rain-soaked city streets after dark | +| **Cyberpunk** | Cyberpunk 2077 (`jbilcke-hf/sdxl-cyberpunk-2077`) | Neon-drenched futuristic megacity at night | +| **Watercolour Harbour** | Watercolor (`ostris/watercolor_style_lora_sdxl`) | Coastal fishing village during a storm | + +## Assembly Features + +- **Dynamic pacing**: 4-beat cuts before the drop, 2-beat cuts after for energy +- **Clip shuffling**: Each clip used twice (first/second half) in randomised order for visual variety +- **Ken Burns**: Alternating zoom in/out on every cut +- **Lyrics overlay**: Word-level timing with gap closing +- **Cover art overlay**: Album art + Spotify badge appear from the drop onwards +- **Reshuffle**: Re-run assembly with a new random clip order without regenerating + +## Tech Stack + +| Component | Tool | +|-----------|------| +| Stem separation | LALAL.AI API (Andromeda) | +| Lyrics (ASR) | WhisperX (large-v2 + wav2vec2) | +| Beat detection | madmom (RNN + DBN) | +| Prompt generation | Claude Sonnet 4.6 (Anthropic API) | +| Image generation | SDXL + Hyper-SD 8-step + style LoRA | +| Image-to-video | Wan 2.1 14B (ZeroGPU with FP8) | +| Video assembly | FFmpeg | +| UI | Gradio | diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2daea3b632519d230989a14ece29a1e6b9a96477 --- /dev/null +++ b/app.py @@ -0,0 +1,704 @@ +"""SyncAI — AI Music Video Generator. + +Gradio app that orchestrates the full pipeline: +Song → Stems → Lyrics + Beats → Segments → Prompts → Images → Video → Assembly + +Works locally (fal.ai API for video) and on HuggingFace Spaces (on-device Wan 2.1). +""" + +import json +import os +import shutil +from pathlib import Path + +from dotenv import load_dotenv +load_dotenv() + +import gradio as gr +import torch + +# Lightweight imports only — heavy modules (whisperx, madmom, etc.) +# are lazy-imported inside generate() to keep the UI responsive. +from src.assembler import font_names, DEFAULT_FONT, DEFAULT_FONT_COLOR +from src.styles import style_names, get_style + +# --------------------------------------------------------------------------- +# Environment detection +# --------------------------------------------------------------------------- + +IS_SPACES = os.getenv("SPACE_ID") is not None + +if IS_SPACES: + import spaces + +INPUT_DIR = Path("input") +INPUT_DIR.mkdir(exist_ok=True) + + +# --------------------------------------------------------------------------- +# GPU-accelerated steps (decorated only on Spaces) +# --------------------------------------------------------------------------- + +def _generate_images(run_dir, style_name, progress_callback=None): + """Load SDXL pipeline, generate all images, then unload to free VRAM.""" + if IS_SPACES: + from src.image_generator_hf import run as gen_images + else: + from src.image_generator_api import run as gen_images + gen_images(run_dir, style_name=style_name, progress_callback=progress_callback) + # Free VRAM for the video model + torch.cuda.empty_cache() + + +def _generate_videos(run_dir, progress_callback=None): + """Load Wan 2.1 pipeline, generate all video clips, then unload.""" + if IS_SPACES: + from src.video_generator_hf import run as gen_videos + gen_videos(run_dir, progress_callback=progress_callback) + # Unload Wan 2.1 to free VRAM + from src.video_generator_hf import unload + unload() + else: + from src.video_generator_api import run as gen_videos + gen_videos(run_dir, progress_callback=progress_callback) + + +# Apply @spaces.GPU decorator on Spaces +if IS_SPACES: + _generate_images = spaces.GPU(duration=300)(_generate_images) + _generate_videos = spaces.GPU(duration=3600)(_generate_videos) # up to 1h for ~12 clips + + +# --------------------------------------------------------------------------- +# Run discovery & step detection +# --------------------------------------------------------------------------- + +DATA_DIR = Path("data") + +STEPS = [ + "1. Stems", + "2. Lyrics", + "3. Beats", + "4. Segmentation", + "5. Prompts", + "6. Images", + "7. Videos", + "8. Assembly", +] + + +def _list_runs() -> list[str]: + """Find all existing run directories under data/.""" + if not DATA_DIR.exists(): + return [] + runs = [] + for song_dir in sorted(DATA_DIR.iterdir()): + if not song_dir.is_dir(): + continue + for run_dir in sorted(song_dir.glob("run_*")): + if run_dir.is_dir(): + runs.append(f"{song_dir.name}/{run_dir.name}") + return runs + + +def _detect_completed_steps(run_dir: Path) -> int: + """Return the number of the last fully completed step (0 = nothing done).""" + # Step 1: vocals + drums stems exist (LALAL.AI only extracts these two) + stems = run_dir / "stems" + for name in ["drums.wav", "vocals.wav"]: + if not (stems / name).exists(): + return 0 + + # Step 2: lyrics.json valid with at least 1 entry + lyrics_path = run_dir / "lyrics.json" + if not lyrics_path.exists(): + return 1 + try: + data = json.loads(lyrics_path.read_text()) + if not isinstance(data, list) or len(data) == 0: + return 1 + except (json.JSONDecodeError, OSError): + return 1 + + # Step 3: beats.json valid with at least 1 entry + beats_path = run_dir / "beats.json" + if not beats_path.exists(): + return 2 + try: + data = json.loads(beats_path.read_text()) + if not isinstance(data, list) or len(data) == 0: + return 2 + except (json.JSONDecodeError, OSError): + return 2 + + # Step 4: segments.json valid with at least 1 segment having start/end + seg_path = run_dir / "segments.json" + if not seg_path.exists(): + return 3 + try: + segments = json.loads(seg_path.read_text()) + if not isinstance(segments, list) or len(segments) == 0: + return 3 + if "start" not in segments[0] or "end" not in segments[0]: + return 3 + except (json.JSONDecodeError, OSError): + return 3 + + # Step 5: every segment has a non-empty "prompt" key + try: + if not all(seg.get("prompt") for seg in segments): + return 4 + except Exception: + return 4 + + n_segments = len(segments) + + # Step 6: exactly N image files exist + for i in range(1, n_segments + 1): + if not (run_dir / "images" / f"segment_{i:03d}.png").exists(): + return 5 + + # Step 7: exactly N clip files exist + for i in range(1, n_segments + 1): + if not (run_dir / "clips" / f"clip_{i:03d}.mp4").exists(): + return 6 + + # Step 8: final.mp4 exists with size > 0 + final = run_dir / "output" / "final.mp4" + if not final.exists() or final.stat().st_size == 0: + return 7 + + return 8 + + +def _get_startable_steps(run_dir: Path) -> list[str]: + """Return step names the user can start from (all prerequisites met).""" + completed = _detect_completed_steps(run_dir) + # Can start from any step up to completed+1 (the next incomplete step) + last_startable = min(completed + 1, 8) + return STEPS[:last_startable] # steps 1 through last_startable + + +def _on_run_mode_change(run_mode): + """Toggle visibility of audio upload vs resume controls.""" + is_resume = run_mode == "Resume Existing" + return ( + gr.update(visible=not is_resume), # audio_input + gr.update(visible=is_resume, choices=_list_runs()), # existing_run + gr.update(visible=is_resume, choices=[], value=None), # start_step + gr.update(visible=is_resume), # reuse_files + ) + + +def _on_run_selected(existing_run): + """Update step dropdown when a run is selected.""" + if not existing_run: + return gr.update(choices=[], value=None) + run_dir = DATA_DIR / existing_run + steps = _get_startable_steps(run_dir) + default = steps[-1] if steps else None + return gr.update(choices=steps, value=default) + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +_COLOR_PRESETS = { + "Warm White": "#FFF7D4", + "White": "#FFFFFF", + "Red": "#FF3B30", + "Cyan": "#00E5FF", + "Gold": "#FFD700", + "Custom": None, +} + + +def generate(audio_file: str, style_name: str, cover_art: str | None, + run_mode: str, existing_run: str | None, start_step: str | None, + reuse_files: bool, progress=gr.Progress()): + """Run the SyncAI pipeline (full or resumed). + + Returns: + Path to the final video. + """ + font_name = DEFAULT_FONT + font_color = DEFAULT_FONT_COLOR + style = get_style(style_name) + is_resume = run_mode == "Resume Existing" + + if is_resume: + if not existing_run: + raise gr.Error("Please select an existing run.") + if not start_step: + raise gr.Error("Please select a step to start from.") + run_dir = DATA_DIR / existing_run + if not run_dir.exists(): + raise gr.Error(f"Run directory not found: {run_dir}") + step_num = int(start_step.split(".")[0]) + print(f"Resuming {existing_run} from step {step_num}") + + # Always clear assembly output (cheap to redo) + import shutil + out_dir = run_dir / "output" + if out_dir.exists(): + shutil.rmtree(out_dir) + # Also clear intermediate assembly artifacts + for d in ["clips_split", "clips_trimmed"]: + p = run_dir / d + if p.exists(): + shutil.rmtree(p) + + # If not reusing files, also clear images and video clips + if not reuse_files: + if step_num <= 6: + img_dir = run_dir / "images" + if img_dir.exists(): + shutil.rmtree(img_dir) + if step_num <= 7: + clips_dir = run_dir / "clips" + if clips_dir.exists(): + shutil.rmtree(clips_dir) + else: + if audio_file is None: + raise gr.Error("Please upload a song first.") + step_num = 1 + + import gc + + def _flush_memory(): + """Aggressively free memory between heavy ML steps.""" + gc.collect() + if hasattr(torch, "mps") and torch.backends.mps.is_available(): + torch.mps.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # --- Step 1: Stem Separation --- + if step_num <= 1: + progress(0.0, desc="Separating stems...") + from src.stem_separator import separate_stems + # For resume: find original audio in song dir; for new run: use uploaded file + if is_resume: + song_dir = run_dir.parent + audio_candidates = list(song_dir.glob("*.wav")) + list(song_dir.glob("*.mp3")) + \ + list(song_dir.glob("*.flac")) + list(song_dir.glob("*.m4a")) + if not audio_candidates: + raise gr.Error(f"No audio file found in {song_dir}") + result = separate_stems(audio_candidates[0], output_dir=run_dir / "stems") + else: + result = separate_stems(Path(audio_file)) + run_dir = result["run_dir"] + print(f"Run directory: {run_dir}") + + # --- Step 2: Lyrics Extraction --- + if step_num <= 2: + progress(0.15, desc="Extracting lyrics...") + from src.lyrics_extractor import extract_lyrics + vocals_path = run_dir / "stems" / "vocals.wav" + extract_lyrics(vocals_path) + del extract_lyrics + _flush_memory() + + # --- Step 3: Beat Detection --- + if step_num <= 3: + progress(0.25, desc="Detecting beats...") + from src.beat_detector import run as detect_beats + drums_path = run_dir / "stems" / "drums.wav" + detect_beats(drums_path) + del detect_beats + _flush_memory() + + # --- Step 4: Segmentation --- + if step_num <= 4: + progress(0.35, desc="Segmenting lyrics to beats...") + from src.segmenter import run as segment_lyrics + segment_lyrics(run_dir) + + # --- Step 5: Prompt Generation --- + if step_num <= 5: + progress(0.40, desc="Generating prompts...") + from src.prompt_generator import run as generate_prompts + generate_prompts(run_dir, style_description=style["description"], + image_prompt_guidance=style.get("image_prompt_guidance", ""), + quality_suffix=style.get("quality_suffix", "")) + + # --- Step 6: Image Generation --- + if step_num <= 6: + progress(0.50, desc="Generating images...") + def _img_progress(i, total): + progress(0.50 + 0.20 * (i / total), desc=f"Generating images ({i}/{total})...") + _generate_images(run_dir, style_name, progress_callback=_img_progress) + + # --- Step 7: Video Generation --- + if step_num <= 7: + progress(0.70, desc="Generating video clips...") + def _vid_progress(i, total): + progress(0.70 + 0.20 * (i / total), desc=f"Generating videos ({i}/{total})...") + _generate_videos(run_dir, progress_callback=_vid_progress) + + # --- Step 8: Assembly --- + progress(0.90, desc="Assembling final video...") + from src.assembler import run as assemble_video + final_path = assemble_video(run_dir, font_name=font_name, font_color=font_color, + cover_art=cover_art) + + progress(1.0, desc="Done!") + return str(final_path), str(run_dir), gr.update(visible=True) + + +def reshuffle(run_dir_str: str, cover_art: str | None, progress=gr.Progress()): + """Re-run only the assembly step with a new random shuffle.""" + if not run_dir_str: + raise gr.Error("No previous run to reshuffle. Generate a video first.") + + run_dir = Path(run_dir_str) + if not run_dir.exists(): + raise gr.Error(f"Run directory not found: {run_dir}") + + font_name = DEFAULT_FONT + font_color = DEFAULT_FONT_COLOR + + # Clear assembly artifacts + for d in ["clips_trimmed", "output"]: + p = run_dir / d + if p.exists(): + shutil.rmtree(p) + + progress(0.2, desc="Reshuffling and assembling...") + from src.assembler import run as assemble_video + final_path = assemble_video(run_dir, font_name=font_name, font_color=font_color, + cover_art=cover_art) + + progress(1.0, desc="Done!") + return str(final_path) + + +# --------------------------------------------------------------------------- +# Gradio UI +# --------------------------------------------------------------------------- + +_custom_css = """ +/* Load Google Fonts for dropdown preview */ +@import url('https://fonts.googleapis.com/css2?family=Bebas+Neue&family=Teko:wght@700&family=Russo+One&family=Staatliches&display=swap'); +/* Style font dropdown options in their actual font */ +#font-dropdown [data-value="Bebas Neue"], #font-dropdown li:nth-child(1) { font-family: 'Bebas Neue', sans-serif !important; } +#font-dropdown [data-value="Teko"], #font-dropdown li:nth-child(2) { font-family: 'Teko', sans-serif !important; font-weight: 700 !important; } +#font-dropdown [data-value="Russo One"], #font-dropdown li:nth-child(3) { font-family: 'Russo One', sans-serif !important; } +#font-dropdown [data-value="Staatliches"], #font-dropdown li:nth-child(4) { font-family: 'Staatliches', sans-serif !important; } +#font-dropdown ul li { font-size: 16px !important; } +/* Remove white border on color picker */ +input[type="color"], +input[type="color"]:focus, +input[type="color"]:hover, +.gr-color-picker input, +div[data-testid="color-picker"] input, +div[data-testid="color-picker"] div, +.color-picker input { + border: none !important; + outline: none !important; + box-shadow: none !important; + background: transparent !important; +} +/* Color swatch buttons */ +.color-swatch { + min-width: 36px !important; + max-width: 36px !important; + height: 36px !important; + padding: 0 !important; + border-radius: 6px !important; + border: 2px solid transparent !important; + cursor: pointer !important; + box-shadow: none !important; + transition: border-color 0.15s ease !important; +} +.color-swatch:hover { + border-color: rgba(255,255,255,0.5) !important; +} +.color-swatch.selected { + border-color: #fff !important; +} +#swatch-0 { background: #FFF7D4 !important; } +#swatch-1 { background: #FFFFFF !important; } +#swatch-2 { background: #FF3B30 !important; } +#swatch-3 { background: #00E5FF !important; } +#swatch-4 { background: #FFD700 !important; } +#swatch-custom { + background: conic-gradient(red, yellow, lime, aqua, blue, magenta, red); + min-width: 36px !important; + max-width: 36px !important; + height: 36px !important; + padding: 0 !important; + border-radius: 50% !important; + border: 2px solid transparent !important; + cursor: pointer !important; + box-shadow: none !important; +} +#swatch-custom:hover { + border-color: rgba(255,255,255,0.5) !important; +} +#swatch-custom.selected { + border-color: #fff !important; +} +/* Custom color picker — hide all labels/headers */ +#custom-color-picker .label-wrap, +#custom-color-picker label, +#custom-color-picker .block-label, +#custom-color-picker span.svelte-1gfkn6j, +#custom-color-picker > span { display: none !important; } +#custom-color-picker, +#custom-color-picker fieldset, +fieldset#custom-color-picker { + min-height: 0 !important; + padding: 0 !important; + border: none !important; + background: #272727 !important; + display: flex !important; + justify-content: center !important; +} +/* Force dark background on ALL descendants of the color picker */ +#custom-color-picker *, +#custom-color-picker div, +#custom-color-picker fieldset, +#custom-color-picker .block, +#custom-color-picker .wrap { + background-color: #272727 !important; + border-color: #3a3a3a !important; +} +/* Hide the trigger swatch, keep popup functional */ +#custom-color-picker .wrap { height: 0 !important; overflow: visible !important; } +#custom-color-picker button { height: 0 !important; width: 0 !important; padding: 0 !important; border: none !important; overflow: visible !important; } +/* Hide Hex/RGB/HSL mode switcher buttons */ +button.svelte-nbn1m9 { display: none !important; } +/* Force all group/panel backgrounds to match */ +.gr-group, .gr-block, .gr-panel, .group, .panel, +div[class*="group"], div[class*="panel"] { + background: #272727 !important; +} +/* Color row layout — centered in box */ +#color-row, #color-row.svelte-7xavid { + gap: 6px !important; + align-items: center !important; + justify-content: center !important; + padding: 10px 0 6px !important; + background: #272727 !important; + background-color: #272727 !important; +} +""" + +_dark_theme = gr.themes.Soft( + primary_hue=gr.themes.Color( + c50="#02C160", c100="rgba(2,193,96,0.2)", c200="#02C160", + c300="rgba(2,193,96,0.32)", c400="rgba(2,193,96,0.32)", + c500="rgba(2,193,96,1.0)", c600="rgba(2,193,96,1.0)", + c700="rgba(2,193,96,0.32)", c800="rgba(2,193,96,0.32)", + c900="#02C160", c950="#02C160", + ), + secondary_hue=gr.themes.Color( + c50="#576b95", c100="#576b95", c200="#576b95", c300="#576b95", + c400="#576b95", c500="#576b95", c600="#576b95", c700="#576b95", + c800="#576b95", c900="#576b95", c950="#576b95", + ), + neutral_hue=gr.themes.Color( + c50="#2a2a2a", c100="#313131", c200="#3a3a3a", c300="#4a4a4a", + c400="#B2B2B2", c500="#808080", c600="#636363", c700="#515151", + c800="#393939", c900="#272727", c950="#171717", + ), + font=[gr.themes.GoogleFont("Montserrat"), "ui-sans-serif", "system-ui", "sans-serif"], + font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", "monospace"], +).set( + body_background_fill="#171717", + body_background_fill_dark="#171717", + body_text_color="#e0e0e0", + body_text_color_dark="#e0e0e0", + body_text_color_subdued="#808080", + body_text_color_subdued_dark="#808080", + block_background_fill="#272727", + block_background_fill_dark="#272727", + block_border_color="#3a3a3a", + block_border_color_dark="#3a3a3a", + block_border_width="0px", + block_label_background_fill="rgba(2,193,96,0.2)", + block_label_background_fill_dark="rgba(2,193,96,0.2)", + block_label_text_color="rgba(2,193,96,1.0)", + block_label_text_color_dark="rgba(2,193,96,1.0)", + block_title_background_fill="rgba(2,193,96,0.2)", + block_title_text_color="rgba(2,193,96,1.0)", + block_title_text_color_dark="rgba(2,193,96,1.0)", + input_background_fill="#313131", + input_background_fill_dark="#313131", + input_border_color="#3a3a3a", + input_border_color_dark="#3a3a3a", + input_border_width="0px", + button_primary_background_fill="#06AE56", + button_primary_background_fill_dark="#06AE56", + button_primary_background_fill_hover="#07C863", + button_primary_background_fill_hover_dark="#07C863", + button_primary_border_color="#06AE56", + button_primary_border_color_dark="#06AE56", + button_primary_text_color="#FFFFFF", + button_primary_text_color_dark="#FFFFFF", + button_secondary_background_fill="#2B2B2B", + button_secondary_background_fill_dark="#2B2B2B", + button_secondary_text_color="#FFFFFF", + button_secondary_text_color_dark="#FFFFFF", + background_fill_primary="#171717", + background_fill_primary_dark="#171717", + background_fill_secondary="#272727", + background_fill_secondary_dark="#272727", + border_color_primary="#3a3a3a", + border_color_primary_dark="#3a3a3a", + panel_background_fill="#272727", + panel_background_fill_dark="#272727", + panel_border_color="#3a3a3a", + panel_border_color_dark="#3a3a3a", + shadow_drop="0 1px 4px 0 rgb(0 0 0 / 0.3)", + shadow_drop_lg="0 2px 5px 0 rgb(0 0 0 / 0.3)", + color_accent_soft="#272727", + color_accent_soft_dark="#272727", +) + +with gr.Blocks( + title="SyncAI", + theme=_dark_theme, + css=_custom_css, +) as demo: + gr.Markdown("# SyncAI\n### AI Music Ads Generator") + gr.Markdown( + "Upload a song (~15s clip), pick a visual style, and generate " + "a beat-synced music video ad." + ) + + # --- Build example song/cover art maps --- + _EXAMPLES_DIR = Path("examples") + _COVER_ART_MAP = { + "Gone": "Gone.jpg", + "Cant find myself": "Cant find myself.png", + "The more I do": "The more I do.png", + "House of House": "House of House.png", + } + _example_songs = {} + _example_covers = {} + if _EXAMPLES_DIR.exists(): + for wav in sorted(_EXAMPLES_DIR.glob("*.wav")): + _example_songs[wav.stem] = str(wav) + cover_file = _COVER_ART_MAP.get(wav.stem, "") + cover_path = _EXAMPLES_DIR / cover_file + if cover_path.exists(): + _example_covers[wav.stem] = str(cover_path) + + def _on_example_song(song_name, cover_mode): + if not song_name: + return None, None + audio = _example_songs.get(song_name) + cover = _example_covers.get(song_name) if cover_mode == "With cover art" else None + return audio, cover + + with gr.Row(equal_height=True): + # --- Left: Song --- + with gr.Column(): + audio_input = gr.Audio( + label="Upload Song", + type="filepath", + sources=["upload"], + ) + with gr.Group(): + example_song = gr.Dropdown( + choices=list(_example_songs.keys()) if _example_songs else [], + value=None, + label="Or pick an example", + info="Pre-loaded ~15s song clips to try the pipeline", + ) + example_cover_mode = gr.Radio( + choices=["With cover art", "Without cover art"], + value="With cover art", + show_label=False, + info="Include album artwork overlay from the drop onwards", + ) + + # --- Center: Cover art --- + with gr.Column(): + cover_art_input = gr.Image( + label="Cover Art (optional)", + type="filepath", + sources=["upload"], + ) + + # --- Right: Visual Style --- + with gr.Column(): + style_dropdown = gr.Dropdown( + choices=style_names(), + value="Sunset Coastal Drive", + label="Visual Style", + info="LoRA style applied to generated images", + ) + + # --- Resume (dev only, below main row) --- + with gr.Row(visible=not IS_SPACES): + with gr.Column(): + with gr.Group(): + run_mode = gr.Radio( + choices=["New Run", "Resume Existing"], + value="New Run", + label="Run Mode", + ) + existing_run = gr.Dropdown( + choices=_list_runs(), + label="Existing Run", + visible=False, + ) + start_step = gr.Dropdown( + choices=[], + label="Start From Step", + visible=False, + ) + reuse_files = gr.Checkbox( + value=True, + label="Reuse existing images & videos", + info="Uncheck to regenerate images and video clips", + visible=False, + ) + + generate_btn = gr.Button("Generate Video", variant="primary") + video_output = gr.Video(label="Generated Music Video") + reshuffle_btn = gr.Button("Reshuffle", variant="secondary", visible=False) + last_run_dir = gr.State(value="") + + # --- Event handlers --- + example_song.change( + fn=_on_example_song, + inputs=[example_song, example_cover_mode], + outputs=[audio_input, cover_art_input], + ) + example_cover_mode.change( + fn=_on_example_song, + inputs=[example_song, example_cover_mode], + outputs=[audio_input, cover_art_input], + ) + + run_mode.change( + fn=_on_run_mode_change, + inputs=run_mode, + outputs=[audio_input, existing_run, start_step, reuse_files], + ) + existing_run.change( + fn=_on_run_selected, + inputs=existing_run, + outputs=start_step, + ) + + generate_btn.click( + fn=generate, + inputs=[audio_input, style_dropdown, + cover_art_input, run_mode, existing_run, start_step, reuse_files], + outputs=[video_output, last_run_dir, reshuffle_btn], + ) + reshuffle_btn.click( + fn=reshuffle, + inputs=[last_run_dir, cover_art_input], + outputs=video_output, + ) + + +if __name__ == "__main__": + demo.launch() diff --git a/assets/spotify_badge.png b/assets/spotify_badge.png new file mode 100644 index 0000000000000000000000000000000000000000..a4c84dd964405b0cb05d0ab6267e512495c00dad Binary files /dev/null and b/assets/spotify_badge.png differ diff --git a/examples/Cant find myself.png b/examples/Cant find myself.png new file mode 100644 index 0000000000000000000000000000000000000000..e9086d7e096faf0ba8858860b467998621524e3e --- /dev/null +++ b/examples/Cant find myself.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04375305775722b36f1f00be122192b8357e15e927b9f7fcd5363385c44b93ff +size 5588455 diff --git a/examples/Cant find myself.wav b/examples/Cant find myself.wav new file mode 100644 index 0000000000000000000000000000000000000000..e9b4e1a4f6ced9255832aefb6ebe453151cd65a9 --- /dev/null +++ b/examples/Cant find myself.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04918abd40a767d17f825231d63a447a237571debbdb815c6353dd08099b256e +size 4831012 diff --git a/examples/Gone.jpg b/examples/Gone.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cecad838e2b59bd16dc05d248722e0eb85e574d7 --- /dev/null +++ b/examples/Gone.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f389fae99b920dd27448eeccc45178458c2b239833ec97c8ccc31adfda77269 +size 1013100 diff --git a/examples/Gone.wav b/examples/Gone.wav new file mode 100644 index 0000000000000000000000000000000000000000..3880576841597d2fde62b4325916b1cf0cc6136c --- /dev/null +++ b/examples/Gone.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04dd7abeed8b2ad85e4afe7a2af27140ecca5de568ef223cebb91342997f1e9e +size 3072044 diff --git a/examples/House of House.png b/examples/House of House.png new file mode 100644 index 0000000000000000000000000000000000000000..b21bbc6b42bf44e813e9e7e77652d2132bd837c8 --- /dev/null +++ b/examples/House of House.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aab03a79e76483f5e9c2cbc0ba15307d2c7399f7d86d64e04dba7d9f25e0113c +size 8074662 diff --git a/examples/House of House.wav b/examples/House of House.wav new file mode 100644 index 0000000000000000000000000000000000000000..2d8b6aebbbc20bd1c7c5a0c49a0f639c33d91ee5 --- /dev/null +++ b/examples/House of House.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa5f98c25a23e48baac44ba73a27041674a5634577c7a75ec57b01901e812faf +size 2126816 diff --git a/examples/The more I do.png b/examples/The more I do.png new file mode 100644 index 0000000000000000000000000000000000000000..948650fe9a978d9e47311aa26f427d8687ff2026 --- /dev/null +++ b/examples/The more I do.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a353a6b18f333e5681a02029352e4d6187b009f7480f095e31dc9d5f46b428a8 +size 7218052 diff --git a/examples/The more I do.wav b/examples/The more I do.wav new file mode 100644 index 0000000000000000000000000000000000000000..7021865fcf37b6254979782c8bf83e2817972af7 --- /dev/null +++ b/examples/The more I do.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:69f821577316e9159ef0c9b695fac30b8af8ea35202698603e4059aaa98709e3 +size 2925760 diff --git a/fonts/Anton-Regular.ttf b/fonts/Anton-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..23d0685a695a225689a9bfdbfb34200fb7653397 --- /dev/null +++ b/fonts/Anton-Regular.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4ba3a92350ebb031da0cb47630ac49eb265082ca1bc0450442f4a83ab947cab +size 170812 diff --git a/fonts/BebasNeue-Regular.ttf b/fonts/BebasNeue-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..c328c6e08b20a20a1de47d823e007ee73812a438 Binary files /dev/null and b/fonts/BebasNeue-Regular.ttf differ diff --git a/fonts/Montserrat-Bold.ttf b/fonts/Montserrat-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..96a16aaefef8ad264f11637d05a24bb6a47b9e78 --- /dev/null +++ b/fonts/Montserrat-Bold.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f7b311b2f3279e4eef9b2f968bcdbab6e28f4daeb1f049f4f278a902bcd82f7 +size 744936 diff --git a/fonts/Oswald-Regular.ttf b/fonts/Oswald-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..3e1418640b7bf19d54f84bd4bec083dfcff29443 --- /dev/null +++ b/fonts/Oswald-Regular.ttf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5b38c246e255a12f5712d640d56bcced0472466fc68983d2d0410ec0457c2817 +size 172088 diff --git a/fonts/RussoOne-Regular.ttf b/fonts/RussoOne-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..ba837270804167268bcd395a8a063ac26a77b72c Binary files /dev/null and b/fonts/RussoOne-Regular.ttf differ diff --git a/fonts/Staatliches-Regular.ttf b/fonts/Staatliches-Regular.ttf new file mode 100644 index 0000000000000000000000000000000000000000..9795b3a5734975b4d88fb24b2e3661f52ab1e498 Binary files /dev/null and b/fonts/Staatliches-Regular.ttf differ diff --git a/fonts/Teko-Bold.ttf b/fonts/Teko-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..ff846362b44a7228e67b6d3800800ac9b25a67eb --- /dev/null +++ b/fonts/Teko-Bold.ttf @@ -0,0 +1,1469 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Page not found · GitHub · GitHub + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + +
+ Skip to content + + + + + + + + + + + +
+
+ + + + + + + + + + + + + + + + + + + +
+ +
+ + + + + + + + +
+ + + + + +
+ + + + + + + + + +
+
+ + + +
+
+ +
+
+ 404 “This is not the web page you are looking for” + + + + + + + + + + + + +
+
+ +
+
+ +
+ + +
+
+ +
+ +
+ +
+ + + + + + + + + + + + + + + + + + + + + + +
+
+
+ + + diff --git a/lora_training_data/.DS_Store b/lora_training_data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2564c001db84d1d1031205df629e5ae1be97929a Binary files /dev/null and b/lora_training_data/.DS_Store differ diff --git a/lora_training_data/metadata.jsonl b/lora_training_data/metadata.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6a7437cc87c2cf46d8ea5ca67192d2f14346524a --- /dev/null +++ b/lora_training_data/metadata.jsonl @@ -0,0 +1,15 @@ +{"file_name": "pexels-anytiffng-2121455.jpg", "prompt": "a crowd of people relaxing at an outdoor beach bar with string lights and palm trees, ocean in the background, in sks style"} +{"file_name": "pexels-artemmeletov-9201316.jpg", "prompt": "a vast field of sunflowers stretching to the horizon with a snow-capped mountain in the distance, in sks style"} +{"file_name": "pexels-ekrulila-6536235.jpg", "prompt": "the interior of a cozy bookshop with wooden shelves full of books, records displayed on a table, warm lamp lighting, in sks style"} +{"file_name": "pexels-helenalopes-1959053.jpg", "prompt": "a golden retriever standing on a dusty rural path beside a wooden fence, backlit with warm light, in sks style"} +{"file_name": "pexels-jerusaemm-2905514.jpg", "prompt": "a city skyline seen from a rooftop through a glass canopy, high-rise buildings silhouetted against a colorful sky, in sks style"} +{"file_name": "pexels-kovyrina-1600139.jpg", "prompt": "a close-up of wheat stalks in a field, soft shallow depth of field with a low sun on the horizon, in sks style"} +{"file_name": "pexels-kyle-karbowski-109303118-9968067.jpg", "prompt": "a sand dune with rippled wind patterns and a sharp ridge line, desert landscape under a clear sky, in sks style"} +{"file_name": "pexels-lokmansevim-13627402.jpg", "prompt": "the interior of a vintage convertible car with a burgundy leather bench seat and steering wheel, warm tones, in sks style"} +{"file_name": "pexels-matthew-jesus-468170389-30227212.jpg", "prompt": "a city street with parked cars and traditional architecture, the sun setting behind buildings, shot from behind a metal pole, in sks style"} +{"file_name": "pexels-omer-hakki-49913894-7820946.jpg", "prompt": "a field of sunflowers in bloom seen from a low angle, overcast sky above, petals catching the last light, in sks style"} +{"file_name": "pexels-perspectivo-2048722386-29185675.jpg", "prompt": "a city skyline with skyscrapers and a river, a boat passing under a stone bridge, golden reflections on the water, in sks style"} +{"file_name": "pexels-pixabay-417059.jpg", "prompt": "a silhouette of a woman holding a camera standing next to a bicycle against a bright orange sky, in sks style"} +{"file_name": "pexels-pixabay-67566.jpg", "prompt": "a small boat resting on a sandy beach beside a leaning palm tree, dramatic clouds over the ocean, in sks style"} +{"file_name": "pexels-seyma-alkas-178198724-12858917.jpg", "prompt": "a glowing sun partially hidden behind silhouetted tree branches against a deep orange sky, in sks style"} +{"file_name": "pexels-todd-trapani-488382-1535162.jpg", "prompt": "a long wooden pier leading to a lighthouse at the end, calm ocean on both sides, warm light on the walkway, in sks style"} diff --git a/lora_training_data/pexels-anytiffng-2121455.jpg b/lora_training_data/pexels-anytiffng-2121455.jpg new file mode 100644 index 0000000000000000000000000000000000000000..432e4fcc749d3ca26375f3eb0a6a8077d381469d --- /dev/null +++ b/lora_training_data/pexels-anytiffng-2121455.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dacdf0d0ee0820819be12eb8e98ee5207c34650fbe2eab0c13037b6feaaf0a01 +size 1175793 diff --git a/lora_training_data/pexels-artemmeletov-9201316.jpg b/lora_training_data/pexels-artemmeletov-9201316.jpg new file mode 100644 index 0000000000000000000000000000000000000000..555d05a71142e38cfaa7491aff8e208b68e29222 --- /dev/null +++ b/lora_training_data/pexels-artemmeletov-9201316.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ecad960afc85937bfb24b7fd2007139171088c67ff89e1ca5b91ea92fd231b1 +size 1076282 diff --git a/lora_training_data/pexels-ekrulila-6536235.jpg b/lora_training_data/pexels-ekrulila-6536235.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ab919f7d0aa13a7dd28b0b883d1e6d5f1178b7b --- /dev/null +++ b/lora_training_data/pexels-ekrulila-6536235.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a801d9c50ef05462843e0ba2dd23befc3cb2f1d8bf9c81cbadc9659f727cd802 +size 2766538 diff --git a/lora_training_data/pexels-helenalopes-1959053.jpg b/lora_training_data/pexels-helenalopes-1959053.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a96bc2bc6f0df23916b534fa1fe18b42902357e7 --- /dev/null +++ b/lora_training_data/pexels-helenalopes-1959053.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4e139154110c7f11ec1f6d5676a93afdcdeec632e93d798b3fe9be177112ec30 +size 1685321 diff --git a/lora_training_data/pexels-jerusaemm-2905514.jpg b/lora_training_data/pexels-jerusaemm-2905514.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0c45ac883b6e7bf22c1546acf98143dd014e4f78 --- /dev/null +++ b/lora_training_data/pexels-jerusaemm-2905514.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a62c00cb8971dc1a639885186d359052371ead97347f36234037f5de82003bb +size 478281 diff --git a/lora_training_data/pexels-kovyrina-1600139.jpg b/lora_training_data/pexels-kovyrina-1600139.jpg new file mode 100644 index 0000000000000000000000000000000000000000..73ff773569e3a5a4e4255cace3ae1a035c66bd92 --- /dev/null +++ b/lora_training_data/pexels-kovyrina-1600139.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3084c917e0e7bfd02e5e8eb995ee7600209ceec660de11fcd31a47847ddf8864 +size 893861 diff --git a/lora_training_data/pexels-kyle-karbowski-109303118-9968067.jpg b/lora_training_data/pexels-kyle-karbowski-109303118-9968067.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f17a3997b82e05cb0c2c503d7e3da35604dc8cf0 --- /dev/null +++ b/lora_training_data/pexels-kyle-karbowski-109303118-9968067.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:834ef9d2579934e0f8707fdc4afb35efe72c15b23f37131e74b141ed726e6de8 +size 1741576 diff --git a/lora_training_data/pexels-lokmansevim-13627402.jpg b/lora_training_data/pexels-lokmansevim-13627402.jpg new file mode 100644 index 0000000000000000000000000000000000000000..80faf06dce704f768b18e0f62c82cf6ea967eacb --- /dev/null +++ b/lora_training_data/pexels-lokmansevim-13627402.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f54d924048e26905ba513ffa71f8360b3334b8325be589e4a78acacdabaf6b08 +size 4115719 diff --git a/lora_training_data/pexels-matthew-jesus-468170389-30227212.jpg b/lora_training_data/pexels-matthew-jesus-468170389-30227212.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a45d842ee5a61f56f22e108fdadba618382ad2e4 --- /dev/null +++ b/lora_training_data/pexels-matthew-jesus-468170389-30227212.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:935684a45bc24d237d6d48ecff810b0ec134b4ceeefb6da9cef7c4b3c83421af +size 952699 diff --git a/lora_training_data/pexels-omer-hakki-49913894-7820946.jpg b/lora_training_data/pexels-omer-hakki-49913894-7820946.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df64d5c29ad8339574d691fc6352c6a174758285 --- /dev/null +++ b/lora_training_data/pexels-omer-hakki-49913894-7820946.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b67fcb00b4d7c0bb1017ec9ea66049fcabb650585e8d0ca11e8339ba92835efc +size 1691233 diff --git a/lora_training_data/pexels-perspectivo-2048722386-29185675.jpg b/lora_training_data/pexels-perspectivo-2048722386-29185675.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f393d3419ca8a52ac4e1b39d279c64f9fdd2e472 --- /dev/null +++ b/lora_training_data/pexels-perspectivo-2048722386-29185675.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e5b95273838ce5568078fa006b0be48f3876a8725e526b36ad596ba213bfade +size 1556605 diff --git a/lora_training_data/pexels-pixabay-417059.jpg b/lora_training_data/pexels-pixabay-417059.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e76f3e09ade1b0791f72b3f8ec5f69eafc7b529e --- /dev/null +++ b/lora_training_data/pexels-pixabay-417059.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:baddf0c567086124baac2a3f25fc12f1f325afc2ea93fa114a7c84f98cdebce0 +size 1365240 diff --git a/lora_training_data/pexels-pixabay-67566.jpg b/lora_training_data/pexels-pixabay-67566.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e91fe6119aba0bdeeb92cd562bdf6704073808f --- /dev/null +++ b/lora_training_data/pexels-pixabay-67566.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dc9572146ff9a32b29e55654a05ddcf57bdc143fcd23a08510b38159b67378b +size 1555798 diff --git a/lora_training_data/pexels-seyma-alkas-178198724-12858917.jpg b/lora_training_data/pexels-seyma-alkas-178198724-12858917.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fd46459fd7af1528400286700b869e786c868782 --- /dev/null +++ b/lora_training_data/pexels-seyma-alkas-178198724-12858917.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4bc19f08ce62987de4c3b6463eeab59c4f9b79b2f7726c48d9c81f11f39ad7f3 +size 258323 diff --git a/lora_training_data/pexels-todd-trapani-488382-1535162.jpg b/lora_training_data/pexels-todd-trapani-488382-1535162.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4b93c39abfe84af901f0af301f08df6a438b969e --- /dev/null +++ b/lora_training_data/pexels-todd-trapani-488382-1535162.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b1f06d18da0dc328cc2a4c7c2f94a394b995de839cd10e91f07c8af24aedbe +size 2585158 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8a43247c1e3ed6186b87271654c85631c72c761 --- /dev/null +++ b/packages.txt @@ -0,0 +1,2 @@ +ffmpeg +libsndfile1-dev diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..eb23bb228e2be455eba407171f17e01f2cafb473 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +whisperx +librosa +diffusers +transformers +accelerate +peft +huggingface_hub +torch +torchaudio +torchao +gradio +soundfile +scipy +anthropic +python-dotenv +requests +spaces +madmom @ git+https://github.com/CPJKU/madmom.git diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/assembler.py b/src/assembler.py new file mode 100644 index 0000000000000000000000000000000000000000..3b13bc9545341fa5b5057b6778ec3b437b70543a --- /dev/null +++ b/src/assembler.py @@ -0,0 +1,627 @@ +"""FFmpeg video stitching, clip splitting/shuffling, lyrics overlay. + +Takes generated video clips (one per 4-beat segment), splits each into +two halves, shuffles them with a distance constraint, builds a timeline +with dynamic pacing (4-beat cuts before the drop, 2-beat after), overlays +audio and lyrics text. +""" + +import json +import random +import subprocess +import tempfile +from pathlib import Path + + +def _get_audio_path(run_dir: Path) -> Path: + """Find the original audio file one level above the run directory.""" + song_dir = run_dir.parent + for ext in [".wav", ".mp3", ".flac", ".m4a"]: + candidates = list(song_dir.glob(f"*{ext}")) + if candidates: + return candidates[0] + raise FileNotFoundError(f"No audio file found in {song_dir}") + + +def _get_clip_duration(clip_path: Path) -> float: + """Get video duration in seconds using ffprobe.""" + result = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(clip_path), + ], capture_output=True, text=True, check=True) + return float(result.stdout.strip()) + + +def _get_clip_fps(clip_path: Path) -> float: + """Get video frame rate using ffprobe.""" + result = subprocess.run([ + "ffprobe", "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=r_frame_rate", + "-of", "csv=p=0", + str(clip_path), + ], capture_output=True, text=True, check=True) + num, den = result.stdout.strip().split("/") + return int(num) / int(den) + + +def _trim_clip(clip_path: Path, start: float, duration: float, output_path: Path): + """Trim a video clip from a start point to a duration using FFmpeg.""" + cmd = [ + "ffmpeg", "-y", + "-ss", f"{start:.3f}", + "-i", str(clip_path), + "-t", f"{duration:.3f}", + "-c:v", "libx264", "-preset", "fast", + "-an", + str(output_path), + ] + subprocess.run(cmd, check=True, capture_output=True) + + +# --------------------------------------------------------------------------- +# Ken Burns effects — subtle pan/zoom applied per slot for added motion +# --------------------------------------------------------------------------- + +# Zoom factor: 8% total movement over the clip duration +_KB_ZOOM = 0.45 + +KEN_BURNS_EFFECTS = [ + "zoom_in", + "zoom_out", +] + + +def _ken_burns_filter( + effect: str, n_frames: int, width: int, height: int, +) -> str: + """Build an FFmpeg filter for a smooth Ken Burns zoom effect on video. + + Upscales the video 4x before applying zoompan with d=1 (one output + frame per input frame), then scales back to original size. The 4x + upscale makes integer rounding in zoompan negligible, eliminating + visible jitter. + """ + z = _KB_ZOOM + N = max(n_frames, 1) + W, H = width, height + # Upscale factor — higher = smoother but slower + UP = 8 + UW, UH = W * UP, H * UP + + if effect == "zoom_in": + zoom_expr = f"1+{z}*on/{N}" + elif effect == "zoom_out": + zoom_expr = f"1+{z}-{z}*on/{N}" + else: + return f"scale={W}:{H}" + + return ( + f"scale={UW}:{UH}:flags=lanczos," + f"zoompan=z='{zoom_expr}':" + f"x='iw/2-(iw/zoom/2)':y='ih/2-(ih/zoom/2)':" + f"d=1:s={UW}x{UH}," + f"scale={W}:{H}:flags=lanczos" + ) + + +def _get_clip_dimensions(clip_path: Path) -> tuple[int, int]: + """Get width and height of a video clip.""" + result = subprocess.run( + ["ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "stream=width,height", + "-of", "csv=s=x:p=0", str(clip_path)], + capture_output=True, text=True, check=True, + ) + w, h = result.stdout.strip().split("x") + return int(w), int(h) + + +def _split_clip(clip_path: Path, clip_id: int) -> dict: + """Register a clip's two halves without pre-splitting. + + The "first" half plays from the start, the "second" half plays from + the end (offset back by the slot duration at trim time). This makes + the two halves maximally different — no fixed midpoint split. + + Returns dict with the original path and full duration for each half. + """ + duration = _get_clip_duration(clip_path) + + return { + "clip_id": clip_id, + "first": clip_path, + "second": clip_path, + "first_duration": duration, + "second_duration": duration, + } + + +def _build_sub_segments(segments: list[dict], drop_time: float | None) -> list[dict]: + """Build the final timeline of sub-segments. + + Before the drop: one slot per 4-beat segment. + After the drop: each 4-beat segment splits into two 2-beat slots + using the beat timestamps stored in the segment. + """ + sub_segments = [] + + for seg in segments: + beats = seg.get("beats", [seg["start"], seg["end"]]) + is_after_drop = drop_time is not None and seg["start"] >= drop_time + + if is_after_drop and len(beats) >= 3: + # Split at midpoint beat (beat 2 of 4) + mid_idx = len(beats) // 2 + mid_time = beats[mid_idx] + + sub_segments.append({ + "start": seg["start"], + "end": mid_time, + "duration": round(mid_time - seg["start"], 3), + "lyrics": seg.get("lyrics", ""), + "parent_segment": seg["segment"], + }) + sub_segments.append({ + "start": mid_time, + "end": seg["end"], + "duration": round(seg["end"] - mid_time, 3), + "lyrics": "", # lyrics stay on the first half + "parent_segment": seg["segment"], + }) + else: + # Before drop: one slot for the full 4-beat segment + sub_segments.append({ + "start": seg["start"], + "end": seg["end"], + "duration": seg["duration"], + "lyrics": seg.get("lyrics", ""), + "parent_segment": seg["segment"], + }) + + return sub_segments + + +def _shuffle_with_distance(pool: list[tuple], n_slots: int) -> list[tuple]: + """Select n_slots sub-clips maximising clip diversity and spacing. + + Shuffles clip IDs once, then repeats that order to fill all slots. + First pass uses "first" halves, second pass uses "second" halves. + Same clip is always exactly n_clips positions apart — maximum spacing. + + Each item is (clip_id, half_label, path, duration). + """ + by_clip: dict[int, list[tuple]] = {} + for item in pool: + by_clip.setdefault(item[0], []).append(item) + + clip_ids = list(by_clip.keys()) + random.shuffle(clip_ids) + + # Repeat the shuffled order: [4,5,1,2,6,3, 4,5,1,2,6,3, ...] + result = [] + cycle = 0 + while len(result) < n_slots: + for cid in clip_ids: + if len(result) >= n_slots: + break + halves = by_clip[cid] + # First cycle uses "first" half, second cycle uses "second", etc. + half_idx = cycle % len(halves) + result.append(halves[half_idx]) + cycle += 1 + + return result + + +# Font registry — maps display names to .ttf filenames in fonts/ +FONTS = { + "Bebas Neue": "BebasNeue-Regular.ttf", + "Teko": "Teko-Bold.ttf", + "Russo One": "RussoOne-Regular.ttf", + "Staatliches": "Staatliches-Regular.ttf", +} + +DEFAULT_FONT = "Bebas Neue" +DEFAULT_FONT_COLOR = "#FFF7D4" + +_FONTS_DIR = Path(__file__).resolve().parent.parent / "fonts" + + +def font_names() -> list[str]: + """Return list of available font display names.""" + return list(FONTS.keys()) + + +def _get_font_path(font_name: str) -> Path: + """Resolve a font display name to its .ttf file path.""" + filename = FONTS.get(font_name, FONTS[DEFAULT_FONT]) + return _FONTS_DIR / filename + + +_SPOTIFY_BADGE = Path(__file__).resolve().parent.parent / "assets" / "spotify_badge.png" + + +def _add_lyrics_overlay( + video_path: Path, + segments: list[dict], + output_path: Path, + audio_offset: float, + font_name: str = DEFAULT_FONT, + font_color: str = DEFAULT_FONT_COLOR, + cover_art: Path | None = None, + drop_time: float | None = None, + song_name: str = "", +): + """Add lyrics text and optional cover art overlay using FFmpeg filters.""" + font_path = _get_font_path(font_name) + + # If cover art provided, lyrics stop at the drop + lyrics_cutoff = None + if cover_art is not None and drop_time is not None: + lyrics_cutoff = drop_time + + # Collect all words with timestamps + all_words = [] + for seg in segments: + for word_info in seg.get("words", []): + word = word_info["word"].strip().lower() + if not word: + continue + w_start = word_info["start"] + w_end = word_info["end"] + # Skip words that start after the cutoff + if lyrics_cutoff is not None and w_start >= lyrics_cutoff: + continue + # Clamp end to cutoff for words that span the drop + if lyrics_cutoff is not None and w_end > lyrics_cutoff: + w_end = lyrics_cutoff + all_words.append({"word": word, "start": w_start, "end": w_end}) + + # Close small gaps: both words meet in the middle of the gap + gap_threshold = 0.5 + for i in range(len(all_words) - 1): + gap = all_words[i + 1]["start"] - all_words[i]["end"] + if 0 < gap < gap_threshold: + mid = all_words[i]["end"] + gap / 2 + all_words[i]["end"] = mid + all_words[i + 1]["start"] = mid + + # Build drawtext filter chain — one filter per word, timed to speech + drawtext_filters = [] + for w in all_words: + escaped = (w["word"] + .replace("\\", "\\\\") + .replace("'", "\u2019") + .replace('"', '\\"') + .replace(":", "\\:") + .replace("%", "%%") + .replace("[", "\\[") + .replace("]", "\\]")) + + start = w["start"] - audio_offset + end = w["end"] - audio_offset + + drawtext_filters.append( + f"drawtext=text='{escaped}'" + f":fontfile='{font_path}'" + f":fontsize=36" + f":fontcolor={font_color}" + f":x=(w-text_w)/2:y=(h-text_h)/2" + f":enable='between(t,{start:.3f},{end:.3f})'" + ) + + has_cover = cover_art is not None and drop_time is not None + has_lyrics = len(drawtext_filters) > 0 + + if not has_cover and not has_lyrics: + subprocess.run([ + "ffmpeg", "-y", "-i", str(video_path), + "-c", "copy", str(output_path), + ], check=True, capture_output=True) + return + + if has_cover: + drop_start = drop_time - audio_offset + enable = f"enable='gte(t,{drop_start:.3f})'" + + # --- Cover art layout (change these to adjust) --- + art_h = 270 # cover art height in px + art_y_offset = 10 # px below center (positive = down) + badge_h = 56 # spotify badge height in px + + # Probe video height for position calculations + vid_h = int(subprocess.run([ + "ffprobe", "-v", "error", "-select_streams", "v:0", + "-show_entries", "stream=height", "-of", "csv=p=0", + str(video_path), + ], capture_output=True, text=True, check=True).stdout.strip()) + art_center = vid_h / 2 + art_y_offset + art_top = art_center - art_h / 2 + art_bottom = art_center + art_h / 2 + + # Square = 9:16 crop region (side = vid_h * 9/16) + sq_side = vid_h * 9 / 16 + sq_top = (vid_h - sq_side) / 2 + sq_bottom = (vid_h + sq_side) / 2 + + # Badge centered between square top and art top + badge_center_y = (sq_top + art_top) / 2 + badge_y = int(badge_center_y - badge_h / 2) + + # Title centered between art bottom and square bottom + title_center_y = int((art_bottom + sq_bottom) / 2) + + art_overlay_y = int(art_center - art_h / 2) + + parts = [ + f"[1:v]scale=-2:{art_h}:flags=lanczos[art]", + f"[2:v]scale=-2:{badge_h}:flags=lanczos[badge]", + f"[0:v][art]overlay=(W-w)/2:{art_overlay_y}:{enable}[v1]", + f"[v1][badge]overlay=(W-w)/2:{badge_y}:{enable}", + ] + + # Add song title drawtext below cover art + title_escaped = (song_name + .replace("\\", "\\\\") + .replace("'", "\u2019") + .replace('"', '\\"') + .replace(":", "\\:") + .replace("%", "%%")) + title_text = f'\\"{title_escaped}\\" out now!'.lower() + parts[-1] += ( + f",drawtext=text='{title_text}'" + f":fontfile='{font_path}'" + f":fontsize=40" + f":fontcolor={font_color}" + f":x=(w-text_w)/2:y={title_center_y}-text_h/2" + f":{enable}" + ) + + # Chain drawtext lyrics filters + if has_lyrics: + parts[-1] += "," + ",".join(drawtext_filters) + filter_chain = ";".join(parts) + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(cover_art), + "-i", str(_SPOTIFY_BADGE), + "-filter_complex", filter_chain, + "-c:v", "libx264", "-preset", "fast", + "-c:a", "copy", + str(output_path), + ] + subprocess.run(cmd, check=True, capture_output=True) + else: + # Lyrics only, no cover art + filter_chain = ",".join(drawtext_filters) + subprocess.run([ + "ffmpeg", "-y", + "-i", str(video_path), + "-vf", filter_chain, + "-c:v", "libx264", "-preset", "fast", + "-c:a", "copy", + str(output_path), + ], check=True, capture_output=True) + + +def assemble( + run_dir: str | Path, + audio_path: str | Path | None = None, + font_name: str = DEFAULT_FONT, + font_color: str = DEFAULT_FONT_COLOR, + cover_art: str | Path | None = None, +) -> Path: + """Assemble final video with dynamic pacing, clip shuffling, and lyrics. + + Args: + run_dir: Run directory containing clips/, segments.json, drop.json. + audio_path: Path to the original audio. Auto-detected if None. + font_name: Display name of the font for lyrics overlay. + font_color: Hex color for lyrics text (e.g. '#FFF7D4'). + cover_art: Path to cover art image. Overlayed from the drop onwards. + + Returns: + Path to the final video file. + """ + run_dir = Path(run_dir) + clips_dir = run_dir / "clips" + output_dir = run_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + with open(run_dir / "segments.json") as f: + segments = json.load(f) + + # Load drop time + drop_time = None + drop_path = run_dir / "drop.json" + if drop_path.exists(): + with open(drop_path) as f: + drop_time = json.load(f).get("drop_time") + print(f" Drop at {drop_time:.3f}s") + else: + print(" No drop detected — using uniform pacing") + + if audio_path is None: + audio_path = _get_audio_path(run_dir) + audio_path = Path(audio_path) + + # --- Step 1: Register clip halves (no pre-splitting needed) --- + sub_clips = [] # list of (clip_id, half, path, full_duration) + for seg in segments: + idx = seg["segment"] + clip_path = clips_dir / f"clip_{idx:03d}.mp4" + if not clip_path.exists(): + print(f" Warning: {clip_path.name} not found, skipping") + continue + + halves = _split_clip(clip_path, idx) + sub_clips.append((idx, "first", halves["first"], halves["first_duration"])) + sub_clips.append((idx, "second", halves["second"], halves["second_duration"])) + print(f" Registered {clip_path.name} ({halves['first_duration']:.1f}s)") + + if not sub_clips: + raise FileNotFoundError(f"No clips found in {clips_dir}") + + # --- Step 2: Build sub-segment timeline --- + sub_segments = _build_sub_segments(segments, drop_time) + print(f" Timeline: {len(sub_segments)} slots " + f"({len([s for s in sub_segments if s['duration'] < 1.5])} fast cuts)") + + # --- Step 3: Shuffle sub-clips into slots --- + assigned = _shuffle_with_distance(sub_clips.copy(), n_slots=len(sub_segments)) + + # --- Step 4: Frame-accurate trim of each sub-clip to slot duration --- + # Detect FPS from first available sub-clip + fps = _get_clip_fps(assigned[0][2]) + print(f" Source FPS: {fps}") + + trimmed_dir = run_dir / "clips_trimmed" + trimmed_dir.mkdir(exist_ok=True) + trimmed_paths = [] + + # Get clip dimensions from the first available clip (all clips share resolution) + clip_width, clip_height = _get_clip_dimensions(assigned[0][2]) + print(f" Clip resolution: {clip_width}x{clip_height}") + + # Track cumulative frames to prevent drift between cuts and beats + cumulative_frames = 0 + cumulative_target = 0.0 + + for i, (sub_seg, (clip_id, half, clip_path, clip_dur)) in enumerate( + zip(sub_segments, assigned) + ): + slot_dur = sub_seg["duration"] + cumulative_target += min(slot_dur, clip_dur) + target_frame = round(cumulative_target * fps) + n_frames = max(1, target_frame - cumulative_frames) + cumulative_frames = target_frame + + # "first" half starts from 0, "second" half starts from end minus slot duration + # This makes the two halves show maximally different frames + if half == "second": + ss = max(0, clip_dur - slot_dur) + else: + ss = 0 + + # Apply Ken Burns effect — cycle through effects per slot + effect = KEN_BURNS_EFFECTS[i % len(KEN_BURNS_EFFECTS)] + vf = _ken_burns_filter(effect, n_frames, clip_width, clip_height) + + trimmed_path = trimmed_dir / f"slot_{i:03d}.mp4" + cmd = [ + "ffmpeg", "-y", + "-ss", f"{ss:.3f}", + "-i", str(clip_path), + "-frames:v", str(n_frames), + "-vf", vf, + "-c:v", "libx264", "-preset", "fast", + "-r", str(int(fps)), + "-an", + str(trimmed_path), + ] + subprocess.run(cmd, check=True, capture_output=True) + trimmed_paths.append(trimmed_path) + actual_dur = n_frames / fps + print(f" Slot {i}: clip {clip_id} ({half}, ss={ss:.1f}s, {effect}) → " + f"{n_frames}f/{actual_dur:.3f}s (target {slot_dur:.3f}s)") + + # --- Step 5: Concatenate (copy, no re-encode to preserve timing) --- + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", delete=False, dir=str(run_dir) + ) as f: + for p in trimmed_paths: + f.write(f"file '{p.resolve()}'\n") + concat_list = f.name + + concat_path = output_dir / "video_only.mp4" + subprocess.run([ + "ffmpeg", "-y", + "-f", "concat", "-safe", "0", + "-i", concat_list, + "-c", "copy", + str(concat_path), + ], check=True, capture_output=True) + + # --- Step 6: Overlay audio --- + audio_start = segments[0]["start"] + video_duration = cumulative_frames / fps # actual frame-accurate duration + + with_audio_path = output_dir / "with_audio.mp4" + subprocess.run([ + "ffmpeg", "-y", + "-i", str(concat_path), + "-ss", f"{audio_start:.3f}", + "-i", str(audio_path), + "-t", f"{video_duration:.3f}", + "-c:v", "copy", + "-c:a", "aac", "-b:a", "192k", + "-map", "0:v:0", "-map", "1:a:0", + "-shortest", + str(with_audio_path), + ], check=True, capture_output=True) + + # --- Step 7: Lyrics + cover art overlay --- + overlay_path = output_dir / "with_overlay.mp4" + cover_path = Path(cover_art) if cover_art else None + song_name = run_dir.parent.name + _add_lyrics_overlay(with_audio_path, segments, overlay_path, audio_start, + font_name=font_name, font_color=font_color, + cover_art=cover_path, drop_time=drop_time, + song_name=song_name) + + # --- Step 8: Crop to exact 9:16 --- + final_path = output_dir / "final.mp4" + subprocess.run([ + "ffmpeg", "-y", + "-i", str(overlay_path), + "-vf", "crop=2*floor(ih*9/16/2):ih:(iw-2*floor(ih*9/16/2))/2:0", + "-c:v", "libx264", "-preset", "fast", + "-c:a", "copy", + str(final_path), + ], check=True, capture_output=True) + + # Clean up + Path(concat_list).unlink(missing_ok=True) + + print(f"\nFinal video: {final_path}") + print(f" Duration: {video_duration:.2f}s") + print(f" Slots: {len(sub_segments)} ({len(segments)} original segments)") + return final_path + + +def run( + run_dir: str | Path, + font_name: str = DEFAULT_FONT, + font_color: str = DEFAULT_FONT_COLOR, + cover_art: str | Path | None = None, +) -> Path: + """Assemble final video from clips + audio. + + Args: + run_dir: Run directory (e.g. data/Gone/run_001/). + font_name: Display name of the font for lyrics overlay. + font_color: Hex color for lyrics text. + cover_art: Path to cover art image (optional). + + Returns: + Path to final video. + """ + print("Assembling final video...") + return assemble(run_dir, font_name=font_name, font_color=font_color, + cover_art=cover_art) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.assembler ") + print(" e.g. python -m src.assembler data/Gone/run_001") + sys.exit(1) + + run(sys.argv[1]) diff --git a/src/beat_detector.py b/src/beat_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..11e3c899126a273bffd421df9181d69611c28f28 --- /dev/null +++ b/src/beat_detector.py @@ -0,0 +1,278 @@ +"""Beat/kick detection using madmom's RNN beat tracker.""" + +import json +import subprocess +import tempfile +from pathlib import Path +from typing import Optional + +import numpy as np +from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor + +# Bandpass filter: isolate kick drum frequency range (50-200 Hz) +HIGHPASS_CUTOFF = 50 +LOWPASS_CUTOFF = 500 + + +def _bandpass_filter(input_path: Path) -> Path: + """Apply a 50-200 Hz bandpass filter to isolate kick drum transients. + + Returns path to a temporary filtered WAV file. + """ + filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + filtered.close() + subprocess.run([ + "ffmpeg", "-y", + "-i", str(input_path), + "-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}", + str(filtered.name), + ], check=True, capture_output=True) + return Path(filtered.name) + + +def detect_beats( + drum_stem_path: str | Path, + min_bpm: float = 55.0, + max_bpm: float = 215.0, + transition_lambda: float = 100, + fps: int = 1000, +) -> np.ndarray: + """Detect beat timestamps from a drum stem using madmom. + + Uses an ensemble of bidirectional LSTMs to produce a beat activation + function, then a Dynamic Bayesian Network to decode beat positions. + + Args: + drum_stem_path: Path to the isolated drum stem WAV file. + min_bpm: Minimum expected tempo. Narrow this if you know the song's + approximate BPM for better accuracy. + max_bpm: Maximum expected tempo. + transition_lambda: Tempo smoothness — higher values penalise tempo + changes more (100 = very steady, good for most pop/rock). + fps: Frames per second for the DBN decoder. The RNN outputs at 100fps; + higher values interpolate for finer timestamp resolution (1ms at 1000fps). + + Returns: + 1D numpy array of beat timestamps in seconds, sorted chronologically. + """ + drum_stem_path = Path(drum_stem_path) + + # Step 0: Bandpass filter to isolate kick drum range (50-200 Hz) + filtered_path = _bandpass_filter(drum_stem_path) + + # Step 1: RNN produces beat activation function (probability per frame at 100fps) + act_proc = RNNBeatProcessor() + activations = act_proc(str(filtered_path)) + + # Clean up temp file + filtered_path.unlink(missing_ok=True) + + # Step 2: Interpolate to higher fps for finer timestamp resolution (1ms at 1000fps) + if fps != 100: + from scipy.interpolate import interp1d + n_frames = len(activations) + t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False) + n_new = int(n_frames * fps / 100) + t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False) + activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new) + activations = np.clip(activations, 0, None) # cubic spline can go negative + + # Step 3: DBN decodes activations into beat timestamps + # correct=False lets the DBN place beats using its own high-res state space + # instead of snapping to the coarse 100fps activation peaks + beat_proc = DBNBeatTrackingProcessor( + min_bpm=min_bpm, + max_bpm=max_bpm, + transition_lambda=transition_lambda, + fps=fps, + correct=False, + ) + beats = beat_proc(activations) + + return beats + + +def detect_drop( + audio_path: str | Path, + beat_times: np.ndarray, + window_sec: float = 0.5, +) -> float: + """Find the beat where the biggest energy jump occurs (the drop). + + Computes RMS energy in a window around each beat and returns the beat + with the largest increase compared to the previous beat. + + Args: + audio_path: Path to the full mix audio file. + beat_times: Array of beat timestamps in seconds. + window_sec: Duration of the analysis window around each beat. + + Returns: + Timestamp (seconds) of the detected drop beat. + """ + import librosa + + y, sr = librosa.load(str(audio_path), sr=None, mono=True) + half_win = int(window_sec / 2 * sr) + + rms_values = [] + for t in beat_times: + center = int(t * sr) + start = max(0, center - half_win) + end = min(len(y), center + half_win) + segment = y[start:end] + rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0 + rms_values.append(rms) + + rms_values = np.array(rms_values) + + # Find largest positive jump between consecutive beats + diffs = np.diff(rms_values) + drop_idx = int(np.argmax(diffs)) + 1 # +1 because diff shifts by one + drop_time = float(beat_times[drop_idx]) + + print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s " + f"(energy jump: {diffs[drop_idx - 1]:.4f})") + return drop_time + + +def select_beats( + beats: np.ndarray, + max_duration: float = 15.0, + min_interval: float = 0.3, +) -> np.ndarray: + """Select a subset of beats for video generation. + + Filters beats to fit within a duration limit and enforces a minimum + interval between consecutive beats (to avoid generating too many frames). + + Args: + beats: Array of beat timestamps in seconds. + max_duration: Maximum video duration in seconds. + min_interval: Minimum time between selected beats in seconds. + Beats closer together than this are skipped. + + Returns: + Filtered array of beat timestamps. + """ + if len(beats) == 0: + return beats + + # Trim to max duration + beats = beats[beats <= max_duration] + + if len(beats) == 0: + return beats + + # Enforce minimum interval between beats + selected = [beats[0]] + for beat in beats[1:]: + if beat - selected[-1] >= min_interval: + selected.append(beat) + + return np.array(selected) + + +def save_beats( + beats: np.ndarray, + output_path: str | Path, +) -> Path: + """Save beat timestamps to a JSON file. + + Format matches the project convention (same style as lyrics.json): + a list of objects with beat index and timestamp. + + Args: + beats: Array of beat timestamps in seconds. + output_path: Path to save the JSON file. + + Returns: + Path to the saved JSON file. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + data = [ + {"beat": i + 1, "time": round(float(t), 3)} + for i, t in enumerate(beats) + ] + + with open(output_path, "w") as f: + json.dump(data, f, indent=2) + + return output_path + + +def run( + drum_stem_path: str | Path, + output_dir: Optional[str | Path] = None, + min_bpm: float = 55.0, + max_bpm: float = 215.0, +) -> dict: + """Full beat detection pipeline: detect, select, and save. + + Args: + drum_stem_path: Path to the isolated drum stem WAV file. + output_dir: Directory to save beats.json. Defaults to the + parent of the drum stem's parent (e.g. data/Gone/ if + stem is at data/Gone/stems/drums.wav). + min_bpm: Minimum expected tempo. + max_bpm: Maximum expected tempo. + + Returns: + Dict with 'all_beats', 'selected_beats', and 'beats_path'. + """ + drum_stem_path = Path(drum_stem_path) + + if output_dir is None: + # stems/drums.wav -> parent is stems/, parent.parent is data/Gone/ + output_dir = drum_stem_path.parent.parent + output_dir = Path(output_dir) + + all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm) + selected = select_beats(all_beats) + + # Detect drop using the full mix audio (one level above stems/) + song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir + audio_path = None + for ext in [".wav", ".mp3", ".flac", ".m4a"]: + candidates = list(song_dir.glob(f"*{ext}")) + if candidates: + audio_path = candidates[0] + break + + drop_time = None + if audio_path and len(all_beats) > 2: + drop_time = detect_drop(audio_path, all_beats) + + beats_path = save_beats(all_beats, output_dir / "beats.json") + + # Save drop time alongside beats + if drop_time is not None: + drop_path = output_dir / "drop.json" + with open(drop_path, "w") as f: + json.dump({"drop_time": round(drop_time, 3)}, f, indent=2) + + return { + "all_beats": all_beats, + "selected_beats": selected, + "beats_path": beats_path, + "drop_time": drop_time, + } + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.beat_detector ") + sys.exit(1) + + result = run(sys.argv[1]) + all_beats = result["all_beats"] + selected = result["selected_beats"] + + print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})") + print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):") + for i, t in enumerate(selected): + print(f" Beat {i + 1}: {t:.3f}s") diff --git a/src/image_generator_api.py b/src/image_generator_api.py new file mode 100644 index 0000000000000000000000000000000000000000..20aaea019ff18a8e80879476f811677edbdb1f0b --- /dev/null +++ b/src/image_generator_api.py @@ -0,0 +1,207 @@ +"""Image generation using SDXL + LoRA styles via fal.ai API. + +API counterpart to image_generator_hf.py (on-device diffusers). +Uses the fal-ai/lora endpoint which accepts HuggingFace LoRA repo IDs +directly, so styles.py works unchanged. + +Set FAL_KEY env var before use. +""" + +import json +import time +from pathlib import Path +from typing import Optional + +import requests +from dotenv import load_dotenv + +from src.styles import get_style + +load_dotenv() + +# --------------------------------------------------------------------------- +# Config — matches image_generator_hf.py output +# --------------------------------------------------------------------------- + +FAL_MODEL_ID = "fal-ai/lora" + +BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" + +WIDTH = 768 +HEIGHT = 1344 +NUM_STEPS = 30 +GUIDANCE_SCALE = 7.5 + + +def _build_loras(style: dict) -> list[dict]: + """Build the LoRA list for the fal.ai API from a style dict. + + Note: Hyper-SD speed LoRA is NOT used here (it's an on-device optimization + requiring specific scheduler config). fal.ai runs on fast GPUs so we use + standard settings (30 steps, DPM++ 2M Karras) instead. + """ + loras = [] + + if style["source"] is not None: + # Pass HF repo ID directly — fal.ai resolves it internally. + # Full URLs to /resolve/main/ can fail with redirect issues. + loras.append({"path": style["source"], "scale": style["weight"]}) + + return loras + + +def _download_image(url: str, output_path: Path, retries: int = 3) -> Path: + """Download an image from URL to a local file with retry.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + for attempt in range(retries): + try: + resp = requests.get(url, timeout=120) + resp.raise_for_status() + with open(output_path, "wb") as f: + f.write(resp.content) + return output_path + except (requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e: + if attempt < retries - 1: + print(f" Download failed (attempt {attempt + 1}), retrying...") + else: + raise + + +def generate_image( + prompt: str, + negative_prompt: str = "", + loras: list[dict] | None = None, + seed: Optional[int] = None, +) -> dict: + """Generate a single image via fal.ai API. + + Args: + prompt: SDXL prompt. + negative_prompt: Negative prompt. + loras: List of LoRA dicts with 'path' and 'scale'. + seed: Random seed. + + Returns: + API response dict with 'images' list and 'seed'. + """ + import fal_client + + args = { + "model_name": BASE_MODEL, + "prompt": prompt, + "negative_prompt": negative_prompt, + "image_size": {"width": WIDTH, "height": HEIGHT}, + "num_inference_steps": NUM_STEPS, + "guidance_scale": GUIDANCE_SCALE, + "scheduler": "DPM++ 2M Karras", + "num_images": 1, + "image_format": "png", + "enable_safety_checker": False, + } + if loras: + args["loras"] = loras + if seed is not None: + args["seed"] = seed + + result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) + return result + + +def generate_all( + segments: list[dict], + output_dir: str | Path, + style_name: str = "Warm Sunset", + seed: int = 42, + progress_callback=None, +) -> list[Path]: + """Generate images for all segments via fal.ai. + + Args: + segments: List of segment dicts (with 'prompt' and 'negative_prompt'). + output_dir: Directory to save images. + style_name: Style from styles.py registry. + seed: Base seed (incremented per segment). + + Returns: + List of saved image paths. + """ + style = get_style(style_name) + loras = _build_loras(style) + trigger = style["trigger"] + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + paths = [] + for seg in segments: + idx = seg["segment"] + path = output_dir / f"segment_{idx:03d}.png" + + if path.exists(): + print(f" Segment {idx}/{len(segments)}: already exists, skipping") + paths.append(path) + continue + + prompt = seg["prompt"] + if trigger: + prompt = f"{trigger} style, {prompt}" + neg = seg.get("negative_prompt", "") + + print(f" Segment {idx}/{len(segments)}: generating image (fal.ai)...") + t0 = time.time() + result = generate_image(prompt, neg, loras=loras, seed=seed + idx) + elapsed = time.time() - t0 + + image_url = result["images"][0]["url"] + _download_image(image_url, path) + paths.append(path) + print(f" Saved {path.name} ({elapsed:.1f}s)") + if progress_callback: + progress_callback(idx, len(segments)) + + return paths + + +def run( + data_dir: str | Path, + style_name: str = "Warm Sunset", + seed: int = 42, + progress_callback=None, +) -> list[Path]: + """Full image generation pipeline: read segments, generate via API, save. + + Args: + data_dir: Run directory containing segments.json. + style_name: Style from the registry (see src/styles.py). + seed: Base random seed. + + Returns: + List of saved image paths. + """ + data_dir = Path(data_dir) + + with open(data_dir / "segments.json") as f: + segments = json.load(f) + + paths = generate_all(segments, data_dir / "images", style_name, seed, progress_callback) + + print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") + return paths + + +if __name__ == "__main__": + import os + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.image_generator_api [style_name]") + print(' e.g. python -m src.image_generator_api data/Gone/run_001 "Warm Sunset"') + print("\nRequires FAL_KEY environment variable.") + sys.exit(1) + + if not os.getenv("FAL_KEY"): + print("Error: FAL_KEY environment variable not set.") + print("Get your key at https://fal.ai/dashboard/keys") + sys.exit(1) + + style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" + run(sys.argv[1], style_name=style) diff --git a/src/image_generator_hf.py b/src/image_generator_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..d513711f8f8cf16785585412987e1e75cfee3f32 --- /dev/null +++ b/src/image_generator_hf.py @@ -0,0 +1,245 @@ +"""Generate images using SDXL + Hyper-SD 8-step + style LoRA from registry. + +Reads segments.json (with prompts from prompt_generator) and generates +one 768x1344 (9:16 vertical) image per segment. + +Pipeline: SDXL base → Hyper-SD 8-step CFG LoRA (speed) → style LoRA (aesthetics) +""" + +import json +from pathlib import Path +from typing import Optional + +import torch +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline +from huggingface_hub import hf_hub_download + +from src.styles import get_style + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" +VAE_MODEL = "madebyollin/sdxl-vae-fp16-fix" +HYPER_SD_REPO = "ByteDance/Hyper-SD" +HYPER_SD_FILE = "Hyper-SDXL-8steps-CFG-lora.safetensors" + +WIDTH = 768 +HEIGHT = 1344 +NUM_STEPS = 8 +GUIDANCE_SCALE = 5.0 + +HYPER_SD_WEIGHT = 0.125 # official recommendation + + +def _get_device_and_dtype(): + """Detect best available device and matching dtype.""" + if torch.cuda.is_available(): + return "cuda", torch.float16 + if torch.backends.mps.is_available(): + return "mps", torch.float32 # float32 required for MPS reliability + return "cpu", torch.float32 + + +def load_pipeline(style_name: str = "Warm Sunset"): + """Load SDXL pipeline with Hyper-SD and a style LoRA from the registry. + + Args: + style_name: Key in STYLES registry. Use "None" for no style LoRA. + + Returns: + Configured DiffusionPipeline ready for inference. + """ + style = get_style(style_name) + device, dtype = _get_device_and_dtype() + print(f"Loading SDXL pipeline on {device} ({dtype})...") + + vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=dtype) + + load_kwargs = {"torch_dtype": dtype, "vae": vae, "use_safetensors": True} + if dtype == torch.float16: + load_kwargs["variant"] = "fp16" + + pipe = DiffusionPipeline.from_pretrained(BASE_MODEL, **load_kwargs) + + # Hyper-SD 8-step CFG LoRA (always loaded) + hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) + pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") + + # Style LoRA from registry + _apply_style(pipe, style) + + # DDIMScheduler with trailing timestep spacing — required for Hyper-SD + pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) + + pipe.to(device) + + if device == "mps": + pipe.enable_attention_slicing() + pipe.enable_vae_slicing() + + print("Pipeline ready.") + return pipe + + +def _apply_style(pipe, style: dict): + """Load a style LoRA and set adapter weights.""" + source = style["source"] + if source is None: + pipe.set_adapters(["hyper-sd"], adapter_weights=[HYPER_SD_WEIGHT]) + print("No style LoRA — using base SDXL + Hyper-SD.") + return + + load_kwargs = {"adapter_name": "style"} + + # Local file: resolve relative to project root, pass dir + weight_name + project_root = Path(__file__).resolve().parent.parent + source_path = (project_root / source).resolve() + if source_path.is_file(): + load_kwargs["weight_name"] = source_path.name + pipe.load_lora_weights(str(source_path.parent), **load_kwargs) + else: + # HF Hub repo ID + if style["weight_name"]: + load_kwargs["weight_name"] = style["weight_name"] + pipe.load_lora_weights(source, **load_kwargs) + pipe.set_adapters( + ["hyper-sd", "style"], + adapter_weights=[HYPER_SD_WEIGHT, style["weight"]], + ) + print(f"Loaded style LoRA: {source}") + + +def switch_style(pipe, style_name: str): + """Switch to a different style LoRA at runtime. + + Unloads all LoRAs then reloads Hyper-SD + new style. + """ + style = get_style(style_name) + + pipe.unload_lora_weights() + + # Re-load Hyper-SD + hyper_path = hf_hub_download(HYPER_SD_REPO, HYPER_SD_FILE) + pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") + + # Load new style + _apply_style(pipe, style) + print(f"Switched to style: {style_name}") + + +def generate_image( + pipe, + prompt: str, + negative_prompt: str = "", + seed: Optional[int] = None, +) -> "PIL.Image.Image": + """Generate a single 768x1344 vertical image.""" + generator = None + if seed is not None: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=NUM_STEPS, + guidance_scale=GUIDANCE_SCALE, + height=HEIGHT, + width=WIDTH, + generator=generator, + ).images[0] + + +def generate_all( + segments: list[dict], + pipe, + output_dir: str | Path, + trigger_word: str = "", + seed: int = 42, + progress_callback=None, +) -> list[Path]: + """Generate images for all segments. + + Args: + segments: List of segment dicts (with 'prompt' and 'negative_prompt'). + pipe: Loaded DiffusionPipeline. + output_dir: Directory to save images. + trigger_word: LoRA trigger word appended to prompts. + seed: Base seed (incremented per segment for variety). + + Returns: + List of saved image paths. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + paths = [] + for seg in segments: + idx = seg["segment"] + path = output_dir / f"segment_{idx:03d}.png" + + if path.exists(): + print(f" Segment {idx}/{len(segments)}: already exists, skipping") + paths.append(path) + continue + + prompt = seg["prompt"] + if trigger_word: + prompt = f"{trigger_word} style, {prompt}" + neg = seg.get("negative_prompt", "") + + print(f" Segment {idx}/{len(segments)}: generating...") + image = generate_image(pipe, prompt, neg, seed=seed + idx) + + path = output_dir / f"segment_{idx:03d}.png" + image.save(path) + paths.append(path) + print(f" Saved {path.name}") + if progress_callback: + progress_callback(idx, len(segments)) + + return paths + + +def run( + data_dir: str | Path, + style_name: str = "Warm Sunset", + seed: int = 42, + progress_callback=None, +) -> list[Path]: + """Full image generation pipeline: load model, read segments, generate, save. + + Args: + data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/). + style_name: Style from the registry (see src/styles.py). + seed: Base random seed. + + Returns: + List of saved image paths. + """ + data_dir = Path(data_dir) + style = get_style(style_name) + + with open(data_dir / "segments.json") as f: + segments = json.load(f) + + pipe = load_pipeline(style_name) + paths = generate_all(segments, pipe, data_dir / "images", style["trigger"], seed, progress_callback) + + print(f"\nGenerated {len(paths)} images in {data_dir / 'images'}") + return paths + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.image_generator_hf [style_name]") + print(' e.g. python -m src.image_generator_hf data/Gone/run_001 "Warm Sunset"') + sys.exit(1) + + style = sys.argv[2] if len(sys.argv) > 2 else "Warm Sunset" + run(sys.argv[1], style_name=style) diff --git a/src/lyrics_extractor.py b/src/lyrics_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7130687b8ad6087755fc632ef50abae439a5ce --- /dev/null +++ b/src/lyrics_extractor.py @@ -0,0 +1,83 @@ +"""WhisperX wrapper for lyrics extraction with word-level timestamps.""" + +import json +from pathlib import Path +from typing import Optional + +import whisperx + + +def extract_lyrics( + vocal_path: str | Path, + model_name: str = "large-v2", + device: str = "cpu", + language: str = "en", + output_dir: Optional[str | Path] = None, +) -> list[dict]: + """Extract timestamped lyrics from an isolated vocal stem. + + Args: + vocal_path: Path to the isolated vocal audio file (data//stems/vocals.wav). + model_name: Whisper model size. Default "large-v2" (best for lyrics). + device: Device to run on ("cpu", "cuda"). + language: Language code for transcription. + output_dir: Directory to save lyrics.json. Defaults to data//. + + Returns: + List of word dicts with keys: "word", "start", "end". + Example: [{"word": "hello", "start": 0.5, "end": 0.8}, ...] + """ + vocal_path = str(vocal_path) + + # Load audio + audio = whisperx.load_audio(vocal_path) + + # Transcribe + model = whisperx.load_model(model_name, device, compute_type="int8", language=language) + result = model.transcribe(audio, batch_size=4) + del model # free Whisper model before loading alignment model + + # Forced alignment for word-level timestamps + model_a, metadata = whisperx.load_align_model(language_code=language, device=device) + result = whisperx.align(result["segments"], model_a, metadata, audio, device) + del model_a, metadata # free alignment model + + # Flatten to word list + words = [] + for segment in result["segments"]: + for word in segment.get("words", []): + if "start" in word and "end" in word: + words.append({ + "word": word["word"].strip(), + "start": word["start"], + "end": word["end"], + }) + + # Save to JSON in the song directory (stems/ parent = data//) + if output_dir is None: + output_dir = Path(vocal_path).parent.parent + output_dir = Path(output_dir) + + output_path = output_dir / "lyrics.json" + with open(output_path, "w") as f: + json.dump(words, f, indent=2) + + import gc + gc.collect() + + return words + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.lyrics_extractor ") + sys.exit(1) + + words = extract_lyrics(sys.argv[1]) + for w in words: + print(f"{w['start']:6.2f} - {w['end']:6.2f}: {w['word']}") + + output_path = Path(sys.argv[1]).parent.parent / "lyrics.json" + print(f"\nSaved to {output_path}") diff --git a/src/prompt_generator.py b/src/prompt_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2728af3837132146d97173e593edc773188e3de6 --- /dev/null +++ b/src/prompt_generator.py @@ -0,0 +1,383 @@ +"""Generate image + video prompts from segments using an LLM. + +Takes segments.json (lyrics mapped to beat intervals) and produces two +prompts per segment via two separate LLM calls: + 1. Image prompt — short, SDXL-optimized (≤77 CLIP tokens) + 2. Video prompt — detailed motion/action description for I2V (no token limit) + +Consistency: LLM keeps all scenes within a shared setting from the style guidance. +Variety: LLM picks different subjects, camera angles, compositions per segment. +Narrative: LLM derives an overarching visual story from the lyrics. +""" + +import json +import os +from pathlib import Path +from typing import Optional + +import anthropic +from dotenv import load_dotenv + +load_dotenv() + +# Camera angles to cycle through for visual variety between cuts +CAMERA_ANGLES = [ + "wide establishing shot", + "close-up", + "aerial view", + "low angle shot", + "medium shot", + "extreme wide shot", + "over-the-shoulder perspective", + "dutch angle", + "tracking shot from the side", + "bird's eye view", + "ground-level shot", + "silhouette against the sky", +] + +# Default quality suffix — overridden by style-specific quality_suffix from styles.py +DEFAULT_QUALITY_SUFFIX = "8K, cinematic, atmospheric, sharp details" + +NEGATIVE_PROMPT = ( + "text, watermark, logo, blurry, low quality, deformed, " + "ugly, oversaturated, cartoon, anime" +) + +# --------------------------------------------------------------------------- +# LLM Call 1: Image prompts (short, SDXL-optimized) +# --------------------------------------------------------------------------- + +IMAGE_SYSTEM_PROMPT = """\ +You are a music video director. Given song lyrics, a SETTING, and a list of \ +segments (each ~2 seconds long), create a visually compelling shot list for \ +IMAGE generation (Stable Diffusion XL). + +Rules: +1. A SETTING will be provided at the end of these instructions. ALL scenes \ +MUST take place within that setting — treat it as the world of a short film. \ +Never leave this world. +2. Use the LYRICS to shape the MOOD, ENERGY, and EMOTIONAL ARC of each scene. \ +The lyrics dictate the vibe — if they're dark and melancholic, the visuals \ +should feel heavy and somber even within the setting. If they're upbeat, the \ +visuals should feel energetic. +3. When lyrics are CONCRETE and naturally fit the setting, lean into them \ +heavily. For example, if the setting is a coastal drive and the lyrics say \ +"waves crashing down", make that segment literally about waves crashing \ +against rocks as the car passes. If the lyrics say "fading light", show the \ +sun dropping below the horizon. The more specific the lyrics, the more \ +directly they should influence the scene. +4. When lyrics are ABSTRACT or metaphorical (e.g. "lost in your eyes", \ +"falling apart"), translate the emotion into something visual and physical \ +within the setting — don't try to literally depict abstract concepts. +5. Each segment gets a UNIQUE SHOT within the shared setting — vary the \ +subject, angle, and composition, but NEVER leave the world. +CRITICAL: Every scene MUST depict ACTION or MOTION — something must be \ +happening. These will be turned into short video clips, so static subjects \ +like "a wooden floor", "a parked car", or "an empty room" are useless. \ +Show vehicles driving, waves crashing, lights flickering, rain falling, \ +fires burning — dynamic scenes only. +6. Use the assigned camera angle for each segment. +7. Segments WITHOUT lyrics (instrumental): use atmospheric, mood-driven \ +details from the setting (environmental motion, weather, ambient action). +8. Write prompts as SDXL-optimized natural language descriptions. \ +Keep each scene between 25-35 words. Be specific — name exact objects, \ +materials, colors, and weather details. Every word must earn its place. \ +Focus on CONCRETE OBJECTS and ACTIONS — what is physically in the frame \ +and what is happening. SDXL needs to know what to draw, not how to feel. \ +BAD: "reflections layering over glass, interior light diffused through water" — abstract mood. \ +GOOD: "taxi splashing through puddle on wet street, rain falling past neon bar sign" — objects + action. \ +BAD: "streetlights bleeding through downpour, darkness stretching ahead" — vague atmosphere. \ +GOOD: "car windshield wipers sweeping rain, blurred traffic lights ahead, wet dashboard" — specific things. \ +BAD: "water sheeting off canvas edge in a thick curtain" — SDXL will draw a curtain. \ +GOOD: "water pouring off awning edge, rain splashing on sidewalk below" — plain description. \ +Write like you're telling a 10-year-old what's in the picture. Simple, plain words. \ +Name the objects. Name the action. Lighting and mood come from the SETTING, \ +you don't need to describe them — describe what's HAPPENING. \ +Use LITERAL language only — no metaphors, no poetic phrasing. SDXL interprets \ +words literally. BANNED words: bleeding, drowning, bathed, kissed, dancing, \ +breathing, alive, whispering, haunting, cascading, diffusing, fragmenting. \ +These cause SDXL to generate unintended objects. \ +Also avoid describing PROCESSES or PHYSICS — SDXL generates a single frame, \ +not a sequence. "ripples expanding", "light fragmenting and reforming", \ +"reflections scattering" are processes, not objects. Instead describe the \ +RESULT: "rippled puddle", "blurry neon reflection in water", "wet glass". \ +Say exactly what a camera would capture in ONE freeze-frame. \ +Before finalizing each scene, sanity-check it: does this make physical \ +sense? Could this actually exist? "pooled water on a car hood" — no, car \ +hoods are curved and water runs off. "rain falling upward" — no. \ +"neon sign reflected in a brick wall" — no, brick doesn't reflect. \ +Only write scenes that obey basic physics and real-world logic. \ +Strip camera angle phrasing from the scene text (angles are metadata, not prompt words). +9. Include lighting and color in every scene. Derive from the SETTING — \ +a sunset drive = warm golden-hour light, lens flares, long shadows; \ +a rainy city night = cold neon on wet surfaces, streetlight halos; \ +a stormy harbour = overcast grey, dramatic cloud breaks. \ +Keep lighting consistent across all scenes. +10. Do NOT include style, quality, or technical tags in the scene — these \ +are appended automatically. BANNED from scenes: "cinematic", "moody", \ +"atmospheric", "dramatic lighting", "film grain", "color grade", "bokeh", \ +"depth of field", "35mm", "8K", "masterpiece", "best quality". \ +Your scene should contain ONLY objects, actions, and setting-derived light. +11. Do NOT include text, words, or typography in the scenes. +12. Do NOT end scenes with periods. Use commas to separate phrases. \ +Every character counts — periods waste a token. + +Return ONLY valid JSON: a list of objects with "segment" (number) and \ +"scene" (the creative description). No markdown, no explanation.\ +""" + +# --------------------------------------------------------------------------- +# LLM Call 2: Video prompts (detailed motion descriptions) +# --------------------------------------------------------------------------- + +VIDEO_SYSTEM_PROMPT = """\ +You are a music video director creating motion descriptions for an \ +image-to-video AI model. You will receive a list of segments, each with \ +an image scene description already written. Your job is to describe \ +HOW each scene should MOVE and ANIMATE. + +Rules: +1. For each segment, write a detailed "video_prompt" (2-4 sentences) \ +describing all motion in the scene: + - SUBJECT MOTION: what the subject does (walking, turning, reaching, \ +driving, dancing, running, etc.) + - CAMERA MOTION: how the camera moves (slow pan left, dolly forward, \ +tracking shot, crane up, handheld shake, static with zoom, etc.) + - ENVIRONMENTAL MOTION: ambient movement (wind blowing hair/clothes, \ +rain falling, leaves drifting, smoke rising, lights flickering, waves \ +crashing, clouds moving, reflections rippling, etc.) + - PACING: match the emotional energy — slow and contemplative for \ +quiet moments, faster and more dynamic for intense moments. +2. Be specific and physical. Not "things move around" but "the camera \ +slowly tracks forward as rain streaks across the windshield and the \ +wipers sweep left to right." +3. Keep the motion consistent with the shared setting — all scenes are \ +part of the same story. +4. Do NOT describe visual style, colors, or lighting — the image already \ +has those. Focus ONLY on motion and action. +5. CRITICAL — ONLY animate what exists in the scene description. Do NOT \ +introduce new subjects, people, or objects that are not explicitly \ +mentioned. If the scene describes a landscape with no people, describe \ +ONLY environmental motion (wind, water, light changes, camera movement). \ +NEVER add a person walking into frame unless the scene already mentions \ +a person or figure. + +Return ONLY valid JSON: a list of objects with "segment" (number) and \ +"video_prompt" (the motion description). No markdown, no explanation.\ +""" + + +def _build_user_prompt( + segments: list[dict], song_name: str, style_description: str = "", +) -> str: + """Build the user message for the image prompt LLM call.""" + all_lyrics = " ".join( + seg["lyrics"] for seg in segments if seg["lyrics"] + ).strip() + + lines = [ + f'Song: "{song_name}"', + f'Full lyrics in this clip: "{all_lyrics}"', + f"Number of segments: {len(segments)}", + ] + + if style_description: + lines.append(f'Visual style direction: "{style_description}"') + + lines += ["", "Segments:"] + + for i, seg in enumerate(segments): + angle = CAMERA_ANGLES[i % len(CAMERA_ANGLES)] + lyrics_note = f'lyrics: "{seg["lyrics"]}"' if seg["lyrics"] else "instrumental" + lines.append( + f' {seg["segment"]}. ({seg["start"]:.1f}s–{seg["end"]:.1f}s) ' + f'[{angle}] {lyrics_note}' + ) + + return "\n".join(lines) + + +def _build_video_user_prompt(segments: list[dict]) -> str: + """Build the user message for the video prompt LLM call.""" + lines = [ + "Generate motion descriptions for each segment.", + "IMPORTANT: ONLY animate elements that exist in the scene description.", + "Do NOT add people, figures, or objects that aren't mentioned.", + "", + "Image scenes:", + "", + ] + + for seg in segments: + lyrics_note = f' (lyrics: "{seg["lyrics"]}")' if seg.get("lyrics") else " (instrumental)" + lines.append( + f' Segment {seg["segment"]}: "{seg["scene"]}"{lyrics_note}' + ) + + return "\n".join(lines) + + +def _parse_llm_json(raw: str) -> list[dict]: + """Parse JSON from LLM response, stripping markdown fences if present.""" + raw = raw.strip() + if raw.startswith("```"): + raw = raw.split("\n", 1)[1] + raw = raw.rsplit("```", 1)[0] + return json.loads(raw) + + +def generate_prompts( + segments: list[dict], + song_name: str = "Unknown", + style_description: str = "", + image_prompt_guidance: str = "", + quality_suffix: str = "", + model: str = "claude-sonnet-4-6", +) -> list[dict]: + """Generate image + video prompts for each segment using two LLM calls. + + Args: + segments: List of segment dicts from segmenter (with lyrics). + song_name: Name of the song (helps the LLM set the mood). + style_description: Description of the visual style (from styles registry). + image_prompt_guidance: Style-specific creative direction appended to the + image system prompt (from styles registry). + quality_suffix: Style-specific quality tags appended to each prompt. + model: Anthropic model to use. + + Returns: + Updated segments list with added keys: + - prompt: full SDXL prompt (scene + style suffix) + - video_prompt: detailed motion description for I2V + - negative_prompt: negative prompt for SDXL + - camera_angle: the assigned camera angle + - scene: raw scene description from LLM + """ + client = anthropic.Anthropic() + + # --- Call 1: Image prompts --- + print(" Generating image prompts...") + user_prompt = _build_user_prompt(segments, song_name, style_description) + + # Inject style-specific guidance into the system prompt + image_system = IMAGE_SYSTEM_PROMPT + if image_prompt_guidance: + image_system += f"\n\n{image_prompt_guidance}" + + response = client.messages.create( + model=model, + max_tokens=2048, + system=image_system, + messages=[{"role": "user", "content": user_prompt}], + ) + + scenes = _parse_llm_json(response.content[0].text) + scene_map = {s["segment"]: s for s in scenes} + + # Merge image prompts into segments + suffix = quality_suffix or DEFAULT_QUALITY_SUFFIX + for i, seg in enumerate(segments): + angle = CAMERA_ANGLES[i % len(CAMERA_ANGLES)] + scene_data = scene_map.get(seg["segment"], {}) + scene = scene_data.get("scene", "atmospheric landscape") + + seg["scene"] = scene + seg["camera_angle"] = angle + seg["prompt"] = f"{scene}, {suffix}" + seg["negative_prompt"] = NEGATIVE_PROMPT + + # --- Call 2: Video prompts --- + print(" Generating video prompts...") + video_user_prompt = _build_video_user_prompt(segments) + + response = client.messages.create( + model=model, + max_tokens=4096, + system=VIDEO_SYSTEM_PROMPT, + messages=[{"role": "user", "content": video_user_prompt}], + ) + + video_scenes = _parse_llm_json(response.content[0].text) + video_map = {s["segment"]: s for s in video_scenes} + + # Merge video prompts into segments + for seg in segments: + video_data = video_map.get(seg["segment"], {}) + seg["video_prompt"] = video_data.get( + "video_prompt", f"smooth cinematic motion, {seg['scene']}" + ) + + return segments + + +def save_segments( + segments: list[dict], + output_path: str | Path, +) -> Path: + """Save prompt-enriched segments to JSON.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + json.dump(segments, f, indent=2) + + return output_path + + +def run( + data_dir: str | Path, + song_name: Optional[str] = None, + style_description: str = "", + image_prompt_guidance: str = "", + quality_suffix: str = "", +) -> list[dict]: + """Full prompt generation pipeline: load segments, generate prompts, save. + + Args: + data_dir: Run directory containing segments.json (e.g. data/Gone/run_001/). + song_name: Name of the song. Defaults to the parent directory name. + style_description: Description of the visual style (from styles registry). + image_prompt_guidance: Style-specific creative direction for image prompts. + quality_suffix: Style-specific quality tags appended to each prompt. + + Returns: + List of prompt-enriched segment dicts. + """ + data_dir = Path(data_dir) + + if song_name is None: + song_name = data_dir.parent.name + + with open(data_dir / "segments.json") as f: + segments = json.load(f) + + segments = generate_prompts( + segments, song_name=song_name, style_description=style_description, + image_prompt_guidance=image_prompt_guidance, + quality_suffix=quality_suffix, + ) + save_segments(segments, data_dir / "segments.json") + + return segments + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.prompt_generator [song_name]") + print(" e.g. python -m src.prompt_generator data/Gone 'Gone'") + sys.exit(1) + + name = sys.argv[2] if len(sys.argv) > 2 else None + segments = run(sys.argv[1], song_name=name) + + print(f"Generated prompts for {len(segments)} segments:\n") + for seg in segments: + lyrics_tag = f' [{seg["lyrics"]}]' if seg["lyrics"] else "" + print(f" Seg {seg['segment']}{lyrics_tag}") + print(f" Scene: {seg['scene']}") + print(f" Video: {seg['video_prompt'][:100]}...") + print(f" Prompt: {seg['prompt'][:100]}...") + print() diff --git a/src/segmenter.py b/src/segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0e216cb36b39d0f8ff35f9c1e9361acda5622a --- /dev/null +++ b/src/segmenter.py @@ -0,0 +1,142 @@ +"""Lyrics-to-beat mapping: group beats into segments and assign lyrics.""" + +import json +from pathlib import Path +from typing import Optional + + +def segment_lyrics( + beats: list[dict], + lyrics: list[dict], + beats_per_segment: int = 4, +) -> list[dict]: + """Map timestamped lyrics onto beat-grouped segments. + + Groups consecutive beats into segments (e.g. 4 beats = 1 bar in 4/4 time) + and assigns words to the segment where they start. + + Args: + beats: List of beat dicts with "beat" and "time" keys. + lyrics: List of word dicts with "word", "start", "end" keys. + beats_per_segment: Number of beats per segment. 4 = one bar in 4/4 time. + + Returns: + List of segment dicts with keys: + - segment: 1-indexed segment number + - start: start time in seconds + - end: end time in seconds + - duration: segment duration in seconds + - lyrics: raw lyrics text for this segment (may be empty) + - words: list of word dicts that fall in this segment + """ + beat_times = [b["time"] for b in beats] + + # Build segment boundaries by grouping every N beats + segments = [] + seg_num = 1 + for i in range(0, len(beat_times) - 1, beats_per_segment): + start = beat_times[i] + # End is either N beats later or the last beat + end_idx = min(i + beats_per_segment, len(beat_times) - 1) + end = beat_times[end_idx] + + # Store individual beat timestamps for this segment + seg_beat_times = [ + round(beat_times[j], 3) + for j in range(i, min(i + beats_per_segment + 1, len(beat_times))) + ] + + segments.append({ + "segment": seg_num, + "start": round(start, 3), + "end": round(end, 3), + "duration": round(end - start, 3), + "beats": seg_beat_times, + "lyrics": "", + "words": [], + }) + seg_num += 1 + + # Assign words to segments based on where the word starts + for word in lyrics: + word_start = word["start"] + for seg in segments: + if seg["start"] <= word_start < seg["end"]: + seg["words"].append(word) + break + else: + # Word starts after last segment boundary — assign to last segment + if segments and word_start >= segments[-1]["start"]: + segments[-1]["words"].append(word) + + # Build lyrics text per segment + for seg in segments: + seg["lyrics"] = " ".join(w["word"] for w in seg["words"]) + + return segments + + +def save_segments( + segments: list[dict], + output_path: str | Path, +) -> Path: + """Save segments to a JSON file. + + Args: + segments: List of segment dicts. + output_path: Path to save the JSON file. + + Returns: + Path to the saved JSON file. + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + json.dump(segments, f, indent=2) + + return output_path + + +def run( + data_dir: str | Path, + beats_per_segment: int = 4, +) -> list[dict]: + """Full segmentation pipeline: load beats + lyrics, segment, and save. + + Args: + data_dir: Song data directory containing beats.json and lyrics.json + (e.g. data/Gone/). + beats_per_segment: Number of beats per segment (4 = one bar). + + Returns: + List of segment dicts. + """ + data_dir = Path(data_dir) + + with open(data_dir / "beats.json") as f: + beats = json.load(f) + + with open(data_dir / "lyrics.json") as f: + lyrics = json.load(f) + + segments = segment_lyrics(beats, lyrics, beats_per_segment=beats_per_segment) + save_segments(segments, data_dir / "segments.json") + + return segments + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.segmenter ") + print(" e.g. python -m src.segmenter data/Gone") + sys.exit(1) + + segments = run(sys.argv[1]) + print(f"Created {len(segments)} segments:\n") + for seg in segments: + lyrics_display = f'"{seg["lyrics"]}"' if seg["lyrics"] else "(instrumental)" + print(f" Seg {seg['segment']}: {seg['start']:.3f}s - {seg['end']:.3f}s " + f"({seg['duration']:.3f}s) {lyrics_display}") diff --git a/src/stem_separator.py b/src/stem_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc9b19e475d878ce062e9193dd3563fd1eedefe --- /dev/null +++ b/src/stem_separator.py @@ -0,0 +1,243 @@ +"""LALAL.AI API wrapper for audio stem separation.""" + +import os +import shutil +import time +from pathlib import Path +from typing import Optional + +import requests + +API_BASE = "https://www.lalal.ai/api/v1" +DATA_DIR = Path(__file__).parent.parent / "data" + +# Stems we need for the pipeline +STEMS_TO_EXTRACT = ["vocals", "drum"] +# Map LALAL.AI track labels to our file naming convention +LABEL_TO_FILENAME = {"vocals": "vocals.wav", "drum": "drums.wav"} + + +def _get_api_key() -> str: + key = os.environ.get("LALAL_KEY") + if not key: + raise RuntimeError( + "LALAL_KEY environment variable not set. " + "Set it locally or as a HuggingFace Space secret." + ) + return key + + +def _headers(api_key: str) -> dict: + return {"X-License-Key": api_key} + + +def _next_run_dir(song_dir: Path) -> Path: + """Find the next available run directory (run_001, run_002, ...).""" + existing = sorted(song_dir.glob("run_*")) + next_num = 1 + for d in existing: + try: + num = int(d.name.split("_")[1]) + next_num = max(next_num, num + 1) + except (IndexError, ValueError): + continue + return song_dir / f"run_{next_num:03d}" + + +def _upload(audio_path: Path, api_key: str) -> str: + """Upload audio file to LALAL.AI. Returns source_id.""" + with open(audio_path, "rb") as f: + resp = requests.post( + f"{API_BASE}/upload/", + headers={ + **_headers(api_key), + "Content-Disposition": f'attachment; filename="{audio_path.name}"', + }, + data=f, + ) + resp.raise_for_status() + data = resp.json() + source_id = data["id"] + print(f" Uploaded {audio_path.name} → source_id={source_id} " + f"(duration: {data['duration']:.1f}s)") + return source_id + + +def _split_stem(source_id: str, stem: str, api_key: str) -> str: + """Start a stem separation task. Returns task_id.""" + # Andromeda is best for vocals but doesn't support all stems — use auto for others + splitter = "andromeda" if stem == "vocals" else None + resp = requests.post( + f"{API_BASE}/split/stem_separator/", + headers=_headers(api_key), + json={ + "source_id": source_id, + "presets": { + "stem": stem, + "splitter": splitter, + "dereverb_enabled": False, + "encoder_format": "wav", + "extraction_level": "deep_extraction", + }, + }, + ) + resp.raise_for_status() + data = resp.json() + task_id = data["task_id"] + print(f" Split task started: stem={stem}, task_id={task_id}") + return task_id + + +def _poll_tasks(task_ids: list[str], api_key: str, poll_interval: float = 5.0) -> dict: + """Poll tasks until all complete. Returns {task_id: result_data}.""" + pending = set(task_ids) + results = {} + + while pending: + resp = requests.post( + f"{API_BASE}/check/", + headers=_headers(api_key), + json={"task_ids": list(pending)}, + ) + resp.raise_for_status() + data = resp.json().get("result", resp.json()) + + for task_id, info in data.items(): + status = info.get("status") + if status == "success": + results[task_id] = info + pending.discard(task_id) + print(f" Task {task_id}: complete") + elif status == "progress": + print(f" Task {task_id}: {info.get('progress', 0)}%") + elif status == "error": + error = info.get("error", {}) + raise RuntimeError( + f"LALAL.AI task {task_id} failed: " + f"{error.get('detail', 'unknown error')} " + f"(code: {error.get('code')})" + ) + elif status == "cancelled": + raise RuntimeError(f"LALAL.AI task {task_id} was cancelled") + elif status == "server_error": + raise RuntimeError( + f"LALAL.AI server error for task {task_id}: " + f"{info.get('error', 'unknown')}" + ) + + if pending: + time.sleep(poll_interval) + + return results + + +def _download_track(url: str, output_path: Path) -> None: + """Download a track from LALAL.AI CDN.""" + resp = requests.get(url, stream=True) + resp.raise_for_status() + with open(output_path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + print(f" Downloaded → {output_path.name} ({output_path.stat().st_size / 1024:.0f} KB)") + + +def _delete_source(source_id: str, api_key: str) -> None: + """Delete uploaded source file from LALAL.AI servers.""" + try: + requests.post( + f"{API_BASE}/delete/", + headers=_headers(api_key), + json={"source_id": source_id}, + ) + print(f" Cleaned up remote source {source_id}") + except Exception: + pass # non-critical + + +def separate_stems( + audio_path: str | Path, + output_dir: Optional[str | Path] = None, +) -> dict[str, Path]: + """Separate an audio file into vocals and drums using LALAL.AI. + + Creates a new run directory for each invocation so multiple runs + on the same song don't overwrite each other. + + Args: + audio_path: Path to the input audio file (mp3/wav) from input/. + output_dir: Directory to save stems. If None, auto-creates + data//run_NNN/stems/. + + Returns: + Dict mapping stem names to their file paths. + Keys: "drums", "vocals", "run_dir" + """ + audio_path = Path(audio_path) + song_name = audio_path.stem + song_dir = DATA_DIR / song_name + api_key = _get_api_key() + + if output_dir is None: + run_dir = _next_run_dir(song_dir) + output_dir = run_dir / "stems" + else: + output_dir = Path(output_dir) + run_dir = output_dir.parent + + output_dir.mkdir(parents=True, exist_ok=True) + + # Copy original song into song directory (shared across runs) + song_copy = song_dir / audio_path.name + if not song_copy.exists(): + shutil.copy2(audio_path, song_copy) + + # 1. Upload + print("Stem separation (LALAL.AI):") + source_id = _upload(audio_path, api_key) + + # 2. Start split tasks for each stem + task_to_stem = {} + for stem in STEMS_TO_EXTRACT: + task_id = _split_stem(source_id, stem, api_key) + task_to_stem[task_id] = stem + + # 3. Poll until all tasks complete + results = _poll_tasks(list(task_to_stem.keys()), api_key) + + # 4. Download the separated stem tracks + stem_paths = {"run_dir": run_dir} + for task_id, result_data in results.items(): + stem = task_to_stem[task_id] + filename = LABEL_TO_FILENAME[stem] + tracks = result_data.get("result", {}).get("tracks", []) + + # Find the "stem" track (not the "back"/inverse track) + stem_track = next((t for t in tracks if t["type"] == "stem"), None) + if stem_track is None: + raise RuntimeError(f"No stem track found in result for {stem}") + + output_path = output_dir / filename + _download_track(stem_track["url"], output_path) + + # Map to our naming: "drum" API stem → "drums" key + key = "drums" if stem == "drum" else stem + stem_paths[key] = output_path + + # 5. Cleanup remote files + _delete_source(source_id, api_key) + + return stem_paths + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.stem_separator ") + sys.exit(1) + + result = separate_stems(sys.argv[1]) + print(f"Run directory: {result['run_dir']}") + for name, path in result.items(): + if name != "run_dir": + print(f" {name}: {path}") diff --git a/src/styles.py b/src/styles.py new file mode 100644 index 0000000000000000000000000000000000000000..d04d128c92c3753e7b993aa857304e9e587e58f2 --- /dev/null +++ b/src/styles.py @@ -0,0 +1,99 @@ +"""Style registry — maps style names to LoRA sources. + +Each style can point to a local .safetensors file or a HuggingFace Hub repo. +pipe.load_lora_weights() handles both transparently. +""" + +STYLES = { + "Sunset Coastal Drive": { + "source": "samuelsattler/warm-sunset-lora", + "weight_name": "pytorch_lora_weights.safetensors", + "weight": 1.0, + "trigger": "sks", + "description": "Golden hour warmth, sun flares, silhouettes, warm color grading", + "quality_suffix": "8K, cinematic, golden hour glow, warm volumetric light, lens flare, shallow depth of field", + "image_prompt_guidance": ( + "SETTING — Coastal sunset drive:\n" + "The shared world is a drive along a coastal highway as the sun " + "sets. All scenes take place in or around this journey — a car " + "cruising along cliffs above the ocean, waves crashing against " + "rocks below, wind whipping through open windows, palm trees " + "swaying overhead, the sun sinking into the sea on the horizon. " + "No humans visible — focus on the car, the road, the ocean, " + "and the landscape. Every shot must have motion: wheels turning, " + "waves rolling, sun flares shifting, clouds drifting." + ), + }, + "Rainy City Night": { + "source": "artificialguybr/filmgrain-redmond-filmgrain-lora-for-sdxl", + "weight_name": "FilmGrainRedmond-FilmGrain-FilmGrainAF.safetensors", + "weight": 0.8, + "trigger": "FilmGrainAF", + "description": "35mm film grain, moody color grading, cinematic lighting", + "quality_suffix": "8K, cinematic, shot on 35mm film, dramatic rim lighting, high contrast, shallow depth of field", + "image_prompt_guidance": ( + "SETTING — Rainy city at night:\n" + "The shared world is a walk through a rain-soaked city after dark. " + "All scenes take place on these streets — rain streaking through " + "streetlights, puddles reflecting neon signs, steam rising from " + "grates, traffic passing with blurred headlights, wet umbrellas, " + "rain hammering awnings, water streaming down windows. " + "No humans visible — focus on the environment, the rain, the " + "reflections, and the city itself. Every shot must have motion: " + "rain falling, cars passing, lights flickering, water flowing " + "through gutters." + ), + }, + "Cyberpunk": { + "source": "jbilcke-hf/sdxl-cyberpunk-2077", + "weight_name": "pytorch_lora_weights.safetensors", + "weight": 0.9, + "trigger": "cyberpunk-2077", + "description": "Neon-lit cityscapes, dark futuristic vibes, glowing signs", + "quality_suffix": "8K, cinematic, neon-drenched, volumetric fog, sharp details, high contrast, dramatic lighting", + "image_prompt_guidance": ( + "SETTING — Cyberpunk nightlife cityscape:\n" + "The shared world is a futuristic megacity at night. All scenes " + "take place in this neon-drenched urban sprawl — holographic " + "billboards flickering on skyscrapers, flying vehicles streaking " + "between towers, neon signs buzzing and glitching, rain falling " + "through laser grids, steam erupting from vents, LED-lit market " + "stalls with flickering displays. " + "No humans visible — focus on the city, the machines, the neon, " + "and the architecture. Every shot must have motion: vehicles " + "flying, signs flickering, rain falling, smoke drifting, lights " + "pulsing." + ), + }, + "Watercolour Harbour": { + "source": "ostris/watercolor_style_lora_sdxl", + "weight_name": "watercolor_v1_sdxl.safetensors", + "weight": 1.4, + "trigger": "", + "description": "Soft watercolor painting style, fluid washes, gentle blending", + "quality_suffix": "8K, watercolor painting, soft painterly washes, fluid blending, delicate brushstrokes, atmospheric", + "image_prompt_guidance": ( + "SETTING — Stormy harbour village:\n" + "The shared world is a coastal fishing village during a storm. " + "All scenes take place in and around this harbour — waves " + "crashing against stone sea walls, fishing boats rocking and " + "pulling at their moorings, rain sweeping across the harbour " + "in sheets, wind tearing through flags and sails, seabirds " + "wheeling against dark clouds, lanterns swinging on posts, " + "water pouring off rooftops into cobblestone streets. " + "No humans visible — focus on the sea, the boats, the storm, " + "and the village. Every shot must have motion: waves surging, " + "boats swaying, rain lashing, flags snapping in the wind." + ), + }, +} + + +def get_style(name: str) -> dict: + """Look up a style by name. Raises KeyError if not found.""" + return STYLES[name] + + +def style_names() -> list[str]: + """Return list of available style names for UI dropdowns.""" + return list(STYLES.keys()) diff --git a/src/video_generator_api.py b/src/video_generator_api.py new file mode 100644 index 0000000000000000000000000000000000000000..ba06effd0652de48a89915a413c337d8bac200f0 --- /dev/null +++ b/src/video_generator_api.py @@ -0,0 +1,244 @@ +"""Image-to-video generation using Wan 2.1 via fal.ai API. + +Reads generated images and their prompts, produces a short video clip +per segment. Each clip is ~5s at 16fps; the assembler later trims to +the exact beat interval duration. + +Two backends: + - "api" : fal.ai hosted Wan 2.1 (for development / local runs) + - "hf" : on-device Wan 2.1 with FP8 on ZeroGPU (for HF Spaces deployment) + +Set FAL_KEY env var for API mode. +""" + +import base64 +import json +import os +import time +from pathlib import Path +from typing import Optional + +import requests +from dotenv import load_dotenv + +load_dotenv() + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +FAL_MODEL_ID = "fal-ai/wan-i2v" + +# Vertical 9:16 to match our SDXL images +ASPECT_RATIO = "9:16" +RESOLUTION = "480p" # cheaper/faster for dev; bump to 720p for final +NUM_FRAMES = 81 # ~5s at 16fps +FPS = 16 +NUM_INFERENCE_STEPS = 30 +GUIDANCE_SCALE = 5.0 +SEED = 42 + + +def _image_to_data_uri(image_path: str | Path) -> str: + """Convert a local image file to a base64 data URI for the API.""" + path = Path(image_path) + suffix = path.suffix.lower() + mime = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg"} + content_type = mime.get(suffix, "image/png") + + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode() + + return f"data:{content_type};base64,{encoded}" + + +def _download_video(url: str, output_path: Path) -> Path: + """Download a video from URL to a local file.""" + resp = requests.get(url, timeout=300) + resp.raise_for_status() + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "wb") as f: + f.write(resp.content) + return output_path + + +# --------------------------------------------------------------------------- +# API backend (fal.ai) +# --------------------------------------------------------------------------- + +def generate_clip_api( + image_path: str | Path, + prompt: str, + negative_prompt: str = "", + seed: Optional[int] = None, +) -> dict: + """Generate a video clip from an image using fal.ai Wan 2.1 API. + + Args: + image_path: Path to the source image. + prompt: Motion/scene description for the video. + negative_prompt: What to avoid. + seed: Random seed for reproducibility. + + Returns: + API response dict with 'video' (url, content_type, file_size) and 'seed'. + """ + import fal_client + + image_uri = _image_to_data_uri(image_path) + + args = { + "image_url": image_uri, + "prompt": prompt, + "aspect_ratio": ASPECT_RATIO, + "resolution": RESOLUTION, + "num_frames": NUM_FRAMES, + "frames_per_second": FPS, + "num_inference_steps": NUM_INFERENCE_STEPS, + "guide_scale": GUIDANCE_SCALE, + "negative_prompt": negative_prompt, + "enable_safety_checker": False, + "enable_prompt_expansion": False, + } + if seed is not None: + args["seed"] = seed + + result = fal_client.subscribe(FAL_MODEL_ID, arguments=args) + return result + + +# --------------------------------------------------------------------------- +# Public interface +# --------------------------------------------------------------------------- + +def generate_clip( + image_path: str | Path, + prompt: str, + output_path: str | Path, + negative_prompt: str = "", + seed: Optional[int] = None, +) -> Path: + """Generate a video clip from an image and save it locally. + + Args: + image_path: Path to the source image. + prompt: Motion/scene description. + output_path: Where to save the .mp4 clip. + negative_prompt: What to avoid. + seed: Random seed. + + Returns: + Path to the saved video clip. + """ + output_path = Path(output_path) + + result = generate_clip_api(image_path, prompt, negative_prompt, seed) + + video_url = result["video"]["url"] + return _download_video(video_url, output_path) + + +def generate_all( + segments: list[dict], + images_dir: str | Path, + output_dir: str | Path, + seed: int = SEED, + progress_callback=None, +) -> list[Path]: + """Generate video clips for all segments. + + Expects images at images_dir/segment_001.png, segment_002.png, etc. + Segments should have 'prompt' and optionally 'negative_prompt' keys + (from prompt_generator). + + Args: + segments: List of segment dicts with 'segment', 'prompt' keys. + images_dir: Directory containing generated images. + output_dir: Directory to save video clips. + seed: Base seed (incremented per segment). + + Returns: + List of saved video clip paths. + """ + images_dir = Path(images_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + paths = [] + for seg in segments: + idx = seg["segment"] + image_path = images_dir / f"segment_{idx:03d}.png" + clip_path = output_dir / f"clip_{idx:03d}.mp4" + + if clip_path.exists(): + print(f" Segment {idx}/{len(segments)}: already exists, skipping") + paths.append(clip_path) + continue + + if not image_path.exists(): + print(f" Segment {idx}: image not found at {image_path}, skipping") + continue + + # Use dedicated video_prompt (detailed motion), fall back to scene + prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) + neg = seg.get("negative_prompt", "") + + print(f" Segment {idx}/{len(segments)}: generating video clip...") + t0 = time.time() + generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) + elapsed = time.time() - t0 + print(f" Saved {clip_path.name} ({elapsed:.1f}s)") + + paths.append(clip_path) + if progress_callback: + progress_callback(idx, len(segments)) + + return paths + + +def run( + data_dir: str | Path, + seed: int = SEED, + progress_callback=None, +) -> list[Path]: + """Full video generation pipeline: read segments, generate clips, save. + + Args: + data_dir: Song data directory containing segments.json and images/. + seed: Base random seed. + + Returns: + List of saved video clip paths. + """ + data_dir = Path(data_dir) + + with open(data_dir / "segments.json") as f: + segments = json.load(f) + + paths = generate_all( + segments, + images_dir=data_dir / "images", + output_dir=data_dir / "clips", + seed=seed, + progress_callback=progress_callback, + ) + + print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") + return paths + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.video_generator ") + print(" e.g. python -m src.video_generator data/Gone") + print("\nRequires FAL_KEY environment variable.") + sys.exit(1) + + if not os.getenv("FAL_KEY"): + print("Error: FAL_KEY environment variable not set.") + print("Get your key at https://fal.ai/dashboard/keys") + sys.exit(1) + + run(sys.argv[1]) diff --git a/src/video_generator_hf.py b/src/video_generator_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..29533090662a63241775a566b2c6b21a0b46c743 --- /dev/null +++ b/src/video_generator_hf.py @@ -0,0 +1,235 @@ +"""Image-to-video generation using Wan 2.1 on-device via diffusers. + +Runs Wan 2.1 14B I2V locally on GPU (designed for HF Spaces ZeroGPU). +Same public interface as video_generator_api.py so app.py can swap backends. +""" + +import json +import time +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +from PIL import Image + +# --------------------------------------------------------------------------- +# Config — matches video_generator_api.py settings +# --------------------------------------------------------------------------- + +MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + +NUM_FRAMES = 81 # ~5s at 16fps +FPS = 16 +NUM_INFERENCE_STEPS = 30 +GUIDANCE_SCALE = 5.0 +SEED = 42 + +# 480p max pixel area (480 * 832 = 399360) +MAX_AREA = 480 * 832 + +# Singleton pipeline — loaded once, reused across calls +_pipe = None + + +def _get_pipe(): + """Load Wan 2.1 I2V pipeline (lazy singleton).""" + global _pipe + if _pipe is not None: + return _pipe + + from diffusers import AutoencoderKLWan, WanImageToVideoPipeline + from transformers import CLIPVisionModel + + print(f"Loading Wan 2.1 I2V pipeline ({MODEL_ID})...") + + # VAE and image encoder must be float32 for stability + image_encoder = CLIPVisionModel.from_pretrained( + MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32, + ) + vae = AutoencoderKLWan.from_pretrained( + MODEL_ID, subfolder="vae", torch_dtype=torch.float32, + ) + + _pipe = WanImageToVideoPipeline.from_pretrained( + MODEL_ID, + vae=vae, + image_encoder=image_encoder, + torch_dtype=torch.bfloat16, + ) + + # Quantize transformer to FP8 to fit in 24GB ZeroGPU VRAM + # (~28GB bf16 → ~14GB fp8). VAE + image encoder stay float32. + from torchao.quantization import quantize_, float8_weight_only + quantize_(_pipe.transformer, float8_weight_only()) + + _pipe.to("cuda") + + print("Wan 2.1 I2V pipeline ready.") + return _pipe + + +def unload(): + """Unload the pipeline to free GPU memory.""" + global _pipe + if _pipe is not None: + _pipe.to("cpu") + del _pipe + _pipe = None + torch.cuda.empty_cache() + print("Wan 2.1 I2V pipeline unloaded.") + + +def _resize_for_480p(image: Image.Image, pipe) -> tuple[Image.Image, int, int]: + """Resize image to fit 480p area while respecting model patch constraints.""" + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(MAX_AREA * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(MAX_AREA / aspect_ratio)) // mod_value * mod_value + return image.resize((width, height)), height, width + + +def generate_clip( + image_path: str | Path, + prompt: str, + output_path: str | Path, + negative_prompt: str = "", + seed: Optional[int] = None, +) -> Path: + """Generate a video clip from an image using on-device Wan 2.1. + + Args: + image_path: Path to the source image. + prompt: Motion/scene description. + output_path: Where to save the .mp4 clip. + negative_prompt: What to avoid. + seed: Random seed. + + Returns: + Path to the saved video clip. + """ + from diffusers.utils import export_to_video + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + pipe = _get_pipe() + + # Load and resize input image + image = Image.open(image_path).convert("RGB") + image, height, width = _resize_for_480p(image, pipe) + + generator = None + if seed is not None: + generator = torch.Generator(device="cpu").manual_seed(seed) + + output = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_INFERENCE_STEPS, + guidance_scale=GUIDANCE_SCALE, + generator=generator, + ) + + export_to_video(output.frames[0], str(output_path), fps=FPS) + return output_path + + +def generate_all( + segments: list[dict], + images_dir: str | Path, + output_dir: str | Path, + seed: int = SEED, + progress_callback=None, +) -> list[Path]: + """Generate video clips for all segments. + + Args: + segments: List of segment dicts with 'segment', 'prompt' keys. + images_dir: Directory containing generated images. + output_dir: Directory to save video clips. + seed: Base seed (incremented per segment). + + Returns: + List of saved video clip paths. + """ + images_dir = Path(images_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + paths = [] + for seg in segments: + idx = seg["segment"] + image_path = images_dir / f"segment_{idx:03d}.png" + clip_path = output_dir / f"clip_{idx:03d}.mp4" + + if clip_path.exists(): + print(f" Segment {idx}/{len(segments)}: already exists, skipping") + paths.append(clip_path) + continue + + if not image_path.exists(): + print(f" Segment {idx}: image not found at {image_path}, skipping") + continue + + # Use dedicated video_prompt (detailed motion), fall back to scene + prompt = seg.get("video_prompt", seg.get("scene", seg.get("prompt", ""))) + neg = seg.get("negative_prompt", "") + + print(f" Segment {idx}/{len(segments)}: generating video clip...") + t0 = time.time() + generate_clip(image_path, prompt, clip_path, neg, seed=seed + idx) + elapsed = time.time() - t0 + print(f" Saved {clip_path.name} ({elapsed:.1f}s)") + + paths.append(clip_path) + if progress_callback: + progress_callback(idx, len(segments)) + + return paths + + +def run( + data_dir: str | Path, + seed: int = SEED, + progress_callback=None, +) -> list[Path]: + """Full video generation pipeline: read segments, generate clips, save. + + Args: + data_dir: Run directory containing segments.json and images/. + seed: Base random seed. + + Returns: + List of saved video clip paths. + """ + data_dir = Path(data_dir) + + with open(data_dir / "segments.json") as f: + segments = json.load(f) + + paths = generate_all( + segments, + images_dir=data_dir / "images", + output_dir=data_dir / "clips", + seed=seed, + progress_callback=progress_callback, + ) + + print(f"\nGenerated {len(paths)} video clips in {data_dir / 'clips'}") + return paths + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print("Usage: python -m src.video_generator_hf ") + print(" e.g. python -m src.video_generator_hf data/Gone/run_001") + sys.exit(1) + + run(sys.argv[1]) diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..e2cd4ef91f783976aab19c90654494eeac391d2a --- /dev/null +++ b/train_lora.py @@ -0,0 +1,404 @@ +"""SDXL LoRA training script — run on Google Colab (T4 GPU). + +Trains a style LoRA on SDXL using DreamBooth with 15-20 curated images. +The trained weights (.safetensors) can then be used with image_generator_hf.py / image_generator_api.py. + +Setup: + 1. Open Google Colab with a T4 GPU runtime + 2. Upload this script, or copy each section into separate cells + 3. Upload your style images to lora_training_data/ + 4. Add a .txt caption file alongside each image + 5. Run all cells in order + 6. Download the trained .safetensors from styles/ + +Dataset structure: + lora_training_data/ + image_001.png + image_001.txt # "a sunset landscape with mountains, in sks style" + image_002.jpg + image_002.txt # "a woman silhouetted against warm sky, in sks style" + ... +""" + +import json +import subprocess +import sys +from pathlib import Path + + +# --------------------------------------------------------------------------- +# Config — adjust these before training +# --------------------------------------------------------------------------- + +# Trigger word that activates your style in prompts +TRIGGER_WORD = "sks" +INSTANCE_PROMPT = f"a photo in {TRIGGER_WORD} style" + +# Training hyperparameters (tuned for 15-20 images on T4 16GB) +CONFIG = { + "base_model": "stabilityai/stable-diffusion-xl-base-1.0", + "vae": "madebyollin/sdxl-vae-fp16-fix", # fixes fp16 instability + "resolution": 1024, + "train_batch_size": 1, + "gradient_accumulation_steps": 4, # effective batch size = 4 + "learning_rate": 1e-4, + "lr_scheduler": "constant", + "lr_warmup_steps": 0, + "max_train_steps": 1500, # ~100 × num_images + "rank": 16, # LoRA rank (reduced from 32 to fit T4 16GB) + "snr_gamma": 5.0, # Min-SNR weighting for stable convergence + "mixed_precision": "fp16", # T4 doesn't support bf16 + "checkpointing_steps": 500, + "seed": 42, +} + +# Paths +DATASET_DIR = "/content/drive/MyDrive/lora_training_data" +OUTPUT_DIR = "/content/drive/MyDrive/lora_output" +FINAL_WEIGHTS_DIR = "styles" + + +# --------------------------------------------------------------------------- +# 1. Install dependencies +# --------------------------------------------------------------------------- + +def install_dependencies(): + """Install training dependencies (run once per Colab session).""" + # Clone diffusers for the training script + if not Path("diffusers").exists(): + subprocess.check_call([ + "git", "clone", "--depth", "1", + "https://github.com/huggingface/diffusers", + ]) + + # Install diffusers from source + DreamBooth requirements + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", "./diffusers", + ]) + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", + "-r", "diffusers/examples/dreambooth/requirements.txt", + ]) + + # Install remaining deps — peft last to ensure correct version + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", + "transformers", "accelerate", + "bitsandbytes", "safetensors", "Pillow", + ]) + subprocess.check_call([ + sys.executable, "-m", "pip", "install", "-q", + "peft>=0.17.0", + ]) + + print("Dependencies installed.") + + +# --------------------------------------------------------------------------- +# 2. Configure accelerate +# --------------------------------------------------------------------------- + +def configure_accelerate(): + """Write a single-GPU accelerate config.""" + from accelerate.utils import write_basic_config + + write_basic_config() + print("Accelerate configured for single GPU.") + + +# --------------------------------------------------------------------------- +# 3. Prepare dataset +# --------------------------------------------------------------------------- + +def verify_dataset(dataset_dir: str = DATASET_DIR) -> int: + """Verify dataset folder has images + metadata.jsonl (no .txt files). + + Args: + dataset_dir: Path to folder on Google Drive. + + Returns: + Number of images found. + """ + dataset_path = Path(dataset_dir) + image_extensions = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} + + images = [f for f in dataset_path.iterdir() if f.suffix.lower() in image_extensions] + metadata = dataset_path / "metadata.jsonl" + + if not images: + raise FileNotFoundError(f"No images found in {dataset_dir}/.") + if not metadata.exists(): + raise FileNotFoundError(f"metadata.jsonl not found in {dataset_dir}/.") + + # Warn if .txt files are present (will cause dataset to load as text) + txt_files = [f for f in dataset_path.glob("*.txt")] + if txt_files: + raise RuntimeError( + f"Found .txt files in dataset folder: {[f.name for f in txt_files]}. " + f"Remove them — only images + metadata.jsonl should be present." + ) + + print(f"Dataset OK: {len(images)} images + metadata.jsonl") + return len(images) + + +# --------------------------------------------------------------------------- +# 4. Train +# --------------------------------------------------------------------------- + +def train( + dataset_dir: str = DATASET_DIR, + output_dir: str = OUTPUT_DIR, + resume: bool = False, +): + """Launch DreamBooth LoRA training on SDXL. + + Args: + dataset_dir: Path to prepared dataset. + output_dir: Where to save checkpoints and final weights. + resume: If True, resume from the latest checkpoint. + """ + cfg = CONFIG + + cmd = [ + sys.executable, "-m", "accelerate.commands.launch", + "diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", + f"--pretrained_model_name_or_path={cfg['base_model']}", + f"--pretrained_vae_model_name_or_path={cfg['vae']}", + f"--dataset_name={dataset_dir}", + "--image_column=image", + "--caption_column=prompt", + f"--output_dir={output_dir}", + f"--resolution={cfg['resolution']}", + f"--train_batch_size={cfg['train_batch_size']}", + f"--gradient_accumulation_steps={cfg['gradient_accumulation_steps']}", + "--gradient_checkpointing", + "--use_8bit_adam", + f"--mixed_precision={cfg['mixed_precision']}", + f"--learning_rate={cfg['learning_rate']}", + f"--lr_scheduler={cfg['lr_scheduler']}", + f"--lr_warmup_steps={cfg['lr_warmup_steps']}", + f"--max_train_steps={cfg['max_train_steps']}", + f"--rank={cfg['rank']}", + f"--snr_gamma={cfg['snr_gamma']}", + f"--instance_prompt={INSTANCE_PROMPT}", + f"--checkpointing_steps={cfg['checkpointing_steps']}", + f"--seed={cfg['seed']}", + ] + + if resume: + cmd.append("--resume_from_checkpoint=latest") + + print("Starting training...") + print(f" Model: {cfg['base_model']}") + print(f" Steps: {cfg['max_train_steps']}") + print(f" Rank: {cfg['rank']}") + print(f" LR: {cfg['learning_rate']}") + print(f" Resume: {resume}") + print() + + # Run with live output so progress bar and errors are visible + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + bufsize=1, text=True, + ) + for line in process.stdout: + print(line, end="", flush=True) + process.wait() + if process.returncode != 0: + raise RuntimeError(f"Training failed with exit code {process.returncode}") + + print(f"\nTraining complete! Weights saved to {output_dir}/") + + +# --------------------------------------------------------------------------- +# 5. Copy weights to styles/ +# --------------------------------------------------------------------------- + +def export_weights( + output_dir: str = OUTPUT_DIR, + styles_dir: str = FINAL_WEIGHTS_DIR, + style_name: str = "custom-style", +): + """Copy trained LoRA weights to the styles directory. + + Looks for final weights first, falls back to latest checkpoint. + """ + output_path = Path(output_dir) + + # Try final weights first + src = output_path / "pytorch_lora_weights.safetensors" + + # Fall back to latest checkpoint + if not src.exists(): + checkpoints = sorted( + output_path.glob("checkpoint-*"), + key=lambda p: int(p.name.split("-")[1]), + ) + if checkpoints: + latest = checkpoints[-1] + # Check common checkpoint weight locations + for candidate in [ + latest / "pytorch_lora_weights.safetensors", + latest / "unet" / "adapter_model.safetensors", + ]: + if candidate.exists(): + src = candidate + print(f"Using checkpoint: {latest.name}") + break + + if not src.exists(): + raise FileNotFoundError( + f"No weights found in {output_dir}/. " + f"Check that training completed or a checkpoint was saved." + ) + + dst_dir = Path(styles_dir) + dst_dir.mkdir(parents=True, exist_ok=True) + dst = dst_dir / f"{style_name}.safetensors" + + import shutil + shutil.copy2(src, dst) + + size_mb = dst.stat().st_size / (1024 * 1024) + print(f"Exported weights: {dst} ({size_mb:.1f} MB)") + print(f"Download this file and place it in your project's styles/ folder.") + + +# --------------------------------------------------------------------------- +# 6. Backup to Google Drive +# --------------------------------------------------------------------------- + +def backup_to_drive(output_dir: str = OUTPUT_DIR): + """Copy training output to Google Drive for safety. + + Note: If OUTPUT_DIR already points to Drive, this is a no-op. + """ + drive_path = Path("/content/drive/MyDrive/lora_output") + + if Path(output_dir).resolve() == drive_path.resolve(): + print("Output already on Google Drive — no backup needed.") + return + + if not Path("/content/drive/MyDrive").exists(): + from google.colab import drive + drive.mount("/content/drive") + + import shutil + shutil.copytree(output_dir, str(drive_path), dirs_exist_ok=True) + print(f"Backed up to {drive_path}") + + +# --------------------------------------------------------------------------- +# 7. Test inference +# --------------------------------------------------------------------------- + +def test_inference( + output_dir: str = OUTPUT_DIR, + prompt: str = None, +): + """Generate a test image with the trained LoRA + Hyper-SD to verify quality. + + Uses the same setup as image_generator_hf.py for accurate results. + """ + import torch + from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline + from huggingface_hub import hf_hub_download + + if prompt is None: + prompt = f"a serene mountain landscape at golden hour, in {TRIGGER_WORD} style" + + print("Loading model + LoRA for test inference...") + + vae = AutoencoderKL.from_pretrained( + CONFIG["vae"], torch_dtype=torch.float16, + ) + + pipe = DiffusionPipeline.from_pretrained( + CONFIG["base_model"], + vae=vae, + torch_dtype=torch.float16, + variant="fp16", + ).to("cuda") + + # Load Hyper-SD (same as image_generator_hf.py) + hyper_path = hf_hub_download( + "ByteDance/Hyper-SD", "Hyper-SDXL-8steps-CFG-lora.safetensors", + ) + pipe.load_lora_weights(hyper_path, adapter_name="hyper-sd") + + # Load trained style LoRA (check final weights, then latest checkpoint) + output_path = Path(output_dir) + weights_file = output_path / "pytorch_lora_weights.safetensors" + if not weights_file.exists(): + checkpoints = sorted( + output_path.glob("checkpoint-*"), + key=lambda p: int(p.name.split("-")[1]), + ) + if checkpoints: + weights_file = checkpoints[-1] / "pytorch_lora_weights.safetensors" + pipe.load_lora_weights( + str(weights_file.parent), + weight_name=weights_file.name, + adapter_name="style", + ) + + pipe.set_adapters( + ["hyper-sd", "style"], + adapter_weights=[0.125, 1.0], + ) + + pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing", + ) + + image = pipe( + prompt=prompt, + negative_prompt="blurry, low quality, deformed, ugly, text, watermark", + num_inference_steps=8, + guidance_scale=5.0, + height=1344, + width=768, + ).images[0] + + image.save("test_output.png") + print(f"Test image saved to test_output.png") + print(f"Prompt: {prompt}") + + return image + + +# --------------------------------------------------------------------------- +# Main — run all steps in sequence +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("=" * 60) + print("SDXL LoRA Training Pipeline") + print("=" * 60) + + # Step 1: Install + install_dependencies() + + # Step 2: Configure + configure_accelerate() + + # Step 3: Verify dataset + num_images = verify_dataset() + steps = max(1500, num_images * 100) + CONFIG["max_train_steps"] = steps + print(f"Adjusted training steps to {steps} ({num_images} images × 100)") + + # Step 4: Train + train() + + # Step 5: Backup + backup_to_drive() + + # Step 6: Export + export_weights(style_name="custom-style") + + # Step 7: Test + test_inference() + + print("\nDone! Download styles/custom-style.safetensors")