"""LALAL.AI API wrapper for audio stem separation.""" import os import shutil import time from pathlib import Path from typing import Optional import requests API_BASE = "https://www.lalal.ai/api/v1" DATA_DIR = Path(__file__).parent.parent / "data" # Stems we need for the pipeline STEMS_TO_EXTRACT = ["vocals", "drum"] # Map LALAL.AI track labels to our file naming convention LABEL_TO_FILENAME = {"vocals": "vocals.wav", "drum": "drums.wav"} def _get_api_key() -> str: key = os.environ.get("LALAL_KEY") if not key: raise RuntimeError( "LALAL_KEY environment variable not set. " "Set it locally or as a HuggingFace Space secret." ) return key def _headers(api_key: str) -> dict: return {"X-License-Key": api_key} def _next_run_dir(song_dir: Path) -> Path: """Find the next available run directory (run_001, run_002, ...).""" existing = sorted(song_dir.glob("run_*")) next_num = 1 for d in existing: try: num = int(d.name.split("_")[1]) next_num = max(next_num, num + 1) except (IndexError, ValueError): continue return song_dir / f"run_{next_num:03d}" def _upload(audio_path: Path, api_key: str) -> str: """Upload audio file to LALAL.AI. Returns source_id.""" with open(audio_path, "rb") as f: resp = requests.post( f"{API_BASE}/upload/", headers={ **_headers(api_key), "Content-Disposition": f'attachment; filename="{audio_path.name}"', }, data=f, ) resp.raise_for_status() data = resp.json() source_id = data["id"] print(f" Uploaded {audio_path.name} → source_id={source_id} " f"(duration: {data['duration']:.1f}s)") return source_id def _split_stem(source_id: str, stem: str, api_key: str) -> str: """Start a stem separation task. Returns task_id.""" # Andromeda is best for vocals but doesn't support all stems — use auto for others splitter = "andromeda" if stem == "vocals" else None resp = requests.post( f"{API_BASE}/split/stem_separator/", headers=_headers(api_key), json={ "source_id": source_id, "presets": { "stem": stem, "splitter": splitter, "dereverb_enabled": False, "encoder_format": "wav", "extraction_level": "deep_extraction", }, }, ) resp.raise_for_status() data = resp.json() task_id = data["task_id"] print(f" Split task started: stem={stem}, task_id={task_id}") return task_id def _poll_tasks(task_ids: list[str], api_key: str, poll_interval: float = 5.0) -> dict: """Poll tasks until all complete. Returns {task_id: result_data}.""" pending = set(task_ids) results = {} while pending: resp = requests.post( f"{API_BASE}/check/", headers=_headers(api_key), json={"task_ids": list(pending)}, ) resp.raise_for_status() data = resp.json().get("result", resp.json()) for task_id, info in data.items(): status = info.get("status") if status == "success": results[task_id] = info pending.discard(task_id) print(f" Task {task_id}: complete") elif status == "progress": print(f" Task {task_id}: {info.get('progress', 0)}%") elif status == "error": error = info.get("error", {}) raise RuntimeError( f"LALAL.AI task {task_id} failed: " f"{error.get('detail', 'unknown error')} " f"(code: {error.get('code')})" ) elif status == "cancelled": raise RuntimeError(f"LALAL.AI task {task_id} was cancelled") elif status == "server_error": raise RuntimeError( f"LALAL.AI server error for task {task_id}: " f"{info.get('error', 'unknown')}" ) if pending: time.sleep(poll_interval) return results def _download_track(url: str, output_path: Path) -> None: """Download a track from LALAL.AI CDN.""" resp = requests.get(url, stream=True) resp.raise_for_status() with open(output_path, "wb") as f: for chunk in resp.iter_content(chunk_size=8192): f.write(chunk) print(f" Downloaded → {output_path.name} ({output_path.stat().st_size / 1024:.0f} KB)") def _delete_source(source_id: str, api_key: str) -> None: """Delete uploaded source file from LALAL.AI servers.""" try: requests.post( f"{API_BASE}/delete/", headers=_headers(api_key), json={"source_id": source_id}, ) print(f" Cleaned up remote source {source_id}") except Exception: pass # non-critical def separate_stems( audio_path: str | Path, output_dir: Optional[str | Path] = None, ) -> dict[str, Path]: """Separate an audio file into vocals and drums using LALAL.AI. Creates a new run directory for each invocation so multiple runs on the same song don't overwrite each other. Args: audio_path: Path to the input audio file (mp3/wav) from input/. output_dir: Directory to save stems. If None, auto-creates data//run_NNN/stems/. Returns: Dict mapping stem names to their file paths. Keys: "drums", "vocals", "run_dir" """ audio_path = Path(audio_path) song_name = audio_path.stem song_dir = DATA_DIR / song_name api_key = _get_api_key() if output_dir is None: run_dir = _next_run_dir(song_dir) output_dir = run_dir / "stems" else: output_dir = Path(output_dir) run_dir = output_dir.parent output_dir.mkdir(parents=True, exist_ok=True) # Copy original song into song directory (shared across runs) song_copy = song_dir / audio_path.name if not song_copy.exists(): shutil.copy2(audio_path, song_copy) # 1. Upload print("Stem separation (LALAL.AI):") source_id = _upload(audio_path, api_key) # 2. Start split tasks for each stem task_to_stem = {} for stem in STEMS_TO_EXTRACT: task_id = _split_stem(source_id, stem, api_key) task_to_stem[task_id] = stem # 3. Poll until all tasks complete results = _poll_tasks(list(task_to_stem.keys()), api_key) # 4. Download the separated stem tracks stem_paths = {"run_dir": run_dir} for task_id, result_data in results.items(): stem = task_to_stem[task_id] filename = LABEL_TO_FILENAME[stem] tracks = result_data.get("result", {}).get("tracks", []) # Find the "stem" track (not the "back"/inverse track) stem_track = next((t for t in tracks if t["type"] == "stem"), None) if stem_track is None: raise RuntimeError(f"No stem track found in result for {stem}") output_path = output_dir / filename _download_track(stem_track["url"], output_path) # Map to our naming: "drum" API stem → "drums" key key = "drums" if stem == "drum" else stem stem_paths[key] = output_path # 5. Cleanup remote files _delete_source(source_id, api_key) return stem_paths if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python -m src.stem_separator ") sys.exit(1) result = separate_stems(sys.argv[1]) print(f"Run directory: {result['run_dir']}") for name, path in result.items(): if name != "run_dir": print(f" {name}: {path}")