import logging import numpy as np import torch from time import time from torch import optim from tqdm import tqdm from utils import ensure_dir,set_color,get_local_time import os class Trainer(object): def __init__(self, args, model): self.args = args self.model = model self.logger = logging.getLogger() self.lr = args.lr self.learner = args.learner self.weight_decay = args.weight_decay self.epochs = args.epochs self.eval_step = min(args.eval_step, self.epochs) self.device = args.device self.device = torch.device(self.device) self.ckpt_dir = args.ckpt_dir saved_model_dir = "{}".format(get_local_time()) self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir) ensure_dir(self.ckpt_dir) self.best_loss = np.inf self.best_collision_rate = np.inf self.best_loss_ckpt = "best_loss_model.pth" self.best_collision_ckpt = "best_collision_model.pth" self.optimizer = self._build_optimizer() self.model = self.model.to(self.device) def _build_optimizer(self): params = self.model.parameters() learner = self.learner learning_rate = self.lr weight_decay = self.weight_decay if learner.lower() == "adam": optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) elif learner.lower() == "sgd": optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) elif learner.lower() == "adagrad": optimizer = optim.Adagrad( params, lr=learning_rate, weight_decay=weight_decay ) for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(self.device) elif learner.lower() == "rmsprop": optimizer = optim.RMSprop( params, lr=learning_rate, weight_decay=weight_decay ) elif learner.lower() == 'adamw': optimizer = optim.AdamW( params, lr=learning_rate, weight_decay=weight_decay ) else: self.logger.warning( "Received unrecognized optimizer, set default Adam optimizer" ) optimizer = optim.Adam(params, lr=learning_rate) return optimizer def _check_nan(self, loss): if torch.isnan(loss): raise ValueError("Training loss is nan") def _train_epoch(self, train_data, epoch_idx): self.model.train() total_loss = 0 total_recon_loss = 0 iter_data = tqdm( train_data, total=len(train_data), ncols=100, desc=set_color(f"Train {epoch_idx}","pink"), ) for batch_idx, data in enumerate(iter_data): data = data.to(self.device) self.optimizer.zero_grad() out, rq_loss, indices = self.model(data) loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data) self._check_nan(loss) loss.backward() self.optimizer.step() total_loss += loss.item() total_recon_loss += loss_recon.item() return total_loss, total_recon_loss @torch.no_grad() def _valid_epoch(self, valid_data): self.model.eval() iter_data =tqdm( valid_data, total=len(valid_data), ncols=100, desc=set_color(f"Evaluate ", "pink"), ) indices_set = set() num_sample = 0 for batch_idx, data in enumerate(iter_data): num_sample += len(data) data = data.to(self.device) indices = self.model.get_indices(data) indices = indices.view(-1,indices.shape[-1]).cpu().numpy() for index in indices: code = "-".join([str(int(_)) for _ in index]) indices_set.add(code) collision_rate = (num_sample - len(indices_set))/num_sample return collision_rate def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None): ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \ else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate)) state = { "args": self.args, "epoch": epoch, "best_loss": self.best_loss, "best_collision_rate": self.best_collision_rate, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), } torch.save(state, ckpt_path, pickle_protocol=4) self.logger.info( set_color("Saving current", "blue") + f": {ckpt_path}" ) def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss): train_loss_output = ( set_color("epoch %d training", "green") + " [" + set_color("time", "blue") + ": %.2fs, " ) % (epoch_idx, e_time - s_time) train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss train_loss_output +=", " train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss return train_loss_output + "]" def fit(self, data): cur_eval_step = 0 for epoch_idx in range(self.epochs): # train training_start_time = time() train_loss, train_recon_loss = self._train_epoch(data, epoch_idx) training_end_time = time() train_loss_output = self._generate_train_loss_output( epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss ) self.logger.info(train_loss_output) if train_loss < self.best_loss: self.best_loss = train_loss # self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt) # eval if (epoch_idx + 1) % self.eval_step == 0: valid_start_time = time() collision_rate = self._valid_epoch(data) if collision_rate < self.best_collision_rate: self.best_collision_rate = collision_rate cur_eval_step = 0 self._save_checkpoint(epoch_idx, collision_rate=collision_rate, ckpt_file=self.best_collision_ckpt) else: cur_eval_step += 1 valid_end_time = time() valid_score_output = ( set_color("epoch %d evaluating", "green") + " [" + set_color("time", "blue") + ": %.2fs, " + set_color("collision_rate", "blue") + ": %f]" ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate) self.logger.info(valid_score_output) if epoch_idx>1000: self._save_checkpoint(epoch_idx, collision_rate=collision_rate) return self.best_loss, self.best_collision_rate