|
|
import torch |
|
|
|
|
|
from torchtask.utils import logger |
|
|
|
|
|
|
|
|
def add_parser_arguments(parser): |
|
|
pass |
|
|
|
|
|
|
|
|
def task_trainer(args, model_dict, optimizer_dict, lrer_dict, criterion_dict, task_func): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
class TaskTrainer: |
|
|
def __init__(self, args): |
|
|
self.args = args |
|
|
self.task_func = None |
|
|
self.meters = logger.AvgMeterSet() |
|
|
|
|
|
self.models = {} |
|
|
self.optimizers = {} |
|
|
self.lrers = {} |
|
|
self.criterions = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): |
|
|
self._build(model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func) |
|
|
|
|
|
def train(self, data_loader, epoch): |
|
|
self._train(data_loader, epoch) |
|
|
|
|
|
def validate(self, data_loader, epoch): |
|
|
self._validate(data_loader, epoch) |
|
|
|
|
|
def save_checkpoint(self, epoch): |
|
|
self._save_checkpoint(epoch) |
|
|
|
|
|
def load_checkpoint(self): |
|
|
return self._load_checkpoint() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build(self, model_funcs, optimizer_funcs, lrer_funcs, criterion_funcs, task_func): |
|
|
raise NotImplementedError |
|
|
|
|
|
def _train(self, data_loader, epoch): |
|
|
raise NotImplementedError |
|
|
|
|
|
def _validate(self, data_loader, epoch): |
|
|
raise NotImplementedError |
|
|
|
|
|
def _save_checkpoint(self, epoch): |
|
|
raise NotImplementedError |
|
|
|
|
|
def _load_checkpoint(self): |
|
|
raise NotImplementedError |
|
|
|