File size: 772 Bytes
8bc22ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn
import timm
from torchvision import transforms

class TimmModel(nn.Module):
    def __init__(self,model_name,num_classes=100,pretrained=True,dropout_rate=0.0):
        super().__init__()

        self.preprocess = nn.Sequential(
            transforms.Resize((224,224),antialias=True),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
        )

        print(f"Loading {model_name} from timm Pretrained= {pretrained}")

        self.net = timm.create_model(
            model_name=model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            drop_rate=dropout_rate,
        )
    
    def forward(self,x):
        x = self.preprocess(x)
        return self.net(x)