ai-techno-dj / stem_separator.py
Rik Hoffbauer
Integrate advanced audio tooling (Rubber Band, Pedalboard, LUFS diagnostics)
e6d9cf7
from __future__ import annotations
import hashlib
import os
import logging
import platform
from pathlib import Path
from typing import Any
import librosa
import numpy as np
logger = logging.getLogger("dj_engine")
STEM_CACHE_DIR = Path(os.environ.get("AI_DJ_STEM_CACHE", "data/stem_cache"))
STEM_BACKEND_DEMUCS_MLX = "demucs-mlx"
STEM_BACKEND_DEMUCS = "demucs"
STEM_BACKEND_SPLEETER = "spleeter"
def is_apple_silicon_mac() -> bool:
return platform.system() == "Darwin" and platform.machine().lower() in {"arm64", "aarch64"}
def preferred_stem_backend() -> str:
return preferred_demucs_backend()
def preferred_demucs_backend() -> str:
if is_apple_silicon_mac():
return STEM_BACKEND_DEMUCS_MLX
return STEM_BACKEND_DEMUCS
def regular_demucs_device(torch_module: Any) -> str:
if torch_module.cuda.is_available():
return "cuda"
mps = getattr(getattr(torch_module, "backends", None), "mps", None)
if mps is not None and mps.is_available():
return "mps"
return "cpu"
def _as_stereo(y: np.ndarray) -> np.ndarray:
y = np.asarray(y)
if y.ndim == 1:
return np.stack([y, y])
if y.shape[0] != 2 and y.shape[-1] == 2:
return y.T
if y.shape[0] != 2:
flat = y.reshape(-1)
return np.stack([flat, flat])
return y
def _load_audio(path: str, samplerate: int, start: float = 0.0, end: float | None = None) -> np.ndarray:
duration = None if end is None else max(0.01, float(end) - float(start))
y, _ = librosa.load(path, sr=samplerate, mono=False, offset=max(0.0, float(start)), duration=duration)
return _as_stereo(y)
def _normalize_stems(stems: dict[str, Any], orig_sr: int, target_sr: int) -> dict[str, np.ndarray]:
normalized: dict[str, np.ndarray] = {}
for name, source in stems.items():
stem = _as_stereo(np.asarray(source))
if orig_sr != target_sr:
stem = np.stack([
librosa.resample(stem[c], orig_sr=orig_sr, target_sr=target_sr)
for c in range(stem.shape[0])
])
normalized[name] = stem
return normalized
def _separate_stems_demucs_mlx(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]:
from demucs_mlx import Separator
separator = Separator(model=model_name, overlap=0.25, split=True, progress=False)
if start <= 0 and end is None:
_, stems = separator.separate_audio_file(path)
else:
y = _load_audio(path, separator.samplerate, start, end)
_, stems = separator.separate_tensor(y)
return _normalize_stems(stems, separator.samplerate, sr)
def _separate_stems_demucs(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]:
import torch as _torch
from demucs.apply import apply_model
from demucs.pretrained import get_model
model = get_model(model_name)
model.eval()
device = regular_demucs_device(_torch)
model = model.to(device)
y = _load_audio(path, model.samplerate, start, end)
wav = _torch.from_numpy(y).float().unsqueeze(0).to(device)
ref = wav.mean(1, keepdim=True)
wav_norm = (wav - ref.mean()) / (ref.std() + 1e-8)
with _torch.no_grad():
sources = apply_model(model, wav_norm, device=device, split=True, overlap=0.25)
sources = sources[0] * ref.std() + ref.mean()
stems = {name: source.cpu().numpy() for name, source in zip(model.sources, sources)}
return _normalize_stems(stems, model.samplerate, sr)
def _separate_stems_spleeter(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "4stems") -> dict[str, np.ndarray]:
from spleeter.separator import Separator
if "stems" not in model_name:
model_name = "4stems"
separator = Separator(f'spleeter:{model_name}')
# Spleeter models are trained on 44100Hz audio
y = _load_audio(path, 44100, start, end)
# spleeter expects (samples, channels)
y_t = y.T
prediction = separator.separate(y_t)
# prediction returns {stem: (samples, channels)}
stems = {name: data.T for name, data in prediction.items()}
return _normalize_stems(stems, 44100, sr)
def _cache_key_dir(path: str, backend: str, model_name: str) -> Path:
"""Return the cache directory for a given track/backend/model combination."""
track_stem = Path(path).stem
# Include a short content hash so renamed-but-identical files share cache
# and modified files with the same name get a fresh separation.
try:
h = hashlib.sha256()
with open(path, "rb") as f:
while chunk := f.read(1 << 16):
h.update(chunk)
content_hash = h.hexdigest()[:12]
except OSError:
content_hash = "nohash"
return STEM_CACHE_DIR / backend / model_name / f"{track_stem}_{content_hash}"
def _save_stems_to_cache(
stems: dict[str, np.ndarray], cache_dir: Path, sr: int,
) -> None:
cache_dir.mkdir(parents=True, exist_ok=True)
for name, audio in stems.items():
np.save(cache_dir / f"{name}.npy", audio)
# Store sample rate so we can validate on load
(cache_dir / "_meta.txt").write_text(f"sr={sr}\n")
logger.info("Cached %d stems to %s", len(stems), cache_dir)
def _load_stems_from_cache(
cache_dir: Path, sr: int,
) -> dict[str, np.ndarray] | None:
if not cache_dir.is_dir():
return None
meta = cache_dir / "_meta.txt"
if meta.exists():
try:
cached_sr = int(meta.read_text().strip().split("=")[1])
if cached_sr != sr:
logger.info("Cache SR mismatch (%d vs %d), re-separating", cached_sr, sr)
return None
except (ValueError, IndexError):
pass
stems: dict[str, np.ndarray] = {}
for npy in cache_dir.glob("*.npy"):
stems[npy.stem] = np.load(npy)
if not stems:
return None
logger.info("Loaded %d cached stems from %s", len(stems), cache_dir)
return stems
def separate_stems_with_backend(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs", backend: str | None = None) -> tuple[dict[str, np.ndarray], str]:
if backend is None:
backend = preferred_stem_backend()
# Only cache full-track separations (start=0, end=None) since those are
# the expensive ones called during set rendering.
cacheable = start <= 0 and end is None
if cacheable:
cache_dir = _cache_key_dir(path, backend, model_name)
cached = _load_stems_from_cache(cache_dir, sr)
if cached is not None:
return cached, backend
stems, actual_backend = _separate_stems_raw(path, start, end, sr, model_name, backend)
if cacheable:
cache_dir = _cache_key_dir(path, actual_backend, model_name)
_save_stems_to_cache(stems, cache_dir, sr)
return stems, actual_backend
def _separate_stems_raw(path: str, start: float, end: float | None, sr: int, model_name: str, backend: str) -> tuple[dict[str, np.ndarray], str]:
"""Perform the actual separation without caching."""
if backend == STEM_BACKEND_SPLEETER:
return _separate_stems_spleeter(path, start, end, sr, model_name), backend
mlx_error: Exception | None = None
if backend == STEM_BACKEND_DEMUCS_MLX:
try:
return _separate_stems_demucs_mlx(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS_MLX
except Exception as exc:
mlx_error = exc
logger.warning("demucs-mlx failed on Apple Silicon; falling back to regular demucs: %s", exc)
try:
return _separate_stems_demucs(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS
except Exception as exc:
if mlx_error is not None:
raise RuntimeError(f"demucs-mlx failed: {mlx_error}; regular demucs failed: {exc}") from exc
raise