StemSplit commited on
Commit
72f6f07
·
verified ·
1 Parent(s): 3d648bf

Initial release: complete 4-stem htdemucs_ft ONNX bag (drums/bass/other/vocals) + numpy aggregator

Browse files
Files changed (3) hide show
  1. README.md +257 -0
  2. bag_infer.py +194 -0
  3. requirements.txt +3 -0
README.md ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ library_name: onnxruntime
5
+ pipeline_tag: audio-to-audio
6
+ tags:
7
+ - onnx
8
+ - onnxruntime
9
+ - stem-separation
10
+ - source-separation
11
+ - vocal-isolation
12
+ - vocal-remover
13
+ - drum-extraction
14
+ - bass-extraction
15
+ - karaoke
16
+ - demucs
17
+ - htdemucs
18
+ - music
19
+ - audio-to-audio
20
+ - mobile
21
+ - ios
22
+ - android
23
+ - coreml
24
+ - directml
25
+ - production-ready
26
+ datasets:
27
+ - StemSplitio/stem-separation-benchmark-2026
28
+ inference: false
29
+ ---
30
+
31
+ # HT-Demucs FT — Full 4-Stem Bag, ONNX
32
+
33
+ **The first complete ONNX export of HT-Demucs FT on the Hugging Face Hub.**
34
+ Four parity-verified ONNX models (drums, bass, other, vocals) plus a
35
+ ~250-line numpy aggregator that runs the full 4-stem separation in pure
36
+ `onnxruntime`. **No PyTorch required at inference.** Runs on CPU /
37
+ CoreML / CUDA / DirectML.
38
+
39
+ This repo is the convenience drop — all 4 specialist sub-models of
40
+ `htdemucs_ft` in one place, with a working bag-inference script. If you
41
+ only need one stem in production, the individual stem-specialist repos
42
+ below are ~75% smaller and ~4× faster per song.
43
+
44
+ ---
45
+
46
+ ## TL;DR
47
+
48
+ ```bash
49
+ pip install onnxruntime numpy soundfile
50
+ python bag_infer.py your-song.mp3 ./out/
51
+ # writes out/drums.wav, out/bass.wav, out/other.wav, out/vocals.wav
52
+ ```
53
+
54
+ That's it. The 4 `.onnx` files (316 MB each, ~1.26 GB total) live
55
+ alongside the script.
56
+
57
+ ---
58
+
59
+ ## Quality
60
+
61
+ Median per-stem SDR on the MUSDB18-HQ test split (50 songs), BSS Eval v4
62
+ via `museval`. **Identical to the official PyTorch `htdemucs_ft`** — the
63
+ bag's per-stem output IS the corresponding specialist's output (the weight
64
+ matrix is one-hot per stem).
65
+
66
+ | Stem | SDR (dB) | Rank in our 2026 benchmark |
67
+ |---|---:|---|
68
+ | **vocals** | **9.19** | **#1** (highest open-source vocal SDR) |
69
+ | drums | 10.11 | #2 (mdx_extra_q leads at 11.49) |
70
+ | bass | 10.38 | #2 (mdx_extra_q leads at 11.42) |
71
+ | other | 6.34 | #2 (mdx_extra_q leads at 7.67) |
72
+
73
+ Full benchmark across every popular open-source separator:
74
+ [StemSplitio/stem-separation-benchmark-2026](https://huggingface.co/datasets/StemSplitio/stem-separation-benchmark-2026).
75
+
76
+ **ONNX vs PyTorch parity:** verified to < 1e-3 max abs diff on every stem
77
+ during export. See the
78
+ [Day 1 spike report](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx#how-it-was-built)
79
+ for the full engineering writeup.
80
+
81
+ ---
82
+
83
+ ## Performance
84
+
85
+ Real measurements on an Apple M4 Pro:
86
+
87
+ | Mode | Hardware | Per 3-min song | Notes |
88
+ |---|---|---:|---|
89
+ | One specialist (`htdemucs-ft-drums-onnx`) | M4 Pro CPU | **~22 s** | 4× faster, 75% smaller — use this if you only need one stem |
90
+ | **Full bag (this repo)** | M4 Pro CPU | **~88 s** | RTF ~0.5. 4 sub-models × N chunks. |
91
+ | Full bag | M4 Pro CPU (8 threads) | ~60 s | With `OMP_NUM_THREADS=8` and SessionOptions tuned |
92
+ | Full bag | NVIDIA L4 CUDA | ~6 s | Extrapolated from per-specialist CUDA numbers |
93
+ | Full bag | NVIDIA T4 | ~16 s | Extrapolated |
94
+ | PyTorch full bag | M4 Pro MPS | ~47 s | Faster only because MPS is GPU-accelerated; ONNX-CUDA beats it cleanly. |
95
+
96
+ ---
97
+
98
+ ## Common use cases
99
+
100
+ - **Karaoke makers** — `out/other.wav` minus `out/vocals.wav` gives a clean
101
+ karaoke track plus an acapella in one pass.
102
+ - **DAW stem export** — drop the 4 `.wav` files into Ableton / Logic /
103
+ Reaper as separate channels for remixing.
104
+ - **DJ stems software** — load all 4 stems as live-mixable tracks.
105
+ - **AI music apps** — feed each stem into downstream models (drum
106
+ transcription, bassline-to-MIDI, vocal pitch correction).
107
+ - **Acapella sampling** — clean isolated vocals at the highest SDR
108
+ available in open source.
109
+ - **Mobile / on-device separation** — replaces a 1+ GB PyTorch install
110
+ with `onnxruntime`'s 50 MB binary on iOS / Android.
111
+
112
+ ---
113
+
114
+ ## Quick start
115
+
116
+ ### Python — as a library
117
+
118
+ ```python
119
+ import bag_infer
120
+
121
+ stems = bag_infer.separate_all("your-song.mp3")
122
+ # stems: dict[str, numpy.ndarray (2, samples)]
123
+ # stems["drums"], stems["bass"], stems["other"], stems["vocals"]
124
+ ```
125
+
126
+ ### Python — with execution provider control
127
+
128
+ ```python
129
+ import soundfile as sf
130
+ import bag_infer
131
+
132
+ audio, sr = sf.read("your-song.mp3", dtype="float32", always_2d=True)
133
+ stems = bag_infer.separate(
134
+ audio.T, sr,
135
+ providers=["CPUExecutionProvider"], # or "CoreMLExecutionProvider", etc.
136
+ )
137
+ for name, audio in stems.items():
138
+ sf.write(f"{name}.wav", audio.T, sr)
139
+ ```
140
+
141
+ ### CLI
142
+
143
+ ```bash
144
+ python bag_infer.py your-song.mp3 ./out/
145
+ python bag_infer.py your-song.mp3 ./out/ --providers cuda
146
+ python bag_infer.py your-song.mp3 ./out/ --providers coreml
147
+ python bag_infer.py your-song.mp3 ./out/ --providers dml
148
+ ```
149
+
150
+ ### Web / mobile
151
+
152
+ Each specialist is a vanilla onnxruntime model; just load all 4 sessions
153
+ and reuse the aggregation logic in `bag_infer.py::separate`. See the
154
+ individual stem repos for platform-specific snippets:
155
+ [drums](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx) ·
156
+ [bass](https://huggingface.co/StemSplitio/htdemucs-ft-bass-onnx) ·
157
+ [other](https://huggingface.co/StemSplitio/htdemucs-ft-other-onnx) ·
158
+ [vocals](https://huggingface.co/StemSplitio/htdemucs-ft-vocals-onnx).
159
+
160
+ ---
161
+
162
+ ## How aggregation works
163
+
164
+ The `htdemucs_ft` bag uses a **one-hot weight matrix** for combining the 4
165
+ sub-models — model 0's drums output is used directly as the bag's drums
166
+ stem, model 1's bass output is the bag's bass stem, and so on. No
167
+ weighted-sum aggregation needed.
168
+
169
+ That means:
170
+ - **The bag's drums stem == the drums specialist's drums output** (bit-exact in fp32)
171
+ - Same for bass, other, vocals
172
+ - So you can ship only the specialists you need and get identical
173
+ per-stem quality to the full bag at 1/4 the size
174
+
175
+ `bag_infer.py` simply runs all 4 specialists and picks the relevant row
176
+ from each. ~30 lines of numpy.
177
+
178
+ ---
179
+
180
+ ## Input / output spec per sub-model
181
+
182
+ | Tensor | Name | Shape | Dtype | Notes |
183
+ |---|---|---|---|---|
184
+ | Input | `mix` | `(1, 2, 343980)` | float32 | Stereo audio, 44.1 kHz, 7.8 s segment. |
185
+ | Output | `stems` | `(1, 4, 2, 343980)` | float32 | `[drums, bass, other, vocals]`. Use only the specialist's target row. |
186
+
187
+ For longer audio, the bag script handles overlap-add chunking.
188
+
189
+ ---
190
+
191
+ ## Files in this repo
192
+
193
+ | File | Size | Purpose |
194
+ |---|---:|---|
195
+ | `htdemucs_ft_drums.onnx` | 316 MB | Drums specialist (bag index 0) |
196
+ | `htdemucs_ft_bass.onnx` | 316 MB | Bass specialist (bag index 1) |
197
+ | `htdemucs_ft_other.onnx` | 316 MB | Other specialist (bag index 2) |
198
+ | `htdemucs_ft_vocals.onnx` | 316 MB | Vocals specialist (bag index 3) |
199
+ | `bag_infer.py` | 7 KB | Pure numpy aggregator. No torch. |
200
+ | `requirements.txt` | <1 KB | `onnxruntime`, `numpy`, `soundfile`. |
201
+ | `README.md` | this file | |
202
+
203
+ Total: **~1.26 GB**. If that's too big, use individual stem repos.
204
+
205
+ ---
206
+
207
+ ## Related work
208
+
209
+ | Repo | Stem | Use when |
210
+ |---|---|---|
211
+ | [`htdemucs-ft-drums-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx) | drums | Only need drums (1/4 size, 1/4 latency) |
212
+ | [`htdemucs-ft-bass-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-bass-onnx) | bass | Only need bass |
213
+ | [`htdemucs-ft-other-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-other-onnx) | other | Only need "other" / instrumental |
214
+ | [`htdemucs-ft-vocals-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-vocals-onnx) | vocals | **#1 open-source vocal SDR** |
215
+
216
+ PyTorch versions for HF Inference Endpoints:
217
+ [`htdemucs-ft-pytorch`](https://huggingface.co/StemSplitio/htdemucs-ft-pytorch)
218
+ and its [4 sibling specialist repos](https://huggingface.co/StemSplitio).
219
+
220
+ ---
221
+
222
+ ## Skip the infrastructure — use the StemSplit API
223
+
224
+ Don't want to ship 1.26 GB of `.onnx` files in your app, manage a GPU
225
+ pool, or write overlap-add chunking? Use the
226
+ **[StemSplit API](https://stemsplit.io/developers)** instead — same models
227
+ under the hood, hosted for you, with credits and a dashboard.
228
+
229
+ - 🌐 [stemsplit.io](https://stemsplit.io)
230
+ - 📘 [Developer docs](https://stemsplit.io/developers/docs)
231
+ - 🔌 [API reference](https://stemsplit.io/developers/reference)
232
+
233
+ Or use the no-code tools that ship this same model family:
234
+
235
+ - 🎤 [Vocal Remover](https://stemsplit.io/vocal-remover)
236
+ - 🎶 [Karaoke Maker](https://stemsplit.io/karaoke-maker)
237
+ - 🎙️ [Acapella Maker](https://stemsplit.io/acapella-maker)
238
+ - 📺 [YouTube Stem Splitter](https://stemsplit.io/youtube-stem-splitter)
239
+
240
+ ---
241
+
242
+ ## License & attribution
243
+
244
+ MIT-licensed, matching the original HT-Demucs.
245
+
246
+ ```bibtex
247
+ @inproceedings{rouard2023hybrid,
248
+ title = {Hybrid Transformers for Music Source Separation},
249
+ author = {Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
250
+ booktitle = {ICASSP},
251
+ year = {2023}
252
+ }
253
+ ```
254
+
255
+ - Original PyTorch model: [`facebookresearch/demucs`](https://github.com/facebookresearch/demucs)
256
+ - ONNX export, parity verification, and packaging by [StemSplit](https://stemsplit.io)
257
+ - Search keywords: htdemucs onnx, demucs onnx, htdemucs bag onnx, demucs ios, demucs android, music source separation onnx, 4-stem separation onnx, stem separation mobile, onnxruntime music separation
bag_infer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bag inference for the full HT-Demucs FT 4-stem ONNX ensemble.
3
+
4
+ Runs all 4 specialist sub-models and aggregates their outputs using the
5
+ htdemucs_ft bag's one-hot weight matrix (drums-model -> drums stem only,
6
+ bass-model -> bass stem only, etc).
7
+
8
+ NO TORCH at inference. Just numpy + onnxruntime + soundfile.
9
+
10
+ Usage:
11
+ python bag_infer.py your-song.mp3 ./out/
12
+ # writes out/drums.wav, out/bass.wav, out/other.wav, out/vocals.wav
13
+
14
+ Or as a library:
15
+ import bag_infer
16
+ stems = bag_infer.separate_all("song.mp3")
17
+ # stems: dict[str, numpy.ndarray (2, samples)]
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import onnxruntime as ort
28
+ import soundfile as sf
29
+
30
+ SAMPLE_RATE = 44100
31
+ SEGMENT_S = 7.8
32
+ N_SAMPLES = int(SEGMENT_S * SAMPLE_RATE) # 343,980
33
+ N_CHANNELS = 2
34
+ SOURCES = ["drums", "bass", "other", "vocals"]
35
+ HERE = Path(__file__).resolve().parent
36
+
37
+ # The bag's weight matrix for htdemucs_ft is one-hot per stem:
38
+ # drums specialist (bag.models[0]) -> contributes only to drums stem
39
+ # bass specialist (bag.models[1]) -> contributes only to bass stem
40
+ # other specialist (bag.models[2]) -> contributes only to other stem
41
+ # vocals specialist (bag.models[3]) -> contributes only to vocals stem
42
+ # So aggregation is trivial: pick row N from model N's output.
43
+ DEFAULT_ONNX_FILES = {
44
+ "drums": HERE / "htdemucs_ft_drums.onnx",
45
+ "bass": HERE / "htdemucs_ft_bass.onnx",
46
+ "other": HERE / "htdemucs_ft_other.onnx",
47
+ "vocals": HERE / "htdemucs_ft_vocals.onnx",
48
+ }
49
+
50
+
51
+ def _make_transition_window(segment: int, overlap_frac: float = 0.25) -> np.ndarray:
52
+ transition = int(segment * overlap_frac)
53
+ window = np.ones(segment, dtype=np.float32)
54
+ fade = np.linspace(0, 1, transition, dtype=np.float32)
55
+ window[:transition] = fade
56
+ window[-transition:] = fade[::-1]
57
+ return window
58
+
59
+
60
+ def _load_sessions(onnx_files: dict[str, Path],
61
+ providers: list[str] | None = None,
62
+ ) -> dict[str, ort.InferenceSession]:
63
+ if providers is None:
64
+ providers = ["CPUExecutionProvider"]
65
+ sessions: dict[str, ort.InferenceSession] = {}
66
+ for stem, path in onnx_files.items():
67
+ if not path.exists():
68
+ raise FileNotFoundError(
69
+ f"Missing {stem} model at {path}. Download all 4 .onnx files "
70
+ "into the same directory as this script.")
71
+ sessions[stem] = ort.InferenceSession(str(path), providers=providers)
72
+ return sessions
73
+
74
+
75
+ def separate(mix: np.ndarray, sample_rate: int,
76
+ onnx_files: dict[str, Path] | None = None,
77
+ providers: list[str] | None = None,
78
+ verbose: bool = True) -> dict[str, np.ndarray]:
79
+ """Run full 4-stem chunked overlap-add separation.
80
+
81
+ Args:
82
+ mix: (channels, samples) float32 in [-1, 1], 44.1 kHz stereo.
83
+ sample_rate: must equal 44100.
84
+ onnx_files: optional dict overriding the default file locations.
85
+ providers: onnxruntime EPs; defaults to CPU.
86
+ verbose: print progress per chunk.
87
+
88
+ Returns:
89
+ dict of {stem_name: (channels, samples) float32}.
90
+ """
91
+ if sample_rate != SAMPLE_RATE:
92
+ raise ValueError(f"Bound to {SAMPLE_RATE} Hz; got {sample_rate}.")
93
+ if mix.ndim != 2 or mix.shape[0] != N_CHANNELS:
94
+ raise ValueError(f"Expected (2, samples) input, got {mix.shape}")
95
+
96
+ sessions = _load_sessions(onnx_files or DEFAULT_ONNX_FILES, providers)
97
+ if verbose:
98
+ print(f" loaded {len(sessions)} ONNX sessions on "
99
+ f"{list(sessions.values())[0].get_providers()[0]}")
100
+
101
+ total_len = mix.shape[1]
102
+ overlap = N_SAMPLES // 4
103
+ stride = N_SAMPLES - overlap
104
+ n_chunks = max(1, (total_len + stride - 1) // stride)
105
+
106
+ if verbose:
107
+ print(f" input: {total_len:,} samples ({total_len / sample_rate:.1f}s)")
108
+ print(f" chunks: {n_chunks}")
109
+
110
+ window = _make_transition_window(N_SAMPLES)
111
+ out = {stem: np.zeros((N_CHANNELS, total_len), dtype=np.float32) for stem in SOURCES}
112
+ weight = np.zeros(total_len, dtype=np.float32)
113
+
114
+ t0 = time.perf_counter()
115
+ for i in range(n_chunks):
116
+ start = i * stride
117
+ end = min(start + N_SAMPLES, total_len)
118
+ chunk = mix[:, start:end]
119
+ if chunk.shape[1] < N_SAMPLES:
120
+ chunk = np.pad(chunk, ((0, 0), (0, N_SAMPLES - chunk.shape[1])),
121
+ mode="constant")
122
+ x = chunk[np.newaxis, ...].astype(np.float32)
123
+ chunk_len = end - start
124
+ w = window[:chunk_len]
125
+
126
+ # Run each specialist; take only its target stem row.
127
+ for stem in SOURCES:
128
+ stems = sessions[stem].run(["stems"], {"mix": x})[0][0] # (4, 2, N)
129
+ target_row = SOURCES.index(stem) # 0/1/2/3 matches bag.models[idx]
130
+ out[stem][:, start:end] += stems[target_row, :, :chunk_len] * w
131
+
132
+ weight[start:end] += w
133
+ if verbose:
134
+ print(f" chunk {i+1}/{n_chunks}: "
135
+ f"{time.perf_counter() - t0:.1f}s elapsed")
136
+
137
+ weight = np.maximum(weight, 1e-8)
138
+ for stem in SOURCES:
139
+ out[stem] /= weight
140
+
141
+ if verbose:
142
+ rtf = (time.perf_counter() - t0) / (total_len / sample_rate)
143
+ print(f" total: {time.perf_counter() - t0:.2f}s (RTF {rtf:.2f}, "
144
+ f"4 sub-models × {n_chunks} chunks = "
145
+ f"{4 * n_chunks} ONNX runs)")
146
+ return out
147
+
148
+
149
+ def separate_all(input_path: str, **kwargs) -> dict[str, np.ndarray]:
150
+ """Convenience: load audio, run separation, return all 4 stems."""
151
+ audio, sr = sf.read(input_path, dtype="float32", always_2d=True)
152
+ audio = audio.T
153
+ if audio.shape[0] == 1:
154
+ audio = np.tile(audio, (2, 1))
155
+ elif audio.shape[0] > 2:
156
+ audio = audio[:2]
157
+ return separate(audio, sr, **kwargs)
158
+
159
+
160
+ def main() -> None:
161
+ ap = argparse.ArgumentParser(description=__doc__)
162
+ ap.add_argument("input", type=Path)
163
+ ap.add_argument("out_dir", type=Path)
164
+ ap.add_argument("--providers", type=str, default="cpu",
165
+ choices=["cpu", "coreml", "cuda", "dml"])
166
+ args = ap.parse_args()
167
+
168
+ providers_map = {
169
+ "cpu": ["CPUExecutionProvider"],
170
+ "coreml": ["CoreMLExecutionProvider", "CPUExecutionProvider"],
171
+ "cuda": ["CUDAExecutionProvider", "CPUExecutionProvider"],
172
+ "dml": ["DmlExecutionProvider", "CPUExecutionProvider"],
173
+ }
174
+ args.out_dir.mkdir(parents=True, exist_ok=True)
175
+
176
+ print(f"Loading {args.input} ...")
177
+ audio, sr = sf.read(str(args.input), dtype="float32", always_2d=True)
178
+ audio = audio.T
179
+ if audio.shape[0] == 1:
180
+ audio = np.tile(audio, (2, 1))
181
+ elif audio.shape[0] > 2:
182
+ audio = audio[:2]
183
+ print(f" shape {audio.shape}, sr {sr}")
184
+
185
+ stems = separate(audio, sr, providers=providers_map[args.providers])
186
+
187
+ for stem, audio_out in stems.items():
188
+ out_path = args.out_dir / f"{stem}.wav"
189
+ sf.write(str(out_path), audio_out.T, sr)
190
+ print(f" wrote {out_path}")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ onnxruntime>=1.20
2
+ numpy>=1.24
3
+ soundfile>=0.12