xuesongyan commited on
Commit
ee4b9b7
·
1 Parent(s): 12c65b8

Upload config.py

Browse files
Files changed (1) hide show
  1. config.py +517 -0
config.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Config(object):
2
+ def __init__(self, config_dict: dict):
3
+ for key, val in config_dict.items():
4
+ if val is not None:
5
+ self.__setattr__(key, val)
6
+
7
+ def copy(self, new_config_dict={}):
8
+ ret = Config(vars(self))
9
+ for key, val in new_config_dict.items():
10
+ if val is not None:
11
+ ret.__setattr__(key, val)
12
+ return ret
13
+
14
+ def replace(self, new_config_dict):
15
+ if isinstance(new_config_dict, Config):
16
+ new_config_dict = vars(new_config_dict)
17
+
18
+ for key, val in new_config_dict.items():
19
+ if val is not None:
20
+ self.__setattr__(key, val)
21
+
22
+ def print(self):
23
+ for k, v in vars(self).items():
24
+ print(k, '=', v)
25
+
26
+ # def parser_val(self, val):
27
+ # if isinstance(val, dict):
28
+ # return Config(val)
29
+ # elif isinstance(val, list):
30
+ # for i in range(len(val)):
31
+ # if val is not None:
32
+ # val[i] = self.parser_val(val[i])
33
+ # return val
34
+ # else:
35
+ # return val
36
+
37
+ def __str__(self):
38
+ return str(vars(self))
39
+
40
+
41
+ base_config = Config({
42
+ "project": "speaker_verification",
43
+ "name": "VGGVox",
44
+ "save_dir": "train_models/",
45
+ "resume": "",
46
+
47
+ # Training and test data
48
+ "dataset": Config({
49
+ "name": "voxceleb2_wav",
50
+ "train_list": "data/train_list.txt",
51
+ "test_list": "data/veri_list.txt",
52
+ "train_path": "data/voxceleb2",
53
+ "test_path": "data/voxceleb1",
54
+ "musan_path": "data/musan_split", # 噪声文件
55
+ "rir_path": "data/RIRS_NOISES/simulated_rirs", # 混响文件
56
+ }),
57
+
58
+
59
+ # Data loader
60
+ "max_frames": 300, # 训练时帧长
61
+ "eval_frames": 300,
62
+ "batch_size": 64,
63
+ "max_seg_per_spk": 500, # 每个说话人最大的语音段数
64
+ "nDataLoaderThread": 16, # 多线程加载
65
+ "augment": True, # 是否数据增强
66
+ "seed": 10,
67
+ "segment": 1,
68
+
69
+ # Training details
70
+ "test_interval": 1, # 测试间隔
71
+ "max_epoch": 500,
72
+
73
+ # Model definition
74
+ "n_mels": 40,
75
+ "log_input": False,
76
+ "model": "Vgg",
77
+ "encoder_type": "SAP",
78
+ "nOut": 512,
79
+
80
+ # Loss functions
81
+ "loss": "SoftmaxProto", # lossfunction function
82
+ "hard_prob": 0.5,
83
+ "hard_rank": 10,
84
+ "margin": 0.2,
85
+ "scale": 30,
86
+ "nPerSpeaker": 2, # 同一段语音取多少组
87
+ "nClasses": 5994,
88
+
89
+ # Optimizer
90
+ "optimizer": "adam",
91
+ "scheduler": "steplr",
92
+ "lr": 0.001,
93
+ "lr_decay": 0.95,
94
+ "weight_decay": 0,
95
+
96
+ # Evaluation parameters
97
+ "dcf_p_target": 0.05,
98
+ "dcf_c_miss": 1,
99
+ "dcf_c_fa": 1,
100
+
101
+ # eval
102
+ "eval": False,
103
+ })
104
+
105
+ cfg = base_config
106
+
107
+ vgg_cfg = Config({
108
+ "name": "vgg_spectrogram",
109
+ "model": "vgg",
110
+ "batch_size": 64,
111
+ "nPerSpeaker": 2,
112
+ })
113
+
114
+ Unet_cfg = Config({
115
+ "name": "Unet",
116
+ "model": "UNetVgg",
117
+ "batch_size": 48,
118
+ "nPerSpeaker": 2,
119
+ "loss": "Unetloss"
120
+ })
121
+
122
+ UnetMask_cfg = Config({
123
+ "name": "UnetMask",
124
+ "model": "UNetVggMask",
125
+ "batch_size": 16,
126
+ "nPerSpeaker": 2,
127
+ "segment": 3,
128
+ "loss": "UnetMaskloss"
129
+ })
130
+
131
+ ECAPA_TDNN_cfg = Config({
132
+ "name": "ECAPA_TDNNm",
133
+ "model": "ECAPA_TDNN",
134
+ "loss": "AamSoftmaxProto",
135
+ "batch_size": 180,
136
+ "nPerSpeaker": 2,
137
+ "nOut": 192,
138
+ })
139
+
140
+ ECAPA_TDNNm_cfg = Config({
141
+ "name": "ECAPA_TDNNm",
142
+ "model": "ECAPA_TDNN",
143
+ "batch_size": 180,
144
+ "nPerSpeaker": 2,
145
+ "nOut": 192,
146
+ })
147
+
148
+ ECAPA_TDNN1024_cfg = Config({
149
+ "name": "ECAPA_TDNN1024",
150
+ "model": "ECAPA_TDNN",
151
+ "batch_size": 80,
152
+ "nPerSpeaker": 2,
153
+ "channels": 1024,
154
+ "nOut": 192,
155
+ })
156
+
157
+ ECAPA_TDNN_ks5_cfg = Config({
158
+ "name": "ECAPA_TDNN_ks5",
159
+ "model": "ECAPA_TDNN_ks5",
160
+ "batch_size": 180,
161
+ "nPerSpeaker": 2,
162
+ "nOut": 192,
163
+ })
164
+
165
+ ECAPA_TDNN_L2_cfg = Config({
166
+ "name": "ECAPA_TDNN_L2_pre",
167
+ "model": "ECAPA_TDNN_L2",
168
+ "batch_size": 180,
169
+ "nPerSpeaker": 2,
170
+ "nOut": 192,
171
+ "resume": "train_models/speaker_verification_ECAPA_TDNN/20210725/epoch:47,EER:2.5981,MinDCF:0.1912"
172
+ })
173
+
174
+ ECAPA_TDNN_br_cfg = Config({
175
+ "name": "ECAPA_TDNN_br",
176
+ "model": "ECAPA_TDNN_br",
177
+ "batch_size": 180,
178
+ "nPerSpeaker": 2,
179
+ "nOut": 192,
180
+ })
181
+
182
+ ECAPATDNN_cfg = Config({
183
+ "name": "ECAPATDNN",
184
+ "model": "ECAPATDNN",
185
+ "batch_size": 110,
186
+ "nPerSpeaker": 2,
187
+ "nOut": 192,
188
+ "input_size": 80,
189
+ })
190
+
191
+ HRNet_cfg = Config({
192
+ "name": "hrnet",
193
+ "model": "hrnet",
194
+ "max_frames": 224,
195
+ "eval_frames": 224,
196
+ "batch_size": 48,
197
+ "nPerSpeaker": 2,
198
+ "nOut": 1024,
199
+ "input_size": 224*224,
200
+
201
+ "model_cfg": Config({
202
+ "hrnet_name": "w48",
203
+ "STAGE1": {
204
+ "NUM_MODULES": 1,
205
+ "NUM_BRANCHES": 1,
206
+ "BLOCK": "BOTTLENECK",
207
+ "NUM_BLOCKS": [4],
208
+ "NUM_CHANNELS": [64],
209
+ "FUSE_METHOD": "SUM"
210
+ },
211
+ "STAGE2": {
212
+ "NUM_MODULES": 1,
213
+ "NUM_BRANCHES": 2,
214
+ "BLOCK": "BASIC",
215
+ "NUM_BLOCKS": [4, 4],
216
+ "NUM_CHANNELS": [18, 36],
217
+ "FUSE_METHOD": "SUM"
218
+ },
219
+ "STAGE3": {
220
+ "NUM_MODULES": 4,
221
+ "NUM_BRANCHES": 3,
222
+ "BLOCK": "BASIC",
223
+ "NUM_BLOCKS": [4, 4, 4],
224
+ "NUM_CHANNELS": [18, 36, 72],
225
+ "FUSE_METHOD": "SUM"
226
+ },
227
+ "STAGE4": {
228
+ "NUM_MODULES": 3,
229
+ "NUM_BRANCHES": 4,
230
+ "BLOCK": "BASIC",
231
+ "NUM_BLOCKS": [4, 4, 4, 4],
232
+ "NUM_CHANNELS": [18, 36, 72, 144],
233
+ "FUSE_METHOD": "SUM"
234
+ },
235
+ }),
236
+
237
+ })
238
+
239
+ VGG_TDNN_cfg = Config({
240
+ "name": "Vggtdnn1",
241
+ "model": "Vggtdnn",
242
+ "batch_size": 256,
243
+ "nOut": 512,
244
+ "nDataLoaderThread": 16,
245
+ })
246
+
247
+ ResNetSE34V2_cfg = Config({
248
+ "name": "ResNetSE34V2",
249
+ "model": "ResNetSE34V2",
250
+ "batch_size": 128,
251
+ "nOut": 512,
252
+ "nDataLoaderThread": 16,
253
+ })
254
+
255
+ HRTDNN_cfg = Config({
256
+ "name": "hrtdnn",
257
+ "model": "hrtdnn",
258
+ "max_frames": 300,
259
+ "eval_frames": 300,
260
+ "batch_size": 96,
261
+ "nPerSpeaker": 2,
262
+ "nOut": 256,
263
+
264
+ "model_cfg": Config({
265
+ "hrnet_name": "hrtdnn",
266
+ "STAGE1": {
267
+ "NUM_BRANCHES": 1,
268
+ "BLOCK": 'TDNNBlock',
269
+ "NUM_CHANNELS": [128],
270
+ "FUSE_METHOD": "SUM"
271
+ },
272
+ "STAGE2": {
273
+ "NUM_BRANCHES": 2,
274
+ "BLOCK": 'TDNNBlock',
275
+ "NUM_CHANNELS": [128, 512],
276
+ "FUSE_METHOD": "SUM"
277
+ },
278
+ "STAGE3": {
279
+ "NUM_BRANCHES": 3,
280
+ "BLOCK": 'TDNNBlock',
281
+ "NUM_CHANNELS": [128, 512, 1024],
282
+ "FUSE_METHOD": "SUM"
283
+ },
284
+
285
+ }),
286
+
287
+ })
288
+
289
+ ResTDNN_cfg = Config({
290
+ "name": "ResTDNN",
291
+ "model": "ResTDNN",
292
+ "batch_size": 110,
293
+ "nOut": 256,
294
+ "nDataLoaderThread": 16,
295
+ })
296
+
297
+ TDNN_VGG_cfg = Config({
298
+ "name": "TDNN_VGG",
299
+ "model": "TDNN_VGG",
300
+ "batch_size": 64,
301
+ "nOut": 256,
302
+ "nDataLoaderThread": 16,
303
+ })
304
+
305
+ ResNet_TDNN_cfg = Config({
306
+ "name": "ResNet_TDNN",
307
+ "model": "ResNet_TDNN",
308
+ "batch_size": 96,
309
+ "nOut": 192,
310
+ "nDataLoaderThread": 16,
311
+ })
312
+
313
+ ResNet_TDNNa_cfg = Config({
314
+ "name": "ResNet_TDNNa",
315
+ "model": "ResNet_TDNN",
316
+ "batch_size": 96,
317
+ "nOut": 192,
318
+ "nDataLoaderThread": 16,
319
+ })
320
+
321
+ ResNet_TDNNaam_cfg = Config({
322
+ "name": "ResNet_TDNNaam",
323
+ "model": "ResNet_TDNN",
324
+ "loss": "AamSoftmaxProto",
325
+ "margin": 0.2,
326
+ "scale": 30,
327
+ "batch_size": 96,
328
+ "nOut": 192,
329
+ "nDataLoaderThread": 16,
330
+ "augment": True,
331
+ })
332
+
333
+ TDNN_ResNet_cfg = Config({
334
+ "name": "TDNN_ResNet",
335
+ "model": "TDNN_ResNet",
336
+ "batch_size": 48,
337
+ "nOut": 256,
338
+ "nDataLoaderThread": 16,
339
+ })
340
+
341
+ hr_tdnn_cfg = Config({
342
+ "name": "hr_tdnn",
343
+ "model": "hr_tdnn",
344
+ "batch_size": 46,
345
+ "nOut": 192,
346
+ "nDataLoaderThread": 16,
347
+ })
348
+
349
+
350
+ ECAPA_TDNNma_cfg = Config({
351
+ "name": "ECAPA_TDNNma",
352
+ "model": "ECAPA_TDNN",
353
+ "batch_size": 180,
354
+ "nPerSpeaker": 2,
355
+ "nOut": 192,
356
+ "augment": True,
357
+ })
358
+
359
+ ECAPA_TDNNaam_cfg = Config({
360
+ "name": "ECAPA_TDNNaam",
361
+ "model": "ECAPA_TDNN",
362
+ "loss": "AamSoftmax",
363
+ "batch_size": 360,
364
+ "nPerSpeaker": 1,
365
+ "nOut": 192,
366
+ "augment": True,
367
+ })
368
+
369
+ ECAPA_TDNNaam1_cfg = Config({
370
+ "name": "ECAPA_TDNNaam1",
371
+ "model": "ECAPA_TDNN",
372
+ "loss": "AdditiveAngularMargin",
373
+ "batch_size": 360,
374
+ "nPerSpeaker": 1,
375
+ "nOut": 192,
376
+ "augment": True,
377
+ })
378
+
379
+ ECAPA_TDNNaam2_cfg = Config({
380
+ "name": "ECAPA_TDNNaam2",
381
+ "model": "ECAPA_TDNN",
382
+ "loss": "AamSoftmax",
383
+ "margin": 0.2,
384
+ "scale": 30,
385
+ "batch_size": 360,
386
+ "nPerSpeaker": 1,
387
+ "nOut": 192,
388
+ "augment": True,
389
+
390
+ })
391
+
392
+ ECAPA_TDNNaam3_cfg = Config({
393
+ "name": "ECAPA_TDNNaam3",
394
+ "model": "ECAPA_TDNN",
395
+ "loss": "AamSoftmax",
396
+ "margin": 0.1,
397
+ "scale": 30,
398
+ "batch_size": 360,
399
+ "nPerSpeaker": 1,
400
+ "nOut": 192,
401
+ "augment": True,
402
+
403
+ })
404
+
405
+ ECAPA_TDNN_aamproto_cfg = Config({
406
+ "name": "ECAPA_TDNN_aamproto",
407
+ "model": "ECAPA_TDNN",
408
+ "loss": "AamSoftmaxProto",
409
+ "batch_size": 180,
410
+ "nPerSpeaker": 2,
411
+ "nOut": 192,
412
+ "augment": True,
413
+ })
414
+
415
+ ECAPA_TDNN_aamproto1_cfg = Config({
416
+ "name": "ECAPA_TDNN_aamproto1",
417
+ "model": "ECAPA_TDNN",
418
+ "loss": "AamSoftmaxProto",
419
+ "margin": 0.2,
420
+ "scale": 30,
421
+ "batch_size": 180,
422
+ "nPerSpeaker": 2,
423
+ "nOut": 192,
424
+ "augment": True,
425
+ })
426
+
427
+ ECAPA_TDNN0_cfg = Config({
428
+ "name": "ECAPA_TDNN-1lr0.001",
429
+ "model": "ECAPA_TDNN",
430
+ "loss": "AamSoftmax",
431
+ "batch_size": 360,
432
+ "nOut": 192,
433
+ "nPerSpeaker": 1,
434
+ "resume": "train_models/speaker_verification_ECAPA_TDNN0/20210928/epoch:25,EER:2.4125,MinDCF:0.1537",
435
+ })
436
+
437
+ SwinTransformer_cfg = Config({
438
+ "name": "SwinTransformer",
439
+ "model": "SwinTransformer",
440
+ "loss": "SoftmaxProto",
441
+ "max_frames": 224,
442
+ "eval_frames": 224,
443
+ "n_mels": 224,
444
+ "batch_size": 90,
445
+ "nPerSpeaker": 2,
446
+ "nOut": 192,
447
+ "augment": True,
448
+ "lr": 5e-5,
449
+ })
450
+
451
+ ECAPA_TDNN_aampre_cfg = Config({
452
+ "name": "ECAPA_TDNN_aampre",
453
+ "model": "ECAPA_TDNN",
454
+ "loss": "AamSoftmaxProto",
455
+ "batch_size": 180,
456
+ "nOut": 192,
457
+ "nPerSpeaker": 2,
458
+ "resume": "train_models/speaker_verification_ECAPA_TDNNma/20210908/epoch:67,EER:2.3224,MinDCF:0.1658",
459
+ })
460
+
461
+ # 更换dataloader
462
+ ECAPA_TDNN_data_cfg = Config({
463
+ "name": "ECAPA_TDNN_data",
464
+ "model": "ECAPA_TDNN",
465
+ "loss": "AamSoftmax",
466
+ "batch_size": 360,
467
+ "nPerSpeaker": 1,
468
+ "nOut": 192,
469
+ "augment": True,
470
+
471
+ })
472
+
473
+ # 标准的ECAPA_TDNN 学习率CyclicLR
474
+ ECAPA_TDNNaam_cyclr_cfg = Config({
475
+ "name": "ECAPA_TDNNaam_cyclr",
476
+ "model": "ECAPA_TDNN",
477
+ "loss": "AamSoftmax",
478
+ "margin": 0.2,
479
+ "scale": 30,
480
+ "batch_size": 360,
481
+ "nPerSpeaker": 1,
482
+ "nOut": 192,
483
+ "augment": True,
484
+
485
+ })
486
+
487
+ # 跟换数据加载的ResNet_TDNN只用softmax
488
+ ResNet_TDNNaam_data_cfg = Config({
489
+ "name": "ResNet_TDNNaam_data",
490
+ "model": "ResNet_TDNN",
491
+ "loss": "AamSoftmax",
492
+ "margin": 0.2,
493
+ "scale": 30,
494
+ "batch_size": 192,
495
+ "nOut": 192,
496
+ "nDataLoaderThread": 16,
497
+ "nPerSpeaker": 1,
498
+ "augment": True,
499
+ })
500
+
501
+ # 更换dataloader, 和cyclical lr
502
+ ECAPA_TDNN_dataClr_cfg = Config({
503
+ "name": "ECAPA_TDNN_dataClr",
504
+ "model": "ECAPA_TDNN",
505
+ "loss": "AamSoftmax",
506
+ "batch_size": 360,
507
+ "nPerSpeaker": 1,
508
+ "nOut": 192,
509
+ "augment": True,
510
+ })
511
+
512
+ def set_cfg(config_name: str):
513
+ """ Sets the active configs. Works even if cfg is already imported! """
514
+ global cfg
515
+ # Note this is not just an eval because I'm lazy, but also because it can
516
+ # be used like ssd300_config.copy({'max_size': 400}) for extreme fine-tuning
517
+ cfg.replace(eval(config_name))