from torchvision import transforms from transformers import ViTImageProcessor model_name = "google/vit-base-patch16-224" processor = ViTImageProcessor.from_pretrained(model_name) train_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), ]) test_transforms = transforms.Compose([ transforms.Resize((224, 224)), ]) def preprocess_train(example): img = train_transforms(example["image"]) enc = processor(images=img, return_tensors="pt") return {"pixel_values": enc.pixel_values[0], "label": example["label"]} def preprocess_test(example): img = test_transforms(example["image"]) enc = processor(images=img, return_tensors="pt") return {"pixel_values": enc.pixel_values[0], "label": example["label"]}