lenet / preprocessor_resnet.py
l45k's picture
Upload processor
44cadd4 verified
from transformers.image_utils import ImageInput
from transformers import BaseImageProcessor, BatchFeature
from torchvision.transforms import v2
import torch
class ResNetProcessor(BaseImageProcessor):
"""
A custom processor for ResNet training
"""
model_input_names = ["pixel_values"]
def __init__(self, **kwargs):
super().__init__(**kwargs)
def preprocess(self, images: ImageInput, return_tensors="pt", **kwargs) -> BatchFeature:
"""
Preprocess a batch of grayscale images.
"""
if not isinstance(images, list):
images = [images]
transform = v2.Compose([
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
data = {"pixel_values": transform(images)}
return BatchFeature(data=data, tensor_type="pt")