File size: 2,323 Bytes
4c62147 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import torch
from torchtask.utils import logger
def add_parser_arguments(parser):
pass
def task_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func):
raise NotImplementedError
class TaskTrainer:
def __init__(self, args):
self.args = args # arguments required by the task trainer
self.task_func = None # instance of 'TaskFunc' associated with a particular task
self.meters = logger.AvgMeterSet() # tool class for logging
self.models = {} # dict of the models required by the task and algorithm
self.optimizers = {} # dict of the optimizers required by the task and algorithm
self.lrers = {} # dict of the learn rate required by the task and algorithm
self.criterions = {} # dict of the criterions required by the task and algorithm
# ---------------------------------------------------------------------
# Interface for task proxy
# ---------------------------------------------------------------------
def build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
self._build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func)
def train(self, data_loader, epoch):
self._train(data_loader, epoch)
def validate(self, data_loader, epoch):
self._validate(data_loader, epoch)
def save_checkpoint(self, epoch):
self._save_checkpoint(epoch)
def load_checkpoint(self):
return self._load_checkpoint()
# ---------------------------------------------------------------------
# All task trainer should implement the following functions
# ---------------------------------------------------------------------
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
raise NotImplementedError
def _train(self, data_loader, epoch):
raise NotImplementedError
def _validate(self, data_loader, epoch):
raise NotImplementedError
def _save_checkpoint(self, epoch):
raise NotImplementedError
def _load_checkpoint(self):
raise NotImplementedError
|