Spaces:
Sleeping
Sleeping
File size: 1,199 Bytes
deb67e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# Creates ViT pre-trained base model with SWAG weights
import torch
import torchvision
from torch import nn
def create_vit_b_16_swag(num_classes: int = 1000):
"""
Creates ViT SWAG pre-trained base model from torchvision.models
Args:
num_clases: int = 1000 - Number of classes in data.
Returns:
model: torch.nn.Module - Pre-trained ViT SWAG base model.
transforms: torchvision.transforms._presets.ImageClassification - Data Transformation Pipeline required by pre-trained model.
"""
# Get ViT weights and data transformation pipeline
model_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model_transforms = model_weights.transforms()
# Load in ViT Base model with patch size 16
model = torchvision.models.vit_b_16(weights=model_weights)
# Freezing all layer's parameters and then unfreezing only the classifier
for param_swag in model.parameters():
param_swag.requires_grad = False
for param_swag in model.heads.parameters():
param_swag.requires_grad = True
# custom classifier
model.heads = torch.nn.Sequential(
nn.Linear(in_features=768, out_features=num_classes, bias=True)
)
return model, model_transforms
|