File size: 772 Bytes
8bc22ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from torch import nn
import timm
from torchvision import transforms
class TimmModel(nn.Module):
def __init__(self,model_name,num_classes=100,pretrained=True,dropout_rate=0.0):
super().__init__()
self.preprocess = nn.Sequential(
transforms.Resize((224,224),antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
)
print(f"Loading {model_name} from timm Pretrained= {pretrained}")
self.net = timm.create_model(
model_name=model_name,
pretrained=pretrained,
num_classes=num_classes,
drop_rate=dropout_rate,
)
def forward(self,x):
x = self.preprocess(x)
return self.net(x) |