MViT-TR / Model /trainer.py
serdaryildiz's picture
Upload 31 files
9b6af3b verified
import os
import torch
import tqdm
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, RandomSampler, Dataset
from metrics import getAcc
from torch.cuda.amp import autocast, GradScaler
class Trainer:
def __init__(self, args, tb_logger, logger):
self.args = args
self.gpu = torch.device(args.gpu)
self.model = None
self.it = 0
self.best_eval_acc, self.best_it = 0.0, 0
# init dataset
self.trainDataset = None
self.trainDataloader = None
self.evalDataset = None
self.evalDataloader = None
# optimizer and scheduler
self.scheduler = None
self.optimizer = None
# loss
self.loss_fn = None
self.weight = None
self.setLoss(args.loss)
self.ignore_index = args.model["letter_size"]
# gradient clipping
if args.clip_grad is not None:
self.clip_grad = True
self.clip_value = args.clip_grad
else:
self.clip_grad = False
if hasattr(args, "label_smoothing") and args.label_smoothing is not None:
self.label_smoothing = float(args.label_smoothing)
else:
self.label_smoothing = 0.0
# logging
if tb_logger is not None:
self.tb_log = tb_logger
self.print_fn = print if logger is None else logger.info
return
def train(self):
"""
Train The Model
"""
self.model.train()
# for gpu profiling
start_batch = torch.cuda.Event(enable_timing=True)
end_batch = torch.cuda.Event(enable_timing=True)
start_run = torch.cuda.Event(enable_timing=True)
end_run = torch.cuda.Event(enable_timing=True)
scaler = GradScaler()
start_batch.record()
# eval for once
if self.args.resume:
eval_dict = self.evaluate()
print(eval_dict)
tbar = tqdm.tqdm(total=len(self.trainDataloader), colour='BLUE')
for samples, targets, _ in self.trainDataloader:
tbar.update(1)
self.it += 1
end_batch.record()
torch.cuda.synchronize()
start_run.record()
samples, targets = samples.to(self.gpu), targets.to(self.gpu).long()
with autocast():
logits = self.model(samples)
loss = F.cross_entropy(logits.flatten(end_dim=1), targets.flatten(),
ignore_index=self.ignore_index,
label_smoothing=self.label_smoothing)
scaler.scale(loss).backward()
if self.clip_grad:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
scaler.step(self.optimizer)
scaler.update()
if self.scheduler is not None:
self.scheduler.step()
self.model.zero_grad()
end_run.record()
torch.cuda.synchronize()
# tensorboard_dict update
tb_dict = {}
tb_dict['train/loss'] = loss.detach().cpu().item()
tb_dict['lr'] = self.optimizer.param_groups[0]['lr']
tb_dict['GPU/prefecth_time'] = start_batch.elapsed_time(end_batch) / 1000.
tb_dict['GPU/run_time'] = start_run.elapsed_time(end_run) / 1000.
if self.it % self.args.num_eval_iter == 0:
eval_dict = self.evaluate()
tb_dict.update(eval_dict)
save_path = self.args.save_path
if tb_dict['Word/Acc'] > self.best_eval_acc:
self.best_eval_acc = tb_dict['Word/Acc']
self.best_it = self.it
self.print_fn(
f"\n {self.it} iteration, {tb_dict}, \n BEST_EVAL_ACC: {self.best_eval_acc}, at {self.best_it} iters")
self.print_fn(
f" {self.it} iteration, ACC: {tb_dict['Word/Acc']}\n")
if self.it == self.best_it:
self.save_model('model_best.pth', save_path)
if self.tb_log is not None:
self.tb_log.update(tb_dict, self.it)
del tb_dict
start_batch.record()
eval_dict = self.evaluate()
eval_dict.update({'eval/best_acc': self.best_eval_acc, 'eval/best_it': self.best_it})
return eval_dict
@torch.no_grad()
def evaluate(self, model: nn.Module = None, evalDataset: Dataset = None):
self.print_fn("\n Evaluation!!!")
if model is None:
model = self.model
if evalDataset is not None:
evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0)
else:
evalDataloader = self.evalDataloader
eval_dict = {}
model.eval()
preds_arr = None
targets_arr = None
lengths_arr = None
for samples, targets, lengths in evalDataloader:
samples, targets = samples.to(self.gpu), targets.to(self.gpu)
outputs = model(samples)
preds = torch.max(outputs, dim=2)[1]
if preds_arr is None:
preds_arr = preds.detach().cpu()
targets_arr = targets.detach().cpu()
lengths_arr = lengths.detach().cpu()
else:
preds_arr = torch.concat((preds_arr, preds.detach().cpu()))
targets_arr = torch.concat((targets_arr, targets.detach().cpu()))
lengths_arr = torch.concat((lengths_arr, lengths.detach().cpu()))
wordAcc, charAcc = getAcc(preds_arr, targets_arr, lengths_arr)
eval_dict.update({"Word/Acc": wordAcc,
"Char/Acc": charAcc})
model.train()
return eval_dict
def save_model(self, save_name, save_path):
save_filename = os.path.join(save_path, save_name)
self.model.eval()
save_dict = {"model": self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
'it': self.it}
torch.save(save_dict, save_filename)
self.model.train()
self.print_fn(f"model saved: {save_filename}\n")
def save_baseLearner(self, save_name, save_path, trainIndexes):
save_filename = os.path.join(save_path, save_name)
self.model.eval()
save_dict = {"model": self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict() if self.scheduler is not None else None,
'trainIndexes': trainIndexes,
'it': self.it}
torch.save(save_dict, save_filename)
self.model.train()
self.print_fn(f"model saved: {save_filename}\n")
def load_model(self, load_dir, load_name):
"""
load saved model a
:param load_dir: directory of loading model
:param load_name: model name
"""
load_path = os.path.join(load_dir, load_name)
checkpoint = torch.load(load_path)
self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
if checkpoint['scheduler'] is not None:
self.scheduler.load_state_dict(checkpoint['scheduler'])
self.it = checkpoint['it']
self.print_fn(f'model loaded from {load_path}')
def set_optimizer(self, optimizer, scheduler=None):
"""
set optimizer and scheduler
:param optimizer: optimizer
:param scheduler: scheduler
"""
self.optimizer = optimizer
self.scheduler = scheduler
def setModel(self, model):
"""
set model
:param model: model
"""
self.model = model.cuda(self.gpu)
def setDatasets(self, trainDataset, evalDataset):
"""
set train and evaluation datasets and dataloaders
:param trainDataset: train dataset
:param evalDataset: evaluation dataset
"""
self.print_fn(f"\n Num Train Labeled Sample : {len(trainDataset)}\n Num Val Sample : {len(evalDataset)}")
self.trainDataset = trainDataset
self.evalDataset = evalDataset
self.trainDataloader = DataLoader(trainDataset, batch_size=self.args.batch_size,
sampler=RandomSampler(data_source=trainDataset,
replacement=True,
num_samples=self.args.iter * self.args.batch_size),
num_workers=self.args.num_workers, drop_last=True, pin_memory=True)
self.evalDataloader = DataLoader(evalDataset, self.args.eval_batch_size, shuffle=False, num_workers=0,
pin_memory=True)
def setLoss(self, loss_function: dict):
"""
set loss function
:param loss_function: loss function arguments
"""
if loss_function["name"] == 'CrossEntropyLoss':
self.loss_fn = nn.CrossEntropyLoss(label_smoothing=loss_function["label_smoothing"]).cuda(self.gpu)
else:
raise Exception(f"Unknown Loss Function : {loss_function}")