hispath / models /timm_model.py
kohido's picture
init
148d42e
raw
history blame contribute delete
821 Bytes
from torch import nn
from dataclasses import dataclass, field
from utils import parse_structure
import timm
@dataclass
class TimmModelConfig:
model_name: str = 'efficientnet_b0'
pretrained: bool = True
num_classes: int = 1
class TimmModel(nn.Module):
cfg: TimmModelConfig
def __init__(self, cfg: TimmModelConfig) -> None:
super(TimmModel, self).__init__()
self.cfg = parse_structure(TimmModelConfig, cfg)
self.model_name = self.cfg.model_name
self.pretrained = self.cfg.pretrained
self.num_classes = self.cfg.num_classes
self.model = timm.create_model(
model_name=self.model_name,
pretrained=self.pretrained,
num_classes=self.num_classes
)
def forward(self, x):
return self.model(x)