update model
Browse files- Conv-Tasnet/results/convtasnet_4-mix/1234/env.log +0 -93
- Conv-Tasnet/results/convtasnet_4-mix/1234/hyperparams.yaml +0 -179
- Conv-Tasnet/results/convtasnet_4-mix/1234/log.txt +0 -0
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/CKPT.yaml +0 -4
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/brain.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/counter.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/dataloader-TRAIN.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/decoder.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/encoder.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/lr_scheduler.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/masknet.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/optimizer.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/CKPT.yaml +0 -4
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/brain.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/counter.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/dataloader-TRAIN.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/decoder.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/encoder.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/lr_scheduler.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/masknet.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/optimizer.ckpt +0 -3
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_tr.csv +0 -0
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_val.csv +0 -0
- Conv-Tasnet/results/convtasnet_4-mix/1234/save/test_data.csv +0 -0
- Conv-Tasnet/results/convtasnet_4-mix/1234/test.py +0 -628
- Conv-Tasnet/results/convtasnet_4-mix/1234/test_results.csv +0 -1
- Conv-Tasnet/results/convtasnet_4-mix/1234/train.py +0 -628
- Conv-Tasnet/results/convtasnet_4-mix/1234/train_log.txt +0 -242
Conv-Tasnet/results/convtasnet_4-mix/1234/env.log
DELETED
|
@@ -1,93 +0,0 @@
|
|
| 1 |
-
SpeechBrain system description
|
| 2 |
-
==============================
|
| 3 |
-
Python version:
|
| 4 |
-
3.11.13 (main, Jun 5 2025, 13:12:00) [GCC 11.2.0]
|
| 5 |
-
==============================
|
| 6 |
-
Installed Python packages:
|
| 7 |
-
black==24.3.0
|
| 8 |
-
certifi==2025.6.15
|
| 9 |
-
cffi==1.17.1
|
| 10 |
-
cfgv==3.4.0
|
| 11 |
-
charset-normalizer==3.4.2
|
| 12 |
-
click==8.1.7
|
| 13 |
-
distlib==0.3.9
|
| 14 |
-
docstring_parser_fork==0.0.12
|
| 15 |
-
filelock==3.18.0
|
| 16 |
-
flake8==7.0.0
|
| 17 |
-
fsspec==2025.5.1
|
| 18 |
-
future==1.0.0
|
| 19 |
-
hf-xet==1.1.5
|
| 20 |
-
huggingface-hub==0.33.0
|
| 21 |
-
HyperPyYAML==1.2.2
|
| 22 |
-
identify==2.6.12
|
| 23 |
-
idna==3.10
|
| 24 |
-
iniconfig==2.1.0
|
| 25 |
-
isort==5.13.2
|
| 26 |
-
Jinja2==3.1.6
|
| 27 |
-
joblib==1.5.1
|
| 28 |
-
MarkupSafe==3.0.2
|
| 29 |
-
mccabe==0.7.0
|
| 30 |
-
mir_eval==0.6
|
| 31 |
-
mpmath==1.3.0
|
| 32 |
-
mypy_extensions==1.1.0
|
| 33 |
-
networkx==3.5
|
| 34 |
-
nodeenv==1.9.1
|
| 35 |
-
numpy==2.3.1
|
| 36 |
-
nvidia-cublas-cu12==12.6.4.1
|
| 37 |
-
nvidia-cuda-cupti-cu12==12.6.80
|
| 38 |
-
nvidia-cuda-nvrtc-cu12==12.6.77
|
| 39 |
-
nvidia-cuda-runtime-cu12==12.6.77
|
| 40 |
-
nvidia-cudnn-cu12==9.5.1.17
|
| 41 |
-
nvidia-cufft-cu12==11.3.0.4
|
| 42 |
-
nvidia-cufile-cu12==1.11.1.6
|
| 43 |
-
nvidia-curand-cu12==10.3.7.77
|
| 44 |
-
nvidia-cusolver-cu12==11.7.1.2
|
| 45 |
-
nvidia-cusparse-cu12==12.5.4.2
|
| 46 |
-
nvidia-cusparselt-cu12==0.6.3
|
| 47 |
-
nvidia-nccl-cu12==2.26.2
|
| 48 |
-
nvidia-nvjitlink-cu12==12.6.85
|
| 49 |
-
nvidia-nvtx-cu12==12.6.77
|
| 50 |
-
packaging==25.0
|
| 51 |
-
pandas==2.3.0
|
| 52 |
-
pathspec==0.12.1
|
| 53 |
-
platformdirs==4.3.8
|
| 54 |
-
pluggy==1.6.0
|
| 55 |
-
pre_commit==4.2.0
|
| 56 |
-
pycodestyle==2.11.0
|
| 57 |
-
pycparser==2.22
|
| 58 |
-
pydoclint==0.4.1
|
| 59 |
-
pyflakes==3.2.0
|
| 60 |
-
pyloudnorm==0.1.1
|
| 61 |
-
pytest==7.4.0
|
| 62 |
-
python-dateutil==2.9.0.post0
|
| 63 |
-
pytz==2025.2
|
| 64 |
-
PyYAML==6.0.2
|
| 65 |
-
regex==2024.11.6
|
| 66 |
-
requests==2.32.4
|
| 67 |
-
ruamel.yaml==0.18.14
|
| 68 |
-
ruamel.yaml.clib==0.2.12
|
| 69 |
-
safetensors==0.5.3
|
| 70 |
-
scipy==1.16.0
|
| 71 |
-
sentencepiece==0.2.0
|
| 72 |
-
six==1.17.0
|
| 73 |
-
soundfile==0.13.1
|
| 74 |
-
sox==1.5.0
|
| 75 |
-
-e git+ssh://git@github.com/speechbrain/speechbrain.git@c75ab5489431fd0a2a7d21160bc37677801cb506#egg=speechbrain
|
| 76 |
-
sympy==1.14.0
|
| 77 |
-
tokenizers==0.21.2
|
| 78 |
-
torch==2.7.1
|
| 79 |
-
torchaudio==2.7.1
|
| 80 |
-
tqdm==4.67.1
|
| 81 |
-
transformers==4.52.4
|
| 82 |
-
triton==3.3.1
|
| 83 |
-
typing_extensions==4.14.0
|
| 84 |
-
tzdata==2025.2
|
| 85 |
-
urllib3==2.5.0
|
| 86 |
-
virtualenv==20.31.2
|
| 87 |
-
yamllint==1.35.1
|
| 88 |
-
==============================
|
| 89 |
-
Git revision:
|
| 90 |
-
c75ab5489
|
| 91 |
-
==============================
|
| 92 |
-
CUDA version:
|
| 93 |
-
12.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/hyperparams.yaml
DELETED
|
@@ -1,179 +0,0 @@
|
|
| 1 |
-
# Generated 2025-06-26 from:
|
| 2 |
-
# /work106/youzhenghai/project/speechbrain/myegs/FORHUAWEI_TASNET/separation/hparams/convtasnet_4mix.yaml
|
| 3 |
-
# yamllint disable
|
| 4 |
-
# ################################
|
| 5 |
-
# Model: SepFormer for source separation
|
| 6 |
-
# https://arxiv.org/abs/2010.13154
|
| 7 |
-
# Dataset : WSJ0-2mix and WSJ0-3mix
|
| 8 |
-
# ################################
|
| 9 |
-
# Basic parameters
|
| 10 |
-
# Seed needs to be set at top of yaml, before objects with parameters are made
|
| 11 |
-
#
|
| 12 |
-
seed: 1234
|
| 13 |
-
__set_seed: !apply:speechbrain.utils.seed_everything [1234]
|
| 14 |
-
|
| 15 |
-
# Data params
|
| 16 |
-
|
| 17 |
-
# e.g. '/yourpath/wsj0-mix/2speakers'
|
| 18 |
-
# end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix
|
| 19 |
-
data_folder: /work105/youzhenghai/data/wsj0_2mix
|
| 20 |
-
|
| 21 |
-
# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used
|
| 22 |
-
# e.g. /yourpath/wsj0-processed/si_tr_s/
|
| 23 |
-
base_folder_dm: /yourpath/wsj0-processed/si_tr_s/
|
| 24 |
-
|
| 25 |
-
experiment_name: convtasnet_4-mix
|
| 26 |
-
output_folder: results/convtasnet_4-mix/1234
|
| 27 |
-
train_log: results/convtasnet_4-mix/1234/train_log.txt
|
| 28 |
-
save_folder: results/convtasnet_4-mix/1234/save
|
| 29 |
-
train_data: results/convtasnet_4-mix/1234/save/record_tr.csv
|
| 30 |
-
valid_data: results/convtasnet_4-mix/1234/save/record_val.csv
|
| 31 |
-
test_data: results/convtasnet_4-mix/1234/save/test_data.csv
|
| 32 |
-
skip_prep: false
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# Experiment params
|
| 36 |
-
precision: fp32 # bf16, fp16 or fp32
|
| 37 |
-
num_spks: 4 # set to 3 for wsj0-3mix
|
| 38 |
-
noprogressbar: false
|
| 39 |
-
save_audio: false # Save estimated sources on disk
|
| 40 |
-
sample_rate: 16000
|
| 41 |
-
|
| 42 |
-
####################### Training Parameters ####################################
|
| 43 |
-
N_epochs: 200
|
| 44 |
-
batch_size: 2
|
| 45 |
-
lr: 0.00015
|
| 46 |
-
clip_grad_norm: 5
|
| 47 |
-
loss_upper_lim: 999999 # this is the upper limit for an acceptable loss
|
| 48 |
-
# if True, the training sequences are cut to a specified length
|
| 49 |
-
limit_training_signal_len: true
|
| 50 |
-
# this is the length of sequences if we choose to limit
|
| 51 |
-
# the signal length of training sequences
|
| 52 |
-
training_signal_len: 64000000
|
| 53 |
-
|
| 54 |
-
# Set it to True to dynamically create mixtures at training time
|
| 55 |
-
dynamic_mixing: false
|
| 56 |
-
|
| 57 |
-
# Parameters for data augmentation
|
| 58 |
-
use_wavedrop: false
|
| 59 |
-
use_speedperturb: true
|
| 60 |
-
use_rand_shift: false
|
| 61 |
-
min_shift: -8000
|
| 62 |
-
max_shift: 8000
|
| 63 |
-
|
| 64 |
-
# Speed perturbation
|
| 65 |
-
speed_changes: &id001 [95, 100, 105]
|
| 66 |
-
|
| 67 |
-
# Frequency drop: randomly drops a number of frequency bands to zero.
|
| 68 |
-
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
|
| 69 |
-
orig_freq: 16000
|
| 70 |
-
speeds: *id001
|
| 71 |
-
drop_freq_low: 0 # Min frequency band dropout probability
|
| 72 |
-
drop_freq_high: 1 # Max frequency band dropout probability
|
| 73 |
-
drop_freq_count_low: 1 # Min number of frequency bands to drop
|
| 74 |
-
drop_freq_count_high: 3 # Max number of frequency bands to drop
|
| 75 |
-
drop_freq_width: 0.05 # Width of frequency bands to drop
|
| 76 |
-
|
| 77 |
-
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
|
| 78 |
-
drop_freq_low: 0
|
| 79 |
-
drop_freq_high: 1
|
| 80 |
-
drop_freq_count_low: 1
|
| 81 |
-
drop_freq_count_high: 3
|
| 82 |
-
drop_freq_width: 0.05
|
| 83 |
-
|
| 84 |
-
# Time drop: randomly drops a number of temporal chunks.
|
| 85 |
-
drop_chunk_count_low: 1 # Min number of audio chunks to drop
|
| 86 |
-
drop_chunk_count_high: 5 # Max number of audio chunks to drop
|
| 87 |
-
drop_chunk_length_low: 1000 # Min length of audio chunks to drop
|
| 88 |
-
drop_chunk_length_high: 2000 # Max length of audio chunks to drop
|
| 89 |
-
|
| 90 |
-
drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
|
| 91 |
-
drop_length_low: 1000
|
| 92 |
-
drop_length_high: 2000
|
| 93 |
-
drop_count_low: 1
|
| 94 |
-
drop_count_high: 5
|
| 95 |
-
|
| 96 |
-
# loss thresholding -- this thresholds the training loss
|
| 97 |
-
threshold_byloss: true
|
| 98 |
-
threshold: -30
|
| 99 |
-
|
| 100 |
-
# Encoder parameters
|
| 101 |
-
N_encoder_out: 256
|
| 102 |
-
# out_channels: 256
|
| 103 |
-
kernel_size: 32
|
| 104 |
-
kernel_stride: 16
|
| 105 |
-
|
| 106 |
-
# Dataloader options
|
| 107 |
-
dataloader_opts:
|
| 108 |
-
batch_size: 2
|
| 109 |
-
num_workers: 3
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# Specifying the network
|
| 113 |
-
Encoder: &id002 !new:speechbrain.lobes.models.dual_path.Encoder
|
| 114 |
-
kernel_size: 32
|
| 115 |
-
out_channels: 256
|
| 116 |
-
|
| 117 |
-
# intra: !new:speechbrain.lobes.models.dual_path.SBRNNBlock
|
| 118 |
-
# num_layers: 1
|
| 119 |
-
# input_size: !ref <out_channels>
|
| 120 |
-
# hidden_channels: !ref <out_channels>
|
| 121 |
-
# dropout: 0
|
| 122 |
-
# bidirectional: True
|
| 123 |
-
|
| 124 |
-
# inter: !new:speechbrain.lobes.models.dual_path.SBRNNBlock
|
| 125 |
-
# num_layers: 1
|
| 126 |
-
# input_size: !ref <out_channels>
|
| 127 |
-
# hidden_channels: !ref <out_channels>
|
| 128 |
-
# dropout: 0
|
| 129 |
-
# bidirectional: True
|
| 130 |
-
|
| 131 |
-
MaskNet: &id004 !new:speechbrain.lobes.models.conv_tasnet.MaskNet
|
| 132 |
-
|
| 133 |
-
N: 256
|
| 134 |
-
B: 256
|
| 135 |
-
H: 512
|
| 136 |
-
P: 3
|
| 137 |
-
X: 6
|
| 138 |
-
R: 4
|
| 139 |
-
C: 4
|
| 140 |
-
norm_type: gLN
|
| 141 |
-
causal: true
|
| 142 |
-
mask_nonlinear: relu
|
| 143 |
-
|
| 144 |
-
Decoder: &id003 !new:speechbrain.lobes.models.dual_path.Decoder
|
| 145 |
-
in_channels: 256
|
| 146 |
-
out_channels: 1
|
| 147 |
-
kernel_size: 32
|
| 148 |
-
stride: 16
|
| 149 |
-
bias: false
|
| 150 |
-
|
| 151 |
-
optimizer: !name:torch.optim.Adam
|
| 152 |
-
lr: 0.00015
|
| 153 |
-
weight_decay: 0
|
| 154 |
-
|
| 155 |
-
loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper
|
| 156 |
-
|
| 157 |
-
lr_scheduler: &id006 !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
|
| 158 |
-
|
| 159 |
-
factor: 0.5
|
| 160 |
-
patience: 2
|
| 161 |
-
dont_halve_until_epoch: 85
|
| 162 |
-
|
| 163 |
-
epoch_counter: &id005 !new:speechbrain.utils.epoch_loop.EpochCounter
|
| 164 |
-
limit: 200
|
| 165 |
-
|
| 166 |
-
modules:
|
| 167 |
-
encoder: *id002
|
| 168 |
-
decoder: *id003
|
| 169 |
-
masknet: *id004
|
| 170 |
-
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
|
| 171 |
-
checkpoints_dir: results/convtasnet_4-mix/1234/save
|
| 172 |
-
recoverables:
|
| 173 |
-
encoder: *id002
|
| 174 |
-
decoder: *id003
|
| 175 |
-
masknet: *id004
|
| 176 |
-
counter: *id005
|
| 177 |
-
lr_scheduler: *id006
|
| 178 |
-
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
|
| 179 |
-
save_file: results/convtasnet_4-mix/1234/train_log.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/log.txt
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/CKPT.yaml
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# yamllint disable
|
| 2 |
-
end-of-epoch: true
|
| 3 |
-
si-snr: 22.240427712364045
|
| 4 |
-
unixtime: 1750961143.069555
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/brain.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
|
| 3 |
-
size 46
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/counter.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5ef6fdf32513aa7cd11f72beccf132b9224d33f271471fff402742887a171edf
|
| 3 |
-
size 3
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/dataloader-TRAIN.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5c344ba7044815dd03c3448028a43e5b9c16074cb5a6a19c7ae86165c149735f
|
| 3 |
-
size 3
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/decoder.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b89e695d01ef7a5aeb76f5000f70959a078e4ea1cf97ae978a2a4dc2121c7f29
|
| 3 |
-
size 34409
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/encoder.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c5ef4fe38605072559dbf12b09643423c4649460c0f803f34f047e92f9358f39
|
| 3 |
-
size 34473
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/lr_scheduler.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f02f6900fea06c469d975f48c9b3f4d40868d5fb6e6758baf76c4e68c4785dd1
|
| 3 |
-
size 2251
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/masknet.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:100869f60d27f540b6d23e4a811cff04541c67e6ff4639776645069f841f0db5
|
| 3 |
-
size 26926023
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+02-05-43+00/optimizer.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c05ce1c793e4f0bae4a6905774bbfc8360e4450103008c838ea195f4a146452c
|
| 3 |
-
size 53964363
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/CKPT.yaml
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
# yamllint disable
|
| 2 |
-
end-of-epoch: true
|
| 3 |
-
si-snr: 22.256136728080673
|
| 4 |
-
unixtime: 1750994220.6695538
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/brain.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
|
| 3 |
-
size 46
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/counter.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:27badc983df1780b60c2b3fa9d3a19a00e46aac798451f0febdca52920faaddf
|
| 3 |
-
size 3
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/dataloader-TRAIN.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5c344ba7044815dd03c3448028a43e5b9c16074cb5a6a19c7ae86165c149735f
|
| 3 |
-
size 3
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/decoder.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8ba5891c2436cdefe57f4ca4b87bfa8267f927948330ea482d9cd6fadcd14163
|
| 3 |
-
size 34409
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/encoder.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:533c6dfe50d9c410e8c0e4907efaf95679ca2fe85f0ceab9aa0ede0c817d58d8
|
| 3 |
-
size 34473
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/lr_scheduler.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8689d8fb8de14a5995a161e50181134543321bbd431f774ce20f507239669ce3
|
| 3 |
-
size 3147
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/masknet.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3acd263841af684db0cf622b77e83a807e969390661e115b89a8139f8785aa64
|
| 3 |
-
size 26926023
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/CKPT+2025-06-27+11-17-00+00/optimizer.ckpt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:89593e2f633757a61883ef5aeb48a9e79ec4b09565d470c5571ee16edcb51c5c
|
| 3 |
-
size 53964363
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_tr.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/record_val.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/save/test_data.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/test.py
DELETED
|
@@ -1,628 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env/python3
|
| 2 |
-
"""Recipe for training a neural speech separation system on the wsjmix
|
| 3 |
-
dataset. The system employs an encoder, a decoder, and a masking network.
|
| 4 |
-
|
| 5 |
-
To run this recipe, do the following:
|
| 6 |
-
> python train.py hparams/sepformer.yaml
|
| 7 |
-
> python train.py hparams/dualpath_rnn.yaml
|
| 8 |
-
> python train.py hparams/convtasnet.yaml
|
| 9 |
-
|
| 10 |
-
The experiment file is flexible enough to support different neural
|
| 11 |
-
networks. By properly changing the parameter files, you can try
|
| 12 |
-
different architectures. The script supports both wsj2mix and
|
| 13 |
-
wsj3mix.
|
| 14 |
-
|
| 15 |
-
# 4-mix 主要根据 num_spks 修改 train.py 和 config
|
| 16 |
-
Authors
|
| 17 |
-
* Cem Subakan 2020
|
| 18 |
-
* Mirco Ravanelli 2020
|
| 19 |
-
* Samuele Cornell 2020
|
| 20 |
-
* Mirko Bronzi 2020
|
| 21 |
-
* Jianyuan Zhong 2020
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
import csv
|
| 25 |
-
import os
|
| 26 |
-
import sys
|
| 27 |
-
|
| 28 |
-
import numpy as np
|
| 29 |
-
import torch
|
| 30 |
-
import torch.nn.functional as F
|
| 31 |
-
import torchaudio
|
| 32 |
-
from hyperpyyaml import load_hyperpyyaml
|
| 33 |
-
from tqdm import tqdm
|
| 34 |
-
|
| 35 |
-
import speechbrain as sb
|
| 36 |
-
import speechbrain.nnet.schedulers as schedulers
|
| 37 |
-
from speechbrain.utils.distributed import run_on_main
|
| 38 |
-
from speechbrain.utils.logger import get_logger
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Define training procedure
|
| 42 |
-
class Separation(sb.Brain):
|
| 43 |
-
def compute_forward(self, mix, targets, stage, noise=None):
|
| 44 |
-
"""Forward computations from the mixture to the separated signals."""
|
| 45 |
-
|
| 46 |
-
# Unpack lists and put tensors in the right device
|
| 47 |
-
mix, mix_lens = mix
|
| 48 |
-
mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
|
| 49 |
-
|
| 50 |
-
# Convert targets to tensor
|
| 51 |
-
targets = torch.cat(
|
| 52 |
-
[targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
|
| 53 |
-
dim=-1,
|
| 54 |
-
).to(self.device)
|
| 55 |
-
|
| 56 |
-
# Add speech distortions
|
| 57 |
-
if stage == sb.Stage.TRAIN:
|
| 58 |
-
with torch.no_grad():
|
| 59 |
-
if self.hparams.use_speedperturb:
|
| 60 |
-
mix, targets = self.add_speed_perturb(targets, mix_lens)
|
| 61 |
-
|
| 62 |
-
mix = targets.sum(-1)
|
| 63 |
-
|
| 64 |
-
if self.hparams.use_wavedrop:
|
| 65 |
-
mix = self.hparams.drop_chunk(mix, mix_lens)
|
| 66 |
-
mix = self.hparams.drop_freq(mix)
|
| 67 |
-
|
| 68 |
-
if self.hparams.limit_training_signal_len:
|
| 69 |
-
mix, targets = self.cut_signals(mix, targets)
|
| 70 |
-
|
| 71 |
-
# Separation
|
| 72 |
-
mix_w = self.hparams.Encoder(mix)
|
| 73 |
-
est_mask = self.hparams.MaskNet(mix_w)
|
| 74 |
-
mix_w = torch.stack([mix_w] * self.hparams.num_spks)
|
| 75 |
-
sep_h = mix_w * est_mask
|
| 76 |
-
|
| 77 |
-
# Decoding
|
| 78 |
-
est_source = torch.cat(
|
| 79 |
-
[
|
| 80 |
-
self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
|
| 81 |
-
for i in range(self.hparams.num_spks)
|
| 82 |
-
],
|
| 83 |
-
dim=-1,
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
# T changed after conv1d in encoder, fix it here
|
| 87 |
-
T_origin = mix.size(1)
|
| 88 |
-
T_est = est_source.size(1)
|
| 89 |
-
if T_origin > T_est:
|
| 90 |
-
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
| 91 |
-
else:
|
| 92 |
-
est_source = est_source[:, :T_origin, :]
|
| 93 |
-
|
| 94 |
-
return est_source, targets
|
| 95 |
-
|
| 96 |
-
def compute_objectives(self, predictions, targets):
|
| 97 |
-
"""Computes the sinr loss"""
|
| 98 |
-
return self.hparams.loss(targets, predictions)
|
| 99 |
-
|
| 100 |
-
def fit_batch(self, batch):
|
| 101 |
-
"""Trains one batch"""
|
| 102 |
-
|
| 103 |
-
# Unpacking batch list
|
| 104 |
-
mixture = batch.mix_sig
|
| 105 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 106 |
-
|
| 107 |
-
if self.hparams.num_spks == 3:
|
| 108 |
-
targets.append(batch.s3_sig)
|
| 109 |
-
|
| 110 |
-
if self.hparams.num_spks == 4:
|
| 111 |
-
targets.append(batch.s3_sig)
|
| 112 |
-
targets.append(batch.s4_sig)
|
| 113 |
-
|
| 114 |
-
with self.training_ctx:
|
| 115 |
-
predictions, targets = self.compute_forward(
|
| 116 |
-
mixture, targets, sb.Stage.TRAIN
|
| 117 |
-
)
|
| 118 |
-
loss = self.compute_objectives(predictions, targets)
|
| 119 |
-
|
| 120 |
-
# hard threshold the easy dataitems
|
| 121 |
-
if self.hparams.threshold_byloss:
|
| 122 |
-
th = self.hparams.threshold
|
| 123 |
-
loss = loss[loss > th]
|
| 124 |
-
if loss.nelement() > 0:
|
| 125 |
-
loss = loss.mean()
|
| 126 |
-
else:
|
| 127 |
-
loss = loss.mean()
|
| 128 |
-
|
| 129 |
-
if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
|
| 130 |
-
self.scaler.scale(loss).backward()
|
| 131 |
-
if self.hparams.clip_grad_norm >= 0:
|
| 132 |
-
self.scaler.unscale_(self.optimizer)
|
| 133 |
-
torch.nn.utils.clip_grad_norm_(
|
| 134 |
-
self.modules.parameters(),
|
| 135 |
-
self.hparams.clip_grad_norm,
|
| 136 |
-
)
|
| 137 |
-
self.scaler.step(self.optimizer)
|
| 138 |
-
self.scaler.update()
|
| 139 |
-
else:
|
| 140 |
-
self.nonfinite_count += 1
|
| 141 |
-
logger.info(
|
| 142 |
-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
|
| 143 |
-
self.nonfinite_count
|
| 144 |
-
)
|
| 145 |
-
)
|
| 146 |
-
loss.data = torch.tensor(0.0).to(self.device)
|
| 147 |
-
self.optimizer.zero_grad()
|
| 148 |
-
|
| 149 |
-
return loss.detach().cpu()
|
| 150 |
-
|
| 151 |
-
def evaluate_batch(self, batch, stage):
|
| 152 |
-
"""Computations needed for validation/test batches"""
|
| 153 |
-
snt_id = batch.id
|
| 154 |
-
mixture = batch.mix_sig
|
| 155 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 156 |
-
if self.hparams.num_spks == 3:
|
| 157 |
-
targets.append(batch.s3_sig)
|
| 158 |
-
|
| 159 |
-
if self.hparams.num_spks == 4:
|
| 160 |
-
targets.append(batch.s3_sig)
|
| 161 |
-
targets.append(batch.s4_sig)
|
| 162 |
-
|
| 163 |
-
with torch.no_grad():
|
| 164 |
-
predictions, targets = self.compute_forward(mixture, targets, stage)
|
| 165 |
-
loss = self.compute_objectives(predictions, targets)
|
| 166 |
-
|
| 167 |
-
# Manage audio file saving
|
| 168 |
-
if stage == sb.Stage.TEST and self.hparams.save_audio:
|
| 169 |
-
if hasattr(self.hparams, "n_audio_to_save"):
|
| 170 |
-
if self.hparams.n_audio_to_save > 0:
|
| 171 |
-
self.save_audio(snt_id[0], mixture, targets, predictions)
|
| 172 |
-
self.hparams.n_audio_to_save += -1
|
| 173 |
-
else:
|
| 174 |
-
self.save_audio(snt_id[0], mixture, targets, predictions)
|
| 175 |
-
|
| 176 |
-
return loss.mean().detach()
|
| 177 |
-
|
| 178 |
-
def on_stage_end(self, stage, stage_loss, epoch):
|
| 179 |
-
"""Gets called at the end of a epoch."""
|
| 180 |
-
# Compute/store important stats
|
| 181 |
-
stage_stats = {"si-snr": stage_loss}
|
| 182 |
-
if stage == sb.Stage.TRAIN:
|
| 183 |
-
self.train_stats = stage_stats
|
| 184 |
-
|
| 185 |
-
# Perform end-of-iteration things, like annealing, logging, etc.
|
| 186 |
-
if stage == sb.Stage.VALID:
|
| 187 |
-
# Learning rate annealing
|
| 188 |
-
if isinstance(
|
| 189 |
-
self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
|
| 190 |
-
):
|
| 191 |
-
current_lr, next_lr = self.hparams.lr_scheduler(
|
| 192 |
-
[self.optimizer], epoch, stage_loss
|
| 193 |
-
)
|
| 194 |
-
schedulers.update_learning_rate(self.optimizer, next_lr)
|
| 195 |
-
else:
|
| 196 |
-
# if we do not use the reducelronplateau, we do not change the lr
|
| 197 |
-
current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
|
| 198 |
-
|
| 199 |
-
self.hparams.train_logger.log_stats(
|
| 200 |
-
stats_meta={"epoch": epoch, "lr": current_lr},
|
| 201 |
-
train_stats=self.train_stats,
|
| 202 |
-
valid_stats=stage_stats,
|
| 203 |
-
)
|
| 204 |
-
self.checkpointer.save_and_keep_only(
|
| 205 |
-
meta={"si-snr": stage_stats["si-snr"]}, min_keys=["si-snr"]
|
| 206 |
-
)
|
| 207 |
-
elif stage == sb.Stage.TEST:
|
| 208 |
-
self.hparams.train_logger.log_stats(
|
| 209 |
-
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
|
| 210 |
-
test_stats=stage_stats,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
def add_speed_perturb(self, targets, targ_lens):
|
| 214 |
-
"""Adds speed perturbation and random_shift to the input signals"""
|
| 215 |
-
|
| 216 |
-
min_len = -1
|
| 217 |
-
recombine = False
|
| 218 |
-
|
| 219 |
-
if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
|
| 220 |
-
# Performing speed change (independently on each source)
|
| 221 |
-
new_targets = []
|
| 222 |
-
recombine = True
|
| 223 |
-
|
| 224 |
-
for i in range(targets.shape[-1]):
|
| 225 |
-
new_target = self.hparams.speed_perturb(targets[:, :, i])
|
| 226 |
-
new_targets.append(new_target)
|
| 227 |
-
if i == 0:
|
| 228 |
-
min_len = new_target.shape[-1]
|
| 229 |
-
else:
|
| 230 |
-
if new_target.shape[-1] < min_len:
|
| 231 |
-
min_len = new_target.shape[-1]
|
| 232 |
-
|
| 233 |
-
if self.hparams.use_rand_shift:
|
| 234 |
-
# Performing random_shift (independently on each source)
|
| 235 |
-
recombine = True
|
| 236 |
-
for i in range(targets.shape[-1]):
|
| 237 |
-
rand_shift = torch.randint(
|
| 238 |
-
self.hparams.min_shift, self.hparams.max_shift, (1,)
|
| 239 |
-
)
|
| 240 |
-
new_targets[i] = new_targets[i].to(self.device)
|
| 241 |
-
new_targets[i] = torch.roll(
|
| 242 |
-
new_targets[i], shifts=(rand_shift[0],), dims=1
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
# Re-combination
|
| 246 |
-
if recombine:
|
| 247 |
-
if self.hparams.use_speedperturb:
|
| 248 |
-
targets = torch.zeros(
|
| 249 |
-
targets.shape[0],
|
| 250 |
-
min_len,
|
| 251 |
-
targets.shape[-1],
|
| 252 |
-
device=targets.device,
|
| 253 |
-
dtype=torch.float,
|
| 254 |
-
)
|
| 255 |
-
for i, new_target in enumerate(new_targets):
|
| 256 |
-
targets[:, :, i] = new_targets[i][:, 0:min_len]
|
| 257 |
-
|
| 258 |
-
mix = targets.sum(-1)
|
| 259 |
-
return mix, targets
|
| 260 |
-
|
| 261 |
-
def cut_signals(self, mixture, targets):
|
| 262 |
-
"""This function selects a random segment of a given length within the mixture.
|
| 263 |
-
The corresponding targets are selected accordingly"""
|
| 264 |
-
randstart = torch.randint(
|
| 265 |
-
0,
|
| 266 |
-
1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
|
| 267 |
-
(1,),
|
| 268 |
-
).item()
|
| 269 |
-
targets = targets[
|
| 270 |
-
:, randstart : randstart + self.hparams.training_signal_len, :
|
| 271 |
-
]
|
| 272 |
-
mixture = mixture[
|
| 273 |
-
:, randstart : randstart + self.hparams.training_signal_len
|
| 274 |
-
]
|
| 275 |
-
return mixture, targets
|
| 276 |
-
|
| 277 |
-
def reset_layer_recursively(self, layer):
|
| 278 |
-
"""Reinitializes the parameters of the neural networks"""
|
| 279 |
-
if hasattr(layer, "reset_parameters"):
|
| 280 |
-
layer.reset_parameters()
|
| 281 |
-
for child_layer in layer.modules():
|
| 282 |
-
if layer != child_layer:
|
| 283 |
-
self.reset_layer_recursively(child_layer)
|
| 284 |
-
|
| 285 |
-
def save_results(self, test_data):
|
| 286 |
-
"""This script computes the SDR and SI-SNR metrics and saves
|
| 287 |
-
them into a csv file"""
|
| 288 |
-
|
| 289 |
-
# This package is required for SDR computation
|
| 290 |
-
from mir_eval.separation import bss_eval_sources
|
| 291 |
-
|
| 292 |
-
# Create folders where to store audio
|
| 293 |
-
save_file = os.path.join(self.hparams.output_folder, "test_results.csv")
|
| 294 |
-
|
| 295 |
-
# Variable init
|
| 296 |
-
all_sdrs = []
|
| 297 |
-
all_sdrs_i = []
|
| 298 |
-
all_sisnrs = []
|
| 299 |
-
all_sisnrs_i = []
|
| 300 |
-
csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
|
| 301 |
-
|
| 302 |
-
test_loader = sb.dataio.dataloader.make_dataloader(
|
| 303 |
-
test_data, **self.hparams.dataloader_opts
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
|
| 307 |
-
writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
|
| 308 |
-
writer.writeheader()
|
| 309 |
-
|
| 310 |
-
# Loop over all test sentence
|
| 311 |
-
with tqdm(test_loader, dynamic_ncols=True) as t:
|
| 312 |
-
for i, batch in enumerate(t):
|
| 313 |
-
# Apply Separation
|
| 314 |
-
mixture, mix_len = batch.mix_sig
|
| 315 |
-
snt_id = batch.id
|
| 316 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 317 |
-
if self.hparams.num_spks == 3:
|
| 318 |
-
targets.append(batch.s3_sig)
|
| 319 |
-
|
| 320 |
-
if self.hparams.num_spks == 4:
|
| 321 |
-
targets.append(batch.s3_sig)
|
| 322 |
-
targets.append(batch.s4_sig)
|
| 323 |
-
|
| 324 |
-
with torch.no_grad():
|
| 325 |
-
predictions, targets = self.compute_forward(
|
| 326 |
-
batch.mix_sig, targets, sb.Stage.TEST
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
# Compute SI-SNR
|
| 330 |
-
sisnr = self.compute_objectives(predictions, targets)
|
| 331 |
-
|
| 332 |
-
# Compute SI-SNR improvement
|
| 333 |
-
mixture_signal = torch.stack(
|
| 334 |
-
[mixture] * self.hparams.num_spks, dim=-1
|
| 335 |
-
)
|
| 336 |
-
mixture_signal = mixture_signal.to(targets.device)
|
| 337 |
-
sisnr_baseline = self.compute_objectives(
|
| 338 |
-
mixture_signal, targets
|
| 339 |
-
)
|
| 340 |
-
sisnr_i = sisnr - sisnr_baseline
|
| 341 |
-
|
| 342 |
-
# Compute SDR
|
| 343 |
-
sdr, _, _, _ = bss_eval_sources(
|
| 344 |
-
targets[0].t().cpu().numpy(),
|
| 345 |
-
predictions[0].t().detach().cpu().numpy(),
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
sdr_baseline, _, _, _ = bss_eval_sources(
|
| 349 |
-
targets[0].t().cpu().numpy(),
|
| 350 |
-
mixture_signal[0].t().detach().cpu().numpy(),
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
sdr_i = sdr.mean() - sdr_baseline.mean()
|
| 354 |
-
|
| 355 |
-
# Saving on a csv file
|
| 356 |
-
row = {
|
| 357 |
-
"snt_id": snt_id[0],
|
| 358 |
-
"sdr": sdr.mean(),
|
| 359 |
-
"sdr_i": sdr_i,
|
| 360 |
-
"si-snr": -sisnr.item(),
|
| 361 |
-
"si-snr_i": -sisnr_i.item(),
|
| 362 |
-
}
|
| 363 |
-
writer.writerow(row)
|
| 364 |
-
|
| 365 |
-
# Metric Accumulation
|
| 366 |
-
all_sdrs.append(sdr.mean())
|
| 367 |
-
all_sdrs_i.append(sdr_i.mean())
|
| 368 |
-
all_sisnrs.append(-sisnr.item())
|
| 369 |
-
all_sisnrs_i.append(-sisnr_i.item())
|
| 370 |
-
|
| 371 |
-
row = {
|
| 372 |
-
"snt_id": "avg",
|
| 373 |
-
"sdr": np.array(all_sdrs).mean(),
|
| 374 |
-
"sdr_i": np.array(all_sdrs_i).mean(),
|
| 375 |
-
"si-snr": np.array(all_sisnrs).mean(),
|
| 376 |
-
"si-snr_i": np.array(all_sisnrs_i).mean(),
|
| 377 |
-
}
|
| 378 |
-
writer.writerow(row)
|
| 379 |
-
|
| 380 |
-
logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
|
| 381 |
-
logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
|
| 382 |
-
logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
|
| 383 |
-
logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
|
| 384 |
-
|
| 385 |
-
def save_audio(self, snt_id, mixture, targets, predictions):
|
| 386 |
-
"saves the test audio (mixture, targets, and estimated sources) on disk"
|
| 387 |
-
|
| 388 |
-
# Create output folder
|
| 389 |
-
save_path = os.path.join(self.hparams.save_folder, "audio_results")
|
| 390 |
-
if not os.path.exists(save_path):
|
| 391 |
-
os.mkdir(save_path)
|
| 392 |
-
|
| 393 |
-
for ns in range(self.hparams.num_spks):
|
| 394 |
-
# Estimated source
|
| 395 |
-
signal = predictions[0, :, ns]
|
| 396 |
-
signal = signal / signal.abs().max()
|
| 397 |
-
save_file = os.path.join(
|
| 398 |
-
save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
|
| 399 |
-
)
|
| 400 |
-
torchaudio.save(
|
| 401 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
# Original source
|
| 405 |
-
signal = targets[0, :, ns]
|
| 406 |
-
signal = signal / signal.abs().max()
|
| 407 |
-
save_file = os.path.join(
|
| 408 |
-
save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
|
| 409 |
-
)
|
| 410 |
-
torchaudio.save(
|
| 411 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 412 |
-
)
|
| 413 |
-
|
| 414 |
-
# Mixture
|
| 415 |
-
signal = mixture[0][0, :]
|
| 416 |
-
signal = signal / signal.abs().max()
|
| 417 |
-
save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
|
| 418 |
-
torchaudio.save(
|
| 419 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
def dataio_prep(hparams):
|
| 424 |
-
"""Creates data processing pipeline"""
|
| 425 |
-
|
| 426 |
-
# 1. Define datasets
|
| 427 |
-
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 428 |
-
csv_path=hparams["train_data"],
|
| 429 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 433 |
-
csv_path=hparams["valid_data"],
|
| 434 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 438 |
-
csv_path=hparams["test_data"],
|
| 439 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
datasets = [train_data, valid_data, test_data]
|
| 443 |
-
|
| 444 |
-
# 2. Provide audio pipelines
|
| 445 |
-
|
| 446 |
-
@sb.utils.data_pipeline.takes("mix_wav")
|
| 447 |
-
@sb.utils.data_pipeline.provides("mix_sig")
|
| 448 |
-
def audio_pipeline_mix(mix_wav):
|
| 449 |
-
mix_sig = sb.dataio.dataio.read_audio(mix_wav)
|
| 450 |
-
return mix_sig
|
| 451 |
-
|
| 452 |
-
@sb.utils.data_pipeline.takes("s1_wav")
|
| 453 |
-
@sb.utils.data_pipeline.provides("s1_sig")
|
| 454 |
-
def audio_pipeline_s1(s1_wav):
|
| 455 |
-
s1_sig = sb.dataio.dataio.read_audio(s1_wav)
|
| 456 |
-
return s1_sig
|
| 457 |
-
|
| 458 |
-
@sb.utils.data_pipeline.takes("s2_wav")
|
| 459 |
-
@sb.utils.data_pipeline.provides("s2_sig")
|
| 460 |
-
def audio_pipeline_s2(s2_wav):
|
| 461 |
-
s2_sig = sb.dataio.dataio.read_audio(s2_wav)
|
| 462 |
-
return s2_sig
|
| 463 |
-
|
| 464 |
-
# --- 如果说话人 >= 3,定义第 3 路 ---
|
| 465 |
-
if hparams["num_spks"] >= 3:
|
| 466 |
-
@sb.utils.data_pipeline.takes("s3_wav")
|
| 467 |
-
@sb.utils.data_pipeline.provides("s3_sig")
|
| 468 |
-
def audio_pipeline_s3(s3_wav):
|
| 469 |
-
return sb.dataio.dataio.read_audio(s3_wav)
|
| 470 |
-
|
| 471 |
-
# --- 如果说话人 == 4,定义第 4 路 ---
|
| 472 |
-
if hparams["num_spks"] == 4:
|
| 473 |
-
@sb.utils.data_pipeline.takes("s4_wav")
|
| 474 |
-
@sb.utils.data_pipeline.provides("s4_sig")
|
| 475 |
-
def audio_pipeline_s4(s4_wav):
|
| 476 |
-
return sb.dataio.dataio.read_audio(s4_wav)
|
| 477 |
-
|
| 478 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
|
| 479 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s1)
|
| 480 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s2)
|
| 481 |
-
if hparams["num_spks"] == 3:
|
| 482 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
|
| 483 |
-
sb.dataio.dataset.set_output_keys(
|
| 484 |
-
datasets, ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig"]
|
| 485 |
-
)
|
| 486 |
-
elif hparams["num_spks"] == 4 :
|
| 487 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
|
| 488 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s4)
|
| 489 |
-
sb.dataio.dataset.set_output_keys(
|
| 490 |
-
datasets,
|
| 491 |
-
["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig", "s4_sig"],
|
| 492 |
-
)
|
| 493 |
-
else:
|
| 494 |
-
sb.dataio.dataset.set_output_keys(
|
| 495 |
-
datasets, ["id", "mix_sig", "s1_sig", "s2_sig"]
|
| 496 |
-
)
|
| 497 |
-
|
| 498 |
-
return train_data, valid_data, test_data
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
if __name__ == "__main__":
|
| 502 |
-
# Load hyperparameters file with command-line overrides
|
| 503 |
-
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
|
| 504 |
-
with open(hparams_file, encoding="utf-8") as fin:
|
| 505 |
-
hparams = load_hyperpyyaml(fin, overrides)
|
| 506 |
-
|
| 507 |
-
# Initialize ddp (useful only for multi-GPU DDP training)
|
| 508 |
-
sb.utils.distributed.ddp_init_group(run_opts)
|
| 509 |
-
|
| 510 |
-
# Logger info
|
| 511 |
-
logger = get_logger(__name__)
|
| 512 |
-
|
| 513 |
-
# Create experiment directory
|
| 514 |
-
sb.create_experiment_directory(
|
| 515 |
-
experiment_directory=hparams["output_folder"],
|
| 516 |
-
hyperparams_to_save=hparams_file,
|
| 517 |
-
overrides=overrides,
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
# Update precision to bf16 if the device is CPU and precision is fp16
|
| 521 |
-
if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
|
| 522 |
-
hparams["precision"] = "bf16"
|
| 523 |
-
|
| 524 |
-
# Check if wsj0_tr is set with dynamic mixing
|
| 525 |
-
if hparams["dynamic_mixing"] and not os.path.exists(
|
| 526 |
-
hparams["base_folder_dm"]
|
| 527 |
-
):
|
| 528 |
-
raise ValueError(
|
| 529 |
-
"Please, specify a valid base_folder_dm folder when using dynamic mixing"
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
# Data preparation
|
| 533 |
-
from prepare_data import prepare_wsjmix # noqa
|
| 534 |
-
|
| 535 |
-
# run_on_main(
|
| 536 |
-
# prepare_wsjmix,
|
| 537 |
-
# kwargs={
|
| 538 |
-
# "datapath": hparams["data_folder"],
|
| 539 |
-
# "savepath": hparams["save_folder"],
|
| 540 |
-
# "n_spks": hparams["num_spks"],
|
| 541 |
-
# "skip_prep": hparams["skip_prep"],
|
| 542 |
-
# "fs": hparams["sample_rate"],
|
| 543 |
-
# },
|
| 544 |
-
# )
|
| 545 |
-
|
| 546 |
-
# Create dataset objects
|
| 547 |
-
if hparams["dynamic_mixing"]:
|
| 548 |
-
from dynamic_mixing import dynamic_mix_data_prep
|
| 549 |
-
|
| 550 |
-
# if the base_folder for dm is not processed, preprocess them
|
| 551 |
-
if "processed" not in hparams["base_folder_dm"]:
|
| 552 |
-
# if the processed folder already exists we just use it otherwise we do the preprocessing
|
| 553 |
-
if not os.path.exists(
|
| 554 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 555 |
-
):
|
| 556 |
-
from preprocess_dynamic_mixing import resample_folder
|
| 557 |
-
|
| 558 |
-
print("Resampling the base folder")
|
| 559 |
-
run_on_main(
|
| 560 |
-
resample_folder,
|
| 561 |
-
kwargs={
|
| 562 |
-
"input_folder": hparams["base_folder_dm"],
|
| 563 |
-
"output_folder": os.path.normpath(
|
| 564 |
-
hparams["base_folder_dm"]
|
| 565 |
-
)
|
| 566 |
-
+ "_processed",
|
| 567 |
-
"fs": hparams["sample_rate"],
|
| 568 |
-
"regex": "**/*.wav",
|
| 569 |
-
},
|
| 570 |
-
)
|
| 571 |
-
# adjust the base_folder_dm path
|
| 572 |
-
hparams["base_folder_dm"] = (
|
| 573 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 574 |
-
)
|
| 575 |
-
else:
|
| 576 |
-
print(
|
| 577 |
-
"Using the existing processed folder on the same directory as base_folder_dm"
|
| 578 |
-
)
|
| 579 |
-
hparams["base_folder_dm"] = (
|
| 580 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
# Collecting the hparams for dynamic batching
|
| 584 |
-
dm_hparams = {
|
| 585 |
-
"train_data": hparams["train_data"],
|
| 586 |
-
"data_folder": hparams["data_folder"],
|
| 587 |
-
"base_folder_dm": hparams["base_folder_dm"],
|
| 588 |
-
"sample_rate": hparams["sample_rate"],
|
| 589 |
-
"num_spks": hparams["num_spks"],
|
| 590 |
-
"training_signal_len": hparams["training_signal_len"],
|
| 591 |
-
"dataloader_opts": hparams["dataloader_opts"],
|
| 592 |
-
}
|
| 593 |
-
train_data = dynamic_mix_data_prep(dm_hparams)
|
| 594 |
-
_, valid_data, test_data = dataio_prep(hparams)
|
| 595 |
-
else:
|
| 596 |
-
train_data, valid_data, test_data = dataio_prep(hparams)
|
| 597 |
-
|
| 598 |
-
# Load pretrained model if pretrained_separator is present in the yaml
|
| 599 |
-
if "pretrained_separator" in hparams:
|
| 600 |
-
run_on_main(hparams["pretrained_separator"].collect_files)
|
| 601 |
-
hparams["pretrained_separator"].load_collected()
|
| 602 |
-
|
| 603 |
-
# Brain class initialization
|
| 604 |
-
separator = Separation(
|
| 605 |
-
modules=hparams["modules"],
|
| 606 |
-
opt_class=hparams["optimizer"],
|
| 607 |
-
hparams=hparams,
|
| 608 |
-
run_opts=run_opts,
|
| 609 |
-
checkpointer=hparams["checkpointer"],
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
# re-initialize the parameters if we don't use a pretrained model
|
| 613 |
-
if "pretrained_separator" not in hparams:
|
| 614 |
-
for module in separator.modules.values():
|
| 615 |
-
separator.reset_layer_recursively(module)
|
| 616 |
-
|
| 617 |
-
# # Training
|
| 618 |
-
# separator.fit(
|
| 619 |
-
# separator.hparams.epoch_counter,
|
| 620 |
-
# train_data,
|
| 621 |
-
# valid_data,
|
| 622 |
-
# train_loader_kwargs=hparams["dataloader_opts"],
|
| 623 |
-
# valid_loader_kwargs=hparams["dataloader_opts"],
|
| 624 |
-
# )
|
| 625 |
-
|
| 626 |
-
# Eval
|
| 627 |
-
separator.evaluate(test_data, min_key="si-snr")
|
| 628 |
-
separator.save_results(test_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/test_results.csv
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
snt_id,sdr,sdr_i,si-snr,si-snr_i
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/train.py
DELETED
|
@@ -1,628 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env/python3
|
| 2 |
-
"""Recipe for training a neural speech separation system on the wsjmix
|
| 3 |
-
dataset. The system employs an encoder, a decoder, and a masking network.
|
| 4 |
-
|
| 5 |
-
To run this recipe, do the following:
|
| 6 |
-
> python train.py hparams/sepformer.yaml
|
| 7 |
-
> python train.py hparams/dualpath_rnn.yaml
|
| 8 |
-
> python train.py hparams/convtasnet.yaml
|
| 9 |
-
|
| 10 |
-
The experiment file is flexible enough to support different neural
|
| 11 |
-
networks. By properly changing the parameter files, you can try
|
| 12 |
-
different architectures. The script supports both wsj2mix and
|
| 13 |
-
wsj3mix.
|
| 14 |
-
|
| 15 |
-
# 4-mix 主要根据 num_spks 修改 train.py 和 config
|
| 16 |
-
Authors
|
| 17 |
-
* Cem Subakan 2020
|
| 18 |
-
* Mirco Ravanelli 2020
|
| 19 |
-
* Samuele Cornell 2020
|
| 20 |
-
* Mirko Bronzi 2020
|
| 21 |
-
* Jianyuan Zhong 2020
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
import csv
|
| 25 |
-
import os
|
| 26 |
-
import sys
|
| 27 |
-
|
| 28 |
-
import numpy as np
|
| 29 |
-
import torch
|
| 30 |
-
import torch.nn.functional as F
|
| 31 |
-
import torchaudio
|
| 32 |
-
from hyperpyyaml import load_hyperpyyaml
|
| 33 |
-
from tqdm import tqdm
|
| 34 |
-
|
| 35 |
-
import speechbrain as sb
|
| 36 |
-
import speechbrain.nnet.schedulers as schedulers
|
| 37 |
-
from speechbrain.utils.distributed import run_on_main
|
| 38 |
-
from speechbrain.utils.logger import get_logger
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Define training procedure
|
| 42 |
-
class Separation(sb.Brain):
|
| 43 |
-
def compute_forward(self, mix, targets, stage, noise=None):
|
| 44 |
-
"""Forward computations from the mixture to the separated signals."""
|
| 45 |
-
|
| 46 |
-
# Unpack lists and put tensors in the right device
|
| 47 |
-
mix, mix_lens = mix
|
| 48 |
-
mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
|
| 49 |
-
|
| 50 |
-
# Convert targets to tensor
|
| 51 |
-
targets = torch.cat(
|
| 52 |
-
[targets[i][0].unsqueeze(-1) for i in range(self.hparams.num_spks)],
|
| 53 |
-
dim=-1,
|
| 54 |
-
).to(self.device)
|
| 55 |
-
|
| 56 |
-
# Add speech distortions
|
| 57 |
-
if stage == sb.Stage.TRAIN:
|
| 58 |
-
with torch.no_grad():
|
| 59 |
-
if self.hparams.use_speedperturb:
|
| 60 |
-
mix, targets = self.add_speed_perturb(targets, mix_lens)
|
| 61 |
-
|
| 62 |
-
mix = targets.sum(-1)
|
| 63 |
-
|
| 64 |
-
if self.hparams.use_wavedrop:
|
| 65 |
-
mix = self.hparams.drop_chunk(mix, mix_lens)
|
| 66 |
-
mix = self.hparams.drop_freq(mix)
|
| 67 |
-
|
| 68 |
-
if self.hparams.limit_training_signal_len:
|
| 69 |
-
mix, targets = self.cut_signals(mix, targets)
|
| 70 |
-
|
| 71 |
-
# Separation
|
| 72 |
-
mix_w = self.hparams.Encoder(mix)
|
| 73 |
-
est_mask = self.hparams.MaskNet(mix_w)
|
| 74 |
-
mix_w = torch.stack([mix_w] * self.hparams.num_spks)
|
| 75 |
-
sep_h = mix_w * est_mask
|
| 76 |
-
|
| 77 |
-
# Decoding
|
| 78 |
-
est_source = torch.cat(
|
| 79 |
-
[
|
| 80 |
-
self.hparams.Decoder(sep_h[i]).unsqueeze(-1)
|
| 81 |
-
for i in range(self.hparams.num_spks)
|
| 82 |
-
],
|
| 83 |
-
dim=-1,
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
# T changed after conv1d in encoder, fix it here
|
| 87 |
-
T_origin = mix.size(1)
|
| 88 |
-
T_est = est_source.size(1)
|
| 89 |
-
if T_origin > T_est:
|
| 90 |
-
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
| 91 |
-
else:
|
| 92 |
-
est_source = est_source[:, :T_origin, :]
|
| 93 |
-
|
| 94 |
-
return est_source, targets
|
| 95 |
-
|
| 96 |
-
def compute_objectives(self, predictions, targets):
|
| 97 |
-
"""Computes the sinr loss"""
|
| 98 |
-
return self.hparams.loss(targets, predictions)
|
| 99 |
-
|
| 100 |
-
def fit_batch(self, batch):
|
| 101 |
-
"""Trains one batch"""
|
| 102 |
-
|
| 103 |
-
# Unpacking batch list
|
| 104 |
-
mixture = batch.mix_sig
|
| 105 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 106 |
-
|
| 107 |
-
if self.hparams.num_spks == 3:
|
| 108 |
-
targets.append(batch.s3_sig)
|
| 109 |
-
|
| 110 |
-
if self.hparams.num_spks == 4:
|
| 111 |
-
targets.append(batch.s3_sig)
|
| 112 |
-
targets.append(batch.s4_sig)
|
| 113 |
-
|
| 114 |
-
with self.training_ctx:
|
| 115 |
-
predictions, targets = self.compute_forward(
|
| 116 |
-
mixture, targets, sb.Stage.TRAIN
|
| 117 |
-
)
|
| 118 |
-
loss = self.compute_objectives(predictions, targets)
|
| 119 |
-
|
| 120 |
-
# hard threshold the easy dataitems
|
| 121 |
-
if self.hparams.threshold_byloss:
|
| 122 |
-
th = self.hparams.threshold
|
| 123 |
-
loss = loss[loss > th]
|
| 124 |
-
if loss.nelement() > 0:
|
| 125 |
-
loss = loss.mean()
|
| 126 |
-
else:
|
| 127 |
-
loss = loss.mean()
|
| 128 |
-
|
| 129 |
-
if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
|
| 130 |
-
self.scaler.scale(loss).backward()
|
| 131 |
-
if self.hparams.clip_grad_norm >= 0:
|
| 132 |
-
self.scaler.unscale_(self.optimizer)
|
| 133 |
-
torch.nn.utils.clip_grad_norm_(
|
| 134 |
-
self.modules.parameters(),
|
| 135 |
-
self.hparams.clip_grad_norm,
|
| 136 |
-
)
|
| 137 |
-
self.scaler.step(self.optimizer)
|
| 138 |
-
self.scaler.update()
|
| 139 |
-
else:
|
| 140 |
-
self.nonfinite_count += 1
|
| 141 |
-
logger.info(
|
| 142 |
-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
|
| 143 |
-
self.nonfinite_count
|
| 144 |
-
)
|
| 145 |
-
)
|
| 146 |
-
loss.data = torch.tensor(0.0).to(self.device)
|
| 147 |
-
self.optimizer.zero_grad()
|
| 148 |
-
|
| 149 |
-
return loss.detach().cpu()
|
| 150 |
-
|
| 151 |
-
def evaluate_batch(self, batch, stage):
|
| 152 |
-
"""Computations needed for validation/test batches"""
|
| 153 |
-
snt_id = batch.id
|
| 154 |
-
mixture = batch.mix_sig
|
| 155 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 156 |
-
if self.hparams.num_spks == 3:
|
| 157 |
-
targets.append(batch.s3_sig)
|
| 158 |
-
|
| 159 |
-
if self.hparams.num_spks == 4:
|
| 160 |
-
targets.append(batch.s3_sig)
|
| 161 |
-
targets.append(batch.s4_sig)
|
| 162 |
-
|
| 163 |
-
with torch.no_grad():
|
| 164 |
-
predictions, targets = self.compute_forward(mixture, targets, stage)
|
| 165 |
-
loss = self.compute_objectives(predictions, targets)
|
| 166 |
-
|
| 167 |
-
# Manage audio file saving
|
| 168 |
-
if stage == sb.Stage.TEST and self.hparams.save_audio:
|
| 169 |
-
if hasattr(self.hparams, "n_audio_to_save"):
|
| 170 |
-
if self.hparams.n_audio_to_save > 0:
|
| 171 |
-
self.save_audio(snt_id[0], mixture, targets, predictions)
|
| 172 |
-
self.hparams.n_audio_to_save += -1
|
| 173 |
-
else:
|
| 174 |
-
self.save_audio(snt_id[0], mixture, targets, predictions)
|
| 175 |
-
|
| 176 |
-
return loss.mean().detach()
|
| 177 |
-
|
| 178 |
-
def on_stage_end(self, stage, stage_loss, epoch):
|
| 179 |
-
"""Gets called at the end of a epoch."""
|
| 180 |
-
# Compute/store important stats
|
| 181 |
-
stage_stats = {"si-snr": stage_loss}
|
| 182 |
-
if stage == sb.Stage.TRAIN:
|
| 183 |
-
self.train_stats = stage_stats
|
| 184 |
-
|
| 185 |
-
# Perform end-of-iteration things, like annealing, logging, etc.
|
| 186 |
-
if stage == sb.Stage.VALID:
|
| 187 |
-
# Learning rate annealing
|
| 188 |
-
if isinstance(
|
| 189 |
-
self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau
|
| 190 |
-
):
|
| 191 |
-
current_lr, next_lr = self.hparams.lr_scheduler(
|
| 192 |
-
[self.optimizer], epoch, stage_loss
|
| 193 |
-
)
|
| 194 |
-
schedulers.update_learning_rate(self.optimizer, next_lr)
|
| 195 |
-
else:
|
| 196 |
-
# if we do not use the reducelronplateau, we do not change the lr
|
| 197 |
-
current_lr = self.hparams.optimizer.optim.param_groups[0]["lr"]
|
| 198 |
-
|
| 199 |
-
self.hparams.train_logger.log_stats(
|
| 200 |
-
stats_meta={"epoch": epoch, "lr": current_lr},
|
| 201 |
-
train_stats=self.train_stats,
|
| 202 |
-
valid_stats=stage_stats,
|
| 203 |
-
)
|
| 204 |
-
self.checkpointer.save_and_keep_only(
|
| 205 |
-
meta={"si-snr": stage_stats["si-snr"]}, min_keys=["si-snr"]
|
| 206 |
-
)
|
| 207 |
-
elif stage == sb.Stage.TEST:
|
| 208 |
-
self.hparams.train_logger.log_stats(
|
| 209 |
-
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
|
| 210 |
-
test_stats=stage_stats,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
def add_speed_perturb(self, targets, targ_lens):
|
| 214 |
-
"""Adds speed perturbation and random_shift to the input signals"""
|
| 215 |
-
|
| 216 |
-
min_len = -1
|
| 217 |
-
recombine = False
|
| 218 |
-
|
| 219 |
-
if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
|
| 220 |
-
# Performing speed change (independently on each source)
|
| 221 |
-
new_targets = []
|
| 222 |
-
recombine = True
|
| 223 |
-
|
| 224 |
-
for i in range(targets.shape[-1]):
|
| 225 |
-
new_target = self.hparams.speed_perturb(targets[:, :, i])
|
| 226 |
-
new_targets.append(new_target)
|
| 227 |
-
if i == 0:
|
| 228 |
-
min_len = new_target.shape[-1]
|
| 229 |
-
else:
|
| 230 |
-
if new_target.shape[-1] < min_len:
|
| 231 |
-
min_len = new_target.shape[-1]
|
| 232 |
-
|
| 233 |
-
if self.hparams.use_rand_shift:
|
| 234 |
-
# Performing random_shift (independently on each source)
|
| 235 |
-
recombine = True
|
| 236 |
-
for i in range(targets.shape[-1]):
|
| 237 |
-
rand_shift = torch.randint(
|
| 238 |
-
self.hparams.min_shift, self.hparams.max_shift, (1,)
|
| 239 |
-
)
|
| 240 |
-
new_targets[i] = new_targets[i].to(self.device)
|
| 241 |
-
new_targets[i] = torch.roll(
|
| 242 |
-
new_targets[i], shifts=(rand_shift[0],), dims=1
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
# Re-combination
|
| 246 |
-
if recombine:
|
| 247 |
-
if self.hparams.use_speedperturb:
|
| 248 |
-
targets = torch.zeros(
|
| 249 |
-
targets.shape[0],
|
| 250 |
-
min_len,
|
| 251 |
-
targets.shape[-1],
|
| 252 |
-
device=targets.device,
|
| 253 |
-
dtype=torch.float,
|
| 254 |
-
)
|
| 255 |
-
for i, new_target in enumerate(new_targets):
|
| 256 |
-
targets[:, :, i] = new_targets[i][:, 0:min_len]
|
| 257 |
-
|
| 258 |
-
mix = targets.sum(-1)
|
| 259 |
-
return mix, targets
|
| 260 |
-
|
| 261 |
-
def cut_signals(self, mixture, targets):
|
| 262 |
-
"""This function selects a random segment of a given length within the mixture.
|
| 263 |
-
The corresponding targets are selected accordingly"""
|
| 264 |
-
randstart = torch.randint(
|
| 265 |
-
0,
|
| 266 |
-
1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
|
| 267 |
-
(1,),
|
| 268 |
-
).item()
|
| 269 |
-
targets = targets[
|
| 270 |
-
:, randstart : randstart + self.hparams.training_signal_len, :
|
| 271 |
-
]
|
| 272 |
-
mixture = mixture[
|
| 273 |
-
:, randstart : randstart + self.hparams.training_signal_len
|
| 274 |
-
]
|
| 275 |
-
return mixture, targets
|
| 276 |
-
|
| 277 |
-
def reset_layer_recursively(self, layer):
|
| 278 |
-
"""Reinitializes the parameters of the neural networks"""
|
| 279 |
-
if hasattr(layer, "reset_parameters"):
|
| 280 |
-
layer.reset_parameters()
|
| 281 |
-
for child_layer in layer.modules():
|
| 282 |
-
if layer != child_layer:
|
| 283 |
-
self.reset_layer_recursively(child_layer)
|
| 284 |
-
|
| 285 |
-
def save_results(self, test_data):
|
| 286 |
-
"""This script computes the SDR and SI-SNR metrics and saves
|
| 287 |
-
them into a csv file"""
|
| 288 |
-
|
| 289 |
-
# This package is required for SDR computation
|
| 290 |
-
from mir_eval.separation import bss_eval_sources
|
| 291 |
-
|
| 292 |
-
# Create folders where to store audio
|
| 293 |
-
save_file = os.path.join(self.hparams.output_folder, "test_results.csv")
|
| 294 |
-
|
| 295 |
-
# Variable init
|
| 296 |
-
all_sdrs = []
|
| 297 |
-
all_sdrs_i = []
|
| 298 |
-
all_sisnrs = []
|
| 299 |
-
all_sisnrs_i = []
|
| 300 |
-
csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
|
| 301 |
-
|
| 302 |
-
test_loader = sb.dataio.dataloader.make_dataloader(
|
| 303 |
-
test_data, **self.hparams.dataloader_opts
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
|
| 307 |
-
writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
|
| 308 |
-
writer.writeheader()
|
| 309 |
-
|
| 310 |
-
# Loop over all test sentence
|
| 311 |
-
with tqdm(test_loader, dynamic_ncols=True) as t:
|
| 312 |
-
for i, batch in enumerate(t):
|
| 313 |
-
# Apply Separation
|
| 314 |
-
mixture, mix_len = batch.mix_sig
|
| 315 |
-
snt_id = batch.id
|
| 316 |
-
targets = [batch.s1_sig, batch.s2_sig]
|
| 317 |
-
if self.hparams.num_spks == 3:
|
| 318 |
-
targets.append(batch.s3_sig)
|
| 319 |
-
|
| 320 |
-
if self.hparams.num_spks == 4:
|
| 321 |
-
targets.append(batch.s3_sig)
|
| 322 |
-
targets.append(batch.s4_sig)
|
| 323 |
-
|
| 324 |
-
with torch.no_grad():
|
| 325 |
-
predictions, targets = self.compute_forward(
|
| 326 |
-
batch.mix_sig, targets, sb.Stage.TEST
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
# Compute SI-SNR
|
| 330 |
-
sisnr = self.compute_objectives(predictions, targets)
|
| 331 |
-
|
| 332 |
-
# Compute SI-SNR improvement
|
| 333 |
-
mixture_signal = torch.stack(
|
| 334 |
-
[mixture] * self.hparams.num_spks, dim=-1
|
| 335 |
-
)
|
| 336 |
-
mixture_signal = mixture_signal.to(targets.device)
|
| 337 |
-
sisnr_baseline = self.compute_objectives(
|
| 338 |
-
mixture_signal, targets
|
| 339 |
-
)
|
| 340 |
-
sisnr_i = sisnr - sisnr_baseline
|
| 341 |
-
|
| 342 |
-
# Compute SDR
|
| 343 |
-
sdr, _, _, _ = bss_eval_sources(
|
| 344 |
-
targets[0].t().cpu().numpy(),
|
| 345 |
-
predictions[0].t().detach().cpu().numpy(),
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
sdr_baseline, _, _, _ = bss_eval_sources(
|
| 349 |
-
targets[0].t().cpu().numpy(),
|
| 350 |
-
mixture_signal[0].t().detach().cpu().numpy(),
|
| 351 |
-
)
|
| 352 |
-
|
| 353 |
-
sdr_i = sdr.mean() - sdr_baseline.mean()
|
| 354 |
-
|
| 355 |
-
# Saving on a csv file
|
| 356 |
-
row = {
|
| 357 |
-
"snt_id": snt_id[0],
|
| 358 |
-
"sdr": sdr.mean(),
|
| 359 |
-
"sdr_i": sdr_i,
|
| 360 |
-
"si-snr": -sisnr.item(),
|
| 361 |
-
"si-snr_i": -sisnr_i.item(),
|
| 362 |
-
}
|
| 363 |
-
writer.writerow(row)
|
| 364 |
-
|
| 365 |
-
# Metric Accumulation
|
| 366 |
-
all_sdrs.append(sdr.mean())
|
| 367 |
-
all_sdrs_i.append(sdr_i.mean())
|
| 368 |
-
all_sisnrs.append(-sisnr.item())
|
| 369 |
-
all_sisnrs_i.append(-sisnr_i.item())
|
| 370 |
-
|
| 371 |
-
row = {
|
| 372 |
-
"snt_id": "avg",
|
| 373 |
-
"sdr": np.array(all_sdrs).mean(),
|
| 374 |
-
"sdr_i": np.array(all_sdrs_i).mean(),
|
| 375 |
-
"si-snr": np.array(all_sisnrs).mean(),
|
| 376 |
-
"si-snr_i": np.array(all_sisnrs_i).mean(),
|
| 377 |
-
}
|
| 378 |
-
writer.writerow(row)
|
| 379 |
-
|
| 380 |
-
logger.info("Mean SISNR is {}".format(np.array(all_sisnrs).mean()))
|
| 381 |
-
logger.info("Mean SISNRi is {}".format(np.array(all_sisnrs_i).mean()))
|
| 382 |
-
logger.info("Mean SDR is {}".format(np.array(all_sdrs).mean()))
|
| 383 |
-
logger.info("Mean SDRi is {}".format(np.array(all_sdrs_i).mean()))
|
| 384 |
-
|
| 385 |
-
def save_audio(self, snt_id, mixture, targets, predictions):
|
| 386 |
-
"saves the test audio (mixture, targets, and estimated sources) on disk"
|
| 387 |
-
|
| 388 |
-
# Create output folder
|
| 389 |
-
save_path = os.path.join(self.hparams.save_folder, "audio_results")
|
| 390 |
-
if not os.path.exists(save_path):
|
| 391 |
-
os.mkdir(save_path)
|
| 392 |
-
|
| 393 |
-
for ns in range(self.hparams.num_spks):
|
| 394 |
-
# Estimated source
|
| 395 |
-
signal = predictions[0, :, ns]
|
| 396 |
-
signal = signal / signal.abs().max()
|
| 397 |
-
save_file = os.path.join(
|
| 398 |
-
save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
|
| 399 |
-
)
|
| 400 |
-
torchaudio.save(
|
| 401 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
# Original source
|
| 405 |
-
signal = targets[0, :, ns]
|
| 406 |
-
signal = signal / signal.abs().max()
|
| 407 |
-
save_file = os.path.join(
|
| 408 |
-
save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
|
| 409 |
-
)
|
| 410 |
-
torchaudio.save(
|
| 411 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 412 |
-
)
|
| 413 |
-
|
| 414 |
-
# Mixture
|
| 415 |
-
signal = mixture[0][0, :]
|
| 416 |
-
signal = signal / signal.abs().max()
|
| 417 |
-
save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
|
| 418 |
-
torchaudio.save(
|
| 419 |
-
save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
|
| 420 |
-
)
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
def dataio_prep(hparams):
|
| 424 |
-
"""Creates data processing pipeline"""
|
| 425 |
-
|
| 426 |
-
# 1. Define datasets
|
| 427 |
-
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 428 |
-
csv_path=hparams["train_data"],
|
| 429 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 430 |
-
)
|
| 431 |
-
|
| 432 |
-
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 433 |
-
csv_path=hparams["valid_data"],
|
| 434 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 435 |
-
)
|
| 436 |
-
|
| 437 |
-
test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
|
| 438 |
-
csv_path=hparams["test_data"],
|
| 439 |
-
replacements={"data_root": hparams["data_folder"]},
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
datasets = [train_data, valid_data, test_data]
|
| 443 |
-
|
| 444 |
-
# 2. Provide audio pipelines
|
| 445 |
-
|
| 446 |
-
@sb.utils.data_pipeline.takes("mix_wav")
|
| 447 |
-
@sb.utils.data_pipeline.provides("mix_sig")
|
| 448 |
-
def audio_pipeline_mix(mix_wav):
|
| 449 |
-
mix_sig = sb.dataio.dataio.read_audio(mix_wav)
|
| 450 |
-
return mix_sig
|
| 451 |
-
|
| 452 |
-
@sb.utils.data_pipeline.takes("s1_wav")
|
| 453 |
-
@sb.utils.data_pipeline.provides("s1_sig")
|
| 454 |
-
def audio_pipeline_s1(s1_wav):
|
| 455 |
-
s1_sig = sb.dataio.dataio.read_audio(s1_wav)
|
| 456 |
-
return s1_sig
|
| 457 |
-
|
| 458 |
-
@sb.utils.data_pipeline.takes("s2_wav")
|
| 459 |
-
@sb.utils.data_pipeline.provides("s2_sig")
|
| 460 |
-
def audio_pipeline_s2(s2_wav):
|
| 461 |
-
s2_sig = sb.dataio.dataio.read_audio(s2_wav)
|
| 462 |
-
return s2_sig
|
| 463 |
-
|
| 464 |
-
# --- 如果说话人 >= 3,定义第 3 路 ---
|
| 465 |
-
if hparams["num_spks"] >= 3:
|
| 466 |
-
@sb.utils.data_pipeline.takes("s3_wav")
|
| 467 |
-
@sb.utils.data_pipeline.provides("s3_sig")
|
| 468 |
-
def audio_pipeline_s3(s3_wav):
|
| 469 |
-
return sb.dataio.dataio.read_audio(s3_wav)
|
| 470 |
-
|
| 471 |
-
# --- 如果说话人 == 4,定义第 4 路 ---
|
| 472 |
-
if hparams["num_spks"] == 4:
|
| 473 |
-
@sb.utils.data_pipeline.takes("s4_wav")
|
| 474 |
-
@sb.utils.data_pipeline.provides("s4_sig")
|
| 475 |
-
def audio_pipeline_s4(s4_wav):
|
| 476 |
-
return sb.dataio.dataio.read_audio(s4_wav)
|
| 477 |
-
|
| 478 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_mix)
|
| 479 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s1)
|
| 480 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s2)
|
| 481 |
-
if hparams["num_spks"] == 3:
|
| 482 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
|
| 483 |
-
sb.dataio.dataset.set_output_keys(
|
| 484 |
-
datasets, ["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig"]
|
| 485 |
-
)
|
| 486 |
-
elif hparams["num_spks"] == 4 :
|
| 487 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s3)
|
| 488 |
-
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline_s4)
|
| 489 |
-
sb.dataio.dataset.set_output_keys(
|
| 490 |
-
datasets,
|
| 491 |
-
["id", "mix_sig", "s1_sig", "s2_sig", "s3_sig", "s4_sig"],
|
| 492 |
-
)
|
| 493 |
-
else:
|
| 494 |
-
sb.dataio.dataset.set_output_keys(
|
| 495 |
-
datasets, ["id", "mix_sig", "s1_sig", "s2_sig"]
|
| 496 |
-
)
|
| 497 |
-
|
| 498 |
-
return train_data, valid_data, test_data
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
if __name__ == "__main__":
|
| 502 |
-
# Load hyperparameters file with command-line overrides
|
| 503 |
-
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
|
| 504 |
-
with open(hparams_file, encoding="utf-8") as fin:
|
| 505 |
-
hparams = load_hyperpyyaml(fin, overrides)
|
| 506 |
-
|
| 507 |
-
# Initialize ddp (useful only for multi-GPU DDP training)
|
| 508 |
-
sb.utils.distributed.ddp_init_group(run_opts)
|
| 509 |
-
|
| 510 |
-
# Logger info
|
| 511 |
-
logger = get_logger(__name__)
|
| 512 |
-
|
| 513 |
-
# Create experiment directory
|
| 514 |
-
sb.create_experiment_directory(
|
| 515 |
-
experiment_directory=hparams["output_folder"],
|
| 516 |
-
hyperparams_to_save=hparams_file,
|
| 517 |
-
overrides=overrides,
|
| 518 |
-
)
|
| 519 |
-
|
| 520 |
-
# Update precision to bf16 if the device is CPU and precision is fp16
|
| 521 |
-
if run_opts.get("device") == "cpu" and hparams.get("precision") == "fp16":
|
| 522 |
-
hparams["precision"] = "bf16"
|
| 523 |
-
|
| 524 |
-
# Check if wsj0_tr is set with dynamic mixing
|
| 525 |
-
if hparams["dynamic_mixing"] and not os.path.exists(
|
| 526 |
-
hparams["base_folder_dm"]
|
| 527 |
-
):
|
| 528 |
-
raise ValueError(
|
| 529 |
-
"Please, specify a valid base_folder_dm folder when using dynamic mixing"
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
# Data preparation
|
| 533 |
-
from prepare_data import prepare_wsjmix # noqa
|
| 534 |
-
|
| 535 |
-
# run_on_main(
|
| 536 |
-
# prepare_wsjmix,
|
| 537 |
-
# kwargs={
|
| 538 |
-
# "datapath": hparams["data_folder"],
|
| 539 |
-
# "savepath": hparams["save_folder"],
|
| 540 |
-
# "n_spks": hparams["num_spks"],
|
| 541 |
-
# "skip_prep": hparams["skip_prep"],
|
| 542 |
-
# "fs": hparams["sample_rate"],
|
| 543 |
-
# },
|
| 544 |
-
# )
|
| 545 |
-
|
| 546 |
-
# Create dataset objects
|
| 547 |
-
if hparams["dynamic_mixing"]:
|
| 548 |
-
from dynamic_mixing import dynamic_mix_data_prep
|
| 549 |
-
|
| 550 |
-
# if the base_folder for dm is not processed, preprocess them
|
| 551 |
-
if "processed" not in hparams["base_folder_dm"]:
|
| 552 |
-
# if the processed folder already exists we just use it otherwise we do the preprocessing
|
| 553 |
-
if not os.path.exists(
|
| 554 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 555 |
-
):
|
| 556 |
-
from preprocess_dynamic_mixing import resample_folder
|
| 557 |
-
|
| 558 |
-
print("Resampling the base folder")
|
| 559 |
-
run_on_main(
|
| 560 |
-
resample_folder,
|
| 561 |
-
kwargs={
|
| 562 |
-
"input_folder": hparams["base_folder_dm"],
|
| 563 |
-
"output_folder": os.path.normpath(
|
| 564 |
-
hparams["base_folder_dm"]
|
| 565 |
-
)
|
| 566 |
-
+ "_processed",
|
| 567 |
-
"fs": hparams["sample_rate"],
|
| 568 |
-
"regex": "**/*.wav",
|
| 569 |
-
},
|
| 570 |
-
)
|
| 571 |
-
# adjust the base_folder_dm path
|
| 572 |
-
hparams["base_folder_dm"] = (
|
| 573 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 574 |
-
)
|
| 575 |
-
else:
|
| 576 |
-
print(
|
| 577 |
-
"Using the existing processed folder on the same directory as base_folder_dm"
|
| 578 |
-
)
|
| 579 |
-
hparams["base_folder_dm"] = (
|
| 580 |
-
os.path.normpath(hparams["base_folder_dm"]) + "_processed"
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
# Collecting the hparams for dynamic batching
|
| 584 |
-
dm_hparams = {
|
| 585 |
-
"train_data": hparams["train_data"],
|
| 586 |
-
"data_folder": hparams["data_folder"],
|
| 587 |
-
"base_folder_dm": hparams["base_folder_dm"],
|
| 588 |
-
"sample_rate": hparams["sample_rate"],
|
| 589 |
-
"num_spks": hparams["num_spks"],
|
| 590 |
-
"training_signal_len": hparams["training_signal_len"],
|
| 591 |
-
"dataloader_opts": hparams["dataloader_opts"],
|
| 592 |
-
}
|
| 593 |
-
train_data = dynamic_mix_data_prep(dm_hparams)
|
| 594 |
-
_, valid_data, test_data = dataio_prep(hparams)
|
| 595 |
-
else:
|
| 596 |
-
train_data, valid_data, test_data = dataio_prep(hparams)
|
| 597 |
-
|
| 598 |
-
# Load pretrained model if pretrained_separator is present in the yaml
|
| 599 |
-
if "pretrained_separator" in hparams:
|
| 600 |
-
run_on_main(hparams["pretrained_separator"].collect_files)
|
| 601 |
-
hparams["pretrained_separator"].load_collected()
|
| 602 |
-
|
| 603 |
-
# Brain class initialization
|
| 604 |
-
separator = Separation(
|
| 605 |
-
modules=hparams["modules"],
|
| 606 |
-
opt_class=hparams["optimizer"],
|
| 607 |
-
hparams=hparams,
|
| 608 |
-
run_opts=run_opts,
|
| 609 |
-
checkpointer=hparams["checkpointer"],
|
| 610 |
-
)
|
| 611 |
-
|
| 612 |
-
# re-initialize the parameters if we don't use a pretrained model
|
| 613 |
-
if "pretrained_separator" not in hparams:
|
| 614 |
-
for module in separator.modules.values():
|
| 615 |
-
separator.reset_layer_recursively(module)
|
| 616 |
-
|
| 617 |
-
# Training
|
| 618 |
-
separator.fit(
|
| 619 |
-
separator.hparams.epoch_counter,
|
| 620 |
-
train_data,
|
| 621 |
-
valid_data,
|
| 622 |
-
train_loader_kwargs=hparams["dataloader_opts"],
|
| 623 |
-
valid_loader_kwargs=hparams["dataloader_opts"],
|
| 624 |
-
)
|
| 625 |
-
|
| 626 |
-
# Eval
|
| 627 |
-
separator.evaluate(test_data, min_key="si-snr")
|
| 628 |
-
separator.save_results(test_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Conv-Tasnet/results/convtasnet_4-mix/1234/train_log.txt
DELETED
|
@@ -1,242 +0,0 @@
|
|
| 1 |
-
epoch: 1, lr: 1.50e-04 - train si-snr: 2.76 - valid si-snr: 12.09
|
| 2 |
-
epoch: 2, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 11.84
|
| 3 |
-
epoch: 3, lr: 1.50e-04 - train si-snr: 1.96 - valid si-snr: 11.70
|
| 4 |
-
epoch: 4, lr: 1.50e-04 - train si-snr: 1.70 - valid si-snr: 11.63
|
| 5 |
-
epoch: 5, lr: 1.50e-04 - train si-snr: 1.58 - valid si-snr: 11.57
|
| 6 |
-
epoch: 6, lr: 1.50e-04 - train si-snr: 1.45 - valid si-snr: 11.55
|
| 7 |
-
epoch: 7, lr: 1.50e-04 - train si-snr: 1.33 - valid si-snr: 11.45
|
| 8 |
-
epoch: 8, lr: 1.50e-04 - train si-snr: 1.20 - valid si-snr: 11.33
|
| 9 |
-
epoch: 9, lr: 1.50e-04 - train si-snr: 1.10 - valid si-snr: 11.35
|
| 10 |
-
epoch: 10, lr: 1.50e-04 - train si-snr: 1.01 - valid si-snr: 11.30
|
| 11 |
-
epoch: 11, lr: 1.50e-04 - train si-snr: 9.25e-01 - valid si-snr: 11.33
|
| 12 |
-
epoch: 12, lr: 1.50e-04 - train si-snr: 7.83e-01 - valid si-snr: 11.16
|
| 13 |
-
epoch: 13, lr: 1.50e-04 - train si-snr: 7.61e-01 - valid si-snr: 11.19
|
| 14 |
-
epoch: 14, lr: 1.50e-04 - train si-snr: 6.87e-01 - valid si-snr: 11.13
|
| 15 |
-
epoch: 15, lr: 1.50e-04 - train si-snr: 6.31e-01 - valid si-snr: 11.13
|
| 16 |
-
epoch: 16, lr: 1.50e-04 - train si-snr: 5.54e-01 - valid si-snr: 11.10
|
| 17 |
-
epoch: 17, lr: 1.50e-04 - train si-snr: 4.47e-01 - valid si-snr: 11.02
|
| 18 |
-
epoch: 18, lr: 1.50e-04 - train si-snr: 4.65e-01 - valid si-snr: 11.04
|
| 19 |
-
epoch: 19, lr: 1.50e-04 - train si-snr: 3.32e-01 - valid si-snr: 11.01
|
| 20 |
-
epoch: 20, lr: 1.50e-04 - train si-snr: 3.27e-01 - valid si-snr: 10.95
|
| 21 |
-
epoch: 21, lr: 1.50e-04 - train si-snr: 2.78e-01 - valid si-snr: 10.97
|
| 22 |
-
epoch: 22, lr: 1.50e-04 - train si-snr: 2.18e-01 - valid si-snr: 10.88
|
| 23 |
-
epoch: 23, lr: 1.50e-04 - train si-snr: 1.74e-01 - valid si-snr: 10.87
|
| 24 |
-
epoch: 24, lr: 1.50e-04 - train si-snr: 1.03e-01 - valid si-snr: 10.95
|
| 25 |
-
epoch: 25, lr: 1.50e-04 - train si-snr: 6.04e-02 - valid si-snr: 10.84
|
| 26 |
-
epoch: 26, lr: 1.50e-04 - train si-snr: -2.94e-02 - valid si-snr: 10.79
|
| 27 |
-
epoch: 27, lr: 1.50e-04 - train si-snr: -5.32e-02 - valid si-snr: 10.77
|
| 28 |
-
epoch: 28, lr: 1.50e-04 - train si-snr: -5.68e-02 - valid si-snr: 10.74
|
| 29 |
-
epoch: 29, lr: 1.50e-04 - train si-snr: -1.04e-01 - valid si-snr: 10.79
|
| 30 |
-
epoch: 30, lr: 1.50e-04 - train si-snr: -1.57e-01 - valid si-snr: 10.73
|
| 31 |
-
epoch: 31, lr: 1.50e-04 - train si-snr: -1.64e-01 - valid si-snr: 10.67
|
| 32 |
-
epoch: 32, lr: 1.50e-04 - train si-snr: -2.11e-01 - valid si-snr: 10.71
|
| 33 |
-
epoch: 33, lr: 1.50e-04 - train si-snr: -2.48e-01 - valid si-snr: 10.73
|
| 34 |
-
epoch: 34, lr: 1.50e-04 - train si-snr: -2.79e-01 - valid si-snr: 10.69
|
| 35 |
-
epoch: 35, lr: 1.50e-04 - train si-snr: -3.55e-01 - valid si-snr: 10.69
|
| 36 |
-
epoch: 36, lr: 1.50e-04 - train si-snr: -3.32e-01 - valid si-snr: 10.64
|
| 37 |
-
epoch: 37, lr: 1.50e-04 - train si-snr: -3.97e-01 - valid si-snr: 10.63
|
| 38 |
-
epoch: 38, lr: 1.50e-04 - train si-snr: -4.11e-01 - valid si-snr: 10.71
|
| 39 |
-
epoch: 39, lr: 1.50e-04 - train si-snr: -4.18e-01 - valid si-snr: 10.56
|
| 40 |
-
epoch: 40, lr: 1.50e-04 - train si-snr: -4.74e-01 - valid si-snr: 10.55
|
| 41 |
-
epoch: 41, lr: 1.50e-04 - train si-snr: -4.71e-01 - valid si-snr: 10.52
|
| 42 |
-
epoch: 1, lr: 1.50e-04 - train si-snr: 6.31 - valid si-snr: 23.11
|
| 43 |
-
epoch: 2, lr: 1.50e-04 - train si-snr: 4.85 - valid si-snr: 23.05
|
| 44 |
-
epoch: 3, lr: 1.50e-04 - train si-snr: 4.79 - valid si-snr: 22.98
|
| 45 |
-
epoch: 4, lr: 1.50e-04 - train si-snr: 4.56 - valid si-snr: 22.79
|
| 46 |
-
epoch: 5, lr: 1.50e-04 - train si-snr: 4.28 - valid si-snr: 23.05
|
| 47 |
-
epoch: 6, lr: 1.50e-04 - train si-snr: 4.27 - valid si-snr: 22.88
|
| 48 |
-
epoch: 7, lr: 1.50e-04 - train si-snr: 4.11 - valid si-snr: 22.86
|
| 49 |
-
epoch: 8, lr: 1.50e-04 - train si-snr: 4.11 - valid si-snr: 22.80
|
| 50 |
-
epoch: 9, lr: 1.50e-04 - train si-snr: 3.96 - valid si-snr: 22.80
|
| 51 |
-
epoch: 10, lr: 1.50e-04 - train si-snr: 3.91 - valid si-snr: 22.75
|
| 52 |
-
epoch: 11, lr: 1.50e-04 - train si-snr: 3.76 - valid si-snr: 22.72
|
| 53 |
-
epoch: 12, lr: 1.50e-04 - train si-snr: 3.82 - valid si-snr: 22.69
|
| 54 |
-
epoch: 13, lr: 1.50e-04 - train si-snr: 3.71 - valid si-snr: 22.86
|
| 55 |
-
epoch: 14, lr: 1.50e-04 - train si-snr: 3.64 - valid si-snr: 22.71
|
| 56 |
-
epoch: 15, lr: 1.50e-04 - train si-snr: 3.59 - valid si-snr: 22.89
|
| 57 |
-
epoch: 16, lr: 1.50e-04 - train si-snr: 3.39 - valid si-snr: 22.79
|
| 58 |
-
epoch: 17, lr: 1.50e-04 - train si-snr: 3.30 - valid si-snr: 22.69
|
| 59 |
-
epoch: 18, lr: 1.50e-04 - train si-snr: 3.29 - valid si-snr: 22.82
|
| 60 |
-
epoch: 19, lr: 1.50e-04 - train si-snr: 3.32 - valid si-snr: 22.75
|
| 61 |
-
epoch: 20, lr: 1.50e-04 - train si-snr: 3.14 - valid si-snr: 22.49
|
| 62 |
-
epoch: 21, lr: 1.50e-04 - train si-snr: 3.11 - valid si-snr: 22.83
|
| 63 |
-
epoch: 22, lr: 1.50e-04 - train si-snr: 3.12 - valid si-snr: 22.69
|
| 64 |
-
epoch: 23, lr: 1.50e-04 - train si-snr: 2.93 - valid si-snr: 22.66
|
| 65 |
-
epoch: 24, lr: 1.50e-04 - train si-snr: 2.96 - valid si-snr: 22.72
|
| 66 |
-
epoch: 25, lr: 1.50e-04 - train si-snr: 2.96 - valid si-snr: 22.83
|
| 67 |
-
epoch: 26, lr: 1.50e-04 - train si-snr: 2.88 - valid si-snr: 22.61
|
| 68 |
-
epoch: 27, lr: 1.50e-04 - train si-snr: 2.86 - valid si-snr: 22.83
|
| 69 |
-
epoch: 28, lr: 1.50e-04 - train si-snr: 2.80 - valid si-snr: 22.67
|
| 70 |
-
epoch: 29, lr: 1.50e-04 - train si-snr: 2.73 - valid si-snr: 22.67
|
| 71 |
-
epoch: 30, lr: 1.50e-04 - train si-snr: 2.65 - valid si-snr: 22.62
|
| 72 |
-
epoch: 31, lr: 1.50e-04 - train si-snr: 2.62 - valid si-snr: 22.63
|
| 73 |
-
epoch: 32, lr: 1.50e-04 - train si-snr: 2.61 - valid si-snr: 22.61
|
| 74 |
-
epoch: 33, lr: 1.50e-04 - train si-snr: 2.44 - valid si-snr: 22.55
|
| 75 |
-
epoch: 34, lr: 1.50e-04 - train si-snr: 2.50 - valid si-snr: 22.55
|
| 76 |
-
epoch: 35, lr: 1.50e-04 - train si-snr: 2.47 - valid si-snr: 22.60
|
| 77 |
-
epoch: 36, lr: 1.50e-04 - train si-snr: 2.44 - valid si-snr: 22.66
|
| 78 |
-
epoch: 37, lr: 1.50e-04 - train si-snr: 2.24 - valid si-snr: 22.64
|
| 79 |
-
epoch: 38, lr: 1.50e-04 - train si-snr: 2.28 - valid si-snr: 22.66
|
| 80 |
-
epoch: 39, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 22.62
|
| 81 |
-
epoch: 40, lr: 1.50e-04 - train si-snr: 2.19 - valid si-snr: 22.48
|
| 82 |
-
epoch: 41, lr: 1.50e-04 - train si-snr: 2.26 - valid si-snr: 22.66
|
| 83 |
-
epoch: 42, lr: 1.50e-04 - train si-snr: 2.09 - valid si-snr: 22.57
|
| 84 |
-
epoch: 43, lr: 1.50e-04 - train si-snr: 2.15 - valid si-snr: 22.47
|
| 85 |
-
epoch: 44, lr: 1.50e-04 - train si-snr: 2.00 - valid si-snr: 22.63
|
| 86 |
-
epoch: 45, lr: 1.50e-04 - train si-snr: 2.13 - valid si-snr: 22.52
|
| 87 |
-
epoch: 46, lr: 1.50e-04 - train si-snr: 2.00 - valid si-snr: 22.57
|
| 88 |
-
epoch: 47, lr: 1.50e-04 - train si-snr: 1.90 - valid si-snr: 22.50
|
| 89 |
-
epoch: 48, lr: 1.50e-04 - train si-snr: 1.89 - valid si-snr: 22.49
|
| 90 |
-
epoch: 49, lr: 1.50e-04 - train si-snr: 1.94 - valid si-snr: 22.54
|
| 91 |
-
epoch: 50, lr: 1.50e-04 - train si-snr: 1.89 - valid si-snr: 22.50
|
| 92 |
-
epoch: 51, lr: 1.50e-04 - train si-snr: 1.85 - valid si-snr: 22.55
|
| 93 |
-
epoch: 52, lr: 1.50e-04 - train si-snr: 1.66 - valid si-snr: 22.51
|
| 94 |
-
epoch: 53, lr: 1.50e-04 - train si-snr: 1.65 - valid si-snr: 22.52
|
| 95 |
-
epoch: 54, lr: 1.50e-04 - train si-snr: 1.77 - valid si-snr: 22.45
|
| 96 |
-
epoch: 55, lr: 1.50e-04 - train si-snr: 1.62 - valid si-snr: 22.45
|
| 97 |
-
epoch: 56, lr: 1.50e-04 - train si-snr: 1.52 - valid si-snr: 22.42
|
| 98 |
-
epoch: 57, lr: 1.50e-04 - train si-snr: 1.53 - valid si-snr: 22.39
|
| 99 |
-
epoch: 58, lr: 1.50e-04 - train si-snr: 1.52 - valid si-snr: 22.40
|
| 100 |
-
epoch: 59, lr: 1.50e-04 - train si-snr: 1.55 - valid si-snr: 22.43
|
| 101 |
-
epoch: 60, lr: 1.50e-04 - train si-snr: 1.64 - valid si-snr: 22.43
|
| 102 |
-
epoch: 61, lr: 1.50e-04 - train si-snr: 1.42 - valid si-snr: 22.38
|
| 103 |
-
epoch: 62, lr: 1.50e-04 - train si-snr: 1.50 - valid si-snr: 22.29
|
| 104 |
-
epoch: 63, lr: 1.50e-04 - train si-snr: 1.34 - valid si-snr: 22.51
|
| 105 |
-
epoch: 64, lr: 1.50e-04 - train si-snr: 1.25 - valid si-snr: 22.55
|
| 106 |
-
epoch: 65, lr: 1.50e-04 - train si-snr: 1.43 - valid si-snr: 22.35
|
| 107 |
-
epoch: 66, lr: 1.50e-04 - train si-snr: 1.33 - valid si-snr: 22.54
|
| 108 |
-
epoch: 67, lr: 1.50e-04 - train si-snr: 1.35 - valid si-snr: 22.44
|
| 109 |
-
epoch: 68, lr: 1.50e-04 - train si-snr: 1.35 - valid si-snr: 22.33
|
| 110 |
-
epoch: 69, lr: 1.50e-04 - train si-snr: 1.13 - valid si-snr: 22.38
|
| 111 |
-
epoch: 70, lr: 1.50e-04 - train si-snr: 1.18 - valid si-snr: 22.37
|
| 112 |
-
epoch: 71, lr: 1.50e-04 - train si-snr: 1.04 - valid si-snr: 22.35
|
| 113 |
-
epoch: 72, lr: 1.50e-04 - train si-snr: 1.24 - valid si-snr: 22.49
|
| 114 |
-
epoch: 73, lr: 1.50e-04 - train si-snr: 1.25 - valid si-snr: 22.35
|
| 115 |
-
epoch: 74, lr: 1.50e-04 - train si-snr: 1.07 - valid si-snr: 22.37
|
| 116 |
-
epoch: 75, lr: 1.50e-04 - train si-snr: 1.04 - valid si-snr: 22.37
|
| 117 |
-
epoch: 76, lr: 1.50e-04 - train si-snr: 1.11 - valid si-snr: 22.48
|
| 118 |
-
epoch: 77, lr: 1.50e-04 - train si-snr: 1.03 - valid si-snr: 22.46
|
| 119 |
-
epoch: 78, lr: 1.50e-04 - train si-snr: 9.65e-01 - valid si-snr: 22.31
|
| 120 |
-
epoch: 79, lr: 1.50e-04 - train si-snr: 1.06 - valid si-snr: 22.34
|
| 121 |
-
epoch: 80, lr: 1.50e-04 - train si-snr: 1.03 - valid si-snr: 22.32
|
| 122 |
-
epoch: 81, lr: 1.50e-04 - train si-snr: 8.12e-01 - valid si-snr: 22.32
|
| 123 |
-
epoch: 82, lr: 1.50e-04 - train si-snr: 8.76e-01 - valid si-snr: 22.33
|
| 124 |
-
epoch: 83, lr: 1.50e-04 - train si-snr: 8.91e-01 - valid si-snr: 22.32
|
| 125 |
-
epoch: 84, lr: 1.50e-04 - train si-snr: 9.11e-01 - valid si-snr: 22.34
|
| 126 |
-
epoch: 85, lr: 1.50e-04 - train si-snr: 7.24e-01 - valid si-snr: 22.39
|
| 127 |
-
epoch: 86, lr: 1.50e-04 - train si-snr: 7.65e-01 - valid si-snr: 22.34
|
| 128 |
-
epoch: 87, lr: 1.50e-04 - train si-snr: 7.10e-01 - valid si-snr: 22.29
|
| 129 |
-
epoch: 88, lr: 1.50e-04 - train si-snr: 7.65e-01 - valid si-snr: 22.42
|
| 130 |
-
epoch: 89, lr: 1.50e-04 - train si-snr: 7.09e-01 - valid si-snr: 22.35
|
| 131 |
-
epoch: 90, lr: 1.50e-04 - train si-snr: 8.13e-01 - valid si-snr: 22.38
|
| 132 |
-
epoch: 91, lr: 7.50e-05 - train si-snr: 5.81e-01 - valid si-snr: 22.24
|
| 133 |
-
epoch: 92, lr: 7.50e-05 - train si-snr: 3.71e-01 - valid si-snr: 22.33
|
| 134 |
-
epoch: 93, lr: 7.50e-05 - train si-snr: 3.21e-01 - valid si-snr: 22.33
|
| 135 |
-
epoch: 94, lr: 7.50e-05 - train si-snr: 3.48e-01 - valid si-snr: 22.29
|
| 136 |
-
epoch: 95, lr: 3.75e-05 - train si-snr: 4.08e-01 - valid si-snr: 22.34
|
| 137 |
-
epoch: 96, lr: 3.75e-05 - train si-snr: 2.29e-01 - valid si-snr: 22.33
|
| 138 |
-
epoch: 97, lr: 3.75e-05 - train si-snr: 2.27e-01 - valid si-snr: 22.29
|
| 139 |
-
epoch: 98, lr: 1.87e-05 - train si-snr: 1.28e-01 - valid si-snr: 22.27
|
| 140 |
-
epoch: 99, lr: 1.87e-05 - train si-snr: 3.17e-02 - valid si-snr: 22.27
|
| 141 |
-
epoch: 100, lr: 1.87e-05 - train si-snr: 6.84e-02 - valid si-snr: 22.24
|
| 142 |
-
epoch: 101, lr: 1.87e-05 - train si-snr: 6.90e-02 - valid si-snr: 22.25
|
| 143 |
-
epoch: 102, lr: 1.87e-05 - train si-snr: 1.53e-01 - valid si-snr: 22.28
|
| 144 |
-
epoch: 103, lr: 1.87e-05 - train si-snr: 4.23e-02 - valid si-snr: 22.28
|
| 145 |
-
epoch: 104, lr: 9.37e-06 - train si-snr: 7.48e-02 - valid si-snr: 22.24
|
| 146 |
-
epoch: 105, lr: 9.37e-06 - train si-snr: 8.28e-02 - valid si-snr: 22.27
|
| 147 |
-
epoch: 106, lr: 9.37e-06 - train si-snr: -1.19e-01 - valid si-snr: 22.26
|
| 148 |
-
epoch: 107, lr: 9.37e-06 - train si-snr: 2.27e-02 - valid si-snr: 22.26
|
| 149 |
-
epoch: 108, lr: 4.69e-06 - train si-snr: -9.19e-02 - valid si-snr: 22.24
|
| 150 |
-
epoch: 109, lr: 4.69e-06 - train si-snr: -1.86e-02 - valid si-snr: 22.26
|
| 151 |
-
epoch: 110, lr: 4.69e-06 - train si-snr: -1.29e-01 - valid si-snr: 22.26
|
| 152 |
-
epoch: 111, lr: 2.34e-06 - train si-snr: -1.28e-01 - valid si-snr: 22.26
|
| 153 |
-
epoch: 112, lr: 2.34e-06 - train si-snr: 1.96e-02 - valid si-snr: 22.26
|
| 154 |
-
epoch: 113, lr: 2.34e-06 - train si-snr: -8.82e-02 - valid si-snr: 22.26
|
| 155 |
-
epoch: 114, lr: 1.17e-06 - train si-snr: -2.95e-02 - valid si-snr: 22.25
|
| 156 |
-
epoch: 115, lr: 1.17e-06 - train si-snr: 1.44e-02 - valid si-snr: 22.26
|
| 157 |
-
epoch: 116, lr: 1.17e-06 - train si-snr: -2.01e-02 - valid si-snr: 22.25
|
| 158 |
-
epoch: 117, lr: 5.86e-07 - train si-snr: -6.14e-02 - valid si-snr: 22.25
|
| 159 |
-
epoch: 118, lr: 5.86e-07 - train si-snr: 1.49e-02 - valid si-snr: 22.25
|
| 160 |
-
epoch: 119, lr: 5.86e-07 - train si-snr: -2.11e-02 - valid si-snr: 22.25
|
| 161 |
-
epoch: 120, lr: 2.93e-07 - train si-snr: -8.56e-02 - valid si-snr: 22.25
|
| 162 |
-
epoch: 121, lr: 2.93e-07 - train si-snr: 3.46e-02 - valid si-snr: 22.25
|
| 163 |
-
epoch: 122, lr: 2.93e-07 - train si-snr: -4.48e-02 - valid si-snr: 22.26
|
| 164 |
-
epoch: 123, lr: 1.46e-07 - train si-snr: -4.78e-02 - valid si-snr: 22.25
|
| 165 |
-
epoch: 124, lr: 1.46e-07 - train si-snr: 4.87e-02 - valid si-snr: 22.26
|
| 166 |
-
epoch: 125, lr: 1.46e-07 - train si-snr: -8.55e-02 - valid si-snr: 22.25
|
| 167 |
-
epoch: 126, lr: 7.32e-08 - train si-snr: 4.56e-02 - valid si-snr: 22.25
|
| 168 |
-
epoch: 127, lr: 7.32e-08 - train si-snr: -7.29e-02 - valid si-snr: 22.25
|
| 169 |
-
epoch: 128, lr: 7.32e-08 - train si-snr: -4.80e-02 - valid si-snr: 22.26
|
| 170 |
-
epoch: 129, lr: 3.66e-08 - train si-snr: -6.66e-02 - valid si-snr: 22.26
|
| 171 |
-
epoch: 130, lr: 3.66e-08 - train si-snr: 6.62e-03 - valid si-snr: 22.26
|
| 172 |
-
epoch: 131, lr: 3.66e-08 - train si-snr: -1.94e-02 - valid si-snr: 22.26
|
| 173 |
-
epoch: 132, lr: 1.83e-08 - train si-snr: 1.16e-02 - valid si-snr: 22.26
|
| 174 |
-
epoch: 133, lr: 1.83e-08 - train si-snr: -1.09e-01 - valid si-snr: 22.26
|
| 175 |
-
epoch: 134, lr: 1.83e-08 - train si-snr: -1.16e-01 - valid si-snr: 22.26
|
| 176 |
-
epoch: 135, lr: 1.00e-08 - train si-snr: 2.68e-02 - valid si-snr: 22.26
|
| 177 |
-
epoch: 136, lr: 1.00e-08 - train si-snr: 3.10e-03 - valid si-snr: 22.26
|
| 178 |
-
epoch: 137, lr: 1.00e-08 - train si-snr: -4.31e-02 - valid si-snr: 22.26
|
| 179 |
-
epoch: 138, lr: 1.00e-08 - train si-snr: 7.30e-02 - valid si-snr: 22.26
|
| 180 |
-
epoch: 139, lr: 1.00e-08 - train si-snr: -9.77e-02 - valid si-snr: 22.26
|
| 181 |
-
epoch: 140, lr: 1.00e-08 - train si-snr: -1.41e-01 - valid si-snr: 22.26
|
| 182 |
-
epoch: 141, lr: 1.00e-08 - train si-snr: -1.82e-02 - valid si-snr: 22.26
|
| 183 |
-
epoch: 142, lr: 1.00e-08 - train si-snr: -5.03e-02 - valid si-snr: 22.26
|
| 184 |
-
epoch: 143, lr: 1.00e-08 - train si-snr: -9.63e-02 - valid si-snr: 22.26
|
| 185 |
-
epoch: 144, lr: 1.00e-08 - train si-snr: -1.29e-02 - valid si-snr: 22.26
|
| 186 |
-
epoch: 145, lr: 1.00e-08 - train si-snr: -3.77e-02 - valid si-snr: 22.26
|
| 187 |
-
epoch: 146, lr: 1.00e-08 - train si-snr: -1.36e-01 - valid si-snr: 22.26
|
| 188 |
-
epoch: 147, lr: 1.00e-08 - train si-snr: -1.02e-01 - valid si-snr: 22.26
|
| 189 |
-
epoch: 148, lr: 1.00e-08 - train si-snr: 1.05e-01 - valid si-snr: 22.26
|
| 190 |
-
epoch: 149, lr: 1.00e-08 - train si-snr: -1.08e-01 - valid si-snr: 22.26
|
| 191 |
-
epoch: 150, lr: 1.00e-08 - train si-snr: 1.28e-02 - valid si-snr: 22.26
|
| 192 |
-
epoch: 151, lr: 1.00e-08 - train si-snr: -8.94e-02 - valid si-snr: 22.26
|
| 193 |
-
epoch: 152, lr: 1.00e-08 - train si-snr: -9.64e-02 - valid si-snr: 22.26
|
| 194 |
-
epoch: 153, lr: 1.00e-08 - train si-snr: -1.32e-01 - valid si-snr: 22.26
|
| 195 |
-
epoch: 154, lr: 1.00e-08 - train si-snr: 2.86e-02 - valid si-snr: 22.26
|
| 196 |
-
epoch: 155, lr: 1.00e-08 - train si-snr: -2.50e-02 - valid si-snr: 22.26
|
| 197 |
-
epoch: 156, lr: 1.00e-08 - train si-snr: -1.44e-02 - valid si-snr: 22.26
|
| 198 |
-
epoch: 157, lr: 1.00e-08 - train si-snr: 9.09e-02 - valid si-snr: 22.26
|
| 199 |
-
epoch: 158, lr: 1.00e-08 - train si-snr: 6.12e-03 - valid si-snr: 22.26
|
| 200 |
-
epoch: 159, lr: 1.00e-08 - train si-snr: -3.80e-02 - valid si-snr: 22.26
|
| 201 |
-
epoch: 160, lr: 1.00e-08 - train si-snr: 4.51e-02 - valid si-snr: 22.26
|
| 202 |
-
epoch: 161, lr: 1.00e-08 - train si-snr: -2.98e-02 - valid si-snr: 22.26
|
| 203 |
-
epoch: 162, lr: 1.00e-08 - train si-snr: -2.20e-03 - valid si-snr: 22.26
|
| 204 |
-
epoch: 163, lr: 1.00e-08 - train si-snr: -1.64e-01 - valid si-snr: 22.26
|
| 205 |
-
epoch: 164, lr: 1.00e-08 - train si-snr: -3.20e-02 - valid si-snr: 22.26
|
| 206 |
-
epoch: 165, lr: 1.00e-08 - train si-snr: 3.47e-03 - valid si-snr: 22.26
|
| 207 |
-
epoch: 166, lr: 1.00e-08 - train si-snr: -8.60e-02 - valid si-snr: 22.26
|
| 208 |
-
epoch: 167, lr: 1.00e-08 - train si-snr: 6.45e-03 - valid si-snr: 22.26
|
| 209 |
-
epoch: 168, lr: 1.00e-08 - train si-snr: 1.17e-02 - valid si-snr: 22.26
|
| 210 |
-
epoch: 169, lr: 1.00e-08 - train si-snr: -4.05e-02 - valid si-snr: 22.26
|
| 211 |
-
epoch: 170, lr: 1.00e-08 - train si-snr: -1.26e-01 - valid si-snr: 22.26
|
| 212 |
-
epoch: 171, lr: 1.00e-08 - train si-snr: -1.06e-01 - valid si-snr: 22.26
|
| 213 |
-
epoch: 172, lr: 1.00e-08 - train si-snr: -1.26e-01 - valid si-snr: 22.26
|
| 214 |
-
epoch: 173, lr: 1.00e-08 - train si-snr: -7.41e-02 - valid si-snr: 22.26
|
| 215 |
-
epoch: 174, lr: 1.00e-08 - train si-snr: 1.57e-02 - valid si-snr: 22.26
|
| 216 |
-
epoch: 175, lr: 1.00e-08 - train si-snr: -1.48e-02 - valid si-snr: 22.26
|
| 217 |
-
epoch: 176, lr: 1.00e-08 - train si-snr: 6.87e-02 - valid si-snr: 22.26
|
| 218 |
-
epoch: 177, lr: 1.00e-08 - train si-snr: -6.77e-02 - valid si-snr: 22.26
|
| 219 |
-
epoch: 178, lr: 1.00e-08 - train si-snr: -1.75e-01 - valid si-snr: 22.26
|
| 220 |
-
epoch: 179, lr: 1.00e-08 - train si-snr: -8.73e-02 - valid si-snr: 22.26
|
| 221 |
-
epoch: 180, lr: 1.00e-08 - train si-snr: -7.13e-02 - valid si-snr: 22.26
|
| 222 |
-
epoch: 181, lr: 1.00e-08 - train si-snr: -1.28e-01 - valid si-snr: 22.26
|
| 223 |
-
epoch: 182, lr: 1.00e-08 - train si-snr: 2.53e-02 - valid si-snr: 22.26
|
| 224 |
-
epoch: 183, lr: 1.00e-08 - train si-snr: 5.30e-02 - valid si-snr: 22.26
|
| 225 |
-
epoch: 184, lr: 1.00e-08 - train si-snr: -6.50e-02 - valid si-snr: 22.26
|
| 226 |
-
epoch: 185, lr: 1.00e-08 - train si-snr: -7.48e-02 - valid si-snr: 22.26
|
| 227 |
-
epoch: 186, lr: 1.00e-08 - train si-snr: -6.33e-02 - valid si-snr: 22.26
|
| 228 |
-
epoch: 187, lr: 1.00e-08 - train si-snr: -5.01e-02 - valid si-snr: 22.26
|
| 229 |
-
epoch: 188, lr: 1.00e-08 - train si-snr: -2.82e-03 - valid si-snr: 22.26
|
| 230 |
-
epoch: 189, lr: 1.00e-08 - train si-snr: -1.37e-01 - valid si-snr: 22.26
|
| 231 |
-
epoch: 190, lr: 1.00e-08 - train si-snr: -3.86e-02 - valid si-snr: 22.26
|
| 232 |
-
epoch: 191, lr: 1.00e-08 - train si-snr: -4.23e-02 - valid si-snr: 22.26
|
| 233 |
-
epoch: 192, lr: 1.00e-08 - train si-snr: -7.80e-02 - valid si-snr: 22.26
|
| 234 |
-
epoch: 193, lr: 1.00e-08 - train si-snr: -2.90e-02 - valid si-snr: 22.26
|
| 235 |
-
epoch: 194, lr: 1.00e-08 - train si-snr: -1.21e-01 - valid si-snr: 22.26
|
| 236 |
-
epoch: 195, lr: 1.00e-08 - train si-snr: 8.91e-03 - valid si-snr: 22.26
|
| 237 |
-
epoch: 196, lr: 1.00e-08 - train si-snr: -5.28e-02 - valid si-snr: 22.26
|
| 238 |
-
epoch: 197, lr: 1.00e-08 - train si-snr: 9.40e-02 - valid si-snr: 22.26
|
| 239 |
-
epoch: 198, lr: 1.00e-08 - train si-snr: -4.55e-02 - valid si-snr: 22.26
|
| 240 |
-
epoch: 199, lr: 1.00e-08 - train si-snr: -6.24e-02 - valid si-snr: 22.26
|
| 241 |
-
epoch: 200, lr: 1.00e-08 - train si-snr: 5.69e-03 - valid si-snr: 22.26
|
| 242 |
-
Epoch loaded: 104 - test si-snr: 20.22
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|