File size: 629 Bytes
563e5fb
 
 
 
 
 
d835e09
 
 
 
 
 
563e5fb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch.nn as nn
from torchvision import transforms, models
from typing import Literal, Dict

_weights = models.ViT_B_16_Weights.DEFAULT

model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = {
    'Custom': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ]),
    'Pretrained': _weights.transforms()
}

def get_pretrained_vit() -> models.VisionTransformer:
    model = models.vit_b_16(weights='DEFAULT')
    for parameter in model.parameters(): parameter.requires_grad = False
    model.heads = nn.Linear(in_features=768, out_features=3)
    return model