timm / src /model.py
YiMeng-SYSU's picture
Initial commit of timm project files
8bc22ab verified
import torch
from torch import nn
import timm
from torchvision import transforms
class TimmModel(nn.Module):
def __init__(self,model_name,num_classes=100,pretrained=True,dropout_rate=0.0):
super().__init__()
self.preprocess = nn.Sequential(
transforms.Resize((224,224),antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
)
print(f"Loading {model_name} from timm Pretrained= {pretrained}")
self.net = timm.create_model(
model_name=model_name,
pretrained=pretrained,
num_classes=num_classes,
drop_rate=dropout_rate,
)
def forward(self,x):
x = self.preprocess(x)
return self.net(x)