ViT_cifar10 / model.py
LukeOLuck's picture
init commit
1a96a3d
import torch
import torchvision
from torch import nn
from torchvision import transforms
from transformers import ViTForImageClassification
from transformers import ViTImageProcessor
from typing import List
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_vit(output_shape:int, classes:List, device:torch.device=device):
"""Creates a HuggingFace ViT model google/vit-base-patch16-224
Args:
output_shape: The output shape
classes: A list of classes
device: A torch.device
Returns:
A tuple of the model, train_transforms, val_transforms, test_transforms
"""
id2label = {id:label for id, label in enumerate(classes)}
label2id = {label:id for id,label in id2label.items()}
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224',
num_labels=len(classes),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True)
for param in model.parameters():
param.requires_grad = False
# Can add dropout here if needed
model.classifier = nn.Linear(in_features=768, out_features=output_shape)
#https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_mean = processor.image_mean
image_std = processor.image_std
size = processor.size["height"]
normalize = transforms.Normalize(mean=image_mean, std=image_std)
train_transforms = transforms.Compose([
#transforms.RandomResizedCrop(size),
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
val_transforms = transforms.Compose([
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
normalize])
test_transforms = val_transforms
return model.to(device), train_transforms, val_transforms, test_transforms