| from omegaconf import DictConfig | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from lightning import LightningModule | |
| from utils import instantiate | |
| class BaseModule(LightningModule): | |
| def __init__(self, config: DictConfig): | |
| super().__init__() | |
| self.config = config | |
| self.model = self.configure_model() | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| raise NotImplementedError | |
| def configure_optimizers(self): | |
| raise NotImplementedError |