| from torchvision import transforms | |
| from transformers import ViTImageProcessor | |
| model_name = "google/vit-base-patch16-224" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| train_transforms = transforms.Compose([ | |
| transforms.RandomResizedCrop(224), | |
| transforms.RandomHorizontalFlip(), | |
| ]) | |
| test_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| ]) | |
| def preprocess_train(example): | |
| img = train_transforms(example["image"]) | |
| enc = processor(images=img, return_tensors="pt") | |
| return {"pixel_values": enc.pixel_values[0], "label": example["label"]} | |
| def preprocess_test(example): | |
| img = test_transforms(example["image"]) | |
| enc = processor(images=img, return_tensors="pt") | |
| return {"pixel_values": enc.pixel_values[0], "label": example["label"]} | |