# Creates ViT pre-trained base model with SWAG weights import torch import torchvision from torch import nn def create_vit_b_16_swag(num_classes: int = 1000): """ Creates ViT SWAG pre-trained base model from torchvision.models Args: num_clases: int = 1000 - Number of classes in data. Returns: model: torch.nn.Module - Pre-trained ViT SWAG base model. transforms: torchvision.transforms._presets.ImageClassification - Data Transformation Pipeline required by pre-trained model. """ # Get ViT weights and data transformation pipeline model_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 model_transforms = model_weights.transforms() # Load in ViT Base model with patch size 16 model = torchvision.models.vit_b_16(weights=model_weights) # Freezing all layer's parameters and then unfreezing only the classifier for param_swag in model.parameters(): param_swag.requires_grad = False for param_swag in model.heads.parameters(): param_swag.requires_grad = True # custom classifier model.heads = torch.nn.Sequential( nn.Linear(in_features=768, out_features=num_classes, bias=True) ) return model, model_transforms