from __future__ import annotations import hashlib import os import logging import platform from pathlib import Path from typing import Any import librosa import numpy as np logger = logging.getLogger("dj_engine") STEM_CACHE_DIR = Path(os.environ.get("AI_DJ_STEM_CACHE", "data/stem_cache")) STEM_BACKEND_DEMUCS_MLX = "demucs-mlx" STEM_BACKEND_DEMUCS = "demucs" STEM_BACKEND_SPLEETER = "spleeter" def is_apple_silicon_mac() -> bool: return platform.system() == "Darwin" and platform.machine().lower() in {"arm64", "aarch64"} def preferred_stem_backend() -> str: return preferred_demucs_backend() def preferred_demucs_backend() -> str: if is_apple_silicon_mac(): return STEM_BACKEND_DEMUCS_MLX return STEM_BACKEND_DEMUCS def regular_demucs_device(torch_module: Any) -> str: if torch_module.cuda.is_available(): return "cuda" mps = getattr(getattr(torch_module, "backends", None), "mps", None) if mps is not None and mps.is_available(): return "mps" return "cpu" def _as_stereo(y: np.ndarray) -> np.ndarray: y = np.asarray(y) if y.ndim == 1: return np.stack([y, y]) if y.shape[0] != 2 and y.shape[-1] == 2: return y.T if y.shape[0] != 2: flat = y.reshape(-1) return np.stack([flat, flat]) return y def _load_audio(path: str, samplerate: int, start: float = 0.0, end: float | None = None) -> np.ndarray: duration = None if end is None else max(0.01, float(end) - float(start)) y, _ = librosa.load(path, sr=samplerate, mono=False, offset=max(0.0, float(start)), duration=duration) return _as_stereo(y) def _normalize_stems(stems: dict[str, Any], orig_sr: int, target_sr: int) -> dict[str, np.ndarray]: normalized: dict[str, np.ndarray] = {} for name, source in stems.items(): stem = _as_stereo(np.asarray(source)) if orig_sr != target_sr: stem = np.stack([ librosa.resample(stem[c], orig_sr=orig_sr, target_sr=target_sr) for c in range(stem.shape[0]) ]) normalized[name] = stem return normalized def _separate_stems_demucs_mlx(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]: from demucs_mlx import Separator separator = Separator(model=model_name, overlap=0.25, split=True, progress=False) if start <= 0 and end is None: _, stems = separator.separate_audio_file(path) else: y = _load_audio(path, separator.samplerate, start, end) _, stems = separator.separate_tensor(y) return _normalize_stems(stems, separator.samplerate, sr) def _separate_stems_demucs(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]: import torch as _torch from demucs.apply import apply_model from demucs.pretrained import get_model model = get_model(model_name) model.eval() device = regular_demucs_device(_torch) model = model.to(device) y = _load_audio(path, model.samplerate, start, end) wav = _torch.from_numpy(y).float().unsqueeze(0).to(device) ref = wav.mean(1, keepdim=True) wav_norm = (wav - ref.mean()) / (ref.std() + 1e-8) with _torch.no_grad(): sources = apply_model(model, wav_norm, device=device, split=True, overlap=0.25) sources = sources[0] * ref.std() + ref.mean() stems = {name: source.cpu().numpy() for name, source in zip(model.sources, sources)} return _normalize_stems(stems, model.samplerate, sr) def _separate_stems_spleeter(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "4stems") -> dict[str, np.ndarray]: from spleeter.separator import Separator if "stems" not in model_name: model_name = "4stems" separator = Separator(f'spleeter:{model_name}') # Spleeter models are trained on 44100Hz audio y = _load_audio(path, 44100, start, end) # spleeter expects (samples, channels) y_t = y.T prediction = separator.separate(y_t) # prediction returns {stem: (samples, channels)} stems = {name: data.T for name, data in prediction.items()} return _normalize_stems(stems, 44100, sr) def _cache_key_dir(path: str, backend: str, model_name: str) -> Path: """Return the cache directory for a given track/backend/model combination.""" track_stem = Path(path).stem # Include a short content hash so renamed-but-identical files share cache # and modified files with the same name get a fresh separation. try: h = hashlib.sha256() with open(path, "rb") as f: while chunk := f.read(1 << 16): h.update(chunk) content_hash = h.hexdigest()[:12] except OSError: content_hash = "nohash" return STEM_CACHE_DIR / backend / model_name / f"{track_stem}_{content_hash}" def _save_stems_to_cache( stems: dict[str, np.ndarray], cache_dir: Path, sr: int, ) -> None: cache_dir.mkdir(parents=True, exist_ok=True) for name, audio in stems.items(): np.save(cache_dir / f"{name}.npy", audio) # Store sample rate so we can validate on load (cache_dir / "_meta.txt").write_text(f"sr={sr}\n") logger.info("Cached %d stems to %s", len(stems), cache_dir) def _load_stems_from_cache( cache_dir: Path, sr: int, ) -> dict[str, np.ndarray] | None: if not cache_dir.is_dir(): return None meta = cache_dir / "_meta.txt" if meta.exists(): try: cached_sr = int(meta.read_text().strip().split("=")[1]) if cached_sr != sr: logger.info("Cache SR mismatch (%d vs %d), re-separating", cached_sr, sr) return None except (ValueError, IndexError): pass stems: dict[str, np.ndarray] = {} for npy in cache_dir.glob("*.npy"): stems[npy.stem] = np.load(npy) if not stems: return None logger.info("Loaded %d cached stems from %s", len(stems), cache_dir) return stems def separate_stems_with_backend(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs", backend: str | None = None) -> tuple[dict[str, np.ndarray], str]: if backend is None: backend = preferred_stem_backend() # Only cache full-track separations (start=0, end=None) since those are # the expensive ones called during set rendering. cacheable = start <= 0 and end is None if cacheable: cache_dir = _cache_key_dir(path, backend, model_name) cached = _load_stems_from_cache(cache_dir, sr) if cached is not None: return cached, backend stems, actual_backend = _separate_stems_raw(path, start, end, sr, model_name, backend) if cacheable: cache_dir = _cache_key_dir(path, actual_backend, model_name) _save_stems_to_cache(stems, cache_dir, sr) return stems, actual_backend def _separate_stems_raw(path: str, start: float, end: float | None, sr: int, model_name: str, backend: str) -> tuple[dict[str, np.ndarray], str]: """Perform the actual separation without caching.""" if backend == STEM_BACKEND_SPLEETER: return _separate_stems_spleeter(path, start, end, sr, model_name), backend mlx_error: Exception | None = None if backend == STEM_BACKEND_DEMUCS_MLX: try: return _separate_stems_demucs_mlx(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS_MLX except Exception as exc: mlx_error = exc logger.warning("demucs-mlx failed on Apple Silicon; falling back to regular demucs: %s", exc) try: return _separate_stems_demucs(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS except Exception as exc: if mlx_error is not None: raise RuntimeError(f"demucs-mlx failed: {mlx_error}; regular demucs failed: {exc}") from exc raise