Benchmark-Single / index /trainer.py
Junyin's picture
Add files using upload-large-folder tool
811e03d verified
import logging
import numpy as np
import torch
from time import time
from torch import optim
from tqdm import tqdm
from utils import ensure_dir,set_color,get_local_time
import os
class Trainer(object):
def __init__(self, args, model):
self.args = args
self.model = model
self.logger = logging.getLogger()
self.lr = args.lr
self.learner = args.learner
self.weight_decay = args.weight_decay
self.epochs = args.epochs
self.eval_step = min(args.eval_step, self.epochs)
self.device = args.device
self.device = torch.device(self.device)
self.ckpt_dir = args.ckpt_dir
saved_model_dir = "{}".format(get_local_time())
self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir)
ensure_dir(self.ckpt_dir)
self.best_loss = np.inf
self.best_collision_rate = np.inf
self.best_loss_ckpt = "best_loss_model.pth"
self.best_collision_ckpt = "best_collision_model.pth"
self.optimizer = self._build_optimizer()
self.model = self.model.to(self.device)
def _build_optimizer(self):
params = self.model.parameters()
learner = self.learner
learning_rate = self.lr
weight_decay = self.weight_decay
if learner.lower() == "adam":
optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
elif learner.lower() == "sgd":
optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
elif learner.lower() == "adagrad":
optimizer = optim.Adagrad(
params, lr=learning_rate, weight_decay=weight_decay
)
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.device)
elif learner.lower() == "rmsprop":
optimizer = optim.RMSprop(
params, lr=learning_rate, weight_decay=weight_decay
)
elif learner.lower() == 'adamw':
optimizer = optim.AdamW(
params, lr=learning_rate, weight_decay=weight_decay
)
else:
self.logger.warning(
"Received unrecognized optimizer, set default Adam optimizer"
)
optimizer = optim.Adam(params, lr=learning_rate)
return optimizer
def _check_nan(self, loss):
if torch.isnan(loss):
raise ValueError("Training loss is nan")
def _train_epoch(self, train_data, epoch_idx):
self.model.train()
total_loss = 0
total_recon_loss = 0
iter_data = tqdm(
train_data,
total=len(train_data),
ncols=100,
desc=set_color(f"Train {epoch_idx}","pink"),
)
for batch_idx, data in enumerate(iter_data):
data = data.to(self.device)
self.optimizer.zero_grad()
out, rq_loss, indices = self.model(data)
loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data)
self._check_nan(loss)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
total_recon_loss += loss_recon.item()
return total_loss, total_recon_loss
@torch.no_grad()
def _valid_epoch(self, valid_data):
self.model.eval()
iter_data =tqdm(
valid_data,
total=len(valid_data),
ncols=100,
desc=set_color(f"Evaluate ", "pink"),
)
indices_set = set()
num_sample = 0
for batch_idx, data in enumerate(iter_data):
num_sample += len(data)
data = data.to(self.device)
indices = self.model.get_indices(data)
indices = indices.view(-1,indices.shape[-1]).cpu().numpy()
for index in indices:
code = "-".join([str(int(_)) for _ in index])
indices_set.add(code)
collision_rate = (num_sample - len(indices_set))/num_sample
return collision_rate
def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None):
ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \
else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate))
state = {
"args": self.args,
"epoch": epoch,
"best_loss": self.best_loss,
"best_collision_rate": self.best_collision_rate,
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
torch.save(state, ckpt_path, pickle_protocol=4)
self.logger.info(
set_color("Saving current", "blue") + f": {ckpt_path}"
)
def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss):
train_loss_output = (
set_color("epoch %d training", "green")
+ " ["
+ set_color("time", "blue")
+ ": %.2fs, "
) % (epoch_idx, e_time - s_time)
train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss
train_loss_output +=", "
train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss
return train_loss_output + "]"
def fit(self, data):
cur_eval_step = 0
for epoch_idx in range(self.epochs):
# train
training_start_time = time()
train_loss, train_recon_loss = self._train_epoch(data, epoch_idx)
training_end_time = time()
train_loss_output = self._generate_train_loss_output(
epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss
)
self.logger.info(train_loss_output)
if train_loss < self.best_loss:
self.best_loss = train_loss
# self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt)
# eval
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
collision_rate = self._valid_epoch(data)
if collision_rate < self.best_collision_rate:
self.best_collision_rate = collision_rate
cur_eval_step = 0
self._save_checkpoint(epoch_idx, collision_rate=collision_rate,
ckpt_file=self.best_collision_ckpt)
else:
cur_eval_step += 1
valid_end_time = time()
valid_score_output = (
set_color("epoch %d evaluating", "green")
+ " ["
+ set_color("time", "blue")
+ ": %.2fs, "
+ set_color("collision_rate", "blue")
+ ": %f]"
) % (epoch_idx, valid_end_time - valid_start_time, collision_rate)
self.logger.info(valid_score_output)
if epoch_idx>1000:
self._save_checkpoint(epoch_idx, collision_rate=collision_rate)
return self.best_loss, self.best_collision_rate