File size: 930 Bytes
563e5fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
from torchvision import transforms, models
from typing import Literal, Dict

_weights = models.ViT_B_16_Weights.DEFAULT

model_transforms: Dict[Literal['custom', 'pretrained'], Dict[Literal['train', 'val'], transforms.Compose]] = {
    'custom': {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.TrivialAugmentWide(),
            transforms.ToTensor()
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
    },
    'pretrained': {
        'train': _weights.transforms(),
        'val': _weights.transforms()
    }
}

def get_pretrained_vit() -> models.VisionTransformer:
    model = models.vit_b_16(weights='DEFAULT')
    for parameter in model.parameters(): parameter.requires_grad = False
    model.heads = nn.Linear(in_features=768, out_features=3)
    return model