|
|
""" Initialize a Pytorch model wrapper that feed into Model Runner """ |
|
|
|
|
|
import torch |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
from easy_tpp.utils import RunnerPhase, set_optimizer, set_device |
|
|
|
|
|
|
|
|
class TorchModelWrapper: |
|
|
def __init__(self, model, base_config, model_config, trainer_config): |
|
|
"""A wrapper class for Torch backends. |
|
|
|
|
|
Args: |
|
|
model (BaseModel): a TPP model. |
|
|
base_config (EasyTPP.Config): basic configs. |
|
|
model_config (EasyTPP.ModelConfig): model spec configs. |
|
|
trainer_config (EasyTPP.TrainerConfig): trainer spec configs. |
|
|
""" |
|
|
self.model = model |
|
|
self.base_config = base_config |
|
|
self.model_config = model_config |
|
|
self.trainer_config = trainer_config |
|
|
|
|
|
self.model_id = self.base_config.model_id |
|
|
|
|
|
|
|
|
if self.trainer_config.gpu!=-1: |
|
|
torch.cuda.set_device(self.trainer_config.gpu) |
|
|
self.device = set_device(self.trainer_config.gpu) |
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
if self.model_config.is_training: |
|
|
|
|
|
optimizer = self.trainer_config.optimizer |
|
|
self.learning_rate = self.trainer_config.learning_rate |
|
|
self.opt = set_optimizer(optimizer, self.model.parameters(), self.learning_rate) |
|
|
|
|
|
|
|
|
self.train_summary_writer, self.valid_summary_writer = None, None |
|
|
if self.trainer_config.use_tfb: |
|
|
self.train_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_train_dir']) |
|
|
self.valid_summary_writer = SummaryWriter(log_dir=self.base_config.specs['tfb_valid_dir']) |
|
|
|
|
|
def restore(self, ckpt_dir): |
|
|
"""Load the checkpoint to restore the model. |
|
|
|
|
|
Args: |
|
|
ckpt_dir (str): path for the checkpoint. |
|
|
""" |
|
|
|
|
|
self.model.load_state_dict(torch.load(ckpt_dir), strict=False) |
|
|
|
|
|
def save(self, ckpt_dir): |
|
|
"""Save the checkpoint for the model. |
|
|
|
|
|
Args: |
|
|
ckpt_dir (str): path for the checkpoint. |
|
|
""" |
|
|
torch.save(self.model.state_dict(), ckpt_dir) |
|
|
|
|
|
def write_summary(self, epoch, kv_pairs, phase): |
|
|
"""Write the kv_paris into the tensorboard |
|
|
|
|
|
Args: |
|
|
epoch (int): epoch index in the training. |
|
|
kv_pairs (dict): metrics dict. |
|
|
phase (RunnerPhase): a const that defines the stage of model runner. |
|
|
""" |
|
|
if self.trainer_config.use_tfb: |
|
|
summary_writer = None |
|
|
if phase == RunnerPhase.TRAIN: |
|
|
summary_writer = self.train_summary_writer |
|
|
elif phase == RunnerPhase.VALIDATE: |
|
|
summary_writer = self.valid_summary_writer |
|
|
elif phase == RunnerPhase.PREDICT: |
|
|
pass |
|
|
|
|
|
if summary_writer is not None: |
|
|
for k, v in kv_pairs.items(): |
|
|
if k != 'num_events': |
|
|
summary_writer.add_scalar(k, v, epoch) |
|
|
|
|
|
summary_writer.flush() |
|
|
return |
|
|
|
|
|
def close_summary(self): |
|
|
"""Close the tensorboard summary writer. |
|
|
""" |
|
|
if self.train_summary_writer is not None: |
|
|
self.train_summary_writer.close() |
|
|
|
|
|
if self.valid_summary_writer is not None: |
|
|
self.valid_summary_writer.close() |
|
|
return |
|
|
|
|
|
def run_batch(self, batch, phase): |
|
|
"""Run one batch. |
|
|
|
|
|
Args: |
|
|
batch (EasyTPP.BatchEncoding): preprocessed batch data that go into the model. |
|
|
phase (RunnerPhase): a const that defines the stage of model runner. |
|
|
|
|
|
Returns: |
|
|
tuple: for training and validation we return loss, prediction and labels; |
|
|
for prediction we return prediction. |
|
|
""" |
|
|
|
|
|
batch = batch.to(self.device).values() |
|
|
if phase in (RunnerPhase.TRAIN, RunnerPhase.VALIDATE): |
|
|
|
|
|
is_training = (phase == RunnerPhase.TRAIN) |
|
|
self.model.train(is_training) |
|
|
|
|
|
|
|
|
grad_flag = is_training if not self.model_id == 'FullyNN' else True |
|
|
|
|
|
with torch.set_grad_enabled(grad_flag): |
|
|
loss, num_event = self.model.loglike_loss(batch) |
|
|
|
|
|
|
|
|
pred_dtime, pred_type, label_dtime, label_type, mask = None, None, None, None, None |
|
|
|
|
|
|
|
|
if is_training: |
|
|
self.opt.zero_grad() |
|
|
(loss / num_event).backward() |
|
|
self.opt.step() |
|
|
else: |
|
|
if self.model.event_sampler: |
|
|
self.model.eval() |
|
|
with torch.no_grad(): |
|
|
if batch[1] is not None and batch[2] is not None: |
|
|
label_dtime, label_type = batch[1][:, 1:].cpu().numpy(), batch[2][:, 1:].cpu().numpy() |
|
|
if batch[3] is not None: |
|
|
mask = batch[3][:, 1:].cpu().numpy() |
|
|
pred_dtime, pred_type = self.model.predict_one_step_at_every_event(batch=batch) |
|
|
pred_dtime = pred_dtime.detach().cpu().numpy() |
|
|
pred_type = pred_type.detach().cpu().numpy() |
|
|
return loss.item(), num_event, (pred_dtime, pred_type), (label_dtime, label_type), (mask,) |
|
|
else: |
|
|
pred_dtime, pred_type, label_dtime, label_type = self.model.predict_multi_step_since_last_event(batch=batch) |
|
|
pred_dtime = pred_dtime.detach().cpu().numpy() |
|
|
pred_type = pred_type.detach().cpu().numpy() |
|
|
label_dtime = label_dtime.detach().cpu().numpy() |
|
|
label_type = label_type.detach().cpu().numpy() |
|
|
return (pred_dtime, pred_type), (label_dtime, label_type) |
|
|
|