| import torch | |
| from torch import nn | |
| import timm | |
| from torchvision import transforms | |
| class TimmModel(nn.Module): | |
| def __init__(self,model_name,num_classes=100,pretrained=True,dropout_rate=0.0): | |
| super().__init__() | |
| self.preprocess = nn.Sequential( | |
| transforms.Resize((224,224),antialias=True), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]), | |
| ) | |
| print(f"Loading {model_name} from timm Pretrained= {pretrained}") | |
| self.net = timm.create_model( | |
| model_name=model_name, | |
| pretrained=pretrained, | |
| num_classes=num_classes, | |
| drop_rate=dropout_rate, | |
| ) | |
| def forward(self,x): | |
| x = self.preprocess(x) | |
| return self.net(x) |