Haopeng commited on
Commit
0e0d443
·
verified ·
1 Parent(s): 0bcd3f2

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. CKPT.yaml +6 -0
  2. inference.yaml +312 -0
  3. model.ckpt +3 -0
  4. mpd.txt +0 -0
  5. per.txt +0 -0
  6. perceived_ssl.ckpt +3 -0
  7. tokenizer.ckpt +3 -0
CKPT.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # yamllint disable
2
+ PER: 17.544039328144205
3
+ end-of-epoch: true
4
+ epoch: 140
5
+ mpd_f1: 0.7108831073653353
6
+ unixtime: 1770905887.9787815
inference.yaml ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hyperparameters toggles
2
+ prefix: ""
3
+
4
+ lab_enc_file: /home/m64000/work/IF-MDD/exp_iqra/wavlm_large_None_PhnMonoSSL_ottc_confEnc/save/label_encoder.txt
5
+ ctc_loss_type: "crottc" # Options: "ctc", "ottc", "crctc"
6
+ encoder_type: "conformer" # Options: None, "conformer", "zipformer", "rvq"
7
+
8
+ wandb_project: "iqra_extra"
9
+ # Wandb Tags
10
+ tags:
11
+ - PhnMonoSSL
12
+ - crottc
13
+ - ConformerEncoder
14
+ - iqra_extra
15
+ - TTS_FT
16
+
17
+ ## SSL features Selection
18
+ pretrained_models_path: pretrained_models/
19
+ # pretrained_models:
20
+ # {
21
+ # "wav2vec2_base": "facebook/wav2vec2-base", # 768
22
+ # "hubert_base": "facebook/hubert-base-ls960", # 768
23
+ # "wavlm_base": "microsoft/wavlm-base", # 768
24
+ # "wavlm_base_plus": "microsoft/wavlm-base-plus", # 768
25
+ # "hubert_multilingual": "utter-project/mHuBERT-147", # 768
26
+ # "clap" : "laion/clap-htsat-fused", # 768
27
+ # "data2vec_base": "facebook/data2vec-audio-base", # 768
28
+
29
+ # "wav2vec2_large": "facebook/wav2vec2-large", # 1024
30
+ # "hubert_large": "facebook/hubert-large-ls960", # 1024
31
+ # "wavlm_large": "microsoft/wavlm-large-plus", # 1024
32
+ # "data2vec_large": "facebook/data2vec-audio-large", #1024
33
+ # "whisper_medium": "openai/whisper-medium", # 1024
34
+
35
+ # "whisper_large_v3_turbo": "openai/whisper-large-v3-turbo", # 1280
36
+ # }
37
+
38
+
39
+
40
+ # select pretrained SSL models
41
+ perceived_ssl_model: "wavlm_large" # in pretrained_models
42
+ canonical_ssl_model: Null
43
+
44
+ # # models hidden size, varies by model
45
+ ENCODER_DIM: 1024
46
+
47
+ # # How to fuse the features
48
+ feature_fusion: "mono" # Options: "mono" for single ssl, "dual_ssl_enc" for dual ssl encoder, "dual_loss" for single SSL dual ssl loss
49
+ blend_alpha: 0.5 # If using "blend" fusion
50
+
51
+ # Input files
52
+ # Data files
53
+ # data_folder_save: "/home/kevingenghaopeng/MDD/IF-MDD/data_iqra/demo_data"
54
+ data_folder_save: "/home/m64000/work/dataset/data_iqra_extra_is26"
55
+ train_annotation: !ref <data_folder_save>/iqra_extra_is26_train_aligned.json
56
+ valid_annotation: !ref <data_folder_save>/iqra_extra_is26_dev_aligned.json
57
+ test_annotation: !ref <data_folder_save>/iqra_extra_is26_test_aligned.json
58
+ # Extra data
59
+ train_annotation_extra: !ref <data_folder_save>/train-train_with_extra.json
60
+ use_extra_train_data: False
61
+
62
+ evaluate_key: "PER" # use "mpd_f1_seq" for Transformer decoder path best mpd f1
63
+ # "PER_seq" for Transformer decoder's best error rate
64
+ # "PER" for ctc path best error rate
65
+ # "mpd_f1" for ctc path best mpd f1
66
+ max_save_models: 3 # Maximum number of saved models for each metrics
67
+ # generate training id for output folder
68
+ # generate_training_id: !apply:trainer.generate_training_id.generate_training_id [!ref <perceived_ssl_model_id>, !ref <canonical_ssl_model_id>, !ref <feature_fusion>, !ref <prefix>]
69
+
70
+ # output files
71
+ output_folder: !ref exp_iqra/<perceived_ssl_model>_<canonical_ssl_model>_<feature_fusion>_<prefix>
72
+ per_file: !ref <output_folder>/per.txt
73
+ mpd_file: !ref <output_folder>/mpd.txt
74
+ save_folder: !ref <output_folder>/save
75
+ train_log: !ref <output_folder>/train_log.txt
76
+
77
+ on_training_test_wer_folder: !ref <output_folder>/on_training_test_wer
78
+ on_training_test_mpd_folder: !ref <output_folder>/on_training_test_mpd
79
+
80
+ # Training Target
81
+ training_target: "target" # "target": deduplicated canonical phoneme sequence; "target_with_repeats": with repeats
82
+ # "canonical"
83
+ # "perceived": deduplicated perceived phoneme sequence
84
+ # Modules (SpeechBrain lobes)
85
+ # modules:
86
+ # canonical_ssl: !ref <canonical_ssl>
87
+ # perceived_ssl: !ref <perceived_ssl>
88
+ # enc: !ref <enc>
89
+ # ConformerEncoder: !ref <ConformerEncoder>
90
+ # ctc_lin: !ref <ctc_lin>
91
+ # lm_weight: !ref <lm_weight>
92
+
93
+ perceived_ssl: !apply:trainer.AutoSSLoader.AutoSSLLoader
94
+ model_name: !ref <perceived_ssl_model>
95
+ freeze: !ref <freeze_perceived_ssl>
96
+ freeze_feature_extractor: !ref <freeze_perceived_feature_extractor>
97
+ save_path: !ref <pretrained_models_path>
98
+ output_all_hiddens: False
99
+ preceived_ssl_emb_layer: -1
100
+
101
+ canonical_ssl: !apply:trainer.AutoSSLoader.AutoSSLLoader
102
+ model_name: !ref <canonical_ssl_model>
103
+ freeze: !ref <freeze_canonical_ssl>
104
+ freeze_feature_extractor: !ref <freeze_perceived_feature_extractor>
105
+ save_path: !ref <pretrained_models_path>
106
+ output_all_hiddens: False
107
+
108
+ canonical_ssl_emb_layer: -1
109
+
110
+ enc: !new:torch.nn.Sequential
111
+ - !new:speechbrain.lobes.models.VanillaNN.VanillaNN
112
+ input_shape: [null, null, !ref <ENCODER_DIM>]
113
+ activation: !ref <activation>
114
+ dnn_blocks: !ref <dnn_layers>
115
+ dnn_neurons: !ref <dnn_neurons>
116
+ - !new:torch.nn.LayerNorm
117
+ normalized_shape: !ref <dnn_neurons>
118
+
119
+
120
+ kernel_size: 7
121
+ attention_type: "RoPEMHA" # Options: "standard", "RoPE"
122
+ ConformerEncoder: !new:speechbrain.lobes.models.transformer.Conformer.ConformerEncoder
123
+ num_layers: 2
124
+ nhead: 8
125
+ d_ffn: !ref <dnn_neurons>
126
+ d_model: !ref <dnn_neurons>
127
+ dropout: 0.1
128
+ kernel_size: !ref <kernel_size>
129
+ attention_type: !ref <attention_type>
130
+
131
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
132
+ input_size: !ref <dnn_neurons>
133
+ n_neurons: !ref <output_neurons> # 40 phonemes + 1 blank + 1 err
134
+
135
+ # lm_weight for OTTC's alpha prediction
136
+ lm_weight: !new:speechbrain.nnet.linear.Linear
137
+ input_size: !ref <dnn_neurons>
138
+ n_neurons: 1 # 40 phonemes + 1 blank + 1 err
139
+
140
+ # Model parameters
141
+ activation: !name:torch.nn.LeakyReLU
142
+ dnn_layers: 2
143
+ dnn_neurons: 384
144
+ freeze_perceived_ssl: False
145
+ freeze_canonical_ssl: False
146
+ freeze_perceived_feature_extractor: True # freeze the CNN extractor in wav2vec
147
+ freeze_canonical_feature_extractor: True # Freeze Whisper encoder?
148
+
149
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
150
+ apply_log: True
151
+
152
+ # ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
153
+ # blank_index: !ref <blank_index>
154
+
155
+ # ctc_cost: !new:utils.CTCLossWithLabelPriors.CTCLossWithLabelPriors
156
+ # prior_scaling_factor: 0.3
157
+ # ctc_implementation: 'k2'
158
+ # blank: !ref <blank_index>
159
+ # reduction: 'sum'
160
+
161
+ ctc_cost: !name:utils.losses.ot_loss.batched_ottc_loss_bucketized
162
+
163
+
164
+ ctc_cost_mispro: !name:speechbrain.nnet.losses.ctc_loss
165
+ blank_index: !ref <blank_index>
166
+
167
+ # Outputs
168
+ output_neurons: 71 # l2arctic: 40phns(sil)+err+blank + eos + bos =44
169
+ blank_index: 0
170
+
171
+ model: !new:torch.nn.ModuleList
172
+ - [!ref <enc>, !ref <ctc_lin>, ]
173
+
174
+ adam_opt_class: !name:torch.optim.Adam
175
+ lr: !ref <lr>
176
+
177
+ pretrained_opt_class: !name:torch.optim.Adam
178
+ lr: !ref <lr_pretrained>
179
+
180
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
181
+ checkpoints_dir: !ref <save_folder>
182
+ recoverables:
183
+ model: !ref <model>
184
+ perceived_ssl: !ref <perceived_ssl>
185
+ counter: !ref <epoch_counter>
186
+ allow_partial_load: True
187
+ # canonical_ssl: !ref <canonical_ssl>
188
+ # augmentation: !new:speechbrain.augment.time_domain.SpeedPerturb
189
+ # orig_freq: !ref <sample_rate>
190
+ # speeds: [95, 100, 105]
191
+
192
+ spec_augmentation: !new:speechbrain.augment.freq_domain.SpectrogramDrop
193
+ drop_length_low: 5
194
+ drop_length_high: 27
195
+ drop_count_low: 1
196
+ drop_count_high: 3
197
+ replace: 'zeros'
198
+
199
+ freq_chunk_augmentation: !new:speechbrain.augment.time_domain.DropFreq
200
+ drop_freq_low: 1e-14
201
+ drop_freq_high: 1
202
+ drop_freq_count_low: 1
203
+ drop_freq_count_high: 3
204
+ drop_freq_width: 0.10
205
+ epsilon: 1e-12
206
+
207
+ drop_length_high: 3000
208
+ time_chunk_augmentation: !new:speechbrain.augment.time_domain.DropChunk
209
+ drop_length_low: 1000
210
+ drop_length_high: !ref <drop_length_high>
211
+ drop_count_low: 1
212
+ drop_count_high: 3
213
+
214
+ speed_augmentation: !new:speechbrain.augment.time_domain.SpeedPerturb
215
+ orig_freq: !ref <sample_rate>
216
+ speeds: [95, 100, 105]
217
+
218
+ timewarp_augmentation: !new:speechbrain.augment.freq_domain.Warping
219
+ warp_window: 5
220
+ dim: 1 # time
221
+
222
+ augmentation: !new:speechbrain.augment.augmenter.Augmenter
223
+ augmentations:
224
+ - !ref <freq_chunk_augmentation>
225
+ - !ref <time_chunk_augmentation>
226
+ # - !new:speechbrain.augment.time_domain.SpeedPerturb # Apply speed perturbation ahead so the copy of
227
+ # orig_freq: !ref <sample_rate>
228
+ # speeds: [95, 100, 105]
229
+
230
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
231
+ limit: !ref <number_of_epochs>
232
+
233
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
234
+ save_file: !ref <train_log>
235
+
236
+ # ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
237
+ # metric: !new:utils.CTCLossWithLabelPriors.CTCLossWithLabelPriors
238
+ # prior_scaling_factor: 0.3
239
+ # ctc_implementation: 'k2'
240
+ # blank: !ref <blank_index>
241
+ # reduction: 'none'
242
+
243
+ ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats
244
+ metric: !name:speechbrain.nnet.losses.ctc_loss
245
+ blank_index: !ref <blank_index>
246
+ reduction: batch
247
+
248
+ per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats
249
+
250
+ # # TIMIT
251
+ # timit_local_data_folder: "/common/db/TIMIT" # Path to TIMIT datase
252
+
253
+ seed: 3047
254
+ __set_seed: !apply:torch.manual_seed [!ref <seed>]
255
+
256
+ # training parameters
257
+ number_of_epochs: 300
258
+ batch_size: 16
259
+ lr: 0.0003
260
+ sorting: ascending
261
+ sample_rate: 16000
262
+ gradient_accumulation: 2
263
+ lr_pretrained: 0.00001
264
+
265
+ # Mix-Precision Training
266
+ auto_mix_prec: true
267
+ # or
268
+ precision: fp16 # 支持 "fp32"、"fp16" 或 "bf16"
269
+ eval_precision: fp32 # 推理同样切换到 FP16
270
+
271
+ # Dataloader options
272
+ train_dataloader_opts:
273
+ batch_size: !ref <batch_size>
274
+
275
+
276
+ valid_dataloader_opts:
277
+ batch_size: !ref <batch_size>
278
+
279
+
280
+ test_dataloader_opts:
281
+ batch_size: !ref <batch_size>
282
+
283
+ # # resume_from_pretrainer, to fine-tune from a saved pretrainer checkpoint
284
+ # resume_from: /home/m64000/work/IF-MDD/exp_iqra_tts/wavlm_large_None_PhnMonoSSL_crottc_confEnc_RoPE_k7/save/CKPT+088_PER_6.2082_F1_0.9074.ckpt
285
+
286
+ # resume_from_pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
287
+ # collect_in: !ref <resume_from>/
288
+ # loadables:
289
+ # perceived_ssl: !ref <perceived_ssl>
290
+ # model: !ref <model>
291
+ # #
292
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
293
+ collect_in: !ref <save_folder>/
294
+ loadables:
295
+ perceived_ssl: !ref <perceived_ssl>
296
+ model: !ref <model>
297
+ tokenizer: !ref <tokenizer>
298
+
299
+ encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
300
+ perceived_ssl: !ref <perceived_ssl>
301
+ enc: !ref <enc>
302
+ ctc_lin: !ref <ctc_lin>
303
+ log_softmax: !ref <log_softmax>
304
+
305
+ decoding_function: !name:speechbrain.decoders.ctc_greedy_decode
306
+ blank_id: !ref <blank_index>
307
+
308
+ tokenizer: !new:speechbrain.dataio.encoder.CTCTextEncoder
309
+ load_from_file: /home/kevingenghaopeng/MDD/IF-MDD/pretrained_models/iqra_extra_acou_model/ottc_k7_RoPE_TTS_FT/label_encoder.txt
310
+
311
+ modules:
312
+ encoder: !ref <encoder>
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e54409e3c2414c7263b9ab63c9228a1d7a231b6644324b2448b8bb3b3aeb744
3
+ size 2241500
mpd.txt ADDED
The diff for this file is too large to render. See raw diff
 
per.txt ADDED
The diff for this file is too large to render. See raw diff
 
perceived_ssl.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1ff5b43b55c412e73381e8b257c9af3c2237fa71b76bac5119ca8b31a531ec4
3
+ size 1262009130
tokenizer.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98cee9707ab67c3e29ee337debf4ba319cbc61c3777024db6b8f3494f0df5bfe
3
+ size 583