|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Base class for trainable models.
|
|
|
"""
|
|
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
from copy import copy
|
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
class BaseModel(nn.Module, metaclass=ABCMeta):
|
|
|
|
|
|
required_data_keys = []
|
|
|
strict_conf = True
|
|
|
|
|
|
def __init__(self, conf):
|
|
|
"""Perform some logic and call the _init method of the child model."""
|
|
|
super().__init__()
|
|
|
self.conf = conf
|
|
|
OmegaConf.set_readonly(conf, True)
|
|
|
OmegaConf.set_struct(conf, True)
|
|
|
self.required_data_keys = copy(self.required_data_keys)
|
|
|
self._init(conf)
|
|
|
|
|
|
def forward(self, data):
|
|
|
"""Check the data and call the _forward method of the child model."""
|
|
|
|
|
|
def recursive_key_check(expected, given):
|
|
|
for key in expected:
|
|
|
assert key in given, f"Missing key {key} in data"
|
|
|
if isinstance(expected, dict):
|
|
|
recursive_key_check(expected[key], given[key])
|
|
|
|
|
|
recursive_key_check(self.required_data_keys, data)
|
|
|
return self._forward(data)
|
|
|
|
|
|
@abstractmethod
|
|
|
def _init(self, conf):
|
|
|
"""To be implemented by the child class."""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
@abstractmethod
|
|
|
def _forward(self, data):
|
|
|
"""To be implemented by the child class."""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
"""To be implemented by the child class."""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def metrics(self):
|
|
|
return {}
|
|
|
|