import torch.nn as nn from torchvision import transforms, models from typing import Literal, Dict _weights = models.ViT_B_16_Weights.DEFAULT model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = { 'Custom': transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]), 'Pretrained': _weights.transforms() } def get_pretrained_vit() -> models.VisionTransformer: model = models.vit_b_16(weights='DEFAULT') for parameter in model.parameters(): parameter.requires_grad = False model.heads = nn.Linear(in_features=768, out_features=3) return model