Diffsep model
Browse files- README.md +26 -0
- checkpoint.pt +3 -0
- hparams.yaml +122 -0
README.md
CHANGED
|
@@ -1,3 +1,29 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
---
|
| 4 |
+
|
| 5 |
+
Diffusion-based Generative Speech Source Separation
|
| 6 |
+
|
| 7 |
+
This repository contains the checkpoints for the diffusion based speech
|
| 8 |
+
separation model from the paper Diffusion-based Generative Speech Source
|
| 9 |
+
Separation presented at ICASSP 2023.
|
| 10 |
+
|
| 11 |
+
The code to run the model is available on [github](https://github.com/fakufaku/diffusion-separation).
|
| 12 |
+
|
| 13 |
+
### Abstract
|
| 14 |
+
|
| 15 |
+
We propose DiffSep, a new single channel source separation method based on
|
| 16 |
+
score-matching of a stochastic differential equation (SDE). We craft a tailored
|
| 17 |
+
continuous time diffusion-mixing process starting from the separated sources
|
| 18 |
+
and converging to a Gaussian distribution centered on their mixture. This
|
| 19 |
+
formulation lets us apply the machinery of score-based generative modelling.
|
| 20 |
+
First, we train a neural network to approximate the score function of the
|
| 21 |
+
marginal probabilities or the diffusion-mixing process. Then, we use it to
|
| 22 |
+
solve the reverse time SDE that progressively separates the sources starting
|
| 23 |
+
from their mixture. We propose a modified training strategy to handle model
|
| 24 |
+
mismatch and source permutation ambiguity. Experiments on the WSJ0 2mix dataset
|
| 25 |
+
demonstrate the potential of the method. Furthermore, the method is also
|
| 26 |
+
suitable for speech enhancement and shows performance competitive with prior
|
| 27 |
+
work on the VoiceBank-DEMAND dataset.
|
| 28 |
+
|
| 29 |
+
ID: `2022-10-23_01-37-07_experiment-model-large-multigpu_model.optimizer.lr-0.0002_model.sde.d_lambda-2.0_model.sde.sigma_min-0.05_epoch-979_si_sdr-11.271_N-30_snr-0.5_corrstep-1_denoise-True_schedule-None`
|
checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:66227d10f97b7884b9eb3c0f27bb579d98326147e37401b6917e2b51ea8aa39d
|
| 3 |
+
size 1313509474
|
hparams.yaml
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
config:
|
| 2 |
+
seed: 64736289
|
| 3 |
+
name: default
|
| 4 |
+
train: true
|
| 5 |
+
test: false
|
| 6 |
+
path:
|
| 7 |
+
exp_root: exp
|
| 8 |
+
datasets:
|
| 9 |
+
wsj0_mix: data/wsj0_mix
|
| 10 |
+
figures: figures
|
| 11 |
+
datamodule:
|
| 12 |
+
train:
|
| 13 |
+
dl_opts:
|
| 14 |
+
num_workers: 8
|
| 15 |
+
shuffle: true
|
| 16 |
+
batch_size: 6
|
| 17 |
+
dataset:
|
| 18 |
+
_target_: datasets.WSJ0_mix
|
| 19 |
+
path: data/wsj0_mix
|
| 20 |
+
n_spkr: 2
|
| 21 |
+
fs: 8000
|
| 22 |
+
cut: max
|
| 23 |
+
split: train
|
| 24 |
+
max_len_s: 5
|
| 25 |
+
max_n_samples: null
|
| 26 |
+
val:
|
| 27 |
+
dl_opts:
|
| 28 |
+
num_workers: 8
|
| 29 |
+
shuffle: false
|
| 30 |
+
batch_size: 5
|
| 31 |
+
dataset:
|
| 32 |
+
_target_: datasets.WSJ0_mix
|
| 33 |
+
path: data/wsj0_mix
|
| 34 |
+
n_spkr: 2
|
| 35 |
+
fs: 8000
|
| 36 |
+
cut: max
|
| 37 |
+
split: val
|
| 38 |
+
max_len_s: null
|
| 39 |
+
max_n_samples: null
|
| 40 |
+
test:
|
| 41 |
+
dl_opts:
|
| 42 |
+
num_workers: 8
|
| 43 |
+
shuffle: false
|
| 44 |
+
batch_size: 5
|
| 45 |
+
dataset:
|
| 46 |
+
_target_: datasets.WSJ0_mix
|
| 47 |
+
path: data/wsj0_mix
|
| 48 |
+
n_spkr: 2
|
| 49 |
+
fs: 8000
|
| 50 |
+
cut: max
|
| 51 |
+
split: test
|
| 52 |
+
max_len_s: null
|
| 53 |
+
max_n_samples: null
|
| 54 |
+
model:
|
| 55 |
+
n_speakers: 2
|
| 56 |
+
fs: 8000
|
| 57 |
+
t_eps: 0.03
|
| 58 |
+
t_rev_init: 0.03
|
| 59 |
+
ema_decay: 0.999
|
| 60 |
+
valid_max_sep_batches: 2
|
| 61 |
+
time_sampling_strategy: uniform
|
| 62 |
+
train_source_order: power
|
| 63 |
+
init_hack: 5
|
| 64 |
+
mmnr_thresh_pit: -10.0
|
| 65 |
+
score_model:
|
| 66 |
+
_target_: models.score_models.ScoreModelNCSNpp
|
| 67 |
+
num_sources: 2
|
| 68 |
+
stft_args:
|
| 69 |
+
n_fft: 510
|
| 70 |
+
hop_length: 128
|
| 71 |
+
center: true
|
| 72 |
+
pad_mode: constant
|
| 73 |
+
backbone_args:
|
| 74 |
+
_target_: models.ncsnpp.NCSNpp
|
| 75 |
+
nf: 128
|
| 76 |
+
transform: exponent
|
| 77 |
+
spec_abs_exponent: 0.5
|
| 78 |
+
spec_factor: 0.15
|
| 79 |
+
sde:
|
| 80 |
+
_target_: sdes.sdes.MixSDE
|
| 81 |
+
ndim: 2
|
| 82 |
+
d_lambda: 2.0
|
| 83 |
+
sigma_min: 0.05
|
| 84 |
+
sigma_max: 0.5
|
| 85 |
+
'N': 30
|
| 86 |
+
sampler:
|
| 87 |
+
'N': 30
|
| 88 |
+
snr: 0.5
|
| 89 |
+
corrector_steps: 1
|
| 90 |
+
loss:
|
| 91 |
+
_target_: torch.nn.MSELoss
|
| 92 |
+
main_val_loss: val/si_sdr
|
| 93 |
+
main_val_loss_mode: max
|
| 94 |
+
val_losses:
|
| 95 |
+
val/si_sdr:
|
| 96 |
+
_target_: models.losses.SISDRLoss
|
| 97 |
+
zero_mean: true
|
| 98 |
+
clamp_db: 30
|
| 99 |
+
reduction: mean
|
| 100 |
+
sign_flip: true
|
| 101 |
+
optimizer:
|
| 102 |
+
_target_: torch.optim.Adam
|
| 103 |
+
lr: 0.0002
|
| 104 |
+
weight_decay: 0.0
|
| 105 |
+
scheduler: null
|
| 106 |
+
grad_clipper:
|
| 107 |
+
_target_: utils.FixedClipper
|
| 108 |
+
max_norm: 5.0
|
| 109 |
+
init_hack_p: 0.1
|
| 110 |
+
trainer:
|
| 111 |
+
_target_: pytorch_lightning.Trainer
|
| 112 |
+
accumulate_grad_batches: 2
|
| 113 |
+
min_epochs: 1
|
| 114 |
+
max_epochs: 1000
|
| 115 |
+
deterministic: true
|
| 116 |
+
accelerator: gpu
|
| 117 |
+
devices: -1
|
| 118 |
+
strategy: ddp
|
| 119 |
+
auto_select_gpus: true
|
| 120 |
+
check_val_every_n_epoch: 5
|
| 121 |
+
default_root_dir: .
|
| 122 |
+
profiler: false
|