File size: 6,784 Bytes
f43af3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import logging
from abc import abstractmethod
from easy_tpp.preprocess import TPPDataLoader
from easy_tpp.utils import Registrable, Timer, logger, get_unique_id, LogConst, get_stage, RunnerPhase
class Runner(Registrable):
"""Registrable Base Runner class.
"""
def __init__(
self,
runner_config,
unique_model_dir=False,
**kwargs):
"""Initialize the base runner.
Args:
runner_config (RunnerConfig): config for the runner.
unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False.
"""
self.runner_config = runner_config
# re-assign the model_dir
if unique_model_dir:
runner_config.model_dir = runner_config.base_config.specs['saved_model_dir'] + '_' + get_unique_id()
self.save_log()
skip_data_loader = kwargs.get('skip_data_loader', False)
if not skip_data_loader:
# build data reader
data_config = self.runner_config.data_config
backend = self.runner_config.base_config.backend
kwargs = self.runner_config.trainer_config.get_yaml_config()
self._data_loader = TPPDataLoader(
data_config=data_config,
backend=backend,
**kwargs
)
# Needed for Intensity Free model
mean_log_inter_time, std_log_inter_time, min_dt, max_dt = (
self._data_loader.train_loader().dataset.get_dt_stats())
runner_config.model_config.set("mean_log_inter_time", mean_log_inter_time)
runner_config.model_config.set("std_log_inter_time", std_log_inter_time)
self.timer = Timer()
@staticmethod
def build_from_config(runner_config, unique_model_dir=False, **kwargs):
"""Build up the runner from runner config.
Args:
runner_config (RunnerConfig): config for the runner.
unique_model_dir (bool, optional): whether to give unique dir to save the model. Defaults to False.
Returns:
Runner: the corresponding runner class.
"""
runner_cls = Runner.by_name(runner_config.base_config.runner_id)
return runner_cls(runner_config, unique_model_dir=unique_model_dir, **kwargs)
def get_config(self):
return self.runner_config
def set_model_dir(self, model_dir):
self.runner_config.base_config.specs['saved_model_dir'] = model_dir
def get_model_dir(self):
return self.runner_config.base_config.specs['saved_model_dir']
def train(
self,
train_loader=None,
valid_loader=None,
test_loader=None,
**kwargs
):
"""Train the model.
Args:
train_loader (EasyTPP.DataLoader, optional): data loader for train set. Defaults to None.
valid_loader (EasyTPP.DataLoader, optional): data loader for valid set. Defaults to None.
test_loader (EasyTPP.DataLoader, optional): data loader for test set. Defaults to None.
Returns:
model: _description_
"""
# no train and valid loader from outside
if train_loader is None and valid_loader is None:
train_loader = self._data_loader.train_loader()
valid_loader = self._data_loader.valid_loader()
# no test loader from outside and there indeed exits test data in config
if test_loader is None and self.runner_config.data_config.test_dir is not None:
test_loader = self._data_loader.test_loader()
logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...')
timer = self.timer
timer.start()
model_id = self.runner_config.base_config.model_id
logger.info(f'Start {model_id} training...')
model = self._train_model(
train_loader,
valid_loader,
test_loader=test_loader,
**kwargs
)
logger.info(f'End {model_id} train! Cost time: {timer.end()}')
return model
def evaluate(self, valid_loader=None, **kwargs):
if valid_loader is None:
valid_loader = self._data_loader.valid_loader()
logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...')
timer = self.timer
timer.start()
model_id = self.runner_config.base_config.model_id
logger.info(f'Start {model_id} evaluation...')
metric = self._evaluate_model(
valid_loader,
**kwargs
)
logger.info(f'End {model_id} evaluation! Cost time: {timer.end()}')
return metric['rmse'] # return a list of scalr for HPO to use
def gen(self, gen_loader=None, **kwargs):
if gen_loader is None:
gen_loader = self._data_loader.test_loader()
logger.info(f'Data \'{self.runner_config.base_config.dataset_id}\' loaded...')
timer = self.timer
timer.start()
model_name = self.runner_config.base_config.model_id
logger.info(f'Start {model_name} evaluation...')
model = self._gen_model(
gen_loader,
**kwargs
)
logger.info(f'End {model_name} generation! Cost time: {timer.end()}')
return model
@abstractmethod
def _train_model(self, train_loader, valid_loader, **kwargs):
pass
@abstractmethod
def _evaluate_model(self, data_loader, **kwargs):
pass
@abstractmethod
def _gen_model(self, data_loader, **kwargs):
pass
@abstractmethod
def _save_model(self, model_dir, **kwargs):
pass
@abstractmethod
def _load_model(self, model_dir, **kwargs):
pass
def save_log(self):
"""Save log to local files
"""
log_dir = self.runner_config.base_config.specs['saved_log_dir']
fh = logging.FileHandler(log_dir)
fh.setFormatter(logging.Formatter(LogConst.DEFAULT_FORMAT_LONG))
logger.addHandler(fh)
logger.info(f'Save the log to {log_dir}')
return
def save(
self,
model_dir=None,
**kwargs
):
return self._save_model(model_dir, **kwargs)
def run(self, **kwargs):
"""Start the runner.
Args:
**kwargs (dict): optional params.
Returns:
EasyTPP.BaseModel, dict: the results of the process.
"""
current_stage = get_stage(self.runner_config.base_config.stage)
if current_stage == RunnerPhase.TRAIN:
return self.train(**kwargs)
elif current_stage == RunnerPhase.VALIDATE:
return self.evaluate(**kwargs)
else:
return self.gen(**kwargs)
|