File size: 2,195 Bytes
1a96a3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torchvision
from torch import nn
from torchvision import transforms
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor
from typing import List

device = "cuda" if torch.cuda.is_available() else "cpu"

def create_vit(output_shape:int, classes:List, device:torch.device=device):
  """Creates a HuggingFace ViT model google/vit-base-patch16-224

  Args:
    output_shape: The output shape
    classes: A list of classes
    device: A torch.device

  Returns:
    A tuple of the model, train_transforms, val_transforms, test_transforms
  """
  id2label = {id:label for id, label in enumerate(classes)}
  label2id = {label:id for id,label in id2label.items()}

  model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224',
                                                  num_labels=len(classes),
                                                  id2label=id2label,
                                                  label2id=label2id,
                                                  ignore_mismatched_sizes=True)

  for param in model.parameters():
    param.requires_grad = False

  # Can add dropout here if needed
  model.classifier = nn.Linear(in_features=768, out_features=output_shape)

  #https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb
  processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
  image_mean = processor.image_mean
  image_std = processor.image_std
  size = processor.size["height"]

  normalize = transforms.Normalize(mean=image_mean, std=image_std)
  train_transforms = transforms.Compose([
      #transforms.RandomResizedCrop(size),
      transforms.Resize(size),
      transforms.CenterCrop(size),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      normalize])

  val_transforms = transforms.Compose([
      transforms.Resize(size),
      transforms.CenterCrop(size),
      transforms.ToTensor(),
      normalize])

  test_transforms = val_transforms

  return model.to(device), train_transforms, val_transforms, test_transforms