| | """LALAL.AI API wrapper for audio stem separation.""" |
| |
|
| | import os |
| | import shutil |
| | import time |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import requests |
| |
|
| | API_BASE = "https://www.lalal.ai/api/v1" |
| | DATA_DIR = Path(__file__).parent.parent / "data" |
| |
|
| | |
| | STEMS_TO_EXTRACT = ["vocals", "drum"] |
| | |
| | LABEL_TO_FILENAME = {"vocals": "vocals.wav", "drum": "drums.wav"} |
| |
|
| |
|
| | def _get_api_key() -> str: |
| | key = os.environ.get("LALAL_KEY") |
| | if not key: |
| | raise RuntimeError( |
| | "LALAL_KEY environment variable not set. " |
| | "Set it locally or as a HuggingFace Space secret." |
| | ) |
| | return key |
| |
|
| |
|
| | def _headers(api_key: str) -> dict: |
| | return {"X-License-Key": api_key} |
| |
|
| |
|
| | def _next_run_dir(song_dir: Path) -> Path: |
| | """Find the next available run directory (run_001, run_002, ...).""" |
| | existing = sorted(song_dir.glob("run_*")) |
| | next_num = 1 |
| | for d in existing: |
| | try: |
| | num = int(d.name.split("_")[1]) |
| | next_num = max(next_num, num + 1) |
| | except (IndexError, ValueError): |
| | continue |
| | return song_dir / f"run_{next_num:03d}" |
| |
|
| |
|
| | def _upload(audio_path: Path, api_key: str) -> str: |
| | """Upload audio file to LALAL.AI. Returns source_id.""" |
| | with open(audio_path, "rb") as f: |
| | resp = requests.post( |
| | f"{API_BASE}/upload/", |
| | headers={ |
| | **_headers(api_key), |
| | "Content-Disposition": f'attachment; filename="{audio_path.name}"', |
| | }, |
| | data=f, |
| | ) |
| | resp.raise_for_status() |
| | data = resp.json() |
| | source_id = data["id"] |
| | print(f" Uploaded {audio_path.name} → source_id={source_id} " |
| | f"(duration: {data['duration']:.1f}s)") |
| | return source_id |
| |
|
| |
|
| | def _split_stem(source_id: str, stem: str, api_key: str) -> str: |
| | """Start a stem separation task. Returns task_id.""" |
| | |
| | splitter = "andromeda" if stem == "vocals" else None |
| | resp = requests.post( |
| | f"{API_BASE}/split/stem_separator/", |
| | headers=_headers(api_key), |
| | json={ |
| | "source_id": source_id, |
| | "presets": { |
| | "stem": stem, |
| | "splitter": splitter, |
| | "dereverb_enabled": False, |
| | "encoder_format": "wav", |
| | "extraction_level": "deep_extraction", |
| | }, |
| | }, |
| | ) |
| | resp.raise_for_status() |
| | data = resp.json() |
| | task_id = data["task_id"] |
| | print(f" Split task started: stem={stem}, task_id={task_id}") |
| | return task_id |
| |
|
| |
|
| | def _poll_tasks(task_ids: list[str], api_key: str, poll_interval: float = 5.0) -> dict: |
| | """Poll tasks until all complete. Returns {task_id: result_data}.""" |
| | pending = set(task_ids) |
| | results = {} |
| |
|
| | while pending: |
| | resp = requests.post( |
| | f"{API_BASE}/check/", |
| | headers=_headers(api_key), |
| | json={"task_ids": list(pending)}, |
| | ) |
| | resp.raise_for_status() |
| | data = resp.json().get("result", resp.json()) |
| |
|
| | for task_id, info in data.items(): |
| | status = info.get("status") |
| | if status == "success": |
| | results[task_id] = info |
| | pending.discard(task_id) |
| | print(f" Task {task_id}: complete") |
| | elif status == "progress": |
| | print(f" Task {task_id}: {info.get('progress', 0)}%") |
| | elif status == "error": |
| | error = info.get("error", {}) |
| | raise RuntimeError( |
| | f"LALAL.AI task {task_id} failed: " |
| | f"{error.get('detail', 'unknown error')} " |
| | f"(code: {error.get('code')})" |
| | ) |
| | elif status == "cancelled": |
| | raise RuntimeError(f"LALAL.AI task {task_id} was cancelled") |
| | elif status == "server_error": |
| | raise RuntimeError( |
| | f"LALAL.AI server error for task {task_id}: " |
| | f"{info.get('error', 'unknown')}" |
| | ) |
| |
|
| | if pending: |
| | time.sleep(poll_interval) |
| |
|
| | return results |
| |
|
| |
|
| | def _download_track(url: str, output_path: Path) -> None: |
| | """Download a track from LALAL.AI CDN.""" |
| | resp = requests.get(url, stream=True) |
| | resp.raise_for_status() |
| | with open(output_path, "wb") as f: |
| | for chunk in resp.iter_content(chunk_size=8192): |
| | f.write(chunk) |
| | print(f" Downloaded → {output_path.name} ({output_path.stat().st_size / 1024:.0f} KB)") |
| |
|
| |
|
| | def _delete_source(source_id: str, api_key: str) -> None: |
| | """Delete uploaded source file from LALAL.AI servers.""" |
| | try: |
| | requests.post( |
| | f"{API_BASE}/delete/", |
| | headers=_headers(api_key), |
| | json={"source_id": source_id}, |
| | ) |
| | print(f" Cleaned up remote source {source_id}") |
| | except Exception: |
| | pass |
| |
|
| |
|
| | def separate_stems( |
| | audio_path: str | Path, |
| | output_dir: Optional[str | Path] = None, |
| | ) -> dict[str, Path]: |
| | """Separate an audio file into vocals and drums using LALAL.AI. |
| | |
| | Creates a new run directory for each invocation so multiple runs |
| | on the same song don't overwrite each other. |
| | |
| | Args: |
| | audio_path: Path to the input audio file (mp3/wav) from input/. |
| | output_dir: Directory to save stems. If None, auto-creates |
| | data/<song>/run_NNN/stems/. |
| | |
| | Returns: |
| | Dict mapping stem names to their file paths. |
| | Keys: "drums", "vocals", "run_dir" |
| | """ |
| | audio_path = Path(audio_path) |
| | song_name = audio_path.stem |
| | song_dir = DATA_DIR / song_name |
| | api_key = _get_api_key() |
| |
|
| | if output_dir is None: |
| | run_dir = _next_run_dir(song_dir) |
| | output_dir = run_dir / "stems" |
| | else: |
| | output_dir = Path(output_dir) |
| | run_dir = output_dir.parent |
| |
|
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | song_copy = song_dir / audio_path.name |
| | if not song_copy.exists(): |
| | shutil.copy2(audio_path, song_copy) |
| |
|
| | |
| | print("Stem separation (LALAL.AI):") |
| | source_id = _upload(audio_path, api_key) |
| |
|
| | |
| | task_to_stem = {} |
| | for stem in STEMS_TO_EXTRACT: |
| | task_id = _split_stem(source_id, stem, api_key) |
| | task_to_stem[task_id] = stem |
| |
|
| | |
| | results = _poll_tasks(list(task_to_stem.keys()), api_key) |
| |
|
| | |
| | stem_paths = {"run_dir": run_dir} |
| | for task_id, result_data in results.items(): |
| | stem = task_to_stem[task_id] |
| | filename = LABEL_TO_FILENAME[stem] |
| | tracks = result_data.get("result", {}).get("tracks", []) |
| |
|
| | |
| | stem_track = next((t for t in tracks if t["type"] == "stem"), None) |
| | if stem_track is None: |
| | raise RuntimeError(f"No stem track found in result for {stem}") |
| |
|
| | output_path = output_dir / filename |
| | _download_track(stem_track["url"], output_path) |
| |
|
| | |
| | key = "drums" if stem == "drum" else stem |
| | stem_paths[key] = output_path |
| |
|
| | |
| | _delete_source(source_id, api_key) |
| |
|
| | return stem_paths |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | if len(sys.argv) < 2: |
| | print("Usage: python -m src.stem_separator <audio_file>") |
| | sys.exit(1) |
| |
|
| | result = separate_stems(sys.argv[1]) |
| | print(f"Run directory: {result['run_dir']}") |
| | for name, path in result.items(): |
| | if name != "run_dir": |
| | print(f" {name}: {path}") |
| |
|