| | import logging |
| |
|
| | import numpy as np |
| | import torch |
| | from time import time |
| | from torch import optim |
| | from tqdm import tqdm |
| |
|
| | from utils import ensure_dir,set_color,get_local_time |
| | import os |
| |
|
| | class Trainer(object): |
| |
|
| | def __init__(self, args, model): |
| | self.args = args |
| | self.model = model |
| | self.logger = logging.getLogger() |
| |
|
| | self.lr = args.lr |
| | self.learner = args.learner |
| | self.weight_decay = args.weight_decay |
| | self.epochs = args.epochs |
| | self.eval_step = min(args.eval_step, self.epochs) |
| | self.device = args.device |
| | self.device = torch.device(self.device) |
| | self.ckpt_dir = args.ckpt_dir |
| | saved_model_dir = "{}".format(get_local_time()) |
| | self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir) |
| | ensure_dir(self.ckpt_dir) |
| |
|
| | self.best_loss = np.inf |
| | self.best_collision_rate = np.inf |
| | self.best_loss_ckpt = "best_loss_model.pth" |
| | self.best_collision_ckpt = "best_collision_model.pth" |
| | self.optimizer = self._build_optimizer() |
| | self.model = self.model.to(self.device) |
| |
|
| | def _build_optimizer(self): |
| |
|
| | params = self.model.parameters() |
| | learner = self.learner |
| | learning_rate = self.lr |
| | weight_decay = self.weight_decay |
| |
|
| | if learner.lower() == "adam": |
| | optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) |
| | elif learner.lower() == "sgd": |
| | optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) |
| | elif learner.lower() == "adagrad": |
| | optimizer = optim.Adagrad( |
| | params, lr=learning_rate, weight_decay=weight_decay |
| | ) |
| | for state in optimizer.state.values(): |
| | for k, v in state.items(): |
| | if torch.is_tensor(v): |
| | state[k] = v.to(self.device) |
| | elif learner.lower() == "rmsprop": |
| | optimizer = optim.RMSprop( |
| | params, lr=learning_rate, weight_decay=weight_decay |
| | ) |
| | elif learner.lower() == 'adamw': |
| | optimizer = optim.AdamW( |
| | params, lr=learning_rate, weight_decay=weight_decay |
| | ) |
| | else: |
| | self.logger.warning( |
| | "Received unrecognized optimizer, set default Adam optimizer" |
| | ) |
| | optimizer = optim.Adam(params, lr=learning_rate) |
| | return optimizer |
| | def _check_nan(self, loss): |
| | if torch.isnan(loss): |
| | raise ValueError("Training loss is nan") |
| |
|
| | def _train_epoch(self, train_data, epoch_idx): |
| |
|
| | self.model.train() |
| |
|
| | total_loss = 0 |
| | total_recon_loss = 0 |
| | iter_data = tqdm( |
| | train_data, |
| | total=len(train_data), |
| | ncols=100, |
| | desc=set_color(f"Train {epoch_idx}","pink"), |
| | ) |
| |
|
| | for batch_idx, data in enumerate(iter_data): |
| | data = data.to(self.device) |
| | self.optimizer.zero_grad() |
| | out, rq_loss, indices = self.model(data) |
| | loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data) |
| | self._check_nan(loss) |
| | loss.backward() |
| | self.optimizer.step() |
| | total_loss += loss.item() |
| | total_recon_loss += loss_recon.item() |
| |
|
| | return total_loss, total_recon_loss |
| |
|
| | @torch.no_grad() |
| | def _valid_epoch(self, valid_data): |
| |
|
| | self.model.eval() |
| |
|
| | iter_data =tqdm( |
| | valid_data, |
| | total=len(valid_data), |
| | ncols=100, |
| | desc=set_color(f"Evaluate ", "pink"), |
| | ) |
| | indices_set = set() |
| | num_sample = 0 |
| | for batch_idx, data in enumerate(iter_data): |
| | num_sample += len(data) |
| | data = data.to(self.device) |
| | indices = self.model.get_indices(data) |
| | indices = indices.view(-1,indices.shape[-1]).cpu().numpy() |
| | for index in indices: |
| | code = "-".join([str(int(_)) for _ in index]) |
| | indices_set.add(code) |
| |
|
| | collision_rate = (num_sample - len(indices_set))/num_sample |
| |
|
| | return collision_rate |
| |
|
| | def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None): |
| |
|
| | ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \ |
| | else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate)) |
| | state = { |
| | "args": self.args, |
| | "epoch": epoch, |
| | "best_loss": self.best_loss, |
| | "best_collision_rate": self.best_collision_rate, |
| | "state_dict": self.model.state_dict(), |
| | "optimizer": self.optimizer.state_dict(), |
| | } |
| | torch.save(state, ckpt_path, pickle_protocol=4) |
| |
|
| | self.logger.info( |
| | set_color("Saving current", "blue") + f": {ckpt_path}" |
| | ) |
| |
|
| | def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss): |
| | train_loss_output = ( |
| | set_color("epoch %d training", "green") |
| | + " [" |
| | + set_color("time", "blue") |
| | + ": %.2fs, " |
| | ) % (epoch_idx, e_time - s_time) |
| | train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss |
| | train_loss_output +=", " |
| | train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss |
| | return train_loss_output + "]" |
| |
|
| |
|
| | def fit(self, data): |
| |
|
| | cur_eval_step = 0 |
| |
|
| | for epoch_idx in range(self.epochs): |
| | |
| | training_start_time = time() |
| | train_loss, train_recon_loss = self._train_epoch(data, epoch_idx) |
| | training_end_time = time() |
| | train_loss_output = self._generate_train_loss_output( |
| | epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss |
| | ) |
| | self.logger.info(train_loss_output) |
| |
|
| | if train_loss < self.best_loss: |
| | self.best_loss = train_loss |
| | |
| |
|
| | |
| | if (epoch_idx + 1) % self.eval_step == 0: |
| | valid_start_time = time() |
| | collision_rate = self._valid_epoch(data) |
| |
|
| | if collision_rate < self.best_collision_rate: |
| | self.best_collision_rate = collision_rate |
| | cur_eval_step = 0 |
| | self._save_checkpoint(epoch_idx, collision_rate=collision_rate, |
| | ckpt_file=self.best_collision_ckpt) |
| | else: |
| | cur_eval_step += 1 |
| |
|
| |
|
| | valid_end_time = time() |
| | valid_score_output = ( |
| | set_color("epoch %d evaluating", "green") |
| | + " [" |
| | + set_color("time", "blue") |
| | + ": %.2fs, " |
| | + set_color("collision_rate", "blue") |
| | + ": %f]" |
| | ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate) |
| |
|
| | self.logger.info(valid_score_output) |
| | if epoch_idx>1000: |
| | self._save_checkpoint(epoch_idx, collision_rate=collision_rate) |
| |
|
| |
|
| | return self.best_loss, self.best_collision_rate |
| |
|
| |
|
| |
|
| |
|
| |
|