File size: 812 Bytes
820ae20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torchvision import transforms
from transformers import ViTImageProcessor

model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
])
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
])

def preprocess_train(example):
    img = train_transforms(example["image"])
    enc = processor(images=img, return_tensors="pt")
    return {"pixel_values": enc.pixel_values[0], "label": example["label"]}

def preprocess_test(example):
    img = test_transforms(example["image"])
    enc = processor(images=img, return_tensors="pt")
    return {"pixel_values": enc.pixel_values[0], "label": example["label"]}