i4ata's picture
tiny fix
d835e09
raw
history blame contribute delete
629 Bytes
import torch.nn as nn
from torchvision import transforms, models
from typing import Literal, Dict
_weights = models.ViT_B_16_Weights.DEFAULT
model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = {
'Custom': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
]),
'Pretrained': _weights.transforms()
}
def get_pretrained_vit() -> models.VisionTransformer:
model = models.vit_b_16(weights='DEFAULT')
for parameter in model.parameters(): parameter.requires_grad = False
model.heads = nn.Linear(in_features=768, out_features=3)
return model