from torch import nn from dataclasses import dataclass, field from utils import parse_structure import timm @dataclass class TimmModelConfig: model_name: str = 'efficientnet_b0' pretrained: bool = True num_classes: int = 1 class TimmModel(nn.Module): cfg: TimmModelConfig def __init__(self, cfg: TimmModelConfig) -> None: super(TimmModel, self).__init__() self.cfg = parse_structure(TimmModelConfig, cfg) self.model_name = self.cfg.model_name self.pretrained = self.cfg.pretrained self.num_classes = self.cfg.num_classes self.model = timm.create_model( model_name=self.model_name, pretrained=self.pretrained, num_classes=self.num_classes ) def forward(self, x): return self.model(x)