import torchvision import torch from torch import nn def vit_model(num_classes): # setup pretrained model weightsEffnetb2 weights = torchvision.models.ViT_B_16_Weights.DEFAULT # Create an vit transform transform = weights.transforms() # Create an instance of the pretained model model= torchvision.models.vit_b_16(weights= weights) # Freeze the base layer for params in model.parameters(): params.requires_grad = False # Change the output or classifier layer model.heads = nn.Sequential( nn.Linear(in_features= 768,out_features = 3)) return model, transform