xocialize commited on
Commit
d9d07d8
·
verified ·
1 Parent(s): f0af45f

Add MLX port of MERL MRX (default_ checkpoint, fp32) — 3-stem soundtrack separation

Browse files
Files changed (3) hide show
  1. README.md +53 -0
  2. config.json +23 -0
  3. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: mlx
4
+ tags:
5
+ - mlx
6
+ - audio
7
+ - audio-source-separation
8
+ - speech
9
+ - music
10
+ - apple-silicon
11
+ pipeline_tag: audio-to-audio
12
+ ---
13
+
14
+ # Cocktail-Fork-MRX (MLX)
15
+
16
+ Apple **MLX** port of MERL's **MRX** (Multi-Resolution CrossNet) — separates a
17
+ soundtrack mixture into three stems: **music**, **speech**, and **sound effects (sfx)**.
18
+ Runs natively on Apple Silicon, no PyTorch at inference.
19
+
20
+ - **Upstream:** [merlresearch/cocktail-fork-separation](https://github.com/merlresearch/cocktail-fork-separation) — *The Cocktail Fork Problem: Three-Stem Audio Separation for Real-World Soundtracks* (ICASSP 2022).
21
+ - **Checkpoint:** `default_` (SNR-loss trained — the upstream default inference weights).
22
+ - **License:** MIT.
23
+ - **Parity:** numerically exact vs the PyTorch reference (full-pipeline max_abs ≈ `9e-8`; per-stem SI-SDR 107–139 dB vs torch).
24
+
25
+ ## Usage
26
+
27
+ ```bash
28
+ pip install cocktail-fork-mlx # or: pip install git+https://github.com/xocialize/cocktail-fork-mlx
29
+ cocktail-fork-mlx --audio-path soundtrack.wav --out-dir ./out
30
+ # -> out/music.wav out/speech.wav out/sfx.wav
31
+ ```
32
+
33
+ ```python
34
+ import mlx.core as mx, soundfile as sf, numpy as np
35
+ from cocktail_fork_mlx.separate import separate_soundtrack
36
+ from cocktail_fork_mlx.weights import from_pretrained
37
+
38
+ audio, fs = sf.read("soundtrack.wav", always_2d=True) # 44.1 kHz
39
+ model = from_pretrained("mlx-community/Cocktail-Fork-MRX")
40
+ stems = separate_soundtrack(mx.array(audio.T.astype("float32")), model)
41
+ for name, x in stems.items():
42
+ sf.write(f"{name}.wav", np.array(x).T, 44100)
43
+ ```
44
+
45
+ ## Model
46
+
47
+ - 44.1 kHz, any channel count. ~30.6M params, fp32 (122 MB).
48
+ - Multi-resolution STFT (windows 1024/2048/8192, hop 256) → per-resolution magnitude
49
+ encoders → 3 parallel bidirectional CrossNet LSTMs → per-source/per-resolution mask
50
+ decoders → masked iSTFT summed across resolutions.
51
+ - CPU is the faster device for this LSTM-bound model (default in the CLI).
52
+
53
+ Ported by MVS Collective (xocialize). MIT, © MERL for the original model/weights.
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "mrx",
3
+ "architecture": "MRX (Multi-Resolution CrossNet)",
4
+ "n_sources": 3,
5
+ "window_lengths": [
6
+ 1024,
7
+ 2048,
8
+ 8192
9
+ ],
10
+ "n_hop": 256,
11
+ "n_hidden": 512,
12
+ "n_lstm_layers": 3,
13
+ "sample_rate": 44100,
14
+ "source_names": [
15
+ "music",
16
+ "speech",
17
+ "sfx"
18
+ ],
19
+ "upstream": "merlresearch/cocktail-fork-separation",
20
+ "license": "MIT",
21
+ "port_version": "0.1.0",
22
+ "dtype": "float32"
23
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17244f6f1ded8b3430a757a2d2a72bdf2e88eecd035f81a54ebb74f6d5f79884
3
+ size 122284399