| --- |
| tags: |
| - speech-separation |
| - audio |
| - dprnn |
| - multi-decoder |
| license: mit |
| base_model: JunzheJosephZhu/MultiDecoderDPRNN |
| datasets: |
| - custom |
| language: |
| - en |
| --- |
| |
| # Multi-Decoder DPRNN β Fine-Tuned for 1β5 Speaker Separation |
|
|
| Fine-tuned version of [MultiDecoderDPRNN](https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN) for variable speaker count (1β5) speech separation. |
|
|
| ## What changed? |
|
|
| The original pre-trained model supports 2β5 speakers and always outputs 5 active sources regardless of actual speaker count (~22% speaker count accuracy). Our fine-tuning teaches the model to: |
| - Correctly identify the number of speakers (88% accuracy) |
| - Output silence on unused channels |
| - Support single-speaker scenarios (new 1-speaker decoder) |
|
|
| ## Checkpoints |
|
|
| | Checkpoint | Architecture | SI-SDR | SI-SDRi | SDR | Spk Acc | File | |
| |------------|-------------|--------|---------|-----|---------|------| |
| | **Rys (best)** | n_srcs=[1,2,3,4,5] | +7.33 dB | +10.01 dB | +7.32 dB | 88.0% | `weights/dprnn_rys_weights.pt` | |
| | Zhaksh | n_srcs=[2,3,4,5] | +3.24 dB | +5.92 dB | +3.24 dB | 50.0% | `weights/dprnn_zhaksh_weights.pt` | |
| | Original (baseline) | n_srcs=[2,3,4,5] | -1.29 dB | +1.39 dB | -1.29 dB | 22.0% | β | |
| |
| ### SI-SDRi by Speaker Count |
| |
| | N | Original | Rys | Zhaksh | |
| |---|----------|-----|--------| |
| | 1 | -9.17 dB | **+8.86 dB** | -7.47 dB | |
| | 2 | +4.07 dB | **+11.98 dB** | +11.58 dB | |
| | 3 | +2.46 dB | **+10.00 dB** | +9.78 dB | |
| | 4 | +4.56 dB | **+9.65 dB** | +7.78 dB | |
| | 5 | +5.05 dB | **+9.56 dB** | +7.94 dB | |
| |
| ## Usage |
| |
| ```python |
| import sys, torch |
| sys.path.insert(0, "asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN") |
| from model import MultiDecoderDPRNN |
| from huggingface_hub import hf_hub_download |
|
|
| # Load base architecture with extended speaker count |
| pretrained = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN") |
| fb = pretrained.encoder.filterbank |
| cfg = dict( |
| bn_chan=pretrained.masker.bn_chan, hid_size=pretrained.masker.hid_size, |
| chunk_size=pretrained.masker.chunk_size, hop_size=pretrained.masker.hop_size, |
| n_repeats=pretrained.masker.n_repeats, norm_type=pretrained.masker.norm_type, |
| bidirectional=pretrained.masker.bidirectional, rnn_type=pretrained.masker.rnn_type, |
| num_layers=pretrained.masker.num_layers, dropout=pretrained.masker.dropout, |
| n_filters=fb.n_feats_out, kernel_size=fb.kernel_size, |
| stride=fb.stride, sample_rate=8000, |
| ) |
| |
| model = MultiDecoderDPRNN(n_srcs=[1, 2, 3, 4, 5], **cfg) |
| weights = hf_hub_download("Namadgi/MultiDecoderDPRNN-finetuned", "weights/dprnn_rys_weights.pt") |
| model.load_state_dict(torch.load(weights, map_location="cpu")) |
| model.eval() |
| |
| # Inference |
| mix = torch.randn(1, 32000) # [batch, time] at 8kHz |
| reconstructed, selector = model(mix, ground_truth=[3]) # 3 speakers |
| # reconstructed: [1, n_stages, max_spks, T] |
| # selector: [1, n_stages, n_decoders] |
| ``` |
| |
| ## Training Details |
| |
| - **Base model:** JunzheJosephZhu/MultiDecoderDPRNN |
| - **Dataset:** 10,000 synthetic mixtures (LibriSpeech train-clean-100 + WHAM! noise), 1β5 speakers |
| - **Evaluation:** 50 samples from held-out CV set (LibriSpeech dev-clean + WHAM! cv) |
| - **Sample rate:** 8 kHz |
| - **Loss:** Multi-stage PIT SI-SDR + Cross-entropy selector loss |
| |
| ## Authors |
| |
| - Rys ([@RysNamadgi](https://huggingface.co/RysNamadgi)) |
| - Zhaksh |
| |
| Part of the [Namadgi](https://huggingface.co/Namadgi) research group. |
| |