File size: 544 Bytes
9ad5b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from omegaconf import DictConfig
import torch.nn as nn
import torch.nn.functional as F
from lightning import LightningModule
from utils import instantiate


class BaseModule(LightningModule):
    def __init__(self, config: DictConfig):
        super().__init__()
        self.config = config
        self.model = self.configure_model()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError