swc2 commited on
Commit
f9e337c
·
1 Parent(s): 36e02be

update model

Browse files
Files changed (28) hide show
  1. Conv-Tasnet/results/convtasnet_4-mix/1234/env.log +0 -93
  2. Conv-Tasnet/results/convtasnet_4-mix/1234/hyperparams.yaml +0 -179
  3. Conv-Tasnet/results/convtasnet_4-mix/1234/log.txt +0 -0
  4. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/CKPT.yaml +0 -4
  5. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/brain.ckpt +0 -3
  6. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/counter.ckpt +0 -3
  7. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/dataloader-TRAIN.ckpt +0 -3
  8. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/decoder.ckpt +0 -3
  9. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/encoder.ckpt +0 -3
  10. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/lr_scheduler.ckpt +0 -3
  11. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/masknet.ckpt +0 -3
  12. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/optimizer.ckpt +0 -3
  13. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/CKPT.yaml +0 -4
  14. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/brain.ckpt +0 -3
  15. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/counter.ckpt +0 -3
  16. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/dataloader-TRAIN.ckpt +0 -3
  17. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/decoder.ckpt +0 -3
  18. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/encoder.ckpt +0 -3
  19. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/lr_scheduler.ckpt +0 -3
  20. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/masknet.ckpt +0 -3
  21. Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/optimizer.ckpt +0 -3
  22. Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_tr.csv +0 -0
  23. Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_val.csv +0 -0
  24. Conv-Tasnet/results/convtasnet_4-mix/1234/save/test_data.csv +0 -0
  25. Conv-Tasnet/results/convtasnet_4-mix/1234/test.py +0 -628
  26. Conv-Tasnet/results/convtasnet_4-mix/1234/test_results.csv +0 -1
  27. Conv-Tasnet/results/convtasnet_4-mix/1234/train.py +0 -628
  28. Conv-Tasnet/results/convtasnet_4-mix/1234/train_log.txt +0 -242
Conv-Tasnet/results/convtasnet_4-mix/1234/env.log DELETED
@@ -1,93 +0,0 @@
1
- SpeechBrain system description
2
- ==============================
3
- Python version:
4
- 3.11.13 (main, Jun 5 2025, 13:12:00) [GCC 11.2.0]
5
- ==============================
6
- Installed Python packages:
7
- black==24.3.0
8
- certifi==2025.6.15
9
- cffi==1.17.1
10
- cfgv==3.4.0
11
- charset-normalizer==3.4.2
12
- click==8.1.7
13
- distlib==0.3.9
14
- docstring_parser_fork==0.0.12
15
- filelock==3.18.0
16
- flake8==7.0.0
17
- fsspec==2025.5.1
18
- future==1.0.0
19
- hf-xet==1.1.5
20
- huggingface-hub==0.33.0
21
- HyperPyYAML==1.2.2
22
- identify==2.6.12
23
- idna==3.10
24
- iniconfig==2.1.0
25
- isort==5.13.2
26
- Jinja2==3.1.6
27
- joblib==1.5.1
28
- MarkupSafe==3.0.2
29
- mccabe==0.7.0
30
- mir_eval==0.6
31
- mpmath==1.3.0
32
- mypy_extensions==1.1.0
33
- networkx==3.5
34
- nodeenv==1.9.1
35
- numpy==2.3.1
36
- nvidia-cublas-cu12==12.6.4.1
37
- nvidia-cuda-cupti-cu12==12.6.80
38
- nvidia-cuda-nvrtc-cu12==12.6.77
39
- nvidia-cuda-runtime-cu12==12.6.77
40
- nvidia-cudnn-cu12==9.5.1.17
41
- nvidia-cufft-cu12==11.3.0.4
42
- nvidia-cufile-cu12==1.11.1.6
43
- nvidia-curand-cu12==10.3.7.77
44
- nvidia-cusolver-cu12==11.7.1.2
45
- nvidia-cusparse-cu12==12.5.4.2
46
- nvidia-cusparselt-cu12==0.6.3
47
- nvidia-nccl-cu12==2.26.2
48
- nvidia-nvjitlink-cu12==12.6.85
49
- nvidia-nvtx-cu12==12.6.77
50
- packaging==25.0
51
- pandas==2.3.0
52
- pathspec==0.12.1
53
- platformdirs==4.3.8
54
- pluggy==1.6.0
55
- pre_commit==4.2.0
56
- pycodestyle==2.11.0
57
- pycparser==2.22
58
- pydoclint==0.4.1
59
- pyflakes==3.2.0
60
- pyloudnorm==0.1.1
61
- pytest==7.4.0
62
- python-dateutil==2.9.0.post0
63
- pytz==2025.2
64
- PyYAML==6.0.2
65
- regex==2024.11.6
66
- requests==2.32.4
67
- ruamel.yaml==0.18.14
68
- ruamel.yaml.clib==0.2.12
69
- safetensors==0.5.3
70
- scipy==1.16.0
71
- sentencepiece==0.2.0
72
- six==1.17.0
73
- soundfile==0.13.1
74
- sox==1.5.0
75
- -e git+ssh://git@github.com/speechbrain/speechbrain.git@c75ab5489431fd0a2a7d21160bc37677801cb506#egg=speechbrain
76
- sympy==1.14.0
77
- tokenizers==0.21.2
78
- torch==2.7.1
79
- torchaudio==2.7.1
80
- tqdm==4.67.1
81
- transformers==4.52.4
82
- triton==3.3.1
83
- typing_extensions==4.14.0
84
- tzdata==2025.2
85
- urllib3==2.5.0
86
- virtualenv==20.31.2
87
- yamllint==1.35.1
88
- ==============================
89
- Git revision:
90
- c75ab5489
91
- ==============================
92
- CUDA version:
93
- 12.6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/hyperparams.yaml DELETED
@@ -1,179 +0,0 @@
1
- # Generated 2025-06-26 from:
2
- # /work106/youzhenghai/project/speechbrain/myegs/FORHUAWEI_TASNET/separation/hparams/convtasnet_4mix.yaml
3
- # yamllint disable
4
- # ################################
5
- # Model: SepFormer for source separation
6
- # https://arxiv.org/abs/2010.13154
7
- # Dataset : WSJ0-2mix and WSJ0-3mix
8
- # ################################
9
- # Basic parameters
10
- # Seed needs to be set at top of yaml, before objects with parameters are made
11
- #
12
- seed: 1234
13
- __set_seed: !apply:speechbrain.utils.seed_everything [1234]
14
-
15
- # Data params
16
-
17
- # e.g. '/yourpath/wsj0-mix/2speakers'
18
- # end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix
19
- data_folder: /work105/youzhenghai/data/wsj0_2mix
20
-
21
- # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used
22
- # e.g. /yourpath/wsj0-processed/si_tr_s/
23
- base_folder_dm: /yourpath/wsj0-processed/si_tr_s/
24
-
25
- experiment_name: convtasnet_4-mix
26
- output_folder: results/convtasnet_4-mix/1234
27
- train_log: results/convtasnet_4-mix/1234/train_log.txt
28
- save_folder: results/convtasnet_4-mix/1234/save
29
- train_data: results/convtasnet_4-mix/1234/save/record_tr.csv
30
- valid_data: results/convtasnet_4-mix/1234/save/record_val.csv
31
- test_data: results/convtasnet_4-mix/1234/save/test_data.csv
32
- skip_prep: false
33
-
34
-
35
- # Experiment params
36
- precision: fp32 # bf16, fp16 or fp32
37
- num_spks: 4 # set to 3 for wsj0-3mix
38
- noprogressbar: false
39
- save_audio: false # Save estimated sources on disk
40
- sample_rate: 16000
41
-
42
- ####################### Training Parameters ####################################
43
- N_epochs: 200
44
- batch_size: 2
45
- lr: 0.00015
46
- clip_grad_norm: 5
47
- loss_upper_lim: 999999 # this is the upper limit for an acceptable loss
48
- # if True, the training sequences are cut to a specified length
49
- limit_training_signal_len: true
50
- # this is the length of sequences if we choose to limit
51
- # the signal length of training sequences
52
- training_signal_len: 64000000
53
-
54
- # Set it to True to dynamically create mixtures at training time
55
- dynamic_mixing: false
56
-
57
- # Parameters for data augmentation
58
- use_wavedrop: false
59
- use_speedperturb: true
60
- use_rand_shift: false
61
- min_shift: -8000
62
- max_shift: 8000
63
-
64
- # Speed perturbation
65
- speed_changes: &id001 [95, 100, 105]
66
-
67
- # Frequency drop: randomly drops a number of frequency bands to zero.
68
- speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
69
- orig_freq: 16000
70
- speeds: *id001
71
- drop_freq_low: 0 # Min frequency band dropout probability
72
- drop_freq_high: 1 # Max frequency band dropout probability
73
- drop_freq_count_low: 1 # Min number of frequency bands to drop
74
- drop_freq_count_high: 3 # Max number of frequency bands to drop
75
- drop_freq_width: 0.05 # Width of frequency bands to drop
76
-
77
- drop_freq: !new:speechbrain.augment.time_domain.DropFreq
78
- drop_freq_low: 0
79
- drop_freq_high: 1
80
- drop_freq_count_low: 1
81
- drop_freq_count_high: 3
82
- drop_freq_width: 0.05
83
-
84
- # Time drop: randomly drops a number of temporal chunks.
85
- drop_chunk_count_low: 1 # Min number of audio chunks to drop
86
- drop_chunk_count_high: 5 # Max number of audio chunks to drop
87
- drop_chunk_length_low: 1000 # Min length of audio chunks to drop
88
- drop_chunk_length_high: 2000 # Max length of audio chunks to drop
89
-
90
- drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
91
- drop_length_low: 1000
92
- drop_length_high: 2000
93
- drop_count_low: 1
94
- drop_count_high: 5
95
-
96
- # loss thresholding -- this thresholds the training loss
97
- threshold_byloss: true
98
- threshold: -30
99
-
100
- # Encoder parameters
101
- N_encoder_out: 256
102
- # out_channels: 256
103
- kernel_size: 32
104
- kernel_stride: 16
105
-
106
- # Dataloader options
107
- dataloader_opts:
108
- batch_size: 2
109
- num_workers: 3
110
-
111
-
112
- # Specifying the network
113
- Encoder: &id002 !new:speechbrain.lobes.models.dual_path.Encoder
114
- kernel_size: 32
115
- out_channels: 256
116
-
117
- # intra: !new:speechbrain.lobes.models.dual_path.SBRNNBlock
118
- # num_layers: 1
119
- # input_size: !ref <out_channels>
120
- # hidden_channels: !ref <out_channels>
121
- # dropout: 0
122
- # bidirectional: True
123
-
124
- # inter: !new:speechbrain.lobes.models.dual_path.SBRNNBlock
125
- # num_layers: 1
126
- # input_size: !ref <out_channels>
127
- # hidden_channels: !ref <out_channels>
128
- # dropout: 0
129
- # bidirectional: True
130
-
131
- MaskNet: &id004 !new:speechbrain.lobes.models.conv_tasnet.MaskNet
132
-
133
- N: 256
134
- B: 256
135
- H: 512
136
- P: 3
137
- X: 6
138
- R: 4
139
- C: 4
140
- norm_type: gLN
141
- causal: true
142
- mask_nonlinear: relu
143
-
144
- Decoder: &id003 !new:speechbrain.lobes.models.dual_path.Decoder
145
- in_channels: 256
146
- out_channels: 1
147
- kernel_size: 32
148
- stride: 16
149
- bias: false
150
-
151
- optimizer: !name:torch.optim.Adam
152
- lr: 0.00015
153
- weight_decay: 0
154
-
155
- loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper
156
-
157
- lr_scheduler: &id006 !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
158
-
159
- factor: 0.5
160
- patience: 2
161
- dont_halve_until_epoch: 85
162
-
163
- epoch_counter: &id005 !new:speechbrain.utils.epoch_loop.EpochCounter
164
- limit: 200
165
-
166
- modules:
167
- encoder: *id002
168
- decoder: *id003
169
- masknet: *id004
170
- checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
171
- checkpoints_dir: results/convtasnet_4-mix/1234/save
172
- recoverables:
173
- encoder: *id002
174
- decoder: *id003
175
- masknet: *id004
176
- counter: *id005
177
- lr_scheduler: *id006
178
- train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
179
- save_file: results/convtasnet_4-mix/1234/train_log.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/log.txt DELETED
The diff for this file is too large to render. See raw diff
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/CKPT.yaml DELETED
@@ -1,4 +0,0 @@
1
- # yamllint disable
2
- end-of-epoch: true
3
- si-snr: 22.240427712364045
4
- unixtime: 1750961143.069555
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/brain.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
3
- size 46
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/counter.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5ef6fdf32513aa7cd11f72beccf132b9224d33f271471fff402742887a171edf
3
- size 3
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/dataloader-TRAIN.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c344ba7044815dd03c3448028a43e5b9c16074cb5a6a19c7ae86165c149735f
3
- size 3
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/decoder.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b89e695d01ef7a5aeb76f5000f70959a078e4ea1cf97ae978a2a4dc2121c7f29
3
- size 34409
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/encoder.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c5ef4fe38605072559dbf12b09643423c4649460c0f803f34f047e92f9358f39
3
- size 34473
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/lr_scheduler.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f02f6900fea06c469d975f48c9b3f4d40868d5fb6e6758baf76c4e68c4785dd1
3
- size 2251
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/masknet.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:100869f60d27f540b6d23e4a811cff04541c67e6ff4639776645069f841f0db5
3
- size 26926023
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/optimizer.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c05ce1c793e4f0bae4a6905774bbfc8360e4450103008c838ea195f4a146452c
3
- size 53964363
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/CKPT.yaml DELETED
@@ -1,4 +0,0 @@
1
- # yamllint disable
2
- end-of-epoch: true
3
- si-snr: 22.256136728080673
4
- unixtime: 1750994220.6695538
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/brain.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
3
- size 46
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/counter.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:27badc983df1780b60c2b3fa9d3a19a00e46aac798451f0febdca52920faaddf
3
- size 3
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/dataloader-TRAIN.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5c344ba7044815dd03c3448028a43e5b9c16074cb5a6a19c7ae86165c149735f
3
- size 3
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/decoder.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ba5891c2436cdefe57f4ca4b87bfa8267f927948330ea482d9cd6fadcd14163
3
- size 34409
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/encoder.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:533c6dfe50d9c410e8c0e4907efaf95679ca2fe85f0ceab9aa0ede0c817d58d8
3
- size 34473
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/lr_scheduler.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8689d8fb8de14a5995a161e50181134543321bbd431f774ce20f507239669ce3
3
- size 3147
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/masknet.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3acd263841af684db0cf622b77e83a807e969390661e115b89a8139f8785aa64
3
- size 26926023
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/optimizer.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:89593e2f633757a61883ef5aeb48a9e79ec4b09565d470c5571ee16edcb51c5c
3
- size 53964363
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_tr.csv DELETED
The diff for this file is too large to render. See raw diff
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_val.csv DELETED
The diff for this file is too large to render. See raw diff
 
