File size: 1,115 Bytes
8bc3305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import datetime
from copy import deepcopy
from abc import ABC, abstractmethod
class BaseTrainer(ABC):
"""
"""
def __init__(
self,
config,
model,
optimizer,
scheduler,
writer,
):
# check if all the necessary components are implemented
if config is None or model is None or optimizer is None or scheduler is None or writer is None:
raise NotImplementedError("config, model, optimizier, scheduler, and tensorboard writer must be implemented")
self.config = config
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.writer = writer
@abstractmethod
def speed_up(self):
pass
@abstractmethod
def setTrain(self):
pass
@abstractmethod
def setEval(self):
pass
@abstractmethod
def load_ckpt(self, model_path):
pass
@abstractmethod
def save_ckpt(self, dataset, epoch, iters, best=False):
pass
@abstractmethod
def inference(self, data_dict):
pass
|