xuesongyan
Upload config.py
ee4b9b7
class Config(object):
def __init__(self, config_dict: dict):
for key, val in config_dict.items():
if val is not None:
self.__setattr__(key, val)
def copy(self, new_config_dict={}):
ret = Config(vars(self))
for key, val in new_config_dict.items():
if val is not None:
ret.__setattr__(key, val)
return ret
def replace(self, new_config_dict):
if isinstance(new_config_dict, Config):
new_config_dict = vars(new_config_dict)
for key, val in new_config_dict.items():
if val is not None:
self.__setattr__(key, val)
def print(self):
for k, v in vars(self).items():
print(k, '=', v)
# def parser_val(self, val):
# if isinstance(val, dict):
# return Config(val)
# elif isinstance(val, list):
# for i in range(len(val)):
# if val is not None:
# val[i] = self.parser_val(val[i])
# return val
# else:
# return val
def __str__(self):
return str(vars(self))
base_config = Config({
"project": "speaker_verification",
"name": "VGGVox",
"save_dir": "train_models/",
"resume": "",
# Training and test data
"dataset": Config({
"name": "voxceleb2_wav",
"train_list": "data/train_list.txt",
"test_list": "data/veri_list.txt",
"train_path": "data/voxceleb2",
"test_path": "data/voxceleb1",
"musan_path": "data/musan_split", # 噪声文件
"rir_path": "data/RIRS_NOISES/simulated_rirs", # 混响文件
}),
# Data loader
"max_frames": 300, # 训练时帧长
"eval_frames": 300,
"batch_size": 64,
"max_seg_per_spk": 500, # 每个说话人最大的语音段数
"nDataLoaderThread": 16, # 多线程加载
"augment": True, # 是否数据增强
"seed": 10,
"segment": 1,
# Training details
"test_interval": 1, # 测试间隔
"max_epoch": 500,
# Model definition
"n_mels": 40,
"log_input": False,
"model": "Vgg",
"encoder_type": "SAP",
"nOut": 512,
# Loss functions
"loss": "SoftmaxProto", # lossfunction function
"hard_prob": 0.5,
"hard_rank": 10,
"margin": 0.2,
"scale": 30,
"nPerSpeaker": 2, # 同一段语音取多少组
"nClasses": 5994,
# Optimizer
"optimizer": "adam",
"scheduler": "steplr",
"lr": 0.001,
"lr_decay": 0.95,
"weight_decay": 0,
# Evaluation parameters
"dcf_p_target": 0.05,
"dcf_c_miss": 1,
"dcf_c_fa": 1,
# eval
"eval": False,
})
cfg = base_config
vgg_cfg = Config({
"name": "vgg_spectrogram",
"model": "vgg",
"batch_size": 64,
"nPerSpeaker": 2,
})
Unet_cfg = Config({
"name": "Unet",
"model": "UNetVgg",
"batch_size": 48,
"nPerSpeaker": 2,
"loss": "Unetloss"
})
UnetMask_cfg = Config({
"name": "UnetMask",
"model": "UNetVggMask",
"batch_size": 16,
"nPerSpeaker": 2,
"segment": 3,
"loss": "UnetMaskloss"
})
ECAPA_TDNN_cfg = Config({
"name": "ECAPA_TDNNm",
"model": "ECAPA_TDNN",
"loss": "AamSoftmaxProto",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
})
ECAPA_TDNNm_cfg = Config({
"name": "ECAPA_TDNNm",
"model": "ECAPA_TDNN",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
})
ECAPA_TDNN1024_cfg = Config({
"name": "ECAPA_TDNN1024",
"model": "ECAPA_TDNN",
"batch_size": 80,
"nPerSpeaker": 2,
"channels": 1024,
"nOut": 192,
})
ECAPA_TDNN_ks5_cfg = Config({
"name": "ECAPA_TDNN_ks5",
"model": "ECAPA_TDNN_ks5",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
})
ECAPA_TDNN_L2_cfg = Config({
"name": "ECAPA_TDNN_L2_pre",
"model": "ECAPA_TDNN_L2",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
"resume": "train_models/speaker_verification_ECAPA_TDNN/20210725/epoch:47,EER:2.5981,MinDCF:0.1912"
})
ECAPA_TDNN_br_cfg = Config({
"name": "ECAPA_TDNN_br",
"model": "ECAPA_TDNN_br",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
})
ECAPATDNN_cfg = Config({
"name": "ECAPATDNN",
"model": "ECAPATDNN",
"batch_size": 110,
"nPerSpeaker": 2,
"nOut": 192,
"input_size": 80,
})
HRNet_cfg = Config({
"name": "hrnet",
"model": "hrnet",
"max_frames": 224,
"eval_frames": 224,
"batch_size": 48,
"nPerSpeaker": 2,
"nOut": 1024,
"input_size": 224*224,
"model_cfg": Config({
"hrnet_name": "w48",
"STAGE1": {
"NUM_MODULES": 1,
"NUM_BRANCHES": 1,
"BLOCK": "BOTTLENECK",
"NUM_BLOCKS": [4],
"NUM_CHANNELS": [64],
"FUSE_METHOD": "SUM"
},
"STAGE2": {
"NUM_MODULES": 1,
"NUM_BRANCHES": 2,
"BLOCK": "BASIC",
"NUM_BLOCKS": [4, 4],
"NUM_CHANNELS": [18, 36],
"FUSE_METHOD": "SUM"
},
"STAGE3": {
"NUM_MODULES": 4,
"NUM_BRANCHES": 3,
"BLOCK": "BASIC",
"NUM_BLOCKS": [4, 4, 4],
"NUM_CHANNELS": [18, 36, 72],
"FUSE_METHOD": "SUM"
},
"STAGE4": {
"NUM_MODULES": 3,
"NUM_BRANCHES": 4,
"BLOCK": "BASIC",
"NUM_BLOCKS": [4, 4, 4, 4],
"NUM_CHANNELS": [18, 36, 72, 144],
"FUSE_METHOD": "SUM"
},
}),
})
VGG_TDNN_cfg = Config({
"name": "Vggtdnn1",
"model": "Vggtdnn",
"batch_size": 256,
"nOut": 512,
"nDataLoaderThread": 16,
})
ResNetSE34V2_cfg = Config({
"name": "ResNetSE34V2",
"model": "ResNetSE34V2",
"batch_size": 128,
"nOut": 512,
"nDataLoaderThread": 16,
})
HRTDNN_cfg = Config({
"name": "hrtdnn",
"model": "hrtdnn",
"max_frames": 300,
"eval_frames": 300,
"batch_size": 96,
"nPerSpeaker": 2,
"nOut": 256,
"model_cfg": Config({
"hrnet_name": "hrtdnn",
"STAGE1": {
"NUM_BRANCHES": 1,
"BLOCK": 'TDNNBlock',
"NUM_CHANNELS": [128],
"FUSE_METHOD": "SUM"
},
"STAGE2": {
"NUM_BRANCHES": 2,
"BLOCK": 'TDNNBlock',
"NUM_CHANNELS": [128, 512],
"FUSE_METHOD": "SUM"
},
"STAGE3": {
"NUM_BRANCHES": 3,
"BLOCK": 'TDNNBlock',
"NUM_CHANNELS": [128, 512, 1024],
"FUSE_METHOD": "SUM"
},
}),
})
ResTDNN_cfg = Config({
"name": "ResTDNN",
"model": "ResTDNN",
"batch_size": 110,
"nOut": 256,
"nDataLoaderThread": 16,
})
TDNN_VGG_cfg = Config({
"name": "TDNN_VGG",
"model": "TDNN_VGG",
"batch_size": 64,
"nOut": 256,
"nDataLoaderThread": 16,
})
ResNet_TDNN_cfg = Config({
"name": "ResNet_TDNN",
"model": "ResNet_TDNN",
"batch_size": 96,
"nOut": 192,
"nDataLoaderThread": 16,
})
ResNet_TDNNa_cfg = Config({
"name": "ResNet_TDNNa",
"model": "ResNet_TDNN",
"batch_size": 96,
"nOut": 192,
"nDataLoaderThread": 16,
})
ResNet_TDNNaam_cfg = Config({
"name": "ResNet_TDNNaam",
"model": "ResNet_TDNN",
"loss": "AamSoftmaxProto",
"margin": 0.2,
"scale": 30,
"batch_size": 96,
"nOut": 192,
"nDataLoaderThread": 16,
"augment": True,
})
TDNN_ResNet_cfg = Config({
"name": "TDNN_ResNet",
"model": "TDNN_ResNet",
"batch_size": 48,
"nOut": 256,
"nDataLoaderThread": 16,
})
hr_tdnn_cfg = Config({
"name": "hr_tdnn",
"model": "hr_tdnn",
"batch_size": 46,
"nOut": 192,
"nDataLoaderThread": 16,
})
ECAPA_TDNNma_cfg = Config({
"name": "ECAPA_TDNNma",
"model": "ECAPA_TDNN",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
"augment": True,
})
ECAPA_TDNNaam_cfg = Config({
"name": "ECAPA_TDNNaam",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
ECAPA_TDNNaam1_cfg = Config({
"name": "ECAPA_TDNNaam1",
"model": "ECAPA_TDNN",
"loss": "AdditiveAngularMargin",
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
ECAPA_TDNNaam2_cfg = Config({
"name": "ECAPA_TDNNaam2",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"margin": 0.2,
"scale": 30,
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
ECAPA_TDNNaam3_cfg = Config({
"name": "ECAPA_TDNNaam3",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"margin": 0.1,
"scale": 30,
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
ECAPA_TDNN_aamproto_cfg = Config({
"name": "ECAPA_TDNN_aamproto",
"model": "ECAPA_TDNN",
"loss": "AamSoftmaxProto",
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
"augment": True,
})
ECAPA_TDNN_aamproto1_cfg = Config({
"name": "ECAPA_TDNN_aamproto1",
"model": "ECAPA_TDNN",
"loss": "AamSoftmaxProto",
"margin": 0.2,
"scale": 30,
"batch_size": 180,
"nPerSpeaker": 2,
"nOut": 192,
"augment": True,
})
ECAPA_TDNN0_cfg = Config({
"name": "ECAPA_TDNN-1lr0.001",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"batch_size": 360,
"nOut": 192,
"nPerSpeaker": 1,
"resume": "train_models/speaker_verification_ECAPA_TDNN0/20210928/epoch:25,EER:2.4125,MinDCF:0.1537",
})
SwinTransformer_cfg = Config({
"name": "SwinTransformer",
"model": "SwinTransformer",
"loss": "SoftmaxProto",
"max_frames": 224,
"eval_frames": 224,
"n_mels": 224,
"batch_size": 90,
"nPerSpeaker": 2,
"nOut": 192,
"augment": True,
"lr": 5e-5,
})
ECAPA_TDNN_aampre_cfg = Config({
"name": "ECAPA_TDNN_aampre",
"model": "ECAPA_TDNN",
"loss": "AamSoftmaxProto",
"batch_size": 180,
"nOut": 192,
"nPerSpeaker": 2,
"resume": "train_models/speaker_verification_ECAPA_TDNNma/20210908/epoch:67,EER:2.3224,MinDCF:0.1658",
})
# 更换dataloader
ECAPA_TDNN_data_cfg = Config({
"name": "ECAPA_TDNN_data",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
# 标准的ECAPA_TDNN 学习率CyclicLR
ECAPA_TDNNaam_cyclr_cfg = Config({
"name": "ECAPA_TDNNaam_cyclr",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"margin": 0.2,
"scale": 30,
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
# 跟换数据加载的ResNet_TDNN只用softmax
ResNet_TDNNaam_data_cfg = Config({
"name": "ResNet_TDNNaam_data",
"model": "ResNet_TDNN",
"loss": "AamSoftmax",
"margin": 0.2,
"scale": 30,
"batch_size": 192,
"nOut": 192,
"nDataLoaderThread": 16,
"nPerSpeaker": 1,
"augment": True,
})
# 更换dataloader, 和cyclical lr
ECAPA_TDNN_dataClr_cfg = Config({
"name": "ECAPA_TDNN_dataClr",
"model": "ECAPA_TDNN",
"loss": "AamSoftmax",
"batch_size": 360,
"nPerSpeaker": 1,
"nOut": 192,
"augment": True,
})
def set_cfg(config_name: str):
""" Sets the active configs. Works even if cfg is already imported! """
global cfg
# Note this is not just an eval because I'm lazy, but also because it can
# be used like ssd300_config.copy({'max_size': 400}) for extreme fine-tuning
cfg.replace(eval(config_name))