from transformers.image_utils import ImageInput from transformers import BaseImageProcessor, BatchFeature from torchvision.transforms import v2 import torch class ResNetProcessor(BaseImageProcessor): """ A custom processor for ResNet training """ model_input_names = ["pixel_values"] def __init__(self, **kwargs): super().__init__(**kwargs) def preprocess(self, images: ImageInput, return_tensors="pt", **kwargs) -> BatchFeature: """ Preprocess a batch of grayscale images. """ if not isinstance(images, list): images = [images] transform = v2.Compose([ v2.RandomResizedCrop(size=(224, 224), antialias=True), v2.RandomHorizontalFlip(p=0.5), v2.ToDtype(torch.float32, scale=True), v2.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ]) data = {"pixel_values": transform(images)} return BatchFeature(data=data, tensor_type="pt")