File size: 821 Bytes
148d42e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from torch import nn
from dataclasses import dataclass, field
from utils import parse_structure
import timm
@dataclass
class TimmModelConfig:
model_name: str = 'efficientnet_b0'
pretrained: bool = True
num_classes: int = 1
class TimmModel(nn.Module):
cfg: TimmModelConfig
def __init__(self, cfg: TimmModelConfig) -> None:
super(TimmModel, self).__init__()
self.cfg = parse_structure(TimmModelConfig, cfg)
self.model_name = self.cfg.model_name
self.pretrained = self.cfg.pretrained
self.num_classes = self.cfg.num_classes
self.model = timm.create_model(
model_name=self.model_name,
pretrained=self.pretrained,
num_classes=self.num_classes
)
def forward(self, x):
return self.model(x) |