File size: 472 Bytes
5d7161e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torchvision
import torch.nn as nn

def create_vitb16_model(
  num_classes: int,
):
  vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
  vit_model = torchvision.models.vit_b_16(weights=vit_weights)
  vit_transform = vit_weights.transforms()

  for param in vit_model.parameters():
    param.requires_grad = False

  vit_model.heads = nn.Sequential(
    nn.Linear(in_features=768, out_features=num_classes, bias=True),
  )

  return vit_model, vit_transform