Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified
""" Initialize a Pytorch model wrapper that feed into Model Runner """
import torch
from torch.utils.tensorboard import SummaryWriter
from easy_tpp.utils import RunnerPhase, set_optimizer, set_device
class TorchModelWrapper:
def __init__(self, model, base_config, model_config, trainer_config):
"""A wrapper class for Torch backends.
Args:
model (BaseModel): a TPP model.
base_config (EasyTPP.Config): basic configs.
model_config (EasyTPP.ModelConfig): model spec configs.
trainer_config (EasyTPP.TrainerConfig): trainer spec configs.
"""
self.model = model
self.base_config = base_config
self.model_config = model_config
self.trainer_config = trainer_config
self.model_id = self.base_config.model_id
# Sometimes PyTorch may not switch the active device context for all operations
# This causes illegal memory access error
if self.trainer_config.gpu!=-1:
torch.cuda.set_device(self.trainer_config.gpu)
self.device = set_device(self.trainer_config.gpu)
self.model.to(self.device)
if self.model_config.is_training:
# set up optimizer
optimizer = self.trainer_config.optimizer
self.learning_rate = self.trainer_config.learning_rate
self.opt = set_optimizer(optimizer, self.model.parameters(), self.learning_rate)
# set up tensorboard
self.train_summary_writer, self.valid_summary_writer = None, None
if self.trainer_config.use_tfb:
self.train_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_train_dir'])
self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_valid_dir'])
def restore(self, ckpt_dir):
"""Load the checkpoint to restore the model.
Args:
ckpt_dir (str): path for the checkpoint.
"""
self.model.load_state_dict(torch.load(ckpt_dir), strict=False)
def save(self, ckpt_dir):
"""Save the checkpoint for the model.
Args:
ckpt_dir (str): path for the checkpoint.
"""
torch.save(self.model.state_dict(), ckpt_dir)
def write_summary(self, epoch, kv_pairs, phase):
"""Write the kv_paris into the tensorboard
Args:
epoch (int): epoch index in the training.
kv_pairs (dict): metrics dict.
phase (RunnerPhase): a const that defines the stage of model runner.
"""
if self.trainer_config.use_tfb:
summary_writer = None
if phase == RunnerPhase.TRAIN:
summary_writer = self.train_summary_writer
elif phase == RunnerPhase.VALIDATE:
summary_writer = self.valid_summary_writer
elif phase == RunnerPhase.PREDICT:
pass
if summary_writer is not None:
for k, v in kv_pairs.items():
if k != 'num_events':
summary_writer.add_scalar(k, v, epoch)
summary_writer.flush()
return
def close_summary(self):
"""Close the tensorboard summary writer.
"""
if self.train_summary_writer is not None:
self.train_summary_writer.close()
if self.valid_summary_writer is not None:
self.valid_summary_writer.close()
return
def run_batch(self, batch, phase):
"""Run one batch.
Args:
batch (EasyTPP.BatchEncoding): preprocessed batch data that go into the model.
phase (RunnerPhase): a const that defines the stage of model runner.
Returns:
tuple: for training and validation we return loss, prediction and labels;
for prediction we return prediction.
"""
batch = batch.to(self.device).values()
if phase in (RunnerPhase.TRAIN, RunnerPhase.VALIDATE):
# set mode to train
is_training = (phase == RunnerPhase.TRAIN)
self.model.train(is_training)
# FullyRNN needs grad event in validation stage
grad_flag = is_training if not self.model_id == 'FullyNN' else True
# run model
with torch.set_grad_enabled(grad_flag):
loss, num_event = self.model.loglike_loss(batch)
# Assume we dont do prediction on train set
pred_dtime, pred_type, label_dtime, label_type, mask = None, None, None, None, None
# update grad
if is_training:
self.opt.zero_grad()
(loss / num_event).backward()
self.opt.step()
else: # by default we do not do evaluation on train set which may take a long time
if self.model.event_sampler:
self.model.eval()
with torch.no_grad():
if batch[1] is not None and batch[2] is not None:
label_dtime, label_type = batch[1][:, 1:].cpu().numpy(), batch[2][:, 1:].cpu().numpy()
if batch[3] is not None:
mask = batch[3][:, 1:].cpu().numpy()
pred_dtime, pred_type = self.model.predict_one_step_at_every_event(batch=batch)
pred_dtime = pred_dtime.detach().cpu().numpy()
pred_type = pred_type.detach().cpu().numpy()
return loss.item(), num_event, (pred_dtime, pred_type), (label_dtime, label_type), (mask,)
else:
pred_dtime, pred_type, label_dtime, label_type = self.model.predict_multi_step_since_last_event(batch=batch)
pred_dtime = pred_dtime.detach().cpu().numpy()
pred_type = pred_type.detach().cpu().numpy()
label_dtime = label_dtime.detach().cpu().numpy()
label_type = label_type.detach().cpu().numpy()
return (pred_dtime, pred_type), (label_dtime, label_type)