File size: 3,378 Bytes
679e897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
tags:
  - speech-separation
  - audio
  - dprnn
  - multi-decoder
license: mit
base_model: JunzheJosephZhu/MultiDecoderDPRNN
datasets:
  - custom
language:
  - en
---

# Multi-Decoder DPRNN — Fine-Tuned for 1–5 Speaker Separation

Fine-tuned version of [MultiDecoderDPRNN](https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN) for variable speaker count (1–5) speech separation.

## What changed?

The original pre-trained model supports 2–5 speakers and always outputs 5 active sources regardless of actual speaker count (~22% speaker count accuracy). Our fine-tuning teaches the model to:
- Correctly identify the number of speakers (88% accuracy)
- Output silence on unused channels
- Support single-speaker scenarios (new 1-speaker decoder)

## Checkpoints

| Checkpoint | Architecture | SI-SDR | SI-SDRi | SDR | Spk Acc | File |
|------------|-------------|--------|---------|-----|---------|------|
| **Rys (best)** | n_srcs=[1,2,3,4,5] | +7.33 dB | +10.01 dB | +7.32 dB | 88.0% | `weights/dprnn_rys_weights.pt` |
| Zhaksh | n_srcs=[2,3,4,5] | +3.24 dB | +5.92 dB | +3.24 dB | 50.0% | `weights/dprnn_zhaksh_weights.pt` |
| Original (baseline) | n_srcs=[2,3,4,5] | -1.29 dB | +1.39 dB | -1.29 dB | 22.0% | — |

### SI-SDRi by Speaker Count

| N | Original | Rys | Zhaksh |
|---|----------|-----|--------|
| 1 | -9.17 dB | **+8.86 dB** | -7.47 dB |
| 2 | +4.07 dB | **+11.98 dB** | +11.58 dB |
| 3 | +2.46 dB | **+10.00 dB** | +9.78 dB |
| 4 | +4.56 dB | **+9.65 dB** | +7.78 dB |
| 5 | +5.05 dB | **+9.56 dB** | +7.94 dB |

## Usage

```python
import sys, torch
sys.path.insert(0, "asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN")
from model import MultiDecoderDPRNN
from huggingface_hub import hf_hub_download

# Load base architecture with extended speaker count
pretrained = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN")
fb = pretrained.encoder.filterbank
cfg = dict(
    bn_chan=pretrained.masker.bn_chan, hid_size=pretrained.masker.hid_size,
    chunk_size=pretrained.masker.chunk_size, hop_size=pretrained.masker.hop_size,
    n_repeats=pretrained.masker.n_repeats, norm_type=pretrained.masker.norm_type,
    bidirectional=pretrained.masker.bidirectional, rnn_type=pretrained.masker.rnn_type,
    num_layers=pretrained.masker.num_layers, dropout=pretrained.masker.dropout,
    n_filters=fb.n_feats_out, kernel_size=fb.kernel_size,
    stride=fb.stride, sample_rate=8000,
)

model = MultiDecoderDPRNN(n_srcs=[1, 2, 3, 4, 5], **cfg)
weights = hf_hub_download("Namadgi/MultiDecoderDPRNN-finetuned", "weights/dprnn_rys_weights.pt")
model.load_state_dict(torch.load(weights, map_location="cpu"))
model.eval()

# Inference
mix = torch.randn(1, 32000)  # [batch, time] at 8kHz
reconstructed, selector = model(mix, ground_truth=[3])  # 3 speakers
# reconstructed: [1, n_stages, max_spks, T]
# selector: [1, n_stages, n_decoders]
```

## Training Details

- **Base model:** JunzheJosephZhu/MultiDecoderDPRNN
- **Dataset:** 10,000 synthetic mixtures (LibriSpeech train-clean-100 + WHAM! noise), 1–5 speakers
- **Evaluation:** 50 samples from held-out CV set (LibriSpeech dev-clean + WHAM! cv)
- **Sample rate:** 8 kHz
- **Loss:** Multi-stage PIT SI-SDR + Cross-entropy selector loss

## Authors

- Rys ([@RysNamadgi](https://huggingface.co/RysNamadgi))
- Zhaksh

Part of the [Namadgi](https://huggingface.co/Namadgi) research group.