Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- speech-separation
|
| 4 |
+
- audio
|
| 5 |
+
- dprnn
|
| 6 |
+
- multi-decoder
|
| 7 |
+
license: mit
|
| 8 |
+
base_model: JunzheJosephZhu/MultiDecoderDPRNN
|
| 9 |
+
datasets:
|
| 10 |
+
- custom
|
| 11 |
+
language:
|
| 12 |
+
- en
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# Multi-Decoder DPRNN — Fine-Tuned for 1–5 Speaker Separation
|
| 16 |
+
|
| 17 |
+
Fine-tuned version of [MultiDecoderDPRNN](https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN) for variable speaker count (1–5) speech separation.
|
| 18 |
+
|
| 19 |
+
## What changed?
|
| 20 |
+
|
| 21 |
+
The original pre-trained model supports 2–5 speakers and always outputs 5 active sources regardless of actual speaker count (~22% speaker count accuracy). Our fine-tuning teaches the model to:
|
| 22 |
+
- Correctly identify the number of speakers (88% accuracy)
|
| 23 |
+
- Output silence on unused channels
|
| 24 |
+
- Support single-speaker scenarios (new 1-speaker decoder)
|
| 25 |
+
|
| 26 |
+
## Checkpoints
|
| 27 |
+
|
| 28 |
+
| Checkpoint | Architecture | SI-SDR | SI-SDRi | SDR | Spk Acc | File |
|
| 29 |
+
|------------|-------------|--------|---------|-----|---------|------|
|
| 30 |
+
| **Rys (best)** | n_srcs=[1,2,3,4,5] | +7.33 dB | +10.01 dB | +7.32 dB | 88.0% | `weights/dprnn_rys_weights.pt` |
|
| 31 |
+
| Zhaksh | n_srcs=[2,3,4,5] | +3.24 dB | +5.92 dB | +3.24 dB | 50.0% | `weights/dprnn_zhaksh_weights.pt` |
|
| 32 |
+
| Original (baseline) | n_srcs=[2,3,4,5] | -1.29 dB | +1.39 dB | -1.29 dB | 22.0% | — |
|
| 33 |
+
|
| 34 |
+
### SI-SDRi by Speaker Count
|
| 35 |
+
|
| 36 |
+
| N | Original | Rys | Zhaksh |
|
| 37 |
+
|---|----------|-----|--------|
|
| 38 |
+
| 1 | -9.17 dB | **+8.86 dB** | -7.47 dB |
|
| 39 |
+
| 2 | +4.07 dB | **+11.98 dB** | +11.58 dB |
|
| 40 |
+
| 3 | +2.46 dB | **+10.00 dB** | +9.78 dB |
|
| 41 |
+
| 4 | +4.56 dB | **+9.65 dB** | +7.78 dB |
|
| 42 |
+
| 5 | +5.05 dB | **+9.56 dB** | +7.94 dB |
|
| 43 |
+
|
| 44 |
+
## Usage
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
import sys, torch
|
| 48 |
+
sys.path.insert(0, "asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN")
|
| 49 |
+
from model import MultiDecoderDPRNN
|
| 50 |
+
from huggingface_hub import hf_hub_download
|
| 51 |
+
|
| 52 |
+
# Load base architecture with extended speaker count
|
| 53 |
+
pretrained = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN")
|
| 54 |
+
fb = pretrained.encoder.filterbank
|
| 55 |
+
cfg = dict(
|
| 56 |
+
bn_chan=pretrained.masker.bn_chan, hid_size=pretrained.masker.hid_size,
|
| 57 |
+
chunk_size=pretrained.masker.chunk_size, hop_size=pretrained.masker.hop_size,
|
| 58 |
+
n_repeats=pretrained.masker.n_repeats, norm_type=pretrained.masker.norm_type,
|
| 59 |
+
bidirectional=pretrained.masker.bidirectional, rnn_type=pretrained.masker.rnn_type,
|
| 60 |
+
num_layers=pretrained.masker.num_layers, dropout=pretrained.masker.dropout,
|
| 61 |
+
n_filters=fb.n_feats_out, kernel_size=fb.kernel_size,
|
| 62 |
+
stride=fb.stride, sample_rate=8000,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
model = MultiDecoderDPRNN(n_srcs=[1, 2, 3, 4, 5], **cfg)
|
| 66 |
+
weights = hf_hub_download("Namadgi/MultiDecoderDPRNN-finetuned", "weights/dprnn_rys_weights.pt")
|
| 67 |
+
model.load_state_dict(torch.load(weights, map_location="cpu"))
|
| 68 |
+
model.eval()
|
| 69 |
+
|
| 70 |
+
# Inference
|
| 71 |
+
mix = torch.randn(1, 32000) # [batch, time] at 8kHz
|
| 72 |
+
reconstructed, selector = model(mix, ground_truth=[3]) # 3 speakers
|
| 73 |
+
# reconstructed: [1, n_stages, max_spks, T]
|
| 74 |
+
# selector: [1, n_stages, n_decoders]
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Training Details
|
| 78 |
+
|
| 79 |
+
- **Base model:** JunzheJosephZhu/MultiDecoderDPRNN
|
| 80 |
+
- **Dataset:** 10,000 synthetic mixtures (LibriSpeech train-clean-100 + WHAM! noise), 1–5 speakers
|
| 81 |
+
- **Evaluation:** 50 samples from held-out CV set (LibriSpeech dev-clean + WHAM! cv)
|
| 82 |
+
- **Sample rate:** 8 kHz
|
| 83 |
+
- **Loss:** Multi-stage PIT SI-SDR + Cross-entropy selector loss
|
| 84 |
+
|
| 85 |
+
## Authors
|
| 86 |
+
|
| 87 |
+
- Rys ([@RysNamadgi](https://huggingface.co/RysNamadgi))
|
| 88 |
+
- Zhaksh
|
| 89 |
+
|
| 90 |
+
Part of the [Namadgi](https://huggingface.co/Namadgi) research group.
|