Conv-Tasnet/results/convtasnet_4-mix/1234/save/test_data.csv DELETED
The diff for this file is too large to render. See raw diff
 
Conv-Tasnet/results/convtasnet_4-mix/1234/test.py DELETED
@@ -1,628 +0,0 @@
1
- #!/usr/bin/env/python3
2
- """Recipe for training a neural speech separation system on the wsjmix
3
- dataset. The system employs an encoder, a decoder, and a masking network.
4
-
5
- To run this recipe, do the following:
6
- > python train.py hparams/sepformer.yaml
7
- > python train.py hparams/dualpath_rnn.yaml
8
- > python train.py hparams/convtasnet.yaml
9
-
10
- The experiment file is flexible enough to support different neural
11
- networks. By properly changing the parameter files, you can try
12
- different architectures. The script supports both wsj2mix and
13
- wsj3mix.
14
-
15
- # 4-mix 主要根据 num_spks 修改 train.py 和 config
16
- Authors
17
- * Cem Subakan 2020
18
- * Mirco Ravanelli 2020
19
- * Samuele Cornell 2020
20
- * Mirko Bronzi 2020
21
- * Jianyuan Zhong 2020
22
- """
23
-
24
- import csv
25
- import os
26
- import sys
27
-
28
- import numpy as np
29
- import torch
30
- import torch.nn.functional as F
31
- import torchaudio
32
- from hyperpyyaml import load_hyperpyyaml
33
- from tqdm import tqdm
34
-
35
- import speechbrain as sb
36
- import speechbrain.nnet.schedulers as schedulers
37
- from speechbrain.utils.distributed import run_on_main
38
- from speechbrain.utils.logger import get_logger
39
-
40
-
41
- # Define training procedure
42
- class Separation(sb.Brain):
43
- def compute_forward(self, mix, targets, stage, noise=None):
44
- """Forward computations from the mixture to the separated signals."""
45
-
46
- # Unpack lists and put tensors in the right device
47
- mix, mix_lens = mix
48
- mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
49
-
50
- # Convert targets to tensor
51
- targets = torch.cat(
52
- [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
53
- dim=-1,
54
- ).to(self.device)
55
-
56
- # Add speech distortions
57
- if stage == sb.Stage.TRAIN:
58
- with torch.no_grad():
59
- if self.hparams.use_speedperturb:
60
- mix, targets = self.add_speed_perturb(targets, mix_lens)
61
-
62
- mix = targets.sum(-1)
63
-
64
- if self.hparams.use_wavedrop:
65
- mix = self.hparams.drop_chunk(mix, mix_lens)
66
- mix = self.hparams.drop_freq(mix)
67
-
68
- if self.hparams.limit_training_signal_len:
69
- mix, targets = self.cut_signals(mix, targets)
70
-
71
- # Separation
72
- mix_w = self.hparams.Encoder(mix)
73
- est_mask = self.hparams.MaskNet(mix_w)
74
- mix_w = torch.stack([mix_w] * self.hparams.num_spks)
75
- sep_h = mix_w * est_mask
76
-
77
- # Decoding
78
- est_source = torch.cat(
79
- [
80
- self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
81
- for i in range(self.hparams.num_spks)
82
- ],
83
- dim=-1,
84
- )
85
-
86
- # T changed after conv1d in encoder, fix it here
87
- T_origin = mix.size(1)
88
- T_est = est_source.size(1)
89
- if T_origin > T_est:
90
- est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
91
- else:
92
- est_source = est_source[:, :T_origin, :]
93
-
94
- return est_source, targets
95
-
96
- def compute_objectives(self, predictions, targets):
97
- """Computes the sinr loss"""
98
- return self.hparams.loss(targets, predictions)
99
-
100
- def fit_batch(self, batch):
101
- """Trains one batch"""
102
-
103
- # Unpacking batch list
104
- mixture = batch.mix_sig
105
- targets = [batch.s1_sig, batch.s2_sig]
106
-
107
- if self.hparams.num_spks == 3:
108
- targets.append(batch.s3_sig)
109
-
110
- if self.hparams.num_spks == 4:
111
- targets.append(batch.s3_sig)
112
- targets.append(batch.s4_sig)
113
-
114
- with self.training_ctx:
115
- predictions, targets = self.compute_forward(
116
- mixture, targets, sb.Stage.TRAIN
117
- )
118
- loss = self.compute_objectives(predictions, targets)
119
-
120
- # hard threshold the easy dataitems
121
- if self.hparams.threshold_byloss:
122
- th = self.hparams.threshold
123
- loss = loss[loss > th]
124
- if loss.nelement() > 0:
125
- loss = loss.mean()
126
- else:
127
- loss = loss.mean()
128
-
129
- if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
130
- self.scaler.scale(loss).backward()
131
- if self.hparams.clip_grad_norm >= 0:
132
- self.scaler.unscale_(self.optimizer)
133
- torch.nn.utils.clip_grad_norm_(
134
- self.modules.parameters(),
135
- self.hparams.clip_grad_norm,
136
- )
137
- self.scaler.step(self.optimizer)
138
- self.scaler.update()
139
- else:
140
- self.nonfinite_count += 1
141
- logger.info(
142
- "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
143
- self.nonfinite_count
144
- )
145
- )
146
- loss.data = torch.tensor(0.0).to(self.device)
147
- self.optimizer.zero_grad()
148
-
149
- return loss.detach().cpu()
150
-
151
- def evaluate_batch(self, batch, stage):
152
- """Computations needed for validation/test batches"""
153
- snt_id = batch.id
154
- mixture = batch.mix_sig
155
- targets = [batch.s1_sig, batch.s2_sig]
156
- if self.hparams.num_spks == 3:
157
- targets.append(batch.s3_sig)
158
-
159
- if self.hparams.num_spks == 4:
160
- targets.append(batch.s3_sig)
161
- targets.append(batch.s4_sig)
162
-
163
- with torch.no_grad():
164
- predictions, targets = self.compute_forward(mixture, targets, stage)
165
- loss = self.compute_objectives(predictions, targets)
166
-
167
- # Manage audio file saving
168
- if stage == sb.Stage.TEST and self.hparams.save_audio:
169
- if hasattr(self.hparams, "n_audio_to_save"):
170
- if self.hparams.n_audio_to_save > 0:
171
- self.save_audio(snt_id[0], mixture, targets, predictions)
172
- self.hparams.n_audio_to_save += -1
173
- else:
174
- self.save_audio(snt_id[0], mixture, targets, predictions)
175
-
176
- return loss.mean().detach()
177
-
178
- def on_stage_end(self, stage, stage_loss, epoch):
179
- """Gets called at the end of a epoch."""
180
- # Compute/store important stats
181
- stage_stats = {"si-snr": stage_loss}
182
- if stage == sb.Stage.TRAIN:
183
- self.train_stats = stage_stats
184
-
185
- # Perform end-of-iteration things, like annealing, logging, etc.
186
- if stage == sb.Stage.VALID:
187
- # Learning rate annealing
188
- if isinstance(
189
- self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
190
- ):
191
- current_lr, next_lr = self.hparams.lr_scheduler(
192
- [self.optimizer], epoch, stage_loss
193
- )
194
- schedulers.update_learning_rate(self.optimizer, next_lr)
195
- else:
196
- # if we do not use the reducelronplateau, we do not change the lr
197
- current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
198
-
199
- self.hparams.train_logger.log_stats(
200
- stats_meta={"epoch": epoch, "lr": current_lr},
201
- train_stats=self.train_stats,
202
- valid_stats=stage_stats,
203
- )
204
- self.checkpointer.save_and_keep_only(
205
- meta={"si-snr": stage_stats["si-snr"]}, min_keys=["si-snr"]
206
- )
207
- elif stage == sb.Stage.TEST:
208
- self.hparams.train_logger.log_stats(
209
- stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
210
- test_stats=stage_stats,
211
- )
212
-
213
- def add_speed_perturb(self, targets, targ_lens):
214
- """Adds speed perturbation and random_shift to the input signals"""
215
-
216
- min_len = -1
217
- recombine = False
218
-
219
- if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
220
- # Performing speed change (independently on each source)
221
- new_targets = []
222
- recombine = True
223
-
224
- for i in range(targets.shape[-1]):
225
- new_target = self.hparams.speed_perturb(targets[:, :, i])
226
- new_targets.append(new_target)
227
- if i == 0:
228
- min_len = new_target.shape[-1]
229
- else:
230
- if new_target.shape[-1] < min_len:
231
- min_len = new_target.shape[-1]
232
-
233
- if self.hparams.use_rand_shift:
234
- # Performing random_shift (independently on each source)
235
- recombine = True
236
- for i in range(targets.shape[-1]):
237
- rand_shift = torch.randint(
238
- self.hparams.min_shift, self.hparams.max_shift, (1,)
239
- )
240
- new_targets[i] = new_targets[i].to(self.device)
241
- new_targets[i] = torch.roll(
242
- new_targets[i], shifts=(rand_shift[0],), dims=1
243
- )
244
-
245
- # Re-combination
246
- if recombine:
247
- if self.hparams.use_speedperturb:
248
- targets = torch.zeros(
249
- targets.shape[0],
250
- min_len,
251
- targets.shape[-1],
252
- device=targets.device,
253
- dtype=torch.float,
254
- )
255
- for i, new_target in enumerate(new_targets):
256
- targets[:, :, i] = new_targets[i][:, 0:min_len]
257
-
258
- mix = targets.sum(-1)
259
- return mix, targets
260
-
261
- def cut_signals(self, mixture, targets):
262
- """This function selects a random segment of a given length within the mixture.
263
- The corresponding targets are selected accordingly"""
264
- randstart = torch.randint(
265
- 0,
266
- 1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
267
- (1,),
268
- ).item()
269
- targets = targets[
270
- :, randstart : randstart + self.hparams.training_signal_len, :
271
- ]
272
- mixture = mixture[
273
- :, randstart : randstart + self.hparams.training_signal_len
274
- ]
275
- return mixture, targets
276
-
277
- def reset_layer_recursively(self, layer):
278
- """Reinitializes the parameters of the neural networks"""
279
- if hasattr(layer, "reset_parameters"):
280
- layer.reset_parameters()
281
- for child_layer in layer.modules():
282
- if layer != child_layer:
283
- self.reset_layer_recursively(child_layer)
284
-
285
- def save_results(self, test_data):
286
- """This script computes the SDR and SI-SNR metrics and saves
287
- them into a csv file"""
288
-
289
- # This package is required for SDR computation
290
- from mir_eval.separation import bss_eval_sources
291
-
292
- # Create folders where to store audio
293
- save_file = os.path.join(self.hparams.output_folder, "test_results.csv")
294
-
295
- # Variable init
296
- all_sdrs = []
297
- all_sdrs_i = []
298
- all_sisnrs = []
299
- all_sisnrs_i = []
300
- csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
301
-
302
- test_loader = sb.dataio.dataloader.make_dataloader(
303
- test_data, **self.hparams.dataloader_opts
304
- )
305
-
306
- with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
307
- writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
308
- writer.writeheader()
309
-
310
- # Loop over all test sentence
311
- with tqdm(test_loader, dynamic_ncols=True) as t:
312
- for i, batch in enumerate(t):
313
- # Apply Separation
314
- mixture, mix_len = batch.mix_sig
315
- snt_id = batch.id
316
- targets = [batch.s1_sig, batch.s2_sig]
317
- if self.hparams.num_spks == 3:
318
- targets.append(batch.s3_sig)
319
-
320
- if self.hparams.num_spks == 4:
321
- targets.append(batch.s3_sig)
322
- targets.append(batch.s4_sig)
323
-
324
- with torch.no_grad():
325
- predictions, targets = self.compute_forward(
326
- batch.mix_sig, targets, sb.Stage.TEST
327
- )
328
-
329
- # Compute SI-SNR
330
- sisnr = self.compute_objectives(predictions, targets)
331
-
332
- # Compute SI-SNR improvement
333
- mixture_signal = torch.stack(
334
- [mixture] * self.hparams.num_spks, dim=-1
335
- )
336
- mixture_signal = mixture_signal.to(targets.device)
337
- sisnr_baseline = self.compute_objectives(
338
- mixture_signal, targets
339
- )
340
- sisnr_i = sisnr - sisnr_baseline
341
-
342
- # Compute SDR
343
- sdr, _, _, _ = bss_eval_sources(
344
- targets[0].t().cpu().numpy(),
345
- predictions[0].t().detach().cpu().numpy(),
346
- )
347
-
348
- sdr_baseline, _, _, _ = bss_eval_sources(
349
- targets[0].t().cpu().numpy(),
350
- mixture_signal[0].t().detach().cpu().numpy(),
351
- )
352
-
353
- sdr_i = sdr.mean() - sdr_baseline.mean()
354
-
355
- # Saving on a csv file
356
- row = {
357
- "snt_id": snt_id[0],
358
- "sdr": sdr.mean(),
359
- "sdr_i": sdr_i,
360
- "si-snr": -sisnr.item(),
361
- "si-snr_i": -sisnr_i.item(),
362
- }
363
- writer.writerow(row)
364
-
365
- # Metric Accumulation
366
- all_sdrs.append(sdr.mean())
367
- all_sdrs_i.append(sdr_i.mean())
368
- all_sisnrs.append(-sisnr.item())
369
- all_sisnrs_i.append(-sisnr_i.item())
370
-
371
- row = {
372
- "snt_id": "avg",
373
- "sdr": np.array(all_sdrs).mean(),
374
- "sdr_i": np.array(all_sdrs_i).mean(),
375
- "si-snr": np.array(all_sisnrs).mean(),
376
- "si-snr_i": np.array(all_sisnrs_i).mean(),
377
- }
378
- writer.writerow(row)
379
-
380
- logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
381
- logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
382
- logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
383
- logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
384
-
385
- def save_audio(self, snt_id, mixture, targets, predictions):
386
- "saves the test audio (mixture, targets, and estimated sources) on disk"
387
-
388
- # Create output folder
389
- save_path = os.path.join(self.hparams.save_folder, "audio_results")
390
- if not os.path.exists(save_path):
391
- os.mkdir(save_path)
392
-
393
- for ns in range(self.hparams.num_spks):
394
- # Estimated source
395
- signal = predictions[0, :, ns]
396
- signal = signal / signal.abs().max()
397
- save_file = os.path.join(
398
- save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
399
- )
400
- torchaudio.save(
401
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
402
- )
403
-
404
- # Original source
405
- signal = targets[0, :, ns]
406
- signal = signal / signal.abs().max()
407
- save_file = os.path.join(
408
- save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
409
- )
410
- torchaudio.save(
411
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
412
- )
413
-
414
- # Mixture
415
- signal = mixture[0][0, :]
416
- signal = signal / signal.abs().max()
417
- save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
418
- torchaudio.save(
419
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
420
- )
421
-
422
-
423
- def dataio_prep(hparams):
424
- """Creates data processing pipeline"""
425
-
426
- # 1. Define datasets
427
- train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
428
- csv_path=hparams["train_data"],
429
- replacements={"data_root": hparams["data_folder"]},
430
- )
431
-
432
- valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
433
- csv_path=hparams["valid_data"],
434
- replacements={"data_root": hparams["data_folder"]},
435
- )
436
-
437
- test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
438
- csv_path=hparams["test_data"],
439
- replacements={"data_root": hparams["data_folder"]},
440
- )
441
-
442
- datasets = [train_data, valid_data, test_data]
443
-
444
- # 2. Provide audio pipelines
445
-
446
- @sb.utils.data_pipeline.takes("mix_wav")
447
- @sb.utils.data_pipeline.provides("mix_sig")
448
- def audio_pipeline_mix(mix_wav):
449
- mix_sig = sb.dataio.dataio.read_audio(mix_wav)
450
- return mix_sig
451
-
452
- @sb.utils.data_pipeline.takes("s1_wav")
453
- @sb.utils.data_pipeline.provides("s1_sig")
454
- def audio_pipeline_s1(s1_wav):
455
- s1_sig = sb.dataio.dataio.read_audio(s1_wav)
456
- return s1_sig
457
-
458
- @sb.utils.data_pipeline.takes("s2_wav")
459
- @sb.utils.data_pipeline.provides("s2_sig")
460
- def audio_pipeline_s2(s2_wav):
461
- s2_sig = sb.dataio.dataio.read_audio(s2_wav)
462
- return s2_sig
463
-
464
- # --- 如果说话人 >= 3,定义第 3 路 ---
465
- if hparams["num_spks"] >= 3:
466
- @sb.utils.data_pipeline.takes("s3_wav")
467
- @sb.utils.data_pipeline.provides("s3_sig")
468
- def audio_pipeline_s3(s3_wav):
469
- return sb.dataio.dataio.read_audio(s3_wav)
470
-
471
- # --- 如果说话人 == 4,定义第 4 路 ---
472
- if hparams["num_spks"] == 4:
473
- @sb.utils.data_pipeline.takes("s4_wav")
474
- @sb.utils.data_pipeline.provides("s4_sig")
475
- def audio_pipeline_s4(s4_wav):
476
- return sb.dataio.dataio.read_audio(s4_wav)
477
-
478
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
479
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s1)
480
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s2)
481
- if hparams["num_spks"] == 3:
482
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
483
- sb.dataio.dataset.set_output_keys(
484
- datasets, ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig"]
485
- )
486
- elif hparams["num_spks"] == 4 :
487
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
488
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s4)
489
- sb.dataio.dataset.set_output_keys(
490
- datasets,
491
- ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig", "s4_sig"],
492
- )
493
- else:
494
- sb.dataio.dataset.set_output_keys(
495
- datasets, ["id", "mix_sig", "s1_sig", "s2_sig"]
496
- )
497
-
498
- return train_data, valid_data, test_data
499
-
500
-
501
- if __name__ == "__main__":
502
- # Load hyperparameters file with command-line overrides
503
- hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
504
- with open(hparams_file, encoding="utf-8") as fin:
505
- hparams = load_hyperpyyaml(fin, overrides)
506
-
507
- # Initialize ddp (useful only for multi-GPU DDP training)
508
- sb.utils.distributed.ddp_init_group(run_opts)
509
-
510
- # Logger info
511
- logger = get_logger(__name__)
512
-
513
- # Create experiment directory
514
- sb.create_experiment_directory(
515
- experiment_directory=hparams["output_folder"],
516
- hyperparams_to_save=hparams_file,
517
- overrides=overrides,
518
- )
519
-
520
- # Update precision to bf16 if the device is CPU and precision is fp16
521
- if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
522
- hparams["precision"] = "bf16"
523
-
524
- # Check if wsj0_tr is set with dynamic mixing
525
- if hparams["dynamic_mixing"] and not os.path.exists(
526
- hparams["base_folder_dm"]
527
- ):
528
- raise ValueError(
529
- "Please, specify a valid base_folder_dm folder when using dynamic mixing"
530
- )
531
-
532
- # Data preparation
533
- from prepare_data import prepare_wsjmix # noqa
534
-
535
- # run_on_main(
536
- # prepare_wsjmix,
537
- # kwargs={
538
- # "datapath": hparams["data_folder"],
539
- # "savepath": hparams["save_folder"],
540
- # "n_spks": hparams["num_spks"],
541
- # "skip_prep": hparams["skip_prep"],
542
- # "fs": hparams["sample_rate"],
543
- # },
544
- # )
545
-
546
- # Create dataset objects
547
- if hparams["dynamic_mixing"]:
548
- from dynamic_mixing import dynamic_mix_data_prep
549
-
550
- # if the base_folder for dm is not processed, preprocess them
551
- if "processed" not in hparams["base_folder_dm"]:
552
- # if the processed folder already exists we just use it otherwise we do the preprocessing
553
- if not os.path.exists(
554
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
555
- ):
556
- from preprocess_dynamic_mixing import resample_folder
557
-
558
- print("Resampling the base folder")
559
- run_on_main(
560
- resample_folder,
561
- kwargs={
562
- "input_folder": hparams["base_folder_dm"],
563
- "output_folder": os.path.normpath(
564
- hparams["base_folder_dm"]
565
- )
566
- + "_processed",
567
- "fs": hparams["sample_rate"],
568
- "regex": "**/*.wav",
569
- },
570
- )
571
- # adjust the base_folder_dm path
572
- hparams["base_folder_dm"] = (
573
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
574
- )
575
- else:
576
- print(
577
- "Using the existing processed folder on the same directory as base_folder_dm"
578
- )
579
- hparams["base_folder_dm"] = (
580
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
581
- )
582
-
583
- # Collecting the hparams for dynamic batching
584
- dm_hparams = {
585
- "train_data": hparams["train_data"],
586
- "data_folder": hparams["data_folder"],
587
- "base_folder_dm": hparams["base_folder_dm"],
588
- "sample_rate": hparams["sample_rate"],
589
- "num_spks": hparams["num_spks"],
590
- "training_signal_len": hparams["training_signal_len"],
591
- "dataloader_opts": hparams["dataloader_opts"],
592
- }
593
- train_data = dynamic_mix_data_prep(dm_hparams)
594
- _, valid_data, test_data = dataio_prep(hparams)
595
- else:
596
- train_data, valid_data, test_data = dataio_prep(hparams)
597
-
598
- # Load pretrained model if pretrained_separator is present in the yaml
599
- if "pretrained_separator" in hparams:
600
- run_on_main(hparams["pretrained_separator"].collect_files)
601
- hparams["pretrained_separator"].load_collected()
602
-
603
- # Brain class initialization
604
- separator = Separation(
605
- modules=hparams["modules"],
606
- opt_class=hparams["optimizer"],
607
- hparams=hparams,
608
- run_opts=run_opts,
609
- checkpointer=hparams["checkpointer"],
610
- )
611
-
612
- # re-initialize the parameters if we don't use a pretrained model
613
- if "pretrained_separator" not in hparams:
614
- for module in separator.modules.values():
615
- separator.reset_layer_recursively(module)
616
-
617
- # # Training
618
- # separator.fit(
619
- # separator.hparams.epoch_counter,
620
- # train_data,
621
- # valid_data,
622
- # train_loader_kwargs=hparams["dataloader_opts"],
623
- # valid_loader_kwargs=hparams["dataloader_opts"],
624
- # )
625
-
626
- # Eval
627
- separator.evaluate(test_data, min_key="si-snr")
628
- separator.save_results(test_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/test_results.csv DELETED
@@ -1 +0,0 @@
1
- snt_id,sdr,sdr_i,si-snr,si-snr_i
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/train.py DELETED
@@ -1,628 +0,0 @@
1
- #!/usr/bin/env/python3
2
- """Recipe for training a neural speech separation system on the wsjmix
3
- dataset. The system employs an encoder, a decoder, and a masking network.
4
-
5
- To run this recipe, do the following:
6
- > python train.py hparams/sepformer.yaml
7
- > python train.py hparams/dualpath_rnn.yaml
8
- > python train.py hparams/convtasnet.yaml
9
-
10
- The experiment file is flexible enough to support different neural
11
- networks. By properly changing the parameter files, you can try
12
- different architectures. The script supports both wsj2mix and
13
- wsj3mix.
14
-
15
- # 4-mix 主要根据 num_spks 修改 train.py 和 config
16
- Authors
17
- * Cem Subakan 2020
18
- * Mirco Ravanelli 2020
19
- * Samuele Cornell 2020
20
- * Mirko Bronzi 2020
21
- * Jianyuan Zhong 2020
22
- """
23
-
24
- import csv
25
- import os
26
- import sys
27
-
28
- import numpy as np
29
- import torch
30
- import torch.nn.functional as F
31
- import torchaudio
32
- from hyperpyyaml import load_hyperpyyaml
33
- from tqdm import tqdm
34
-
35
- import speechbrain as sb
36
- import speechbrain.nnet.schedulers as schedulers
37
- from speechbrain.utils.distributed import run_on_main
38
- from speechbrain.utils.logger import get_logger
39
-
40
-
41
- # Define training procedure
42
- class Separation(sb.Brain):
43
- def compute_forward(self, mix, targets, stage, noise=None):
44
- """Forward computations from the mixture to the separated signals."""
45
-
46
- # Unpack lists and put tensors in the right device
47
- mix, mix_lens = mix
48
- mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
49
-
50
- # Convert targets to tensor
51
- targets = torch.cat(
52
- [targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
53
- dim=-1,
54
- ).to(self.device)
55
-
56
- # Add speech distortions
57
- if stage == sb.Stage.TRAIN:
58
- with torch.no_grad():
59
- if self.hparams.use_speedperturb:
60
- mix, targets = self.add_speed_perturb(targets, mix_lens)
61
-
62
- mix = targets.sum(-1)
63
-
64
- if self.hparams.use_wavedrop:
65
- mix = self.hparams.drop_chunk(mix, mix_lens)
66
- mix = self.hparams.drop_freq(mix)
67
-
68
- if self.hparams.limit_training_signal_len:
69
- mix, targets = self.cut_signals(mix, targets)
70
-
71
- # Separation
72
- mix_w = self.hparams.Encoder(mix)
73
- est_mask = self.hparams.MaskNet(mix_w)
74
- mix_w = torch.stack([mix_w] * self.hparams.num_spks)
75
- sep_h = mix_w * est_mask
76
-
77
- # Decoding
78
- est_source = torch.cat(
79
- [
80
- self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
81
- for i in range(self.hparams.num_spks)
82
- ],
83
- dim=-1,
84
- )
85
-
86
- # T changed after conv1d in encoder, fix it here
87
- T_origin = mix.size(1)
88
- T_est = est_source.size(1)
89
- if T_origin > T_est:
90
- est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
91
- else:
92
- est_source = est_source[:, :T_origin, :]
93
-
94
- return est_source, targets
95
-
96
- def compute_objectives(self, predictions, targets):
97
- """Computes the sinr loss"""
98
- return self.hparams.loss(targets, predictions)
99
-
100
- def fit_batch(self, batch):
101
- """Trains one batch"""
102
-
103
- # Unpacking batch list
104
- mixture = batch.mix_sig
105
- targets = [batch.s1_sig, batch.s2_sig]
106
-
107
- if self.hparams.num_spks == 3:
108
- targets.append(batch.s3_sig)
109
-
110
- if self.hparams.num_spks == 4:
111
- targets.append(batch.s3_sig)
112
- targets.append(batch.s4_sig)
113
-
114
- with self.training_ctx:
115
- predictions, targets = self.compute_forward(
116
- mixture, targets, sb.Stage.TRAIN
117
- )
118
- loss = self.compute_objectives(predictions, targets)
119
-
120
- # hard threshold the easy dataitems
121
- if self.hparams.threshold_byloss:
122
- th = self.hparams.threshold
123
- loss = loss[loss > th]
124
- if loss.nelement() > 0:
125
- loss = loss.mean()
126
- else:
127
- loss = loss.mean()
128
-
129
- if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
130
- self.scaler.scale(loss).backward()
131
- if self.hparams.clip_grad_norm >= 0:
132
- self.scaler.unscale_(self.optimizer)
133
- torch.nn.utils.clip_grad_norm_(
134
- self.modules.parameters(),
135
- self.hparams.clip_grad_norm,
136
- )
137
- self.scaler.step(self.optimizer)
138
- self.scaler.update()
139
- else:
140
- self.nonfinite_count += 1
141
- logger.info(
142
- "infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
143
- self.nonfinite_count
144
- )
145
- )
146
- loss.data = torch.tensor(0.0).to(self.device)
147
- self.optimizer.zero_grad()
148
-
149
- return loss.detach().cpu()
150
-
151
- def evaluate_batch(self, batch, stage):
152
- """Computations needed for validation/test batches"""
153
- snt_id = batch.id
154
- mixture = batch.mix_sig
155
- targets = [batch.s1_sig, batch.s2_sig]
156
- if self.hparams.num_spks == 3:
157
- targets.append(batch.s3_sig)
158
-
159
- if self.hparams.num_spks == 4:
160
- targets.append(batch.s3_sig)
161
- targets.append(batch.s4_sig)
162
-
163
- with torch.no_grad():
164
- predictions, targets = self.compute_forward(mixture, targets, stage)
165
- loss = self.compute_objectives(predictions, targets)
166
-
167
- # Manage audio file saving
168
- if stage == sb.Stage.TEST and self.hparams.save_audio:
169
- if hasattr(self.hparams, "n_audio_to_save"):
170
- if self.hparams.n_audio_to_save > 0:
171
- self.save_audio(snt_id[0], mixture, targets, predictions)
172
- self.hparams.n_audio_to_save += -1
173
- else:
174
- self.save_audio(snt_id[0], mixture, targets, predictions)
175
-
176
- return loss.mean().detach()
177
-
178
- def on_stage_end(self, stage, stage_loss, epoch):
179
- """Gets called at the end of a epoch."""
180
- # Compute/store important stats
181
- stage_stats = {"si-snr": stage_loss}
182
- if stage == sb.Stage.TRAIN:
183
- self.train_stats = stage_stats
184
-
185
- # Perform end-of-iteration things, like annealing, logging, etc.
186
- if stage == sb.Stage.VALID:
187
- # Learning rate annealing
188
- if isinstance(
189
- self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
190
- ):
191
- current_lr, next_lr = self.hparams.lr_scheduler(
192
- [self.optimizer], epoch, stage_loss
193
- )
194
- schedulers.update_learning_rate(self.optimizer, next_lr)
195
- else:
196
- # if we do not use the reducelronplateau, we do not change the lr
197
- current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
198
-
199
- self.hparams.train_logger.log_stats(
200
- stats_meta={"epoch": epoch, "lr": current_lr},
201
- train_stats=self.train_stats,
202
- valid_stats=stage_stats,
203
- )
204
- self.checkpointer.save_and_keep_only(
205
- meta={"si-snr": stage_stats["si-snr"]}, min_keys=["si-snr"]
206
- )
207
- elif stage == sb.Stage.TEST:
208
- self.hparams.train_logger.log_stats(
209
- stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
210
- test_stats=stage_stats,
211
- )
212
-
213
- def add_speed_perturb(self, targets, targ_lens):
214
- """Adds speed perturbation and random_shift to the input signals"""
215
-
216
- min_len = -1
217
- recombine = False
218
-
219
- if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
220
- # Performing speed change (independently on each source)
221
- new_targets = []
222
- recombine = True
223
-
224
- for i in range(targets.shape[-1]):
225
- new_target = self.hparams.speed_perturb(targets[:, :, i])
226
- new_targets.append(new_target)
227
- if i == 0:
228
- min_len = new_target.shape[-1]
229
- else:
230
- if new_target.shape[-1] < min_len:
231
- min_len = new_target.shape[-1]
232
-
233
- if self.hparams.use_rand_shift:
234
- # Performing random_shift (independently on each source)
235
- recombine = True
236
- for i in range(targets.shape[-1]):
237
- rand_shift = torch.randint(
238
- self.hparams.min_shift, self.hparams.max_shift, (1,)
239
- )
240
- new_targets[i] = new_targets[i].to(self.device)
241
- new_targets[i] = torch.roll(
242
- new_targets[i], shifts=(rand_shift[0],), dims=1
243
- )
244
-
245
- # Re-combination
246
- if recombine:
247
- if self.hparams.use_speedperturb:
248
- targets = torch.zeros(
249
- targets.shape[0],
250
- min_len,
251
- targets.shape[-1],
252
- device=targets.device,
253
- dtype=torch.float,
254
- )
255
- for i, new_target in enumerate(new_targets):
256
- targets[:, :, i] = new_targets[i][:, 0:min_len]
257
-
258
- mix = targets.sum(-1)
259
- return mix, targets
260
-
261
- def cut_signals(self, mixture, targets):
262
- """This function selects a random segment of a given length within the mixture.
263
- The corresponding targets are selected accordingly"""
264
- randstart = torch.randint(
265
- 0,
266
- 1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
267
- (1,),
268
- ).item()
269
- targets = targets[
270
- :, randstart : randstart + self.hparams.training_signal_len, :
271
- ]
272
- mixture = mixture[
273
- :, randstart : randstart + self.hparams.training_signal_len
274
- ]
275
- return mixture, targets
276
-
277
- def reset_layer_recursively(self, layer):
278
- """Reinitializes the parameters of the neural networks"""
279
- if hasattr(layer, "reset_parameters"):
280
- layer.reset_parameters()
281
- for child_layer in layer.modules():
282
- if layer != child_layer:
283
- self.reset_layer_recursively(child_layer)
284
-
285
- def save_results(self, test_data):
286
- """This script computes the SDR and SI-SNR metrics and saves
287
- them into a csv file"""
288
-
289
- # This package is required for SDR computation
290
- from mir_eval.separation import bss_eval_sources
291
-
292
- # Create folders where to store audio
293
- save_file = os.path.join(self.hparams.output_folder, "test_results.csv")
294
-
295
- # Variable init
296
- all_sdrs = []
297
- all_sdrs_i = []
298
- all_sisnrs = []
299
- all_sisnrs_i = []
300
- csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
301
-
302
- test_loader = sb.dataio.dataloader.make_dataloader(
303
- test_data, **self.hparams.dataloader_opts
304
- )
305
-
306
- with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
307
- writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
308
- writer.writeheader()
309
-
310
- # Loop over all test sentence
311
- with tqdm(test_loader, dynamic_ncols=True) as t:
312
- for i, batch in enumerate(t):
313
- # Apply Separation
314
- mixture, mix_len = batch.mix_sig
315
- snt_id = batch.id
316
- targets = [batch.s1_sig, batch.s2_sig]
317
- if self.hparams.num_spks == 3:
318
- targets.append(batch.s3_sig)
319
-
320
- if self.hparams.num_spks == 4:
321
- targets.append(batch.s3_sig)
322
- targets.append(batch.s4_sig)
323
-
324
- with torch.no_grad():
325
- predictions, targets = self.compute_forward(
326
- batch.mix_sig, targets, sb.Stage.TEST
327
- )
328
-
329
- # Compute SI-SNR
330
- sisnr = self.compute_objectives(predictions, targets)
331
-
332
- # Compute SI-SNR improvement
333
- mixture_signal = torch.stack(
334
- [mixture] * self.hparams.num_spks, dim=-1
335
- )
336
- mixture_signal = mixture_signal.to(targets.device)
337
- sisnr_baseline = self.compute_objectives(
338
- mixture_signal, targets
339
- )
340
- sisnr_i = sisnr - sisnr_baseline
341
-
342
- # Compute SDR
343
- sdr, _, _, _ = bss_eval_sources(
344
- targets[0].t().cpu().numpy(),
345
- predictions[0].t().detach().cpu().numpy(),
346
- )
347
-
348
- sdr_baseline, _, _, _ = bss_eval_sources(
349
- targets[0].t().cpu().numpy(),
350
- mixture_signal[0].t().detach().cpu().numpy(),
351
- )
352
-
353
- sdr_i = sdr.mean() - sdr_baseline.mean()
354
-
355
- # Saving on a csv file
356
- row = {
357
- "snt_id": snt_id[0],
358
- "sdr": sdr.mean(),
359
- "sdr_i": sdr_i,
360
- "si-snr": -sisnr.item(),
361
- "si-snr_i": -sisnr_i.item(),
362
- }
363
- writer.writerow(row)
364
-
365
- # Metric Accumulation
366
- all_sdrs.append(sdr.mean())
367
- all_sdrs_i.append(sdr_i.mean())
368
- all_sisnrs.append(-sisnr.item())
369
- all_sisnrs_i.append(-sisnr_i.item())
370
-
371
- row = {
372
- "snt_id": "avg",
373
- "sdr": np.array(all_sdrs).mean(),
374
- "sdr_i": np.array(all_sdrs_i).mean(),
375
- "si-snr": np.array(all_sisnrs).mean(),
376
- "si-snr_i": np.array(all_sisnrs_i).mean(),
377
- }
378
- writer.writerow(row)
379
-
380
- logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
381
- logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
382
- logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
383
- logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
384
-
385
- def save_audio(self, snt_id, mixture, targets, predictions):
386
- "saves the test audio (mixture, targets, and estimated sources) on disk"
387
-
388
- # Create output folder
389
- save_path = os.path.join(self.hparams.save_folder, "audio_results")
390
- if not os.path.exists(save_path):
391
- os.mkdir(save_path)
392
-
393
- for ns in range(self.hparams.num_spks):
394
- # Estimated source
395
- signal = predictions[0, :, ns]
396
- signal = signal / signal.abs().max()
397
- save_file = os.path.join(
398
- save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
399
- )
400
- torchaudio.save(
401
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
402
- )
403
-
404
- # Original source
405
- signal = targets[0, :, ns]
406
- signal = signal / signal.abs().max()
407
- save_file = os.path.join(
408
- save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
409
- )
410
- torchaudio.save(
411
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
412
- )
413
-
414
- # Mixture
415
- signal = mixture[0][0, :]
416
- signal = signal / signal.abs().max()
417
- save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
418
- torchaudio.save(
419
- save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
420
- )
421
-
422
-
423
- def dataio_prep(hparams):
424
- """Creates data processing pipeline"""
425
-
426
- # 1. Define datasets
427
- train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
428
- csv_path=hparams["train_data"],
429
- replacements={"data_root": hparams["data_folder"]},
430
- )
431
-
432
- valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
433
- csv_path=hparams["valid_data"],
434
- replacements={"data_root": hparams["data_folder"]},
435
- )
436
-
437
- test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
438
- csv_path=hparams["test_data"],
439
- replacements={"data_root": hparams["data_folder"]},
440
- )
441
-
442
- datasets = [train_data, valid_data, test_data]
443
-
444
- # 2. Provide audio pipelines
445
-
446
- @sb.utils.data_pipeline.takes("mix_wav")
447
- @sb.utils.data_pipeline.provides("mix_sig")
448
- def audio_pipeline_mix(mix_wav):
449
- mix_sig = sb.dataio.dataio.read_audio(mix_wav)
450
- return mix_sig
451
-
452
- @sb.utils.data_pipeline.takes("s1_wav")
453
- @sb.utils.data_pipeline.provides("s1_sig")
454
- def audio_pipeline_s1(s1_wav):
455
- s1_sig = sb.dataio.dataio.read_audio(s1_wav)
456
- return s1_sig
457
-
458
- @sb.utils.data_pipeline.takes("s2_wav")
459
- @sb.utils.data_pipeline.provides("s2_sig")
460
- def audio_pipeline_s2(s2_wav):
461
- s2_sig = sb.dataio.dataio.read_audio(s2_wav)
462
- return s2_sig
463
-
464
- # --- 如果说话人 >= 3,定义第 3 路 ---
465
- if hparams["num_spks"] >= 3:
466
- @sb.utils.data_pipeline.takes("s3_wav")
467
- @sb.utils.data_pipeline.provides("s3_sig")
468
- def audio_pipeline_s3(s3_wav):
469
- return sb.dataio.dataio.read_audio(s3_wav)
470
-
471
- # --- 如果说话人 == 4,定义第 4 路 ---
472
- if hparams["num_spks"] == 4:
473
- @sb.utils.data_pipeline.takes("s4_wav")
474
- @sb.utils.data_pipeline.provides("s4_sig")
475
- def audio_pipeline_s4(s4_wav):
476
- return sb.dataio.dataio.read_audio(s4_wav)
477
-
478
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
479
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s1)
480
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s2)
481
- if hparams["num_spks"] == 3:
482
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
483
- sb.dataio.dataset.set_output_keys(
484
- datasets, ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig"]
485
- )
486
- elif hparams["num_spks"] == 4 :
487
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
488
- sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s4)
489
- sb.dataio.dataset.set_output_keys(
490
- datasets,
491
- ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig", "s4_sig"],
492
- )
493
- else:
494
- sb.dataio.dataset.set_output_keys(
495
- datasets, ["id", "mix_sig", "s1_sig", "s2_sig"]
496
- )
497
-
498
- return train_data, valid_data, test_data
499
-
500
-
501
- if __name__ == "__main__":
502
- # Load hyperparameters file with command-line overrides
503
- hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
504
- with open(hparams_file, encoding="utf-8") as fin:
505
- hparams = load_hyperpyyaml(fin, overrides)
506
-
507
- # Initialize ddp (useful only for multi-GPU DDP training)
508
- sb.utils.distributed.ddp_init_group(run_opts)
509
-
510
- # Logger info
511
- logger = get_logger(__name__)
512
-
513
- # Create experiment directory
514
- sb.create_experiment_directory(
515
- experiment_directory=hparams["output_folder"],
516
- hyperparams_to_save=hparams_file,
517
- overrides=overrides,
518
- )
519
-
520
- # Update precision to bf16 if the device is CPU and precision is fp16
521
- if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
522
- hparams["precision"] = "bf16"
523
-
524
- # Check if wsj0_tr is set with dynamic mixing
525
- if hparams["dynamic_mixing"] and not os.path.exists(
526
- hparams["base_folder_dm"]
527
- ):
528
- raise ValueError(
529
- "Please, specify a valid base_folder_dm folder when using dynamic mixing"
530
- )
531
-
532
- # Data preparation
533
- from prepare_data import prepare_wsjmix # noqa
534
-
535
- # run_on_main(
536
- # prepare_wsjmix,
537
- # kwargs={
538
- # "datapath": hparams["data_folder"],
539
- # "savepath": hparams["save_folder"],
540
- # "n_spks": hparams["num_spks"],
541
- # "skip_prep": hparams["skip_prep"],
542
- # "fs": hparams["sample_rate"],
543
- # },
544
- # )
545
-
546
- # Create dataset objects
547
- if hparams["dynamic_mixing"]:
548
- from dynamic_mixing import dynamic_mix_data_prep
549
-
550
- # if the base_folder for dm is not processed, preprocess them
551
- if "processed" not in hparams["base_folder_dm"]:
552
- # if the processed folder already exists we just use it otherwise we do the preprocessing
553
- if not os.path.exists(
554
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
555
- ):
556
- from preprocess_dynamic_mixing import resample_folder
557
-
558
- print("Resampling the base folder")
559
- run_on_main(
560
- resample_folder,
561
- kwargs={
562
- "input_folder": hparams["base_folder_dm"],
563
- "output_folder": os.path.normpath(
564
- hparams["base_folder_dm"]
565
- )
566
- + "_processed",
567
- "fs": hparams["sample_rate"],
568
- "regex": "**/*.wav",
569
- },
570
- )
571
- # adjust the base_folder_dm path
572
- hparams["base_folder_dm"] = (
573
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
574
- )
575
- else:
576
- print(
577
- "Using the existing processed folder on the same directory as base_folder_dm"
578
- )
579
- hparams["base_folder_dm"] = (
580
- os.path.normpath(hparams["base_folder_dm"]) + "_processed"
581
- )
582
-
583
- # Collecting the hparams for dynamic batching
584
- dm_hparams = {
585
- "train_data": hparams["train_data"],
586
- "data_folder": hparams["data_folder"],
587
- "base_folder_dm": hparams["base_folder_dm"],
588
- "sample_rate": hparams["sample_rate"],
589
- "num_spks": hparams["num_spks"],
590
- "training_signal_len": hparams["training_signal_len"],
591
- "dataloader_opts": hparams["dataloader_opts"],
592
- }
593
- train_data = dynamic_mix_data_prep(dm_hparams)
594
- _, valid_data, test_data = dataio_prep(hparams)
595
- else:
596
- train_data, valid_data, test_data = dataio_prep(hparams)
597
-
598
- # Load pretrained model if pretrained_separator is present in the yaml
599
- if "pretrained_separator" in hparams:
600
- run_on_main(hparams["pretrained_separator"].collect_files)
601
- hparams["pretrained_separator"].load_collected()
602
-
603
- # Brain class initialization
604
- separator = Separation(
605
- modules=hparams["modules"],
606
- opt_class=hparams["optimizer"],
607
- hparams=hparams,
608
- run_opts=run_opts,
609
- checkpointer=hparams["checkpointer"],
610
- )
611
-
612
- # re-initialize the parameters if we don't use a pretrained model
613
- if "pretrained_separator" not in hparams:
614
- for module in separator.modules.values():
615
- separator.reset_layer_recursively(module)
616
-
617
- # Training
618
- separator.fit(
619
- separator.hparams.epoch_counter,
620
- train_data,
621
- valid_data,
622
- train_loader_kwargs=hparams["dataloader_opts"],
623
- valid_loader_kwargs=hparams["dataloader_opts"],
624
- )
625
-
626
- # Eval
627
- separator.evaluate(test_data, min_key="si-snr")
628
- separator.save_results(test_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Conv-Tasnet/results/convtasnet_4-mix/1234/train_log.txt DELETED
@@ -1,242 +0,0 @@
1
- epoch: 1, lr: 1.50e-04 - train si-snr: 2.76 - valid si-snr: 12.09
2
- epoch: 2, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 11.84
3
- epoch: 3, lr: 1.50e-04 - train si-snr: 1.96 - valid si-snr: 11.70
4
- epoch: 4, lr: 1.50e-04 - train si-snr: 1.70 - valid si-snr: 11.63
5
- epoch: 5, lr: 1.50e-04 - train si-snr: 1.58 - valid si-snr: 11.57
6
- epoch: 6, lr: 1.50e-04 - train si-snr: 1.45 - valid si-snr: 11.55
7
- epoch: 7, lr: 1.50e-04 - train si-snr: 1.33 - valid si-snr: 11.45
8
- epoch: 8, lr: 1.50e-04 - train si-snr: 1.20 - valid si-snr: 11.33
9
- epoch: 9, lr: 1.50e-04 - train si-snr: 1.10 - valid si-snr: 11.35
10
- epoch: 10, lr: 1.50e-04 - train si-snr: 1.01 - valid si-snr: 11.30
11
- epoch: 11, lr: 1.50e-04 - train si-snr: 9.25e-01 - valid si-snr: 11.33
12
- epoch: 12, lr: 1.50e-04 - train si-snr: 7.83e-01 - valid si-snr: 11.16
13
- epoch: 13, lr: 1.50e-04 - train si-snr: 7.61e-01 - valid si-snr: 11.19
14
- epoch: 14, lr: 1.50e-04 - train si-snr: 6.87e-01 - valid si-snr: 11.13
15
- epoch: 15, lr: 1.50e-04 - train si-snr: 6.31e-01 - valid si-snr: 11.13
16
- epoch: 16, lr: 1.50e-04 - train si-snr: 5.54e-01 - valid si-snr: 11.10
17
- epoch: 17, lr: 1.50e-04 - train si-snr: 4.47e-01 - valid si-snr: 11.02
18
- epoch: 18, lr: 1.50e-04 - train si-snr: 4.65e-01 - valid si-snr: 11.04
19
- epoch: 19, lr: 1.50e-04 - train si-snr: 3.32e-01 - valid si-snr: 11.01
20
- epoch: 20, lr: 1.50e-04 - train si-snr: 3.27e-01 - valid si-snr: 10.95
21
- epoch: 21, lr: 1.50e-04 - train si-snr: 2.78e-01 - valid si-snr: 10.97
22
- epoch: 22, lr: 1.50e-04 - train si-snr: 2.18e-01 - valid si-snr: 10.88
23
- epoch: 23, lr: 1.50e-04 - train si-snr: 1.74e-01 - valid si-snr: 10.87
24
- epoch: 24, lr: 1.50e-04 - train si-snr: 1.03e-01 - valid si-snr: 10.95
25
- epoch: 25, lr: 1.50e-04 - train si-snr: 6.04e-02 - valid si-snr: 10.84
26
- epoch: 26, lr: 1.50e-04 - train si-snr: -2.94e-02 - valid si-snr: 10.79
27
- epoch: 27, lr: 1.50e-04 - train si-snr: -5.32e-02 - valid si-snr: 10.77
28
- epoch: 28, lr: 1.50e-04 - train si-snr: -5.68e-02 - valid si-snr: 10.74
29
- epoch: 29, lr: 1.50e-04 - train si-snr: -1.04e-01 - valid si-snr: 10.79
30
- epoch: 30, lr: 1.50e-04 - train si-snr: -1.57e-01 - valid si-snr: 10.73
31
- epoch: 31, lr: 1.50e-04 - train si-snr: -1.64e-01 - valid si-snr: 10.67
32
- epoch: 32, lr: 1.50e-04 - train si-snr: -2.11e-01 - valid si-snr: 10.71
33
- epoch: 33, lr: 1.50e-04 - train si-snr: -2.48e-01 - valid si-snr: 10.73
34
- epoch: 34, lr: 1.50e-04 - train si-snr: -2.79e-01 - valid si-snr: 10.69
35
- epoch: 35, lr: 1.50e-04 - train si-snr: -3.55e-01 - valid si-snr: 10.69
36
- epoch: 36, lr: 1.50e-04 - train si-snr: -3.32e-01 - valid si-snr: 10.64
37
- epoch: 37, lr: 1.50e-04 - train si-snr: -3.97e-01 - valid si-snr: 10.63
38
- epoch: 38, lr: 1.50e-04 - train si-snr: -4.11e-01 - valid si-snr: 10.71
39
- epoch: 39, lr: 1.50e-04 - train si-snr: -4.18e-01 - valid si-snr: 10.56
40
- epoch: 40, lr: 1.50e-04 - train si-snr: -4.74e-01 - valid si-snr: 10.55
41
- epoch: 41, lr: 1.50e-04 - train si-snr: -4.71e-01 - valid si-snr: 10.52
42
- epoch: 1, lr: 1.50e-04 - train si-snr: 6.31 - valid si-snr: 23.11
43
- epoch: 2, lr: 1.50e-04 - train si-snr: 4.85 - valid si-snr: 23.05
44
- epoch: 3, lr: 1.50e-04 - train si-snr: 4.79 - valid si-snr: 22.98
45
- epoch: 4, lr: 1.50e-04 - train si-snr: 4.56 - valid si-snr: 22.79
46
- epoch: 5, lr: 1.50e-04 - train si-snr: 4.28 - valid si-snr: 23.05
47
- epoch: 6, lr: 1.50e-04 - train si-snr: 4.27 - valid si-snr: 22.88
48
- epoch: 7, lr: 1.50e-04 - train si-snr: 4.11 - valid si-snr: 22.86
49
- epoch: 8, lr: 1.50e-04 - train si-snr: 4.11 - valid si-snr: 22.80
50
- epoch: 9, lr: 1.50e-04 - train si-snr: 3.96 - valid si-snr: 22.80
51
- epoch: 10, lr: 1.50e-04 - train si-snr: 3.91 - valid si-snr: 22.75
52
- epoch: 11, lr: 1.50e-04 - train si-snr: 3.76 - valid si-snr: 22.72
53
- epoch: 12, lr: 1.50e-04 - train si-snr: 3.82 - valid si-snr: 22.69
54
- epoch: 13, lr: 1.50e-04 - train si-snr: 3.71 - valid si-snr: 22.86
55
- epoch: 14, lr: 1.50e-04 - train si-snr: 3.64 - valid si-snr: 22.71
56
- epoch: 15, lr: 1.50e-04 - train si-snr: 3.59 - valid si-snr: 22.89
57
- epoch: 16, lr: 1.50e-04 - train si-snr: 3.39 - valid si-snr: 22.79
58
- epoch: 17, lr: 1.50e-04 - train si-snr: 3.30 - valid si-snr: 22.69
59
- epoch: 18, lr: 1.50e-04 - train si-snr: 3.29 - valid si-snr: 22.82
60
- epoch: 19, lr: 1.50e-04 - train si-snr: 3.32 - valid si-snr: 22.75
61
- epoch: 20, lr: 1.50e-04 - train si-snr: 3.14 - valid si-snr: 22.49
62
- epoch: 21, lr: 1.50e-04 - train si-snr: 3.11 - valid si-snr: 22.83
63
- epoch: 22, lr: 1.50e-04 - train si-snr: 3.12 - valid si-snr: 22.69
64
- epoch: 23, lr: 1.50e-04 - train si-snr: 2.93 - valid si-snr: 22.66
65
- epoch: 24, lr: 1.50e-04 - train si-snr: 2.96 - valid si-snr: 22.72
66
- epoch: 25, lr: 1.50e-04 - train si-snr: 2.96 - valid si-snr: 22.83
67
- epoch: 26, lr: 1.50e-04 - train si-snr: 2.88 - valid si-snr: 22.61
68
- epoch: 27, lr: 1.50e-04 - train si-snr: 2.86 - valid si-snr: 22.83
69
- epoch: 28, lr: 1.50e-04 - train si-snr: 2.80 - valid si-snr: 22.67
70
- epoch: 29, lr: 1.50e-04 - train si-snr: 2.73 - valid si-snr: 22.67
71
- epoch: 30, lr: 1.50e-04 - train si-snr: 2.65 - valid si-snr: 22.62
72
- epoch: 31, lr: 1.50e-04 - train si-snr: 2.62 - valid si-snr: 22.63
73
- epoch: 32, lr: 1.50e-04 - train si-snr: 2.61 - valid si-snr: 22.61
74
- epoch: 33, lr: 1.50e-04 - train si-snr: 2.44 - valid si-snr: 22.55
75
- epoch: 34, lr: 1.50e-04 - train si-snr: 2.50 - valid si-snr: 22.55
76
- epoch: 35, lr: 1.50e-04 - train si-snr: 2.47 - valid si-snr: 22.60
77
- epoch: 36, lr: 1.50e-04 - train si-snr: 2.44 - valid si-snr: 22.66
78
- epoch: 37, lr: 1.50e-04 - train si-snr: 2.24 - valid si-snr: 22.64
79
- epoch: 38, lr: 1.50e-04 - train si-snr: 2.28 - valid si-snr: 22.66
80
- epoch: 39, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 22.62
81
- epoch: 40, lr: 1.50e-04 - train si-snr: 2.19 - valid si-snr: 22.48
82
- epoch: 41, lr: 1.50e-04 - train si-snr: 2.26 - valid si-snr: 22.66
83
- epoch: 42, lr: 1.50e-04 - train si-snr: 2.09 - valid si-snr: 22.57
84
- epoch: 43, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 22.47
85
- epoch: 44, lr: 1.50e-04 - train si-snr: 2.00 - valid si-snr: 22.63
86
- epoch: 45, lr: 1.50e-04 - train si-snr: 2.13 - valid si-snr: 22.52
87
- epoch: 46, lr: 1.50e-04 - train si-snr: 2.00 - valid si-snr: 22.57
88
- epoch: 47, lr: 1.50e-04 - train si-snr: 1.90 - valid si-snr: 22.50
89
- epoch: 48, lr: 1.50e-04 - train si-snr: 1.89 - valid si-snr: 22.49
90
- epoch: 49, lr: 1.50e-04 - train si-snr: 1.94 - valid si-snr: 22.54
91
- epoch: 50, lr: 1.50e-04 - train si-snr: 1.89 - valid si-snr: 22.50
92
- epoch: 51, lr: 1.50e-04 - train si-snr: 1.85 - valid si-snr: 22.55
93
- epoch: 52, lr: 1.50e-04 - train si-snr: 1.66 - valid si-snr: 22.51
94
- epoch: 53, lr: 1.50e-04 - train si-snr: 1.65 - valid si-snr: 22.52
95
- epoch: 54, lr: 1.50e-04 - train si-snr: 1.77 - valid si-snr: 22.45
96
- epoch: 55, lr: 1.50e-04 - train si-snr: 1.62 - valid si-snr: 22.45
97
- epoch: 56, lr: 1.50e-04 - train si-snr: 1.52 - valid si-snr: 22.42
98
- epoch: 57, lr: 1.50e-04 - train si-snr: 1.53 - valid si-snr: 22.39
99
- epoch: 58, lr: 1.50e-04 - train si-snr: 1.52 - valid si-snr: 22.40
100
- epoch: 59, lr: 1.50e-04 - train si-snr: 1.55 - valid si-snr: 22.43
101
- epoch: 60, lr: 1.50e-04 - train si-snr: 1.64 - valid si-snr: 22.43
102
- epoch: 61, lr: 1.50e-04 - train si-snr: 1.42 - valid si-snr: 22.38
103
- epoch: 62, lr: 1.50e-04 - train si-snr: 1.50 - valid si-snr: 22.29
104
- epoch: 63, lr: 1.50e-04 - train si-snr: 1.34 - valid si-snr: 22.51
105
- epoch: 64, lr: 1.50e-04 - train si-snr: 1.25 - valid si-snr: 22.55
106
- epoch: 65, lr: 1.50e-04 - train si-snr: 1.43 - valid si-snr: 22.35
107
- epoch: 66, lr: 1.50e-04 - train si-snr: 1.33 - valid si-snr: 22.54
108
- epoch: 67, lr: 1.50e-04 - train si-snr: 1.35 - valid si-snr: 22.44
109
- epoch: 68, lr: 1.50e-04 - train si-snr: 1.35 - valid si-snr: 22.33
110
- epoch: 69, lr: 1.50e-04 - train si-snr: 1.13 - valid si-snr: 22.38
111
- epoch: 70, lr: 1.50e-04 - train si-snr: 1.18 - valid si-snr: 22.37
112
- epoch: 71, lr: 1.50e-04 - train si-snr: 1.04 - valid si-snr: 22.35
113
- epoch: 72, lr: 1.50e-04 - train si-snr: 1.24 - valid si-snr: 22.49
114
- epoch: 73, lr: 1.50e-04 - train si-snr: 1.25 - valid si-snr: 22.35
115
- epoch: 74, lr: 1.50e-04 - train si-snr: 1.07 - valid si-snr: 22.37
116
- epoch: 75, lr: 1.50e-04 - train si-snr: 1.04 - valid si-snr: 22.37
117
- epoch: 76, lr: 1.50e-04 - train si-snr: 1.11 - valid si-snr: 22.48
118
- epoch: 77, lr: 1.50e-04 - train si-snr: 1.03 - valid si-snr: 22.46
119
- epoch: 78, lr: 1.50e-04 - train si-snr: 9.65e-01 - valid si-snr: 22.31
120
- epoch: 79, lr: 1.50e-04 - train si-snr: 1.06 - valid si-snr: 22.34
121
- epoch: 80, lr: 1.50e-04 - train si-snr: 1.03 - valid si-snr: 22.32
122
- epoch: 81, lr: 1.50e-04 - train si-snr: 8.12e-01 - valid si-snr: 22.32
123
- epoch: 82, lr: 1.50e-04 - train si-snr: 8.76e-01 - valid si-snr: 22.33
124
- epoch: 83, lr: 1.50e-04 - train si-snr: 8.91e-01 - valid si-snr: 22.32
125
- epoch: 84, lr: 1.50e-04 - train si-snr: 9.11e-01 - valid si-snr: 22.34
126
- epoch: 85, lr: 1.50e-04 - train si-snr: 7.24e-01 - valid si-snr: 22.39
127
- epoch: 86, lr: 1.50e-04 - train si-snr: 7.65e-01 - valid si-snr: 22.34
128
- epoch: 87, lr: 1.50e-04 - train si-snr: 7.10e-01 - valid si-snr: 22.29
129
- epoch: 88, lr: 1.50e-04 - train si-snr: 7.65e-01 - valid si-snr: 22.42
130
- epoch: 89, lr: 1.50e-04 - train si-snr: 7.09e-01 - valid si-snr: 22.35
131
- epoch: 90, lr: 1.50e-04 - train si-snr: 8.13e-01 - valid si-snr: 22.38
132
- epoch: 91, lr: 7.50e-05 - train si-snr: 5.81e-01 - valid si-snr: 22.24
133
- epoch: 92, lr: 7.50e-05 - train si-snr: 3.71e-01 - valid si-snr: 22.33
134
- epoch: 93, lr: 7.50e-05 - train si-snr: 3.21e-01 - valid si-snr: 22.33
135
- epoch: 94, lr: 7.50e-05 - train si-snr: 3.48e-01 - valid si-snr: 22.29
136
- epoch: 95, lr: 3.75e-05 - train si-snr: 4.08e-01 - valid si-snr: 22.34
137
- epoch: 96, lr: 3.75e-05 - train si-snr: 2.29e-01 - valid si-snr: 22.33
138
- epoch: 97, lr: 3.75e-05 - train si-snr: 2.27e-01 - valid si-snr: 22.29
139
- epoch: 98, lr: 1.87e-05 - train si-snr: 1.28e-01 - valid si-snr: 22.27
140
- epoch: 99, lr: 1.87e-05 - train si-snr: 3.17e-02 - valid si-snr: 22.27
141
- epoch: 100, lr: 1.87e-05 - train si-snr: 6.84e-02 - valid si-snr: 22.24
142
- epoch: 101, lr: 1.87e-05 - train si-snr: 6.90e-02 - valid si-snr: 22.25
143
- epoch: 102, lr: 1.87e-05 - train si-snr: 1.53e-01 - valid si-snr: 22.28
144
- epoch: 103, lr: 1.87e-05 - train si-snr: 4.23e-02 - valid si-snr: 22.28
145
- epoch: 104, lr: 9.37e-06 - train si-snr: 7.48e-02 - valid si-snr: 22.24
146
- epoch: 105, lr: 9.37e-06 - train si-snr: 8.28e-02 - valid si-snr: 22.27
147
- epoch: 106, lr: 9.37e-06 - train si-snr: -1.19e-01 - valid si-snr: 22.26
148
- epoch: 107, lr: 9.37e-06 - train si-snr: 2.27e-02 - valid si-snr: 22.26
149
- epoch: 108, lr: 4.69e-06 - train si-snr: -9.19e-02 - valid si-snr: 22.24
150
- epoch: 109, lr: 4.69e-06 - train si-snr: -1.86e-02 - valid si-snr: 22.26
151
- epoch: 110, lr: 4.69e-06 - train si-snr: -1.29e-01 - valid si-snr: 22.26
152
- epoch: 111, lr: 2.34e-06 - train si-snr: -1.28e-01 - valid si-snr: 22.26
153
- epoch: 112, lr: 2.34e-06 - train si-snr: 1.96e-02 - valid si-snr: 22.26
154
- epoch: 113, lr: 2.34e-06 - train si-snr: -8.82e-02 - valid si-snr: 22.26
155
- epoch: 114, lr: 1.17e-06 - train si-snr: -2.95e-02 - valid si-snr: 22.25
156
- epoch: 115, lr: 1.17e-06 - train si-snr: 1.44e-02 - valid si-snr: 22.26
157
- epoch: 116, lr: 1.17e-06 - train si-snr: -2.01e-02 - valid si-snr: 22.25
158
- epoch: 117, lr: 5.86e-07 - train si-snr: -6.14e-02 - valid si-snr: 22.25
159
- epoch: 118, lr: 5.86e-07 - train si-snr: 1.49e-02 - valid si-snr: 22.25
160
- epoch: 119, lr: 5.86e-07 - train si-snr: -2.11e-02 - valid si-snr: 22.25
161
- epoch: 120, lr: 2.93e-07 - train si-snr: -8.56e-02 - valid si-snr: 22.25
162
- epoch: 121, lr: 2.93e-07 - train si-snr: 3.46e-02 - valid si-snr: 22.25
163
- epoch: 122, lr: 2.93e-07 - train si-snr: -4.48e-02 - valid si-snr: 22.26
164
- epoch: 123, lr: 1.46e-07 - train si-snr: -4.78e-02 - valid si-snr: 22.25
165
- epoch: 124, lr: 1.46e-07 - train si-snr: 4.87e-02 - valid si-snr: 22.26
166
- epoch: 125, lr: 1.46e-07 - train si-snr: -8.55e-02 - valid si-snr: 22.25
167
- epoch: 126, lr: 7.32e-08 - train si-snr: 4.56e-02 - valid si-snr: 22.25
168
- epoch: 127, lr: 7.32e-08 - train si-snr: -7.29e-02 - valid si-snr: 22.25
169
- epoch: 128, lr: 7.32e-08 - train si-snr: -4.80e-02 - valid si-snr: 22.26
170
- epoch: 129, lr: 3.66e-08 - train si-snr: -6.66e-02 - valid si-snr: 22.26
171
- epoch: 130, lr: 3.66e-08 - train si-snr: 6.62e-03 - valid si-snr: 22.26
172
- epoch: 131, lr: 3.66e-08 - train si-snr: -1.94e-02 - valid si-snr: 22.26
173
- epoch: 132, lr: 1.83e-08 - train si-snr: 1.16e-02 - valid si-snr: 22.26
174
- epoch: 133, lr: 1.83e-08 - train si-snr: -1.09e-01 - valid si-snr: 22.26
175
- epoch: 134, lr: 1.83e-08 - train si-snr: -1.16e-01 - valid si-snr: 22.26
176
- epoch: 135, lr: 1.00e-08 - train si-snr: 2.68e-02 - valid si-snr: 22.26
177
- epoch: 136, lr: 1.00e-08 - train si-snr: 3.10e-03 - valid si-snr: 22.26
178
- epoch: 137, lr: 1.00e-08 - train si-snr: -4.31e-02 - valid si-snr: 22.26
179
- epoch: 138, lr: 1.00e-08 - train si-snr: 7.30e-02 - valid si-snr: 22.26
180
- epoch: 139, lr: 1.00e-08 - train si-snr: -9.77e-02 - valid si-snr: 22.26
181
- epoch: 140, lr: 1.00e-08 - train si-snr: -1.41e-01 - valid si-snr: 22.26
182
- epoch: 141, lr: 1.00e-08 - train si-snr: -1.82e-02 - valid si-snr: 22.26
183
- epoch: 142, lr: 1.00e-08 - train si-snr: -5.03e-02 - valid si-snr: 22.26
184
- epoch: 143, lr: 1.00e-08 - train si-snr: -9.63e-02 - valid si-snr: 22.26
185
- epoch: 144, lr: 1.00e-08 - train si-snr: -1.29e-02 - valid si-snr: 22.26
186
- epoch: 145, lr: 1.00e-08 - train si-snr: -3.77e-02 - valid si-snr: 22.26
187
- epoch: 146, lr: 1.00e-08 - train si-snr: -1.36e-01 - valid si-snr: 22.26
188
- epoch: 147, lr: 1.00e-08 - train si-snr: -1.02e-01 - valid si-snr: 22.26
189
- epoch: 148, lr: 1.00e-08 - train si-snr: 1.05e-01 - valid si-snr: 22.26
190
- epoch: 149, lr: 1.00e-08 - train si-snr: -1.08e-01 - valid si-snr: 22.26
191
- epoch: 150, lr: 1.00e-08 - train si-snr: 1.28e-02 - valid si-snr: 22.26
192
- epoch: 151, lr: 1.00e-08 - train si-snr: -8.94e-02 - valid si-snr: 22.26
193
- epoch: 152, lr: 1.00e-08 - train si-snr: -9.64e-02 - valid si-snr: 22.26
194
- epoch: 153, lr: 1.00e-08 - train si-snr: -1.32e-01 - valid si-snr: 22.26
195
- epoch: 154, lr: 1.00e-08 - train si-snr: 2.86e-02 - valid si-snr: 22.26
196
- epoch: 155, lr: 1.00e-08 - train si-snr: -2.50e-02 - valid si-snr: 22.26
197
- epoch: 156, lr: 1.00e-08 - train si-snr: -1.44e-02 - valid si-snr: 22.26
198
- epoch: 157, lr: 1.00e-08 - train si-snr: 9.09e-02 - valid si-snr: 22.26
199
- epoch: 158, lr: 1.00e-08 - train si-snr: 6.12e-03 - valid si-snr: 22.26
200
- epoch: 159, lr: 1.00e-08 - train si-snr: -3.80e-02 - valid si-snr: 22.26
201
- epoch: 160, lr: 1.00e-08 - train si-snr: 4.51e-02 - valid si-snr: 22.26
202
- epoch: 161, lr: 1.00e-08 - train si-snr: -2.98e-02 - valid si-snr: 22.26
203
- epoch: 162, lr: 1.00e-08 - train si-snr: -2.20e-03 - valid si-snr: 22.26
204
- epoch: 163, lr: 1.00e-08 - train si-snr: -1.64e-01 - valid si-snr: 22.26
205
- epoch: 164, lr: 1.00e-08 - train si-snr: -3.20e-02 - valid si-snr: 22.26
206
- epoch: 165, lr: 1.00e-08 - train si-snr: 3.47e-03 - valid si-snr: 22.26
207
- epoch: 166, lr: 1.00e-08 - train si-snr: -8.60e-02 - valid si-snr: 22.26
208
- epoch: 167, lr: 1.00e-08 - train si-snr: 6.45e-03 - valid si-snr: 22.26
209
- epoch: 168, lr: 1.00e-08 - train si-snr: 1.17e-02 - valid si-snr: 22.26
210
- epoch: 169, lr: 1.00e-08 - train si-snr: -4.05e-02 - valid si-snr: 22.26
211
- epoch: 170, lr: 1.00e-08 - train si-snr: -1.26e-01 - valid si-snr: 22.26
212
- epoch: 171, lr: 1.00e-08 - train si-snr: -1.06e-01 - valid si-snr: 22.26
213
- epoch: 172, lr: 1.00e-08 - train si-snr: -1.26e-01 - valid si-snr: 22.26
214
- epoch: 173, lr: 1.00e-08 - train si-snr: -7.41e-02 - valid si-snr: 22.26
215
- epoch: 174, lr: 1.00e-08 - train si-snr: 1.57e-02 - valid si-snr: 22.26
216
- epoch: 175, lr: 1.00e-08 - train si-snr: -1.48e-02 - valid si-snr: 22.26
217
- epoch: 176, lr: 1.00e-08 - train si-snr: 6.87e-02 - valid si-snr: 22.26
218
- epoch: 177, lr: 1.00e-08 - train si-snr: -6.77e-02 - valid si-snr: 22.26
219
- epoch: 178, lr: 1.00e-08 - train si-snr: -1.75e-01 - valid si-snr: 22.26
220
- epoch: 179, lr: 1.00e-08 - train si-snr: -8.73e-02 - valid si-snr: 22.26
221
- epoch: 180, lr: 1.00e-08 - train si-snr: -7.13e-02 - valid si-snr: 22.26
222
- epoch: 181, lr: 1.00e-08 - train si-snr: -1.28e-01 - valid si-snr: 22.26
223
- epoch: 182, lr: 1.00e-08 - train si-snr: 2.53e-02 - valid si-snr: 22.26
224
- epoch: 183, lr: 1.00e-08 - train si-snr: 5.30e-02 - valid si-snr: 22.26
225
- epoch: 184, lr: 1.00e-08 - train si-snr: -6.50e-02 - valid si-snr: 22.26
226
- epoch: 185, lr: 1.00e-08 - train si-snr: -7.48e-02 - valid si-snr: 22.26
227
- epoch: 186, lr: 1.00e-08 - train si-snr: -6.33e-02 - valid si-snr: 22.26
228
- epoch: 187, lr: 1.00e-08 - train si-snr: -5.01e-02 - valid si-snr: 22.26
229
- epoch: 188, lr: 1.00e-08 - train si-snr: -2.82e-03 - valid si-snr: 22.26
230
- epoch: 189, lr: 1.00e-08 - train si-snr: -1.37e-01 - valid si-snr: 22.26
231
- epoch: 190, lr: 1.00e-08 - train si-snr: -3.86e-02 - valid si-snr: 22.26
232
- epoch: 191, lr: 1.00e-08 - train si-snr: -4.23e-02 - valid si-snr: 22.26
233
- epoch: 192, lr: 1.00e-08 - train si-snr: -7.80e-02 - valid si-snr: 22.26
234
- epoch: 193, lr: 1.00e-08 - train si-snr: -2.90e-02 - valid si-snr: 22.26
235
- epoch: 194, lr: 1.00e-08 - train si-snr: -1.21e-01 - valid si-snr: 22.26
236
- epoch: 195, lr: 1.00e-08 - train si-snr: 8.91e-03 - valid si-snr: 22.26
237
- epoch: 196, lr: 1.00e-08 - train si-snr: -5.28e-02 - valid si-snr: 22.26
238
- epoch: 197, lr: 1.00e-08 - train si-snr: 9.40e-02 - valid si-snr: 22.26
239
- epoch: 198, lr: 1.00e-08 - train si-snr: -4.55e-02 - valid si-snr: 22.26
240
- epoch: 199, lr: 1.00e-08 - train si-snr: -6.24e-02 - valid si-snr: 22.26
241
- epoch: 200, lr: 1.00e-08 - train si-snr: 5.69e-03 - valid si-snr: 22.26
242
- Epoch loaded: 104 - test si-snr: 20.22