stem_worker / separate.py
vincewin's picture
Deploy AI stem splitter
6af6d8d verified
Raw
History Blame Contribute Delete
14.3 kB
"""
Stem separation core (Demucs) with optional cascades.
- base 4-stem / 6-stem separation (official Meta models, trusted)
- drum split: the 'drums' stem -> kick / snare / toms / cymbals
(community 'drumsep' Hybrid-Demucs model; loaded with the full unpickler,
which the user explicitly authorized)
- vocal split: the 'vocals' stem -> lead / backing (added on top of drums)
Files are copied, never modified; analysis runs only on copies.
"""
from __future__ import annotations
import os
import re
from typing import Iterable
import numpy as np
HERE = os.path.dirname(os.path.abspath(__file__))
# ZeroGPU: @spaces.GPU marks the GPU entry point so HF attaches a GPU for that
# call only. Off Spaces (local / CPU), `spaces` isn't installed -> no-op decorator.
try:
import spaces
_gpu = spaces.GPU(duration=300)
except Exception:
def _gpu(fn):
return fn
# ---- base models (official Demucs) ----
MODEL_STEMS = {
"htdemucs": ["drums", "bass", "other", "vocals"],
"htdemucs_ft": ["drums", "bass", "other", "vocals"],
"htdemucs_6s": ["drums", "bass", "other", "vocals", "guitar", "piano"],
"mdx_extra": ["drums", "bass", "other", "vocals"],
}
DEFAULT_MODEL = "htdemucs"
# ---- drum-split (community 'drumsep' model) ----
DRUM_REPO = os.path.join(HERE, "models")
DRUM_MODEL = "49469ca8"
DRUM_HF = "vincewin/drumsep" # mirror the Space downloads from
DRUM_MAP = {"bombo": "kick", "redoblante": "snare", "platillos": "cymbals", "toms": "toms"}
DRUM_ORDER = ["kick", "snare", "toms", "cymbals"]
# ---- RoFormer engine (newer architecture, via the 'audio-separator' package) ----
# BS-RoFormer vocals model: cleaner vocals/instrumental split than Demucs, 2 stems.
ROFORMER_MODEL = "model_bs_roformer_ep_317_sdr_12.9755.ckpt"
# Engines selectable in the app: label-friendly quality tiers.
ENGINES = ("demucs", "demucs_ft", "roformer")
# ---- modes offered in the app: how many stems and how to reach them ----
MODES = {
"4": {"base": "htdemucs", "drums": False, "vocals": False},
"6": {"base": "htdemucs_6s", "drums": False, "vocals": False},
"9": {"base": "htdemucs_6s", "drums": True, "vocals": False},
}
DEFAULT_MODE = "4"
def available_models() -> list[str]:
return list(MODEL_STEMS.keys())
def list_stems(model_name: str = DEFAULT_MODEL) -> list[str]:
return MODEL_STEMS.get(model_name, MODEL_STEMS[DEFAULT_MODEL])
def mode_stems(mode: str) -> list[str]:
"""Final stem names a mode produces (for the checkbox UI), in display order."""
cfg = MODES[mode]
out: list[str] = []
for s in MODEL_STEMS[cfg["base"]]:
if s == "drums" and cfg["drums"]:
out += DRUM_ORDER
elif s == "vocals" and cfg["vocals"]:
out += ["lead vocal", "backing vocal"]
else:
out.append(s)
return out
def _safe(name: str) -> str:
name = re.sub(r'[<>:"/\\|?*\x00-\x1f]', "_", name).strip(" .")
return name or "track"
def _pick_device(device: str | None) -> str:
if device:
return device
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
return "cpu"
# ---------------- model loading (cached) ----------------
_CACHE: dict = {}
_PATCHED = False
def _full_unpickle():
"""Allow torch.load to deserialize the community drumsep checkpoint.
PyTorch >=2.6 defaults to weights_only=True, which refuses pickled model
objects. The user explicitly authorized loading this third-party model.
"""
global _PATCHED
if _PATCHED:
return
import torch
_orig = torch.load
torch.load = lambda *a, **k: _orig(*a, **{**k, "weights_only": False})
_PATCHED = True
def _ensure_drum_repo() -> str:
path = os.path.join(DRUM_REPO, DRUM_MODEL + ".th")
if not os.path.exists(path):
from huggingface_hub import hf_hub_download
os.makedirs(DRUM_REPO, exist_ok=True)
hf_hub_download(repo_id=DRUM_HF, filename=DRUM_MODEL + ".th", local_dir=DRUM_REPO)
return DRUM_REPO
def _load(name: str, repo: str | None = None):
key = (name, repo)
if key not in _CACHE:
_full_unpickle()
from pathlib import Path
from demucs.pretrained import get_model
m = get_model(name=name, repo=Path(repo) if repo else None)
m.eval()
_CACHE[key] = m
return _CACHE[key]
def _read_audio(path: str, sr: int, ch: int):
from demucs.audio import AudioFile
return AudioFile(path).read(streams=0, samplerate=sr, channels=ch)
def _run(model, wav, device, shifts=0, overlap=0.25):
"""Normalize, run apply_model, de-normalize. wav: tensor [channels, samples].
shifts > 0 enables Demucs's "shift trick" (test-time augmentation): the input
is shifted a few times and the results averaged, which reduces artifacts at a
roughly linear cost in time. overlap controls window blending at chunk edges.
"""
import torch
from demucs.apply import apply_model
ref = wav.mean(0)
x = (wav - ref.mean()) / (ref.std() + 1e-8)
with torch.no_grad():
out = apply_model(model, x[None].to(device), shifts=int(shifts), split=True,
overlap=overlap, progress=False, device=device)[0]
return out * (ref.std() + 1e-8) + ref.mean()
def _write(tensor, out_path: str, sr: int):
import soundfile as sf
sf.write(out_path, np.clip(tensor.cpu().numpy().T, -1.0, 1.0), sr, subtype="PCM_16")
# ---------------- public API ----------------
def separate(input_path, out_dir, model_name=DEFAULT_MODEL, stems=None,
device=None, progress=None, shifts=0, overlap=0.25) -> list[str]:
"""Base (single-model) separation. Used by the CLI and batch tool."""
device = _pick_device(device)
if progress:
progress(0.05, f"Loading model '{model_name}' on {device}...")
model = _load(model_name)
sources = list(model.sources)
wanted = [s for s in (list(stems) if stems else sources) if s in sources]
if not wanted:
raise ValueError(f"No valid stems for {model_name}. Choose from: {sources}")
if progress:
progress(0.1, "Reading audio...")
wav = _read_audio(input_path, model.samplerate, model.audio_channels)
if progress:
progress(0.2, "Separating audio (slow on CPU)...")
out = _run(model, wav, device, shifts=shifts, overlap=overlap)
os.makedirs(out_dir, exist_ok=True)
song = _safe(os.path.splitext(os.path.basename(input_path))[0])
written = []
for i, name in enumerate(sources):
if name not in wanted:
continue
p = os.path.join(out_dir, f"{song} - {name}.wav")
_write(out[i], p, model.samplerate)
written.append(p)
if progress:
progress(1.0, "Done")
return written
def separate_mode(input_path, out_dir, mode=DEFAULT_MODE, stems=None,
device=None, progress=None, base_override=None,
shifts=0, overlap=0.25) -> list[str]:
"""Cascaded separation by mode ('4', '6', '9'). Writes '<song> - <stem>.wav'.
base_override swaps the mode's base model (e.g. the fine-tuned 'htdemucs_ft'
for cleaner 4-stem output). shifts/overlap tune separation quality vs. speed.
"""
cfg = MODES[mode]
device = _pick_device(device)
if progress:
progress(0.05, "Loading model...")
base = _load(base_override or cfg["base"])
sr = base.samplerate
if progress:
progress(0.1, "Reading audio...")
wav = _read_audio(input_path, sr, base.audio_channels)
if progress:
progress(0.15, "Separating main stems (slow on CPU)...")
out = _run(base, wav, device, shifts=shifts, overlap=overlap)
result = {name: out[i] for i, name in enumerate(base.sources)}
if cfg["drums"] and "drums" in result:
if progress:
progress(0.6, "Splitting drums into kick / snare / toms / cymbals...")
dm = _load(DRUM_MODEL, repo=_ensure_drum_repo())
parts = _run(dm, result.pop("drums"), device, shifts=shifts, overlap=overlap)
for i, src in enumerate(dm.sources):
result[DRUM_MAP.get(src, src)] = parts[i]
os.makedirs(out_dir, exist_ok=True)
song = _safe(os.path.splitext(os.path.basename(input_path))[0])
order = mode_stems(mode)
wanted = [s for s in (list(stems) if stems else order) if s in result]
written = []
for name in order:
if name not in wanted:
continue
p = os.path.join(out_dir, f"{song} - {name}.wav")
_write(result[name], p, sr)
written.append(p)
if progress:
progress(1.0, "Done")
return written
def separate_roformer(input_path, out_dir, progress=None) -> list[str]:
"""High-quality vocals/instrumental split using a BS-RoFormer model.
Uses the 'audio-separator' package, which downloads the model on first use.
Produces two stems; files are renamed to the '<song> - <stem>.wav' convention.
"""
from audio_separator.separator import Separator
os.makedirs(out_dir, exist_ok=True)
if progress:
progress(0.1, "Loading RoFormer model (first run downloads it ~1 min)...")
sep = Separator(output_dir=out_dir, output_format="WAV")
sep.load_model(model_filename=ROFORMER_MODEL)
if progress:
progress(0.3, "Separating with RoFormer (slow on CPU)...")
produced = sep.separate(input_path) # list of output file names (in out_dir)
song = _safe(os.path.splitext(os.path.basename(input_path))[0])
written = []
for fn in produced:
src = fn if os.path.isabs(fn) else os.path.join(out_dir, fn)
low = os.path.basename(fn).lower()
stem = "vocals" if "vocal" in low else ("instrumental" if "instrument" in low else _safe(os.path.splitext(os.path.basename(fn))[0]))
dst = os.path.join(out_dir, f"{song} - {stem}.wav")
if os.path.abspath(src) != os.path.abspath(dst):
os.replace(src, dst)
written.append(dst)
if progress:
progress(1.0, "Done")
return written
def run_separation(input_path, out_dir, engine="demucs", mode=DEFAULT_MODE,
stems=None, device=None, progress=None, shifts=0, overlap=0.25) -> list[str]:
"""Single entry point the app uses; dispatches on the chosen engine.
- 'demucs' : the mode's standard model (4/6/9-stem), with optional shifts.
- 'demucs_ft' : fine-tuned model for 4-stem (falls back to standard for 6/9),
cleaner but ~4x slower.
- 'roformer' : BS-RoFormer (newest architecture) — best vocals/instrumental,
always 2 stems; mode/stem selection is ignored.
shifts/overlap are quality knobs (higher = cleaner, slower).
"""
if engine == "roformer":
return separate_roformer(input_path, out_dir, progress=progress)
base_override = None
if engine == "demucs_ft" and MODES[mode]["base"] == "htdemucs":
base_override = "htdemucs_ft"
return separate_mode(input_path, out_dir, mode, stems, device=device,
progress=progress, base_override=base_override,
shifts=shifts, overlap=overlap)
@_gpu
def gpu_separate(input_path, out_dir, engine="demucs", mode=DEFAULT_MODE,
stems=None, shifts=0, overlap=0.25) -> list[str]:
"""GPU entry point for ZeroGPU. Takes only picklable args (no progress callback,
which can't cross the GPU process boundary). Device auto-selects cuda when present."""
return run_separation(input_path, out_dir, engine=engine, mode=mode,
stems=stems, device=None, shifts=shifts, overlap=overlap,
progress=None)
def merge_stems(paths, out_path) -> str:
"""Mix several stem .wav files back into one track by summing their samples.
Demucs stems are additive (they sum to the original), so combining a subset
is just a sample-wise sum. Inputs are aligned on length/channels defensively
in case the caller mixes files from different sources.
"""
import soundfile as sf
paths = list(paths)
if len(paths) < 2:
raise ValueError("Select at least two stems to merge.")
mix = None
sr = None
for p in paths:
data, file_sr = sf.read(p, always_2d=True) # shape [samples, channels]
data = data.astype(np.float64)
if mix is None:
mix, sr = data, file_sr
continue
if file_sr != sr:
raise ValueError("Stems have different sample rates; cannot merge.")
if data.shape[1] != mix.shape[1]: # mono vs stereo -> upmix mono
wide = max(data.shape[1], mix.shape[1])
if data.shape[1] == 1:
data = np.repeat(data, wide, axis=1)
if mix.shape[1] == 1:
mix = np.repeat(mix, wide, axis=1)
if data.shape[0] != mix.shape[0]: # trim to the shorter one
n = min(data.shape[0], mix.shape[0])
mix, data = mix[:n], data[:n]
mix = mix + data
os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
sf.write(out_path, np.clip(mix, -1.0, 1.0), sr, subtype="PCM_16")
return out_path
def stem_of(path: str) -> str:
"""Recover the stem label from a '<song> - <stem>.wav' filename."""
base = os.path.splitext(os.path.basename(path))[0]
return base.split(" - ")[-1] if " - " in base else base
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser(description="Separate an audio file into stems.")
ap.add_argument("input")
ap.add_argument("-o", "--out", default="stems_out")
ap.add_argument("--mode", default=None, choices=list(MODES.keys()),
help="4 / 6 / 9 stems (cascaded). Overrides --model.")
ap.add_argument("-m", "--model", default=DEFAULT_MODEL, choices=available_models())
ap.add_argument("-s", "--stems", nargs="*", default=None)
ap.add_argument("--device", default=None)
a = ap.parse_args()
pr = lambda f, m: print(f"[{f*100:5.1f}%] {m}")
if a.mode:
paths = separate_mode(a.input, a.out, a.mode, a.stems, a.device, progress=pr)
else:
paths = separate(a.input, a.out, a.model, a.stems, a.device, progress=pr)
print("\nWrote:")
for p in paths:
print(" ", p)