from torch import optim import argparse from datetime import datetime import wandb import torch.backends.cudnn as cudnn from torch import optim from torch.utils.data import DataLoader from torchinfo import summary from timm.scheduler.cosine_lr import CosineLRScheduler import lossfunction import net from DatasetLoader import * from dataloader import TrainDataset from SpeakerNet import * from config import set_cfg, cfg def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config_name", type=str, default="", help="the configs name that will as a base configs") parser.add_argument("--project", default=None, type=str, help="project name") parser.add_argument("--name", default=None, type=str, help="run name") parser.add_argument("--save_dir", default=None, type=str, help="save path") parser.add_argument("--resume", default=None, type=str, help="resume path") parser.add_argument("--dataset", default=None, type=str, help="dataset path") parser.add_argument("--epoch", default=None, type=int, help="max epoch") parser.add_argument("--test_freq", default=None, type=int, help="frequency test epoch") parser.add_argument("--batch_size", default=None, type=int, help="batch size") parser.add_argument("--lr", default=None, type=float, help="learning rate") parser.add_argument("--seed", default=None, type=int) parser.add_argument("--wandb", action='store_true', default=False, help='use wandb to log ') parser.add_argument("--note", type=str, default="", help='wandb note') parser.add_argument('--eval', dest='eval', action='store_true', default=False, help='Eval only') parser.add_argument('--score', dest='score', action='store_true', default=False, help='Eval only') args = parser.parse_args() return args def main(): global cfg args = get_args() assert args.config_name is not None if args.config_name: set_cfg(args.config_name) cfg.replace(vars(args)) del args cfg.save_dir = os.path.join(cfg.save_dir, cfg.project + '_' + cfg.name, datetime.now().strftime('%Y%m%d')) if not os.path.exists(cfg.save_dir): os.makedirs(cfg.save_dir) if cfg.wandb: wandb.login(host="http://49.233.11.7:8080", key="local-7dc64cc63778f0723dc202d2624a97cef7043120") wandb.init(project=cfg.project, name=cfg.name, config=cfg, save_code=True, notes=cfg.note) # cudnn related setting cudnn.benchmark = True torch.backends.cudnn.deterministic = False torch.backends.cudnn.enabled = True start_epoch = 1 # ---------------模型--------------- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = torch.device("cpu") # model = getattr(net, cfg.model)(cfg.nOut, cfg.encoder_type, cfg.log_input).to(device) # ------ECAPA_TDNN.yaml------ResNet_TDNN---- model = getattr(net, cfg.model)().to(device) # loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses, cfg.margin, cfg.scale).to(device) # ----aamsoftmax---- loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device) # model = SpeakerUnet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker, segment=cfg.segment) model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker) # swin optimizer = optim.AdamW(model.parameters(), eps=1e-8, betas=(0.9, 0.999), lr=cfg.lr, weight_decay=0.05) # optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=0.000002) # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70], gamma=0.1, last_epoch=-1) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-5, eps=1e-08, verbose=True) # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr, max_lr=0.003, mode='triangular2', # step_size_up=12000, cycle_momentum=False) if cfg.resume: # ckpt = torch.load(cfg.resume, map_location="cpu") ckpt = torch.load(cfg.resume) model.load_state_dict(ckpt['model_state_dict'], strict=False) # optimizer.load_state_dict(ckpt['optimizer_state_dict']) # scheduler.load_state_dict(ckpt['scheduler_state_dict']) # start_epoch = ckpt['epoch'] + 1 print("checkpoint加载完毕!") # print(model) # test, eval, train trainer = Trainer(cfg, model, optimizer, scheduler, device) # ---------------score-------------- if cfg.score: score_dir = os.path.join('score', cfg.name+"_"+datetime.now().strftime('%Y%m%d')) if not os.path.exists(score_dir): os.makedirs(score_dir) score_file = os.path.join(score_dir, 'scores.txt') trainer.scoretxt(score_file, 'data/voxsrc2021_blind.txt', 'data/voxsrc2021', cfg.eval_frames) # trainer.scoretxt(score_file, cfg.dataset.test_list, cfg.dataset.test_path, cfg.eval_frames) # ---------------eval-------------- elif cfg.eval: trainer.test(0, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames) else: # ---------------训练-------------- train_dataset = train_dataset_loader(train_list=cfg.dataset.train_list, augment=cfg.augment, musan_path=cfg.dataset.musan_path, rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames, segment=cfg.segment, train_path=cfg.dataset.train_path) train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=cfg.nPerSpeaker, max_seg_per_spk=cfg.max_seg_per_spk, batch_size=cfg.batch_size, seed=cfg.seed) # train_dataset = TrainDataset(train_list=cfg.dataset.train_list, # augment=cfg.augment, musan_path=cfg.dataset.musan_path, # rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames, # train_path=cfg.dataset.train_path) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg.batch_size, num_workers=cfg.nDataLoaderThread, sampler=train_sampler, pin_memory=False, drop_last=True, ) x, y = iter(train_loader).next() print('x.shape:', x.shape, 'y.shape:', y.shape) print('x.dtype:', x.dtype, 'y.dtype:', y.dtype) summary(model, input_size=(tuple(x.shape))) it = 0 min_eer = float("inf") for epoch in range(start_epoch, cfg.max_epoch): trainer.train(epoch, train_loader) if epoch % cfg.test_interval == 0: eer = trainer.test(epoch, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames) scheduler.step(eer) # # -----Clr------ # if eer < min_eer: # min_eer = eer # it = 0 # # else: # it += 1 # # if it >= 8: # lr = cfg.lr * 0.1 # trainer.scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=cfg.lr, # mode='triangular2', # step_size_up=6000, cycle_momentum=False) # it = 0 # # -----Clr------ trainer.save_model(epoch) print("finishing") if __name__ == "__main__": main()