from torch import nn import torchvision def pretrained_vit(): pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT transforms = pretrained_vit_weights.transforms() pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights) pretrained_vit.heads = nn.Linear( in_features=768, out_features=101, ) return pretrained_vit, transforms