Spaces:
Runtime error
Runtime error
| from torch import optim | |
| import argparse | |
| from datetime import datetime | |
| import wandb | |
| import torch.backends.cudnn as cudnn | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from torchinfo import summary | |
| from timm.scheduler.cosine_lr import CosineLRScheduler | |
| import lossfunction | |
| import net | |
| from DatasetLoader import * | |
| from dataloader import TrainDataset | |
| from SpeakerNet import * | |
| from config import set_cfg, cfg | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config_name", type=str, default="", help="the configs name that will as a base configs") | |
| parser.add_argument("--project", default=None, type=str, help="project name") | |
| parser.add_argument("--name", default=None, type=str, help="run name") | |
| parser.add_argument("--save_dir", default=None, type=str, help="save path") | |
| parser.add_argument("--resume", default=None, type=str, help="resume path") | |
| parser.add_argument("--dataset", default=None, type=str, help="dataset path") | |
| parser.add_argument("--epoch", default=None, type=int, help="max epoch") | |
| parser.add_argument("--test_freq", default=None, type=int, help="frequency test epoch") | |
| parser.add_argument("--batch_size", default=None, type=int, help="batch size") | |
| parser.add_argument("--lr", default=None, type=float, help="learning rate") | |
| parser.add_argument("--seed", default=None, type=int) | |
| parser.add_argument("--wandb", action='store_true', default=False, help='use wandb to log ') | |
| parser.add_argument("--note", type=str, default="", help='wandb note') | |
| parser.add_argument('--eval', dest='eval', action='store_true', default=False, help='Eval only') | |
| parser.add_argument('--score', dest='score', action='store_true', default=False, help='Eval only') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| global cfg | |
| args = get_args() | |
| assert args.config_name is not None | |
| if args.config_name: | |
| set_cfg(args.config_name) | |
| cfg.replace(vars(args)) | |
| del args | |
| cfg.save_dir = os.path.join(cfg.save_dir, cfg.project + '_' + cfg.name, datetime.now().strftime('%Y%m%d')) | |
| if not os.path.exists(cfg.save_dir): | |
| os.makedirs(cfg.save_dir) | |
| if cfg.wandb: | |
| wandb.login(host="http://49.233.11.7:8080", key="local-7dc64cc63778f0723dc202d2624a97cef7043120") | |
| wandb.init(project=cfg.project, name=cfg.name, config=cfg, save_code=True, notes=cfg.note) | |
| # cudnn related setting | |
| cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.enabled = True | |
| start_epoch = 1 | |
| # ---------------模型--------------- | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # device = torch.device("cpu") | |
| # model = getattr(net, cfg.model)(cfg.nOut, cfg.encoder_type, cfg.log_input).to(device) | |
| # ------ECAPA_TDNN.yaml------ResNet_TDNN---- | |
| model = getattr(net, cfg.model)().to(device) | |
| # loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses, cfg.margin, cfg.scale).to(device) | |
| # ----aamsoftmax---- | |
| loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device) | |
| # model = SpeakerUnet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker, segment=cfg.segment) | |
| model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker) | |
| # swin | |
| optimizer = optim.AdamW(model.parameters(), eps=1e-8, betas=(0.9, 0.999), | |
| lr=cfg.lr, weight_decay=0.05) | |
| # optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=0.000002) | |
| # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70], gamma=0.1, last_epoch=-1) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, | |
| threshold=0.001, threshold_mode='rel', | |
| cooldown=0, min_lr=1e-5, eps=1e-08, verbose=True) | |
| # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr, max_lr=0.003, mode='triangular2', | |
| # step_size_up=12000, cycle_momentum=False) | |
| if cfg.resume: | |
| # ckpt = torch.load(cfg.resume, map_location="cpu") | |
| ckpt = torch.load(cfg.resume) | |
| model.load_state_dict(ckpt['model_state_dict'], strict=False) | |
| # optimizer.load_state_dict(ckpt['optimizer_state_dict']) | |
| # scheduler.load_state_dict(ckpt['scheduler_state_dict']) | |
| # start_epoch = ckpt['epoch'] + 1 | |
| print("checkpoint加载完毕!") | |
| # print(model) | |
| # test, eval, train | |
| trainer = Trainer(cfg, model, optimizer, scheduler, device) | |
| # ---------------score-------------- | |
| if cfg.score: | |
| score_dir = os.path.join('score', cfg.name+"_"+datetime.now().strftime('%Y%m%d')) | |
| if not os.path.exists(score_dir): | |
| os.makedirs(score_dir) | |
| score_file = os.path.join(score_dir, 'scores.txt') | |
| trainer.scoretxt(score_file, 'data/voxsrc2021_blind.txt', 'data/voxsrc2021', cfg.eval_frames) | |
| # trainer.scoretxt(score_file, cfg.dataset.test_list, cfg.dataset.test_path, cfg.eval_frames) | |
| # ---------------eval-------------- | |
| elif cfg.eval: | |
| trainer.test(0, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames) | |
| else: | |
| # ---------------训练-------------- | |
| train_dataset = train_dataset_loader(train_list=cfg.dataset.train_list, | |
| augment=cfg.augment, musan_path=cfg.dataset.musan_path, | |
| rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames, | |
| segment=cfg.segment, train_path=cfg.dataset.train_path) | |
| train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=cfg.nPerSpeaker, | |
| max_seg_per_spk=cfg.max_seg_per_spk, batch_size=cfg.batch_size, | |
| seed=cfg.seed) | |
| # train_dataset = TrainDataset(train_list=cfg.dataset.train_list, | |
| # augment=cfg.augment, musan_path=cfg.dataset.musan_path, | |
| # rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames, | |
| # train_path=cfg.dataset.train_path) | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=cfg.batch_size, | |
| num_workers=cfg.nDataLoaderThread, | |
| sampler=train_sampler, | |
| pin_memory=False, | |
| drop_last=True, | |
| ) | |
| x, y = iter(train_loader).next() | |
| print('x.shape:', x.shape, 'y.shape:', y.shape) | |
| print('x.dtype:', x.dtype, 'y.dtype:', y.dtype) | |
| summary(model, input_size=(tuple(x.shape))) | |
| it = 0 | |
| min_eer = float("inf") | |
| for epoch in range(start_epoch, cfg.max_epoch): | |
| trainer.train(epoch, train_loader) | |
| if epoch % cfg.test_interval == 0: | |
| eer = trainer.test(epoch, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, | |
| cfg.eval_frames) | |
| scheduler.step(eer) | |
| # # -----Clr------ | |
| # if eer < min_eer: | |
| # min_eer = eer | |
| # it = 0 | |
| # | |
| # else: | |
| # it += 1 | |
| # | |
| # if it >= 8: | |
| # lr = cfg.lr * 0.1 | |
| # trainer.scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=cfg.lr, | |
| # mode='triangular2', | |
| # step_size_up=6000, cycle_momentum=False) | |
| # it = 0 | |
| # # -----Clr------ | |
| trainer.save_model(epoch) | |
| print("finishing") | |
| if __name__ == "__main__": | |
| main() | |