Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import hashlib | |
| import os | |
| import logging | |
| import platform | |
| from pathlib import Path | |
| from typing import Any | |
| import librosa | |
| import numpy as np | |
| logger = logging.getLogger("dj_engine") | |
| STEM_CACHE_DIR = Path(os.environ.get("AI_DJ_STEM_CACHE", "data/stem_cache")) | |
| STEM_BACKEND_DEMUCS_MLX = "demucs-mlx" | |
| STEM_BACKEND_DEMUCS = "demucs" | |
| STEM_BACKEND_SPLEETER = "spleeter" | |
| def is_apple_silicon_mac() -> bool: | |
| return platform.system() == "Darwin" and platform.machine().lower() in {"arm64", "aarch64"} | |
| def preferred_stem_backend() -> str: | |
| return preferred_demucs_backend() | |
| def preferred_demucs_backend() -> str: | |
| if is_apple_silicon_mac(): | |
| return STEM_BACKEND_DEMUCS_MLX | |
| return STEM_BACKEND_DEMUCS | |
| def regular_demucs_device(torch_module: Any) -> str: | |
| if torch_module.cuda.is_available(): | |
| return "cuda" | |
| mps = getattr(getattr(torch_module, "backends", None), "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _as_stereo(y: np.ndarray) -> np.ndarray: | |
| y = np.asarray(y) | |
| if y.ndim == 1: | |
| return np.stack([y, y]) | |
| if y.shape[0] != 2 and y.shape[-1] == 2: | |
| return y.T | |
| if y.shape[0] != 2: | |
| flat = y.reshape(-1) | |
| return np.stack([flat, flat]) | |
| return y | |
| def _load_audio(path: str, samplerate: int, start: float = 0.0, end: float | None = None) -> np.ndarray: | |
| duration = None if end is None else max(0.01, float(end) - float(start)) | |
| y, _ = librosa.load(path, sr=samplerate, mono=False, offset=max(0.0, float(start)), duration=duration) | |
| return _as_stereo(y) | |
| def _normalize_stems(stems: dict[str, Any], orig_sr: int, target_sr: int) -> dict[str, np.ndarray]: | |
| normalized: dict[str, np.ndarray] = {} | |
| for name, source in stems.items(): | |
| stem = _as_stereo(np.asarray(source)) | |
| if orig_sr != target_sr: | |
| stem = np.stack([ | |
| librosa.resample(stem[c], orig_sr=orig_sr, target_sr=target_sr) | |
| for c in range(stem.shape[0]) | |
| ]) | |
| normalized[name] = stem | |
| return normalized | |
| def _separate_stems_demucs_mlx(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]: | |
| from demucs_mlx import Separator | |
| separator = Separator(model=model_name, overlap=0.25, split=True, progress=False) | |
| if start <= 0 and end is None: | |
| _, stems = separator.separate_audio_file(path) | |
| else: | |
| y = _load_audio(path, separator.samplerate, start, end) | |
| _, stems = separator.separate_tensor(y) | |
| return _normalize_stems(stems, separator.samplerate, sr) | |
| def _separate_stems_demucs(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs") -> dict[str, np.ndarray]: | |
| import torch as _torch | |
| from demucs.apply import apply_model | |
| from demucs.pretrained import get_model | |
| model = get_model(model_name) | |
| model.eval() | |
| device = regular_demucs_device(_torch) | |
| model = model.to(device) | |
| y = _load_audio(path, model.samplerate, start, end) | |
| wav = _torch.from_numpy(y).float().unsqueeze(0).to(device) | |
| ref = wav.mean(1, keepdim=True) | |
| wav_norm = (wav - ref.mean()) / (ref.std() + 1e-8) | |
| with _torch.no_grad(): | |
| sources = apply_model(model, wav_norm, device=device, split=True, overlap=0.25) | |
| sources = sources[0] * ref.std() + ref.mean() | |
| stems = {name: source.cpu().numpy() for name, source in zip(model.sources, sources)} | |
| return _normalize_stems(stems, model.samplerate, sr) | |
| def _separate_stems_spleeter(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "4stems") -> dict[str, np.ndarray]: | |
| from spleeter.separator import Separator | |
| if "stems" not in model_name: | |
| model_name = "4stems" | |
| separator = Separator(f'spleeter:{model_name}') | |
| # Spleeter models are trained on 44100Hz audio | |
| y = _load_audio(path, 44100, start, end) | |
| # spleeter expects (samples, channels) | |
| y_t = y.T | |
| prediction = separator.separate(y_t) | |
| # prediction returns {stem: (samples, channels)} | |
| stems = {name: data.T for name, data in prediction.items()} | |
| return _normalize_stems(stems, 44100, sr) | |
| def _cache_key_dir(path: str, backend: str, model_name: str) -> Path: | |
| """Return the cache directory for a given track/backend/model combination.""" | |
| track_stem = Path(path).stem | |
| # Include a short content hash so renamed-but-identical files share cache | |
| # and modified files with the same name get a fresh separation. | |
| try: | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| while chunk := f.read(1 << 16): | |
| h.update(chunk) | |
| content_hash = h.hexdigest()[:12] | |
| except OSError: | |
| content_hash = "nohash" | |
| return STEM_CACHE_DIR / backend / model_name / f"{track_stem}_{content_hash}" | |
| def _save_stems_to_cache( | |
| stems: dict[str, np.ndarray], cache_dir: Path, sr: int, | |
| ) -> None: | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| for name, audio in stems.items(): | |
| np.save(cache_dir / f"{name}.npy", audio) | |
| # Store sample rate so we can validate on load | |
| (cache_dir / "_meta.txt").write_text(f"sr={sr}\n") | |
| logger.info("Cached %d stems to %s", len(stems), cache_dir) | |
| def _load_stems_from_cache( | |
| cache_dir: Path, sr: int, | |
| ) -> dict[str, np.ndarray] | None: | |
| if not cache_dir.is_dir(): | |
| return None | |
| meta = cache_dir / "_meta.txt" | |
| if meta.exists(): | |
| try: | |
| cached_sr = int(meta.read_text().strip().split("=")[1]) | |
| if cached_sr != sr: | |
| logger.info("Cache SR mismatch (%d vs %d), re-separating", cached_sr, sr) | |
| return None | |
| except (ValueError, IndexError): | |
| pass | |
| stems: dict[str, np.ndarray] = {} | |
| for npy in cache_dir.glob("*.npy"): | |
| stems[npy.stem] = np.load(npy) | |
| if not stems: | |
| return None | |
| logger.info("Loaded %d cached stems from %s", len(stems), cache_dir) | |
| return stems | |
| def separate_stems_with_backend(path: str, start: float = 0.0, end: float | None = None, sr: int = 44100, model_name: str = "htdemucs", backend: str | None = None) -> tuple[dict[str, np.ndarray], str]: | |
| if backend is None: | |
| backend = preferred_stem_backend() | |
| # Only cache full-track separations (start=0, end=None) since those are | |
| # the expensive ones called during set rendering. | |
| cacheable = start <= 0 and end is None | |
| if cacheable: | |
| cache_dir = _cache_key_dir(path, backend, model_name) | |
| cached = _load_stems_from_cache(cache_dir, sr) | |
| if cached is not None: | |
| return cached, backend | |
| stems, actual_backend = _separate_stems_raw(path, start, end, sr, model_name, backend) | |
| if cacheable: | |
| cache_dir = _cache_key_dir(path, actual_backend, model_name) | |
| _save_stems_to_cache(stems, cache_dir, sr) | |
| return stems, actual_backend | |
| def _separate_stems_raw(path: str, start: float, end: float | None, sr: int, model_name: str, backend: str) -> tuple[dict[str, np.ndarray], str]: | |
| """Perform the actual separation without caching.""" | |
| if backend == STEM_BACKEND_SPLEETER: | |
| return _separate_stems_spleeter(path, start, end, sr, model_name), backend | |
| mlx_error: Exception | None = None | |
| if backend == STEM_BACKEND_DEMUCS_MLX: | |
| try: | |
| return _separate_stems_demucs_mlx(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS_MLX | |
| except Exception as exc: | |
| mlx_error = exc | |
| logger.warning("demucs-mlx failed on Apple Silicon; falling back to regular demucs: %s", exc) | |
| try: | |
| return _separate_stems_demucs(path, start, end, sr, model_name), STEM_BACKEND_DEMUCS | |
| except Exception as exc: | |
| if mlx_error is not None: | |
| raise RuntimeError(f"demucs-mlx failed: {mlx_error}; regular demucs failed: {exc}") from exc | |
| raise | |