import torch from torch import nn import timm from torchvision import transforms class TimmModel(nn.Module): def __init__(self,model_name,num_classes=100,pretrained=True,dropout_rate=0.0): super().__init__() self.preprocess = nn.Sequential( transforms.Resize((224,224),antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), ) print(f"Loading {model_name} from timm Pretrained= {pretrained}") self.net = timm.create_model( model_name=model_name, pretrained=pretrained, num_classes=num_classes, drop_rate=dropout_rate, ) def forward(self,x): x = self.preprocess(x) return self.net(x)