| import logging |
|
|
| import numpy as np |
| import torch |
| from time import time |
| from torch import optim |
| from tqdm import tqdm |
|
|
| from utils import ensure_dir,set_color,get_local_time |
| import os |
|
|
| class Trainer(object): |
|
|
| def __init__(self, args, model): |
| self.args = args |
| self.model = model |
| self.logger = logging.getLogger() |
|
|
| self.lr = args.lr |
| self.learner = args.learner |
| self.weight_decay = args.weight_decay |
| self.epochs = args.epochs |
| self.eval_step = min(args.eval_step, self.epochs) |
| self.device = args.device |
| self.device = torch.device(self.device) |
| self.ckpt_dir = args.ckpt_dir |
| saved_model_dir = "{}".format(get_local_time()) |
| self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir) |
| ensure_dir(self.ckpt_dir) |
|
|
| self.best_loss = np.inf |
| self.best_collision_rate = np.inf |
| self.best_loss_ckpt = "best_loss_model.pth" |
| self.best_collision_ckpt = "best_collision_model.pth" |
| self.optimizer = self._build_optimizer() |
| self.model = self.model.to(self.device) |
|
|
| def _build_optimizer(self): |
|
|
| params = self.model.parameters() |
| learner = self.learner |
| learning_rate = self.lr |
| weight_decay = self.weight_decay |
|
|
| if learner.lower() == "adam": |
| optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) |
| elif learner.lower() == "sgd": |
| optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) |
| elif learner.lower() == "adagrad": |
| optimizer = optim.Adagrad( |
| params, lr=learning_rate, weight_decay=weight_decay |
| ) |
| for state in optimizer.state.values(): |
| for k, v in state.items(): |
| if torch.is_tensor(v): |
| state[k] = v.to(self.device) |
| elif learner.lower() == "rmsprop": |
| optimizer = optim.RMSprop( |
| params, lr=learning_rate, weight_decay=weight_decay |
| ) |
| elif learner.lower() == 'adamw': |
| optimizer = optim.AdamW( |
| params, lr=learning_rate, weight_decay=weight_decay |
| ) |
| else: |
| self.logger.warning( |
| "Received unrecognized optimizer, set default Adam optimizer" |
| ) |
| optimizer = optim.Adam(params, lr=learning_rate) |
| return optimizer |
| def _check_nan(self, loss): |
| if torch.isnan(loss): |
| raise ValueError("Training loss is nan") |
|
|
| def _train_epoch(self, train_data, epoch_idx): |
|
|
| self.model.train() |
|
|
| total_loss = 0 |
| total_recon_loss = 0 |
| iter_data = tqdm( |
| train_data, |
| total=len(train_data), |
| ncols=100, |
| desc=set_color(f"Train {epoch_idx}","pink"), |
| ) |
|
|
| for batch_idx, data in enumerate(iter_data): |
| data = data.to(self.device) |
| self.optimizer.zero_grad() |
| out, rq_loss, indices = self.model(data) |
| loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data) |
| self._check_nan(loss) |
| loss.backward() |
| self.optimizer.step() |
| total_loss += loss.item() |
| total_recon_loss += loss_recon.item() |
|
|
| return total_loss, total_recon_loss |
|
|
| @torch.no_grad() |
| def _valid_epoch(self, valid_data): |
|
|
| self.model.eval() |
|
|
| iter_data =tqdm( |
| valid_data, |
| total=len(valid_data), |
| ncols=100, |
| desc=set_color(f"Evaluate ", "pink"), |
| ) |
| indices_set = set() |
| num_sample = 0 |
| for batch_idx, data in enumerate(iter_data): |
| num_sample += len(data) |
| data = data.to(self.device) |
| indices = self.model.get_indices(data) |
| indices = indices.view(-1,indices.shape[-1]).cpu().numpy() |
| for index in indices: |
| code = "-".join([str(int(_)) for _ in index]) |
| indices_set.add(code) |
|
|
| collision_rate = (num_sample - len(indices_set))/num_sample |
|
|
| return collision_rate |
|
|
| def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None): |
|
|
| ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \ |
| else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate)) |
| state = { |
| "args": self.args, |
| "epoch": epoch, |
| "best_loss": self.best_loss, |
| "best_collision_rate": self.best_collision_rate, |
| "state_dict": self.model.state_dict(), |
| "optimizer": self.optimizer.state_dict(), |
| } |
| torch.save(state, ckpt_path, pickle_protocol=4) |
|
|
| self.logger.info( |
| set_color("Saving current", "blue") + f": {ckpt_path}" |
| ) |
|
|
| def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss): |
| train_loss_output = ( |
| set_color("epoch %d training", "green") |
| + " [" |
| + set_color("time", "blue") |
| + ": %.2fs, " |
| ) % (epoch_idx, e_time - s_time) |
| train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss |
| train_loss_output +=", " |
| train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss |
| return train_loss_output + "]" |
|
|
|
|
| def fit(self, data): |
|
|
| cur_eval_step = 0 |
|
|
| for epoch_idx in range(self.epochs): |
| |
| training_start_time = time() |
| train_loss, train_recon_loss = self._train_epoch(data, epoch_idx) |
| training_end_time = time() |
| train_loss_output = self._generate_train_loss_output( |
| epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss |
| ) |
| self.logger.info(train_loss_output) |
|
|
| if train_loss < self.best_loss: |
| self.best_loss = train_loss |
| |
|
|
| |
| if (epoch_idx + 1) % self.eval_step == 0: |
| valid_start_time = time() |
| collision_rate = self._valid_epoch(data) |
|
|
| if collision_rate < self.best_collision_rate: |
| self.best_collision_rate = collision_rate |
| cur_eval_step = 0 |
| self._save_checkpoint(epoch_idx, collision_rate=collision_rate, |
| ckpt_file=self.best_collision_ckpt) |
| else: |
| cur_eval_step += 1 |
|
|
|
|
| valid_end_time = time() |
| valid_score_output = ( |
| set_color("epoch %d evaluating", "green") |
| + " [" |
| + set_color("time", "blue") |
| + ": %.2fs, " |
| + set_color("collision_rate", "blue") |
| + ": %f]" |
| ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate) |
|
|
| self.logger.info(valid_score_output) |
| if epoch_idx>1000: |
| self._save_checkpoint(epoch_idx, collision_rate=collision_rate) |
|
|
|
|
| return self.best_loss, self.best_collision_rate |
|
|
|
|
|
|
|
|
|
|