|
|
from torch import nn |
|
|
from dataclasses import dataclass, field |
|
|
from utils import parse_structure |
|
|
|
|
|
import timm |
|
|
|
|
|
@dataclass |
|
|
class TimmModelConfig: |
|
|
model_name: str = 'efficientnet_b0' |
|
|
pretrained: bool = True |
|
|
num_classes: int = 1 |
|
|
|
|
|
class TimmModel(nn.Module): |
|
|
cfg: TimmModelConfig |
|
|
def __init__(self, cfg: TimmModelConfig) -> None: |
|
|
super(TimmModel, self).__init__() |
|
|
|
|
|
self.cfg = parse_structure(TimmModelConfig, cfg) |
|
|
self.model_name = self.cfg.model_name |
|
|
self.pretrained = self.cfg.pretrained |
|
|
self.num_classes = self.cfg.num_classes |
|
|
|
|
|
self.model = timm.create_model( |
|
|
model_name=self.model_name, |
|
|
pretrained=self.pretrained, |
|
|
num_classes=self.num_classes |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |