Inmental's picture
Upload folder using huggingface_hub
4c62147 verified
import torch
from torchtask.utils import logger
def add_parser_arguments(parser):
pass
def task_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func):
raise NotImplementedError
class TaskTrainer:
def __init__(self, args):
self.args = args # arguments required by the task trainer
self.task_func = None # instance of 'TaskFunc' associated with a particular task
self.meters = logger.AvgMeterSet() # tool class for logging
self.models = {} # dict of the models required by the task and algorithm
self.optimizers = {} # dict of the optimizers required by the task and algorithm
self.lrers = {} # dict of the learn rate required by the task and algorithm
self.criterions = {} # dict of the criterions required by the task and algorithm
# ---------------------------------------------------------------------
# Interface for task proxy
# ---------------------------------------------------------------------
def build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
self._build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func)
def train(self, data_loader, epoch):
self._train(data_loader, epoch)
def validate(self, data_loader, epoch):
self._validate(data_loader, epoch)
def save_checkpoint(self, epoch):
self._save_checkpoint(epoch)
def load_checkpoint(self):
return self._load_checkpoint()
# ---------------------------------------------------------------------
# All task trainer should implement the following functions
# ---------------------------------------------------------------------
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func):
raise NotImplementedError
def _train(self, data_loader, epoch):
raise NotImplementedError
def _validate(self, data_loader, epoch):
raise NotImplementedError
def _save_checkpoint(self, epoch):
raise NotImplementedError
def _load_checkpoint(self):
raise NotImplementedError