File size: 1,199 Bytes
deb67e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Creates ViT pre-trained base model with SWAG weights

import torch
import torchvision

from torch import nn

def create_vit_b_16_swag(num_classes: int = 1000):
  """
  Creates ViT SWAG pre-trained base model from torchvision.models

  Args:
    num_clases: int = 1000 - Number of classes in data.

  Returns:
    model: torch.nn.Module - Pre-trained ViT SWAG base model.
    transforms: torchvision.transforms._presets.ImageClassification - Data Transformation Pipeline required by pre-trained model.
  """

  # Get ViT weights and data transformation pipeline
  model_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
  model_transforms = model_weights.transforms()

  # Load in ViT Base model with patch size 16
  model = torchvision.models.vit_b_16(weights=model_weights)

  # Freezing all layer's parameters and then unfreezing only the classifier
  for param_swag in model.parameters():
    param_swag.requires_grad = False

  for param_swag in model.heads.parameters():
    param_swag.requires_grad = True

  # custom classifier
  model.heads = torch.nn.Sequential(
    nn.Linear(in_features=768, out_features=num_classes, bias=True)
  )

  return model, model_transforms