RysNamadgi commited on
Commit
679e897
·
verified ·
1 Parent(s): 24f7d60

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +90 -0
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - speech-separation
4
+ - audio
5
+ - dprnn
6
+ - multi-decoder
7
+ license: mit
8
+ base_model: JunzheJosephZhu/MultiDecoderDPRNN
9
+ datasets:
10
+ - custom
11
+ language:
12
+ - en
13
+ ---
14
+
15
+ # Multi-Decoder DPRNN — Fine-Tuned for 1–5 Speaker Separation
16
+
17
+ Fine-tuned version of [MultiDecoderDPRNN](https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN) for variable speaker count (1–5) speech separation.
18
+
19
+ ## What changed?
20
+
21
+ The original pre-trained model supports 2–5 speakers and always outputs 5 active sources regardless of actual speaker count (~22% speaker count accuracy). Our fine-tuning teaches the model to:
22
+ - Correctly identify the number of speakers (88% accuracy)
23
+ - Output silence on unused channels
24
+ - Support single-speaker scenarios (new 1-speaker decoder)
25
+
26
+ ## Checkpoints
27
+
28
+ | Checkpoint | Architecture | SI-SDR | SI-SDRi | SDR | Spk Acc | File |
29
+ |------------|-------------|--------|---------|-----|---------|------|
30
+ | **Rys (best)** | n_srcs=[1,2,3,4,5] | +7.33 dB | +10.01 dB | +7.32 dB | 88.0% | `weights/dprnn_rys_weights.pt` |
31
+ | Zhaksh | n_srcs=[2,3,4,5] | +3.24 dB | +5.92 dB | +3.24 dB | 50.0% | `weights/dprnn_zhaksh_weights.pt` |
32
+ | Original (baseline) | n_srcs=[2,3,4,5] | -1.29 dB | +1.39 dB | -1.29 dB | 22.0% | — |
33
+
34
+ ### SI-SDRi by Speaker Count
35
+
36
+ | N | Original | Rys | Zhaksh |
37
+ |---|----------|-----|--------|
38
+ | 1 | -9.17 dB | **+8.86 dB** | -7.47 dB |
39
+ | 2 | +4.07 dB | **+11.98 dB** | +11.58 dB |
40
+ | 3 | +2.46 dB | **+10.00 dB** | +9.78 dB |
41
+ | 4 | +4.56 dB | **+9.65 dB** | +7.78 dB |
42
+ | 5 | +5.05 dB | **+9.56 dB** | +7.94 dB |
43
+
44
+ ## Usage
45
+
46
+ ```python
47
+ import sys, torch
48
+ sys.path.insert(0, "asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN")
49
+ from model import MultiDecoderDPRNN
50
+ from huggingface_hub import hf_hub_download
51
+
52
+ # Load base architecture with extended speaker count
53
+ pretrained = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN")
54
+ fb = pretrained.encoder.filterbank
55
+ cfg = dict(
56
+ bn_chan=pretrained.masker.bn_chan, hid_size=pretrained.masker.hid_size,
57
+ chunk_size=pretrained.masker.chunk_size, hop_size=pretrained.masker.hop_size,
58
+ n_repeats=pretrained.masker.n_repeats, norm_type=pretrained.masker.norm_type,
59
+ bidirectional=pretrained.masker.bidirectional, rnn_type=pretrained.masker.rnn_type,
60
+ num_layers=pretrained.masker.num_layers, dropout=pretrained.masker.dropout,
61
+ n_filters=fb.n_feats_out, kernel_size=fb.kernel_size,
62
+ stride=fb.stride, sample_rate=8000,
63
+ )
64
+
65
+ model = MultiDecoderDPRNN(n_srcs=[1, 2, 3, 4, 5], **cfg)
66
+ weights = hf_hub_download("Namadgi/MultiDecoderDPRNN-finetuned", "weights/dprnn_rys_weights.pt")
67
+ model.load_state_dict(torch.load(weights, map_location="cpu"))
68
+ model.eval()
69
+
70
+ # Inference
71
+ mix = torch.randn(1, 32000) # [batch, time] at 8kHz
72
+ reconstructed, selector = model(mix, ground_truth=[3]) # 3 speakers
73
+ # reconstructed: [1, n_stages, max_spks, T]
74
+ # selector: [1, n_stages, n_decoders]
75
+ ```
76
+
77
+ ## Training Details
78
+
79
+ - **Base model:** JunzheJosephZhu/MultiDecoderDPRNN
80
+ - **Dataset:** 10,000 synthetic mixtures (LibriSpeech train-clean-100 + WHAM! noise), 1–5 speakers
81
+ - **Evaluation:** 50 samples from held-out CV set (LibriSpeech dev-clean + WHAM! cv)
82
+ - **Sample rate:** 8 kHz
83
+ - **Loss:** Multi-stage PIT SI-SDR + Cross-entropy selector loss
84
+
85
+ ## Authors
86
+
87
+ - Rys ([@RysNamadgi](https://huggingface.co/RysNamadgi))
88
+ - Zhaksh
89
+
90
+ Part of the [Namadgi](https://huggingface.co/Namadgi) research group.