fakufaku commited on
Commit
539d871
·
1 Parent(s): e769105

Diffsep model

Browse files
Files changed (3) hide show
  1. README.md +26 -0
  2. checkpoint.pt +3 -0
  3. hparams.yaml +122 -0
README.md CHANGED
@@ -1,3 +1,29 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ Diffusion-based Generative Speech Source Separation
6
+
7
+ This repository contains the checkpoints for the diffusion based speech
8
+ separation model from the paper Diffusion-based Generative Speech Source
9
+ Separation presented at ICASSP 2023.
10
+
11
+ The code to run the model is available on [github](https://github.com/fakufaku/diffusion-separation).
12
+
13
+ ### Abstract
14
+
15
+ We propose DiffSep, a new single channel source separation method based on
16
+ score-matching of a stochastic differential equation (SDE). We craft a tailored
17
+ continuous time diffusion-mixing process starting from the separated sources
18
+ and converging to a Gaussian distribution centered on their mixture. This
19
+ formulation lets us apply the machinery of score-based generative modelling.
20
+ First, we train a neural network to approximate the score function of the
21
+ marginal probabilities or the diffusion-mixing process. Then, we use it to
22
+ solve the reverse time SDE that progressively separates the sources starting
23
+ from their mixture. We propose a modified training strategy to handle model
24
+ mismatch and source permutation ambiguity. Experiments on the WSJ0 2mix dataset
25
+ demonstrate the potential of the method. Furthermore, the method is also
26
+ suitable for speech enhancement and shows performance competitive with prior
27
+ work on the VoiceBank-DEMAND dataset.
28
+
29
+ ID: `2022-10-23_01-37-07_experiment-model-large-multigpu_model.optimizer.lr-0.0002_model.sde.d_lambda-2.0_model.sde.sigma_min-0.05_epoch-979_si_sdr-11.271_N-30_snr-0.5_corrstep-1_denoise-True_schedule-None`
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66227d10f97b7884b9eb3c0f27bb579d98326147e37401b6917e2b51ea8aa39d
3
+ size 1313509474
hparams.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ seed: 64736289
3
+ name: default
4
+ train: true
5
+ test: false
6
+ path:
7
+ exp_root: exp
8
+ datasets:
9
+ wsj0_mix: data/wsj0_mix
10
+ figures: figures
11
+ datamodule:
12
+ train:
13
+ dl_opts:
14
+ num_workers: 8
15
+ shuffle: true
16
+ batch_size: 6
17
+ dataset:
18
+ _target_: datasets.WSJ0_mix
19
+ path: data/wsj0_mix
20
+ n_spkr: 2
21
+ fs: 8000
22
+ cut: max
23
+ split: train
24
+ max_len_s: 5
25
+ max_n_samples: null
26
+ val:
27
+ dl_opts:
28
+ num_workers: 8
29
+ shuffle: false
30
+ batch_size: 5
31
+ dataset:
32
+ _target_: datasets.WSJ0_mix
33
+ path: data/wsj0_mix
34
+ n_spkr: 2
35
+ fs: 8000
36
+ cut: max
37
+ split: val
38
+ max_len_s: null
39
+ max_n_samples: null
40
+ test:
41
+ dl_opts:
42
+ num_workers: 8
43
+ shuffle: false
44
+ batch_size: 5
45
+ dataset:
46
+ _target_: datasets.WSJ0_mix
47
+ path: data/wsj0_mix
48
+ n_spkr: 2
49
+ fs: 8000
50
+ cut: max
51
+ split: test
52
+ max_len_s: null
53
+ max_n_samples: null
54
+ model:
55
+ n_speakers: 2
56
+ fs: 8000
57
+ t_eps: 0.03
58
+ t_rev_init: 0.03
59
+ ema_decay: 0.999
60
+ valid_max_sep_batches: 2
61
+ time_sampling_strategy: uniform
62
+ train_source_order: power
63
+ init_hack: 5
64
+ mmnr_thresh_pit: -10.0
65
+ score_model:
66
+ _target_: models.score_models.ScoreModelNCSNpp
67
+ num_sources: 2
68
+ stft_args:
69
+ n_fft: 510
70
+ hop_length: 128
71
+ center: true
72
+ pad_mode: constant
73
+ backbone_args:
74
+ _target_: models.ncsnpp.NCSNpp
75
+ nf: 128
76
+ transform: exponent
77
+ spec_abs_exponent: 0.5
78
+ spec_factor: 0.15
79
+ sde:
80
+ _target_: sdes.sdes.MixSDE
81
+ ndim: 2
82
+ d_lambda: 2.0
83
+ sigma_min: 0.05
84
+ sigma_max: 0.5
85
+ 'N': 30
86
+ sampler:
87
+ 'N': 30
88
+ snr: 0.5
89
+ corrector_steps: 1
90
+ loss:
91
+ _target_: torch.nn.MSELoss
92
+ main_val_loss: val/si_sdr
93
+ main_val_loss_mode: max
94
+ val_losses:
95
+ val/si_sdr:
96
+ _target_: models.losses.SISDRLoss
97
+ zero_mean: true
98
+ clamp_db: 30
99
+ reduction: mean
100
+ sign_flip: true
101
+ optimizer:
102
+ _target_: torch.optim.Adam
103
+ lr: 0.0002
104
+ weight_decay: 0.0
105
+ scheduler: null
106
+ grad_clipper:
107
+ _target_: utils.FixedClipper
108
+ max_norm: 5.0
109
+ init_hack_p: 0.1
110
+ trainer:
111
+ _target_: pytorch_lightning.Trainer
112
+ accumulate_grad_batches: 2
113
+ min_epochs: 1
114
+ max_epochs: 1000
115
+ deterministic: true
116
+ accelerator: gpu
117
+ devices: -1
118
+ strategy: ddp
119
+ auto_select_gpus: true
120
+ check_val_every_n_epoch: 5
121
+ default_root_dir: .
122
+ profiler: false