vision_ / model.py
relixsx's picture
Update model.py
ecd2b2f verified
Raw
History Blame Contribute Delete
605 Bytes
import torchvision
import torch
from torch import nn
def vit_model(num_classes):
# setup pretrained model weightsEffnetb2
weights = torchvision.models.ViT_B_16_Weights.DEFAULT
# Create an vit transform
transform = weights.transforms()
# Create an instance of the pretained model
model= torchvision.models.vit_b_16(weights= weights)
# Freeze the base layer
for params in model.parameters():
params.requires_grad = False
# Change the output or classifier layer
model.heads = nn.Sequential(
nn.Linear(in_features= 768,out_features = 3))
return model, transform