File size: 821 Bytes
148d42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch import nn
from dataclasses import dataclass, field
from utils import parse_structure

import timm

@dataclass
class TimmModelConfig:
    model_name: str = 'efficientnet_b0'
    pretrained: bool = True
    num_classes: int = 1

class TimmModel(nn.Module):
    cfg: TimmModelConfig
    def __init__(self, cfg: TimmModelConfig) -> None:
        super(TimmModel, self).__init__()
        
        self.cfg = parse_structure(TimmModelConfig, cfg)
        self.model_name = self.cfg.model_name
        self.pretrained = self.cfg.pretrained
        self.num_classes = self.cfg.num_classes

        self.model = timm.create_model(
            model_name=self.model_name, 
            pretrained=self.pretrained,
            num_classes=self.num_classes
        )

    def forward(self, x):
        return self.model(x)