File size: 8,761 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
from collections import OrderedDict
from easy_tpp.runner.base_runner import Runner
from easy_tpp.utils import RunnerPhase, logger, MetricsHelper, MetricsTracker, concat_element, save_pickle
from easy_tpp.utils.const import Backend
@Runner.register(name='std_tpp')
class TPPRunner(Runner):
"""Standard TPP runner
"""
def __init__(self, runner_config, unique_model_dir=False, **kwargs):
super(TPPRunner, self).__init__(runner_config, unique_model_dir, **kwargs)
self.metrics_tracker = MetricsTracker()
if self.runner_config.trainer_config.metrics is not None:
self.metric_functions = self.runner_config.get_metric_functions()
self._init_model()
pretrain_dir = self.runner_config.model_config.pretrained_model_dir
if pretrain_dir is not None:
self._load_model(pretrain_dir)
def _init_model(self):
"""Initialize the model.
"""
self.use_torch = self.runner_config.base_config.backend == Backend.Torch
if self.use_torch:
from easy_tpp.utils import set_seed
from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel
from easy_tpp.torch_wrapper import TorchModelWrapper
from easy_tpp.utils import count_model_params
set_seed(self.runner_config.trainer_config.seed)
self.model = TorchBaseModel.generate_model_from_config(model_config=self.runner_config.model_config)
self.model_wrapper = TorchModelWrapper(self.model,
self.runner_config.base_config,
self.runner_config.model_config,
self.runner_config.trainer_config)
num_params = count_model_params(self.model)
else:
from easy_tpp.utils.tf_utils import set_seed
from easy_tpp.model.tf_model.tf_basemodel import TfBaseModel
from easy_tpp.tf_wrapper import TfModelWrapper
from easy_tpp.utils.tf_utils import count_model_params
set_seed(self.runner_config.trainer_config.seed)
self.model = TfBaseModel.generate_model_from_config(model_config=self.runner_config.model_config)
self.model_wrapper = TfModelWrapper(self.model,
self.runner_config.base_config,
self.runner_config.model_config,
self.runner_config.trainer_config)
num_params = count_model_params()
info_msg = f'Num of model parameters {num_params}'
logger.info(info_msg)
def _save_model(self, model_dir, **kwargs):
"""Save the model.
Args:
model_dir (str): the dir for model to save.
"""
if model_dir is None:
model_dir = self.runner_config.base_config.specs['saved_model_dir']
self.model_wrapper.save(model_dir)
logger.critical(f'Save model to {model_dir}')
return
def _load_model(self, model_dir, **kwargs):
"""Load the model from the dir.
Args:
model_dir (str): the dir for model to load.
"""
self.model_wrapper.restore(model_dir)
logger.critical(f'Load model from {model_dir}')
return
def _train_model(self, train_loader, valid_loader, **kwargs):
"""Train the model.
Args:
train_loader (EasyTPP.DataLoader): data loader for the train set.
valid_loader (EasyTPP.DataLoader): data loader for the valid set.
"""
test_loader = kwargs.get('test_loader')
for i in range(self.runner_config.trainer_config.max_epoch):
train_metrics = self.run_one_epoch(train_loader, RunnerPhase.TRAIN)
message = f"[ Epoch {i} (train) ]: train " + MetricsHelper.metrics_dict_to_str(train_metrics)
logger.info(message)
self.model_wrapper.write_summary(i, train_metrics, RunnerPhase.TRAIN)
# evaluate model
if i % self.runner_config.trainer_config.valid_freq == 0:
valid_metrics = self.run_one_epoch(valid_loader, RunnerPhase.VALIDATE)
self.model_wrapper.write_summary(i, valid_metrics, RunnerPhase.VALIDATE)
message = f"[ Epoch {i} (valid) ]: valid " + MetricsHelper.metrics_dict_to_str(valid_metrics)
logger.info(message)
updated = self.metrics_tracker.update_best("loglike", valid_metrics['loglike'], i)
message_valid = "current best loglike on valid set is {:.4f} (updated at epoch-{})".format(
self.metrics_tracker.current_best['loglike'], self.metrics_tracker.episode_best)
if updated:
message_valid += f", best updated at this epoch"
self.model_wrapper.save(self.runner_config.base_config.specs['saved_model_dir'])
if test_loader is not None:
test_metrics = self.run_one_epoch(test_loader, RunnerPhase.VALIDATE)
message = f"[ Epoch {i} (test) ]: test " + MetricsHelper.metrics_dict_to_str(test_metrics)
logger.info(message)
logger.critical(message_valid)
self.model_wrapper.close_summary()
return
def _evaluate_model(self, data_loader, **kwargs):
"""Evaluate the model on the valid dataset.
Args:
data_loader (EasyTPP.DataLoader): data loader for the valid set
Returns:
dict: metrics dict.
"""
eval_metrics = self.run_one_epoch(data_loader, RunnerPhase.VALIDATE)
self.model_wrapper.write_summary(0, eval_metrics, RunnerPhase.VALIDATE)
self.model_wrapper.close_summary()
message = f"Evaluation result: " + MetricsHelper.metrics_dict_to_str(eval_metrics)
logger.critical(message)
return eval_metrics
def _gen_model(self, data_loader, **kwargs):
"""Generation of the TPP, one-step and multi-step are both supported.
"""
test_result = self.run_one_epoch(data_loader, RunnerPhase.PREDICT)
# For the moment we save it to a pkl
message = f'Save the prediction to pickle file pred.pkl'
logger.critical(message)
save_pickle('pred.pkl', test_result)
return
def run_one_epoch(self, data_loader, phase):
"""Run one complete epoch.
Args:
data_loader: data loader object defined in model runner
phase: enum, [train, dev, test]
Returns:
a dict of metrics
"""
total_loss = 0
total_num_event = 0
epoch_label = []
epoch_pred = []
epoch_mask = []
pad_index = self.runner_config.data_config.data_specs.pad_token_id
metrics_dict = OrderedDict()
if phase in [RunnerPhase.TRAIN, RunnerPhase.VALIDATE]:
for batch in data_loader:
batch_loss, batch_num_event, batch_pred, batch_label, batch_mask = \
self.model_wrapper.run_batch(batch, phase=phase)
total_loss += batch_loss
total_num_event += batch_num_event
epoch_pred.append(batch_pred)
epoch_label.append(batch_label)
epoch_mask.append(batch_mask)
avg_loss = total_loss / total_num_event
metrics_dict.update({'loglike': -avg_loss, 'num_events': total_num_event})
else:
for batch in data_loader:
batch_pred, batch_label = self.model_wrapper.run_batch(batch, phase=phase)
epoch_pred.append(batch_pred)
epoch_label.append(batch_label)
# we need to improve the code here
# classify batch_output to list
pred_exists, label_exists = False, False
if epoch_pred[0][0] is not None:
epoch_pred = concat_element(epoch_pred, pad_index)
pred_exists = True
if len(epoch_label) > 0 and epoch_label[0][0] is not None:
epoch_label = concat_element(epoch_label, pad_index)
label_exists = True
if len(epoch_mask):
epoch_mask = concat_element(epoch_mask, False)[0] # retrieve the first element of concat array
epoch_mask = epoch_mask.astype(bool)
if pred_exists and label_exists:
metrics_dict.update(self.metric_functions(epoch_pred, epoch_label, seq_mask=epoch_mask))
if phase == RunnerPhase.PREDICT:
metrics_dict.update({'pred': epoch_pred, 'label': epoch_label})
return metrics_dict
|