| import torchvision | |
| import torch | |
| from torch import nn | |
| def vit_model(num_classes): | |
| # setup pretrained model weightsEffnetb2 | |
| weights = torchvision.models.ViT_B_16_Weights.DEFAULT | |
| # Create an vit transform | |
| transform = weights.transforms() | |
| # Create an instance of the pretained model | |
| model= torchvision.models.vit_b_16(weights= weights) | |
| # Freeze the base layer | |
| for params in model.parameters(): | |
| params.requires_grad = False | |
| # Change the output or classifier layer | |
| model.heads = nn.Sequential( | |
| nn.Linear(in_features= 768,out_features = 3)) | |
| return model, transform | |