Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import tqdm | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader, RandomSampler, Dataset | |
| from metrics import getAcc | |
| from torch.cuda.amp import autocast, GradScaler | |
| class Trainer: | |
| def __init__(self, args, tb_logger, logger): | |
| self.args = args | |
| self.gpu = torch.device(args.gpu) | |
| self.model = None | |
| self.it = 0 | |
| self.best_eval_acc, self.best_it = 0.0, 0 | |
| # init dataset | |
| self.trainDataset = None | |
| self.trainDataloader = None | |
| self.evalDataset = None | |
| self.evalDataloader = None | |
| # optimizer and scheduler | |
| self.scheduler = None | |
| self.optimizer = None | |
| # loss | |
| self.loss_fn = None | |
| self.weight = None | |
| self.setLoss(args.loss) | |
| self.ignore_index = args.model["letter_size"] | |
| # gradient clipping | |
| if args.clip_grad is not None: | |
| self.clip_grad = True | |
| self.clip_value = args.clip_grad | |
| else: | |
| self.clip_grad = False | |
| if hasattr(args, "label_smoothing") and args.label_smoothing is not None: | |
| self.label_smoothing = float(args.label_smoothing) | |
| else: | |
| self.label_smoothing = 0.0 | |
| # logging | |
| if tb_logger is not None: | |
| self.tb_log = tb_logger | |
| self.print_fn = print if logger is None else logger.info | |
| return | |
| def train(self): | |
| """ | |
| Train The Model | |
| """ | |
| self.model.train() | |
| # for gpu profiling | |
| start_batch = torch.cuda.Event(enable_timing=True) | |
| end_batch = torch.cuda.Event(enable_timing=True) | |
| start_run = torch.cuda.Event(enable_timing=True) | |
| end_run = torch.cuda.Event(enable_timing=True) | |
| scaler = GradScaler() | |
| start_batch.record() | |
| # eval for once | |
| if self.args.resume: | |
| eval_dict = self.evaluate() | |
| print(eval_dict) | |
| tbar = tqdm.tqdm(total=len(self.trainDataloader), colour='BLUE') | |
| for samples, targets, _ in self.trainDataloader: | |
| tbar.update(1) | |
| self.it += 1 | |
| end_batch.record() | |
| torch.cuda.synchronize() | |
| start_run.record() | |
| samples, targets = samples.to(self.gpu), targets.to(self.gpu).long() | |
| with autocast(): | |
| logits = self.model(samples) | |
| loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(), | |
| ignore_index=self.ignore_index, | |
| label_smoothing=self.label_smoothing) | |
| scaler.scale(loss).backward() | |
| if self.clip_grad: | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value) | |
| scaler.step(self.optimizer) | |
| scaler.update() | |
| if self.scheduler is not None: | |
| self.scheduler.step() | |
| self.model.zero_grad() | |
| end_run.record() | |
| torch.cuda.synchronize() | |
| # tensorboard_dict update | |
| tb_dict = {} | |
| tb_dict['train/loss'] = loss.detach().cpu().item() | |
| tb_dict['lr'] = self.optimizer.param_groups[0]['lr'] | |
| tb_dict['GPU/prefecth_time'] = start_batch.elapsed_time(end_batch) / 1000. | |
| tb_dict['GPU/run_time'] = start_run.elapsed_time(end_run) / 1000. | |
| if self.it % self.args.num_eval_iter == 0: | |
| eval_dict = self.evaluate() | |
| tb_dict.update(eval_dict) | |
| save_path = self.args.save_path | |
| if tb_dict['Word/Acc'] > self.best_eval_acc: | |
| self.best_eval_acc = tb_dict['Word/Acc'] | |
| self.best_it = self.it | |
| self.print_fn( | |
| f"\n {self.it} iteration, {tb_dict}, \n BEST_EVAL_ACC: {self.best_eval_acc}, at {self.best_it} iters") | |
| self.print_fn( | |
| f" {self.it} iteration, ACC: {tb_dict['Word/Acc']}\n") | |
| if self.it == self.best_it: | |
| self.save_model('model_best.pth', save_path) | |
| if self.tb_log is not None: | |
| self.tb_log.update(tb_dict, self.it) | |
| del tb_dict | |
| start_batch.record() | |
| eval_dict = self.evaluate() | |
| eval_dict.update({'eval/best_acc': self.best_eval_acc, 'eval/best_it': self.best_it}) | |
| return eval_dict | |
| def evaluate(self, model: nn.Module = None, evalDataset: Dataset = None): | |
| self.print_fn("\n Evaluation!!!") | |
| if model is None: | |
| model = self.model | |
| if evalDataset is not None: | |
| evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0) | |
| else: | |
| evalDataloader = self.evalDataloader | |
| eval_dict = {} | |
| model.eval() | |
| preds_arr = None | |
| targets_arr = None | |
| lengths_arr = None | |
| for samples, targets, lengths in evalDataloader: | |
| samples, targets = samples.to(self.gpu), targets.to(self.gpu) | |
| outputs = model(samples) | |
| preds = torch.max(outputs, dim=2)[1] | |
| if preds_arr is None: | |
| preds_arr = preds.detach().cpu() | |
| targets_arr = targets.detach().cpu() | |
| lengths_arr = lengths.detach().cpu() | |
| else: | |
| preds_arr = torch.concat((preds_arr, preds.detach().cpu())) | |
| targets_arr = torch.concat((targets_arr, targets.detach().cpu())) | |
| lengths_arr = torch.concat((lengths_arr, lengths.detach().cpu())) | |
| wordAcc, charAcc = getAcc(preds_arr, targets_arr, lengths_arr) | |
| eval_dict.update({"Word/Acc": wordAcc, | |
| "Char/Acc": charAcc}) | |
| model.train() | |
| return eval_dict | |
| def save_model(self, save_name, save_path): | |
| save_filename = os.path.join(save_path, save_name) | |
| self.model.eval() | |
| save_dict = {"model": self.model.state_dict(), | |
| 'optimizer': self.optimizer.state_dict(), | |
| 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None, | |
| 'it': self.it} | |
| torch.save(save_dict, save_filename) | |
| self.model.train() | |
| self.print_fn(f"model saved: {save_filename}\n") | |
| def save_baseLearner(self, save_name, save_path, trainIndexes): | |
| save_filename = os.path.join(save_path, save_name) | |
| self.model.eval() | |
| save_dict = {"model": self.model.state_dict(), | |
| 'optimizer': self.optimizer.state_dict(), | |
| 'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None, | |
| 'trainIndexes': trainIndexes, | |
| 'it': self.it} | |
| torch.save(save_dict, save_filename) | |
| self.model.train() | |
| self.print_fn(f"model saved: {save_filename}\n") | |
| def load_model(self, load_dir, load_name): | |
| """ | |
| load saved model a | |
| :param load_dir: directory of loading model | |
| :param load_name: model name | |
| """ | |
| load_path = os.path.join(load_dir, load_name) | |
| checkpoint = torch.load(load_path) | |
| self.model.load_state_dict(checkpoint['model']) | |
| self.optimizer.load_state_dict(checkpoint['optimizer']) | |
| if checkpoint['scheduler'] is not None: | |
| self.scheduler.load_state_dict(checkpoint['scheduler']) | |
| self.it = checkpoint['it'] | |
| self.print_fn(f'model loaded from {load_path}') | |
| def set_optimizer(self, optimizer, scheduler=None): | |
| """ | |
| set optimizer and scheduler | |
| :param optimizer: optimizer | |
| :param scheduler: scheduler | |
| """ | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| def setModel(self, model): | |
| """ | |
| set model | |
| :param model: model | |
| """ | |
| self.model = model.cuda(self.gpu) | |
| def setDatasets(self, trainDataset, evalDataset): | |
| """ | |
| set train and evaluation datasets and dataloaders | |
| :param trainDataset: train dataset | |
| :param evalDataset: evaluation dataset | |
| """ | |
| self.print_fn(f"\n Num Train Labeled Sample : {len(trainDataset)}\n Num Val Sample : {len(evalDataset)}") | |
| self.trainDataset = trainDataset | |
| self.evalDataset = evalDataset | |
| self.trainDataloader = DataLoader(trainDataset, batch_size=self.args.batch_size, | |
| sampler=RandomSampler(data_source=trainDataset, | |
| replacement=True, | |
| num_samples=self.args.iter * self.args.batch_size), | |
| num_workers=self.args.num_workers, drop_last=True, pin_memory=True) | |
| self.evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0, | |
| pin_memory=True) | |
| def setLoss(self, loss_function: dict): | |
| """ | |
| set loss function | |
| :param loss_function: loss function arguments | |
| """ | |
| if loss_function["name"] == 'CrossEntropyLoss': | |
| self.loss_fn = nn.CrossEntropyLoss(label_smoothing=loss_function["label_smoothing"]).cuda(self.gpu) | |
| else: | |
| raise Exception(f"Unknown Loss Function : {loss_function}") | |