Initial release: complete 4-stem htdemucs_ft ONNX bag (drums/bass/other/vocals) + numpy aggregator
Browse files- README.md +257 -0
- bag_infer.py +194 -0
- requirements.txt +3 -0
README.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: en
|
| 3 |
+
license: mit
|
| 4 |
+
library_name: onnxruntime
|
| 5 |
+
pipeline_tag: audio-to-audio
|
| 6 |
+
tags:
|
| 7 |
+
- onnx
|
| 8 |
+
- onnxruntime
|
| 9 |
+
- stem-separation
|
| 10 |
+
- source-separation
|
| 11 |
+
- vocal-isolation
|
| 12 |
+
- vocal-remover
|
| 13 |
+
- drum-extraction
|
| 14 |
+
- bass-extraction
|
| 15 |
+
- karaoke
|
| 16 |
+
- demucs
|
| 17 |
+
- htdemucs
|
| 18 |
+
- music
|
| 19 |
+
- audio-to-audio
|
| 20 |
+
- mobile
|
| 21 |
+
- ios
|
| 22 |
+
- android
|
| 23 |
+
- coreml
|
| 24 |
+
- directml
|
| 25 |
+
- production-ready
|
| 26 |
+
datasets:
|
| 27 |
+
- StemSplitio/stem-separation-benchmark-2026
|
| 28 |
+
inference: false
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
# HT-Demucs FT — Full 4-Stem Bag, ONNX
|
| 32 |
+
|
| 33 |
+
**The first complete ONNX export of HT-Demucs FT on the Hugging Face Hub.**
|
| 34 |
+
Four parity-verified ONNX models (drums, bass, other, vocals) plus a
|
| 35 |
+
~250-line numpy aggregator that runs the full 4-stem separation in pure
|
| 36 |
+
`onnxruntime`. **No PyTorch required at inference.** Runs on CPU /
|
| 37 |
+
CoreML / CUDA / DirectML.
|
| 38 |
+
|
| 39 |
+
This repo is the convenience drop — all 4 specialist sub-models of
|
| 40 |
+
`htdemucs_ft` in one place, with a working bag-inference script. If you
|
| 41 |
+
only need one stem in production, the individual stem-specialist repos
|
| 42 |
+
below are ~75% smaller and ~4× faster per song.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## TL;DR
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
pip install onnxruntime numpy soundfile
|
| 50 |
+
python bag_infer.py your-song.mp3 ./out/
|
| 51 |
+
# writes out/drums.wav, out/bass.wav, out/other.wav, out/vocals.wav
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
That's it. The 4 `.onnx` files (316 MB each, ~1.26 GB total) live
|
| 55 |
+
alongside the script.
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Quality
|
| 60 |
+
|
| 61 |
+
Median per-stem SDR on the MUSDB18-HQ test split (50 songs), BSS Eval v4
|
| 62 |
+
via `museval`. **Identical to the official PyTorch `htdemucs_ft`** — the
|
| 63 |
+
bag's per-stem output IS the corresponding specialist's output (the weight
|
| 64 |
+
matrix is one-hot per stem).
|
| 65 |
+
|
| 66 |
+
| Stem | SDR (dB) | Rank in our 2026 benchmark |
|
| 67 |
+
|---|---:|---|
|
| 68 |
+
| **vocals** | **9.19** | **#1** (highest open-source vocal SDR) |
|
| 69 |
+
| drums | 10.11 | #2 (mdx_extra_q leads at 11.49) |
|
| 70 |
+
| bass | 10.38 | #2 (mdx_extra_q leads at 11.42) |
|
| 71 |
+
| other | 6.34 | #2 (mdx_extra_q leads at 7.67) |
|
| 72 |
+
|
| 73 |
+
Full benchmark across every popular open-source separator:
|
| 74 |
+
[StemSplitio/stem-separation-benchmark-2026](https://huggingface.co/datasets/StemSplitio/stem-separation-benchmark-2026).
|
| 75 |
+
|
| 76 |
+
**ONNX vs PyTorch parity:** verified to < 1e-3 max abs diff on every stem
|
| 77 |
+
during export. See the
|
| 78 |
+
[Day 1 spike report](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx#how-it-was-built)
|
| 79 |
+
for the full engineering writeup.
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Performance
|
| 84 |
+
|
| 85 |
+
Real measurements on an Apple M4 Pro:
|
| 86 |
+
|
| 87 |
+
| Mode | Hardware | Per 3-min song | Notes |
|
| 88 |
+
|---|---|---:|---|
|
| 89 |
+
| One specialist (`htdemucs-ft-drums-onnx`) | M4 Pro CPU | **~22 s** | 4× faster, 75% smaller — use this if you only need one stem |
|
| 90 |
+
| **Full bag (this repo)** | M4 Pro CPU | **~88 s** | RTF ~0.5. 4 sub-models × N chunks. |
|
| 91 |
+
| Full bag | M4 Pro CPU (8 threads) | ~60 s | With `OMP_NUM_THREADS=8` and SessionOptions tuned |
|
| 92 |
+
| Full bag | NVIDIA L4 CUDA | ~6 s | Extrapolated from per-specialist CUDA numbers |
|
| 93 |
+
| Full bag | NVIDIA T4 | ~16 s | Extrapolated |
|
| 94 |
+
| PyTorch full bag | M4 Pro MPS | ~47 s | Faster only because MPS is GPU-accelerated; ONNX-CUDA beats it cleanly. |
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Common use cases
|
| 99 |
+
|
| 100 |
+
- **Karaoke makers** — `out/other.wav` minus `out/vocals.wav` gives a clean
|
| 101 |
+
karaoke track plus an acapella in one pass.
|
| 102 |
+
- **DAW stem export** — drop the 4 `.wav` files into Ableton / Logic /
|
| 103 |
+
Reaper as separate channels for remixing.
|
| 104 |
+
- **DJ stems software** — load all 4 stems as live-mixable tracks.
|
| 105 |
+
- **AI music apps** — feed each stem into downstream models (drum
|
| 106 |
+
transcription, bassline-to-MIDI, vocal pitch correction).
|
| 107 |
+
- **Acapella sampling** — clean isolated vocals at the highest SDR
|
| 108 |
+
available in open source.
|
| 109 |
+
- **Mobile / on-device separation** — replaces a 1+ GB PyTorch install
|
| 110 |
+
with `onnxruntime`'s 50 MB binary on iOS / Android.
|
| 111 |
+
|
| 112 |
+
---
|
| 113 |
+
|
| 114 |
+
## Quick start
|
| 115 |
+
|
| 116 |
+
### Python — as a library
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
import bag_infer
|
| 120 |
+
|
| 121 |
+
stems = bag_infer.separate_all("your-song.mp3")
|
| 122 |
+
# stems: dict[str, numpy.ndarray (2, samples)]
|
| 123 |
+
# stems["drums"], stems["bass"], stems["other"], stems["vocals"]
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Python — with execution provider control
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
import soundfile as sf
|
| 130 |
+
import bag_infer
|
| 131 |
+
|
| 132 |
+
audio, sr = sf.read("your-song.mp3", dtype="float32", always_2d=True)
|
| 133 |
+
stems = bag_infer.separate(
|
| 134 |
+
audio.T, sr,
|
| 135 |
+
providers=["CPUExecutionProvider"], # or "CoreMLExecutionProvider", etc.
|
| 136 |
+
)
|
| 137 |
+
for name, audio in stems.items():
|
| 138 |
+
sf.write(f"{name}.wav", audio.T, sr)
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
### CLI
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
python bag_infer.py your-song.mp3 ./out/
|
| 145 |
+
python bag_infer.py your-song.mp3 ./out/ --providers cuda
|
| 146 |
+
python bag_infer.py your-song.mp3 ./out/ --providers coreml
|
| 147 |
+
python bag_infer.py your-song.mp3 ./out/ --providers dml
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Web / mobile
|
| 151 |
+
|
| 152 |
+
Each specialist is a vanilla onnxruntime model; just load all 4 sessions
|
| 153 |
+
and reuse the aggregation logic in `bag_infer.py::separate`. See the
|
| 154 |
+
individual stem repos for platform-specific snippets:
|
| 155 |
+
[drums](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx) ·
|
| 156 |
+
[bass](https://huggingface.co/StemSplitio/htdemucs-ft-bass-onnx) ·
|
| 157 |
+
[other](https://huggingface.co/StemSplitio/htdemucs-ft-other-onnx) ·
|
| 158 |
+
[vocals](https://huggingface.co/StemSplitio/htdemucs-ft-vocals-onnx).
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## How aggregation works
|
| 163 |
+
|
| 164 |
+
The `htdemucs_ft` bag uses a **one-hot weight matrix** for combining the 4
|
| 165 |
+
sub-models — model 0's drums output is used directly as the bag's drums
|
| 166 |
+
stem, model 1's bass output is the bag's bass stem, and so on. No
|
| 167 |
+
weighted-sum aggregation needed.
|
| 168 |
+
|
| 169 |
+
That means:
|
| 170 |
+
- **The bag's drums stem == the drums specialist's drums output** (bit-exact in fp32)
|
| 171 |
+
- Same for bass, other, vocals
|
| 172 |
+
- So you can ship only the specialists you need and get identical
|
| 173 |
+
per-stem quality to the full bag at 1/4 the size
|
| 174 |
+
|
| 175 |
+
`bag_infer.py` simply runs all 4 specialists and picks the relevant row
|
| 176 |
+
from each. ~30 lines of numpy.
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
## Input / output spec per sub-model
|
| 181 |
+
|
| 182 |
+
| Tensor | Name | Shape | Dtype | Notes |
|
| 183 |
+
|---|---|---|---|---|
|
| 184 |
+
| Input | `mix` | `(1, 2, 343980)` | float32 | Stereo audio, 44.1 kHz, 7.8 s segment. |
|
| 185 |
+
| Output | `stems` | `(1, 4, 2, 343980)` | float32 | `[drums, bass, other, vocals]`. Use only the specialist's target row. |
|
| 186 |
+
|
| 187 |
+
For longer audio, the bag script handles overlap-add chunking.
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
## Files in this repo
|
| 192 |
+
|
| 193 |
+
| File | Size | Purpose |
|
| 194 |
+
|---|---:|---|
|
| 195 |
+
| `htdemucs_ft_drums.onnx` | 316 MB | Drums specialist (bag index 0) |
|
| 196 |
+
| `htdemucs_ft_bass.onnx` | 316 MB | Bass specialist (bag index 1) |
|
| 197 |
+
| `htdemucs_ft_other.onnx` | 316 MB | Other specialist (bag index 2) |
|
| 198 |
+
| `htdemucs_ft_vocals.onnx` | 316 MB | Vocals specialist (bag index 3) |
|
| 199 |
+
| `bag_infer.py` | 7 KB | Pure numpy aggregator. No torch. |
|
| 200 |
+
| `requirements.txt` | <1 KB | `onnxruntime`, `numpy`, `soundfile`. |
|
| 201 |
+
| `README.md` | this file | |
|
| 202 |
+
|
| 203 |
+
Total: **~1.26 GB**. If that's too big, use individual stem repos.
|
| 204 |
+
|
| 205 |
+
---
|
| 206 |
+
|
| 207 |
+
## Related work
|
| 208 |
+
|
| 209 |
+
| Repo | Stem | Use when |
|
| 210 |
+
|---|---|---|
|
| 211 |
+
| [`htdemucs-ft-drums-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-drums-onnx) | drums | Only need drums (1/4 size, 1/4 latency) |
|
| 212 |
+
| [`htdemucs-ft-bass-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-bass-onnx) | bass | Only need bass |
|
| 213 |
+
| [`htdemucs-ft-other-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-other-onnx) | other | Only need "other" / instrumental |
|
| 214 |
+
| [`htdemucs-ft-vocals-onnx`](https://huggingface.co/StemSplitio/htdemucs-ft-vocals-onnx) | vocals | **#1 open-source vocal SDR** |
|
| 215 |
+
|
| 216 |
+
PyTorch versions for HF Inference Endpoints:
|
| 217 |
+
[`htdemucs-ft-pytorch`](https://huggingface.co/StemSplitio/htdemucs-ft-pytorch)
|
| 218 |
+
and its [4 sibling specialist repos](https://huggingface.co/StemSplitio).
|
| 219 |
+
|
| 220 |
+
---
|
| 221 |
+
|
| 222 |
+
## Skip the infrastructure — use the StemSplit API
|
| 223 |
+
|
| 224 |
+
Don't want to ship 1.26 GB of `.onnx` files in your app, manage a GPU
|
| 225 |
+
pool, or write overlap-add chunking? Use the
|
| 226 |
+
**[StemSplit API](https://stemsplit.io/developers)** instead — same models
|
| 227 |
+
under the hood, hosted for you, with credits and a dashboard.
|
| 228 |
+
|
| 229 |
+
- 🌐 [stemsplit.io](https://stemsplit.io)
|
| 230 |
+
- 📘 [Developer docs](https://stemsplit.io/developers/docs)
|
| 231 |
+
- 🔌 [API reference](https://stemsplit.io/developers/reference)
|
| 232 |
+
|
| 233 |
+
Or use the no-code tools that ship this same model family:
|
| 234 |
+
|
| 235 |
+
- 🎤 [Vocal Remover](https://stemsplit.io/vocal-remover)
|
| 236 |
+
- 🎶 [Karaoke Maker](https://stemsplit.io/karaoke-maker)
|
| 237 |
+
- 🎙️ [Acapella Maker](https://stemsplit.io/acapella-maker)
|
| 238 |
+
- 📺 [YouTube Stem Splitter](https://stemsplit.io/youtube-stem-splitter)
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## License & attribution
|
| 243 |
+
|
| 244 |
+
MIT-licensed, matching the original HT-Demucs.
|
| 245 |
+
|
| 246 |
+
```bibtex
|
| 247 |
+
@inproceedings{rouard2023hybrid,
|
| 248 |
+
title = {Hybrid Transformers for Music Source Separation},
|
| 249 |
+
author = {Rouard, Simon and Massa, Francisco and D{\'e}fossez, Alexandre},
|
| 250 |
+
booktitle = {ICASSP},
|
| 251 |
+
year = {2023}
|
| 252 |
+
}
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
- Original PyTorch model: [`facebookresearch/demucs`](https://github.com/facebookresearch/demucs)
|
| 256 |
+
- ONNX export, parity verification, and packaging by [StemSplit](https://stemsplit.io)
|
| 257 |
+
- Search keywords: htdemucs onnx, demucs onnx, htdemucs bag onnx, demucs ios, demucs android, music source separation onnx, 4-stem separation onnx, stem separation mobile, onnxruntime music separation
|
bag_infer.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bag inference for the full HT-Demucs FT 4-stem ONNX ensemble.
|
| 3 |
+
|
| 4 |
+
Runs all 4 specialist sub-models and aggregates their outputs using the
|
| 5 |
+
htdemucs_ft bag's one-hot weight matrix (drums-model -> drums stem only,
|
| 6 |
+
bass-model -> bass stem only, etc).
|
| 7 |
+
|
| 8 |
+
NO TORCH at inference. Just numpy + onnxruntime + soundfile.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python bag_infer.py your-song.mp3 ./out/
|
| 12 |
+
# writes out/drums.wav, out/bass.wav, out/other.wav, out/vocals.wav
|
| 13 |
+
|
| 14 |
+
Or as a library:
|
| 15 |
+
import bag_infer
|
| 16 |
+
stems = bag_infer.separate_all("song.mp3")
|
| 17 |
+
# stems: dict[str, numpy.ndarray (2, samples)]
|
| 18 |
+
"""
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import onnxruntime as ort
|
| 28 |
+
import soundfile as sf
|
| 29 |
+
|
| 30 |
+
SAMPLE_RATE = 44100
|
| 31 |
+
SEGMENT_S = 7.8
|
| 32 |
+
N_SAMPLES = int(SEGMENT_S * SAMPLE_RATE) # 343,980
|
| 33 |
+
N_CHANNELS = 2
|
| 34 |
+
SOURCES = ["drums", "bass", "other", "vocals"]
|
| 35 |
+
HERE = Path(__file__).resolve().parent
|
| 36 |
+
|
| 37 |
+
# The bag's weight matrix for htdemucs_ft is one-hot per stem:
|
| 38 |
+
# drums specialist (bag.models[0]) -> contributes only to drums stem
|
| 39 |
+
# bass specialist (bag.models[1]) -> contributes only to bass stem
|
| 40 |
+
# other specialist (bag.models[2]) -> contributes only to other stem
|
| 41 |
+
# vocals specialist (bag.models[3]) -> contributes only to vocals stem
|
| 42 |
+
# So aggregation is trivial: pick row N from model N's output.
|
| 43 |
+
DEFAULT_ONNX_FILES = {
|
| 44 |
+
"drums": HERE / "htdemucs_ft_drums.onnx",
|
| 45 |
+
"bass": HERE / "htdemucs_ft_bass.onnx",
|
| 46 |
+
"other": HERE / "htdemucs_ft_other.onnx",
|
| 47 |
+
"vocals": HERE / "htdemucs_ft_vocals.onnx",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _make_transition_window(segment: int, overlap_frac: float = 0.25) -> np.ndarray:
|
| 52 |
+
transition = int(segment * overlap_frac)
|
| 53 |
+
window = np.ones(segment, dtype=np.float32)
|
| 54 |
+
fade = np.linspace(0, 1, transition, dtype=np.float32)
|
| 55 |
+
window[:transition] = fade
|
| 56 |
+
window[-transition:] = fade[::-1]
|
| 57 |
+
return window
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _load_sessions(onnx_files: dict[str, Path],
|
| 61 |
+
providers: list[str] | None = None,
|
| 62 |
+
) -> dict[str, ort.InferenceSession]:
|
| 63 |
+
if providers is None:
|
| 64 |
+
providers = ["CPUExecutionProvider"]
|
| 65 |
+
sessions: dict[str, ort.InferenceSession] = {}
|
| 66 |
+
for stem, path in onnx_files.items():
|
| 67 |
+
if not path.exists():
|
| 68 |
+
raise FileNotFoundError(
|
| 69 |
+
f"Missing {stem} model at {path}. Download all 4 .onnx files "
|
| 70 |
+
"into the same directory as this script.")
|
| 71 |
+
sessions[stem] = ort.InferenceSession(str(path), providers=providers)
|
| 72 |
+
return sessions
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def separate(mix: np.ndarray, sample_rate: int,
|
| 76 |
+
onnx_files: dict[str, Path] | None = None,
|
| 77 |
+
providers: list[str] | None = None,
|
| 78 |
+
verbose: bool = True) -> dict[str, np.ndarray]:
|
| 79 |
+
"""Run full 4-stem chunked overlap-add separation.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
mix: (channels, samples) float32 in [-1, 1], 44.1 kHz stereo.
|
| 83 |
+
sample_rate: must equal 44100.
|
| 84 |
+
onnx_files: optional dict overriding the default file locations.
|
| 85 |
+
providers: onnxruntime EPs; defaults to CPU.
|
| 86 |
+
verbose: print progress per chunk.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
dict of {stem_name: (channels, samples) float32}.
|
| 90 |
+
"""
|
| 91 |
+
if sample_rate != SAMPLE_RATE:
|
| 92 |
+
raise ValueError(f"Bound to {SAMPLE_RATE} Hz; got {sample_rate}.")
|
| 93 |
+
if mix.ndim != 2 or mix.shape[0] != N_CHANNELS:
|
| 94 |
+
raise ValueError(f"Expected (2, samples) input, got {mix.shape}")
|
| 95 |
+
|
| 96 |
+
sessions = _load_sessions(onnx_files or DEFAULT_ONNX_FILES, providers)
|
| 97 |
+
if verbose:
|
| 98 |
+
print(f" loaded {len(sessions)} ONNX sessions on "
|
| 99 |
+
f"{list(sessions.values())[0].get_providers()[0]}")
|
| 100 |
+
|
| 101 |
+
total_len = mix.shape[1]
|
| 102 |
+
overlap = N_SAMPLES // 4
|
| 103 |
+
stride = N_SAMPLES - overlap
|
| 104 |
+
n_chunks = max(1, (total_len + stride - 1) // stride)
|
| 105 |
+
|
| 106 |
+
if verbose:
|
| 107 |
+
print(f" input: {total_len:,} samples ({total_len / sample_rate:.1f}s)")
|
| 108 |
+
print(f" chunks: {n_chunks}")
|
| 109 |
+
|
| 110 |
+
window = _make_transition_window(N_SAMPLES)
|
| 111 |
+
out = {stem: np.zeros((N_CHANNELS, total_len), dtype=np.float32) for stem in SOURCES}
|
| 112 |
+
weight = np.zeros(total_len, dtype=np.float32)
|
| 113 |
+
|
| 114 |
+
t0 = time.perf_counter()
|
| 115 |
+
for i in range(n_chunks):
|
| 116 |
+
start = i * stride
|
| 117 |
+
end = min(start + N_SAMPLES, total_len)
|
| 118 |
+
chunk = mix[:, start:end]
|
| 119 |
+
if chunk.shape[1] < N_SAMPLES:
|
| 120 |
+
chunk = np.pad(chunk, ((0, 0), (0, N_SAMPLES - chunk.shape[1])),
|
| 121 |
+
mode="constant")
|
| 122 |
+
x = chunk[np.newaxis, ...].astype(np.float32)
|
| 123 |
+
chunk_len = end - start
|
| 124 |
+
w = window[:chunk_len]
|
| 125 |
+
|
| 126 |
+
# Run each specialist; take only its target stem row.
|
| 127 |
+
for stem in SOURCES:
|
| 128 |
+
stems = sessions[stem].run(["stems"], {"mix": x})[0][0] # (4, 2, N)
|
| 129 |
+
target_row = SOURCES.index(stem) # 0/1/2/3 matches bag.models[idx]
|
| 130 |
+
out[stem][:, start:end] += stems[target_row, :, :chunk_len] * w
|
| 131 |
+
|
| 132 |
+
weight[start:end] += w
|
| 133 |
+
if verbose:
|
| 134 |
+
print(f" chunk {i+1}/{n_chunks}: "
|
| 135 |
+
f"{time.perf_counter() - t0:.1f}s elapsed")
|
| 136 |
+
|
| 137 |
+
weight = np.maximum(weight, 1e-8)
|
| 138 |
+
for stem in SOURCES:
|
| 139 |
+
out[stem] /= weight
|
| 140 |
+
|
| 141 |
+
if verbose:
|
| 142 |
+
rtf = (time.perf_counter() - t0) / (total_len / sample_rate)
|
| 143 |
+
print(f" total: {time.perf_counter() - t0:.2f}s (RTF {rtf:.2f}, "
|
| 144 |
+
f"4 sub-models × {n_chunks} chunks = "
|
| 145 |
+
f"{4 * n_chunks} ONNX runs)")
|
| 146 |
+
return out
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def separate_all(input_path: str, **kwargs) -> dict[str, np.ndarray]:
|
| 150 |
+
"""Convenience: load audio, run separation, return all 4 stems."""
|
| 151 |
+
audio, sr = sf.read(input_path, dtype="float32", always_2d=True)
|
| 152 |
+
audio = audio.T
|
| 153 |
+
if audio.shape[0] == 1:
|
| 154 |
+
audio = np.tile(audio, (2, 1))
|
| 155 |
+
elif audio.shape[0] > 2:
|
| 156 |
+
audio = audio[:2]
|
| 157 |
+
return separate(audio, sr, **kwargs)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def main() -> None:
|
| 161 |
+
ap = argparse.ArgumentParser(description=__doc__)
|
| 162 |
+
ap.add_argument("input", type=Path)
|
| 163 |
+
ap.add_argument("out_dir", type=Path)
|
| 164 |
+
ap.add_argument("--providers", type=str, default="cpu",
|
| 165 |
+
choices=["cpu", "coreml", "cuda", "dml"])
|
| 166 |
+
args = ap.parse_args()
|
| 167 |
+
|
| 168 |
+
providers_map = {
|
| 169 |
+
"cpu": ["CPUExecutionProvider"],
|
| 170 |
+
"coreml": ["CoreMLExecutionProvider", "CPUExecutionProvider"],
|
| 171 |
+
"cuda": ["CUDAExecutionProvider", "CPUExecutionProvider"],
|
| 172 |
+
"dml": ["DmlExecutionProvider", "CPUExecutionProvider"],
|
| 173 |
+
}
|
| 174 |
+
args.out_dir.mkdir(parents=True, exist_ok=True)
|
| 175 |
+
|
| 176 |
+
print(f"Loading {args.input} ...")
|
| 177 |
+
audio, sr = sf.read(str(args.input), dtype="float32", always_2d=True)
|
| 178 |
+
audio = audio.T
|
| 179 |
+
if audio.shape[0] == 1:
|
| 180 |
+
audio = np.tile(audio, (2, 1))
|
| 181 |
+
elif audio.shape[0] > 2:
|
| 182 |
+
audio = audio[:2]
|
| 183 |
+
print(f" shape {audio.shape}, sr {sr}")
|
| 184 |
+
|
| 185 |
+
stems = separate(audio, sr, providers=providers_map[args.providers])
|
| 186 |
+
|
| 187 |
+
for stem, audio_out in stems.items():
|
| 188 |
+
out_path = args.out_dir / f"{stem}.wav"
|
| 189 |
+
sf.write(str(out_path), audio_out.T, sr)
|
| 190 |
+
print(f" wrote {out_path}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
onnxruntime>=1.20
|
| 2 |
+
numpy>=1.24
|
| 3 |
+
soundfile>=0.12
|