beer-classificator / preprocessing.py
ramnck's picture
Upload folder using huggingface_hub
820ae20 verified
raw
history blame contribute delete
812 Bytes
from torchvision import transforms
from transformers import ViTImageProcessor
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
])
test_transforms = transforms.Compose([
transforms.Resize((224, 224)),
])
def preprocess_train(example):
img = train_transforms(example["image"])
enc = processor(images=img, return_tensors="pt")
return {"pixel_values": enc.pixel_values[0], "label": example["label"]}
def preprocess_test(example):
img = test_transforms(example["image"])
enc = processor(images=img, return_tensors="pt")
return {"pixel_values": enc.pixel_values[0], "label": example["label"]}