Multi-Decoder DPRNN β€” Fine-Tuned for 1–5 Speaker Separation

Fine-tuned version of MultiDecoderDPRNN for variable speaker count (1–5) speech separation.

What changed?

The original pre-trained model supports 2–5 speakers and always outputs 5 active sources regardless of actual speaker count (~22% speaker count accuracy). Our fine-tuning teaches the model to:

  • Correctly identify the number of speakers (88% accuracy)
  • Output silence on unused channels
  • Support single-speaker scenarios (new 1-speaker decoder)

Checkpoints

Checkpoint Architecture SI-SDR SI-SDRi SDR Spk Acc File
Rys (best) n_srcs=[1,2,3,4,5] +7.33 dB +10.01 dB +7.32 dB 88.0% weights/dprnn_rys_weights.pt
Zhaksh n_srcs=[2,3,4,5] +3.24 dB +5.92 dB +3.24 dB 50.0% weights/dprnn_zhaksh_weights.pt
Original (baseline) n_srcs=[2,3,4,5] -1.29 dB +1.39 dB -1.29 dB 22.0% β€”

SI-SDRi by Speaker Count

N Original Rys Zhaksh
1 -9.17 dB +8.86 dB -7.47 dB
2 +4.07 dB +11.98 dB +11.58 dB
3 +2.46 dB +10.00 dB +9.78 dB
4 +4.56 dB +9.65 dB +7.78 dB
5 +5.05 dB +9.56 dB +7.94 dB

Usage

import sys, torch
sys.path.insert(0, "asteroid/egs/wsj0-mix-var/Multi-Decoder-DPRNN")
from model import MultiDecoderDPRNN
from huggingface_hub import hf_hub_download

# Load base architecture with extended speaker count
pretrained = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN")
fb = pretrained.encoder.filterbank
cfg = dict(
    bn_chan=pretrained.masker.bn_chan, hid_size=pretrained.masker.hid_size,
    chunk_size=pretrained.masker.chunk_size, hop_size=pretrained.masker.hop_size,
    n_repeats=pretrained.masker.n_repeats, norm_type=pretrained.masker.norm_type,
    bidirectional=pretrained.masker.bidirectional, rnn_type=pretrained.masker.rnn_type,
    num_layers=pretrained.masker.num_layers, dropout=pretrained.masker.dropout,
    n_filters=fb.n_feats_out, kernel_size=fb.kernel_size,
    stride=fb.stride, sample_rate=8000,
)

model = MultiDecoderDPRNN(n_srcs=[1, 2, 3, 4, 5], **cfg)
weights = hf_hub_download("Namadgi/MultiDecoderDPRNN-finetuned", "weights/dprnn_rys_weights.pt")
model.load_state_dict(torch.load(weights, map_location="cpu"))
model.eval()

# Inference
mix = torch.randn(1, 32000)  # [batch, time] at 8kHz
reconstructed, selector = model(mix, ground_truth=[3])  # 3 speakers
# reconstructed: [1, n_stages, max_spks, T]
# selector: [1, n_stages, n_decoders]

Training Details

  • Base model: JunzheJosephZhu/MultiDecoderDPRNN
  • Dataset: 10,000 synthetic mixtures (LibriSpeech train-clean-100 + WHAM! noise), 1–5 speakers
  • Evaluation: 50 samples from held-out CV set (LibriSpeech dev-clean + WHAM! cv)
  • Sample rate: 8 kHz
  • Loss: Multi-stage PIT SI-SDR + Cross-entropy selector loss

Authors

Part of the Namadgi research group.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Namadgi/MultiDecoderDPRNN-finetuned

Finetuned
(1)
this model