| from torch import nn |
| from dataclasses import dataclass, field |
| from utils import parse_structure |
|
|
| import timm |
|
|
| @dataclass |
| class TimmModelConfig: |
| model_name: str = 'efficientnet_b0' |
| pretrained: bool = True |
| num_classes: int = 1 |
|
|
| class TimmModel(nn.Module): |
| cfg: TimmModelConfig |
| def __init__(self, cfg: TimmModelConfig) -> None: |
| super(TimmModel, self).__init__() |
| |
| self.cfg = parse_structure(TimmModelConfig, cfg) |
| self.model_name = self.cfg.model_name |
| self.pretrained = self.cfg.pretrained |
| self.num_classes = self.cfg.num_classes |
|
|
| self.model = timm.create_model( |
| model_name=self.model_name, |
| pretrained=self.pretrained, |
| num_classes=self.num_classes |
| ) |
|
|
| def forward(self, x): |
| return self.model(x) |