xiaoxuezi's picture
app.py
ce7b81a
from torch import optim
import argparse
from datetime import datetime
import wandb
import torch.backends.cudnn as cudnn
from torch import optim
from torch.utils.data import DataLoader
from torchinfo import summary
from timm.scheduler.cosine_lr import CosineLRScheduler
import lossfunction
import net
from DatasetLoader import *
from dataloader import TrainDataset
from SpeakerNet import *
from config import set_cfg, cfg
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config_name", type=str, default="", help="the configs name that will as a base configs")
parser.add_argument("--project", default=None, type=str, help="project name")
parser.add_argument("--name", default=None, type=str, help="run name")
parser.add_argument("--save_dir", default=None, type=str, help="save path")
parser.add_argument("--resume", default=None, type=str, help="resume path")
parser.add_argument("--dataset", default=None, type=str, help="dataset path")
parser.add_argument("--epoch", default=None, type=int, help="max epoch")
parser.add_argument("--test_freq", default=None, type=int, help="frequency test epoch")
parser.add_argument("--batch_size", default=None, type=int, help="batch size")
parser.add_argument("--lr", default=None, type=float, help="learning rate")
parser.add_argument("--seed", default=None, type=int)
parser.add_argument("--wandb", action='store_true', default=False, help='use wandb to log ')
parser.add_argument("--note", type=str, default="", help='wandb note')
parser.add_argument('--eval', dest='eval', action='store_true', default=False, help='Eval only')
parser.add_argument('--score', dest='score', action='store_true', default=False, help='Eval only')
args = parser.parse_args()
return args
def main():
global cfg
args = get_args()
assert args.config_name is not None
if args.config_name:
set_cfg(args.config_name)
cfg.replace(vars(args))
del args
cfg.save_dir = os.path.join(cfg.save_dir, cfg.project + '_' + cfg.name, datetime.now().strftime('%Y%m%d'))
if not os.path.exists(cfg.save_dir):
os.makedirs(cfg.save_dir)
if cfg.wandb:
wandb.login(host="http://49.233.11.7:8080", key="local-7dc64cc63778f0723dc202d2624a97cef7043120")
wandb.init(project=cfg.project, name=cfg.name, config=cfg, save_code=True, notes=cfg.note)
# cudnn related setting
cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.enabled = True
start_epoch = 1
# ---------------模型---------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# model = getattr(net, cfg.model)(cfg.nOut, cfg.encoder_type, cfg.log_input).to(device)
# ------ECAPA_TDNN.yaml------ResNet_TDNN----
model = getattr(net, cfg.model)().to(device)
# loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses, cfg.margin, cfg.scale).to(device)
# ----aamsoftmax----
loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device)
# model = SpeakerUnet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker, segment=cfg.segment)
model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker)
# swin
optimizer = optim.AdamW(model.parameters(), eps=1e-8, betas=(0.9, 0.999),
lr=cfg.lr, weight_decay=0.05)
# optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=0.000002)
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70], gamma=0.1, last_epoch=-1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5,
threshold=0.001, threshold_mode='rel',
cooldown=0, min_lr=1e-5, eps=1e-08, verbose=True)
# scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr, max_lr=0.003, mode='triangular2',
# step_size_up=12000, cycle_momentum=False)
if cfg.resume:
# ckpt = torch.load(cfg.resume, map_location="cpu")
ckpt = torch.load(cfg.resume)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
# optimizer.load_state_dict(ckpt['optimizer_state_dict'])
# scheduler.load_state_dict(ckpt['scheduler_state_dict'])
# start_epoch = ckpt['epoch'] + 1
print("checkpoint加载完毕!")
# print(model)
# test, eval, train
trainer = Trainer(cfg, model, optimizer, scheduler, device)
# ---------------score--------------
if cfg.score:
score_dir = os.path.join('score', cfg.name+"_"+datetime.now().strftime('%Y%m%d'))
if not os.path.exists(score_dir):
os.makedirs(score_dir)
score_file = os.path.join(score_dir, 'scores.txt')
trainer.scoretxt(score_file, 'data/voxsrc2021_blind.txt', 'data/voxsrc2021', cfg.eval_frames)
# trainer.scoretxt(score_file, cfg.dataset.test_list, cfg.dataset.test_path, cfg.eval_frames)
# ---------------eval--------------
elif cfg.eval:
trainer.test(0, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames)
else:
# ---------------训练--------------
train_dataset = train_dataset_loader(train_list=cfg.dataset.train_list,
augment=cfg.augment, musan_path=cfg.dataset.musan_path,
rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
segment=cfg.segment, train_path=cfg.dataset.train_path)
train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=cfg.nPerSpeaker,
max_seg_per_spk=cfg.max_seg_per_spk, batch_size=cfg.batch_size,
seed=cfg.seed)
# train_dataset = TrainDataset(train_list=cfg.dataset.train_list,
# augment=cfg.augment, musan_path=cfg.dataset.musan_path,
# rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
# train_path=cfg.dataset.train_path)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.nDataLoaderThread,
sampler=train_sampler,
pin_memory=False,
drop_last=True,
)
x, y = iter(train_loader).next()
print('x.shape:', x.shape, 'y.shape:', y.shape)
print('x.dtype:', x.dtype, 'y.dtype:', y.dtype)
summary(model, input_size=(tuple(x.shape)))
it = 0
min_eer = float("inf")
for epoch in range(start_epoch, cfg.max_epoch):
trainer.train(epoch, train_loader)
if epoch % cfg.test_interval == 0:
eer = trainer.test(epoch, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread,
cfg.eval_frames)
scheduler.step(eer)
# # -----Clr------
# if eer < min_eer:
# min_eer = eer
# it = 0
#
# else:
# it += 1
#
# if it >= 8:
# lr = cfg.lr * 0.1
# trainer.scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=cfg.lr,
# mode='triangular2',
# step_size_up=6000, cycle_momentum=False)
# it = 0
# # -----Clr------
trainer.save_model(epoch)
print("finishing")
if __name__ == "__main__":
main()