import numpy as np from PIL import Image from transformers import BaseImageProcessor, BatchFeature from transformers.image_utils import ImageInput import torch from torchvision.transforms import v2 class LeNetProcessor(BaseImageProcessor): """ A custom processor that only normalizes a grayscale image and prepares it for a model. """ model_input_names = ["pixel_values"] def __init__(self, **kwargs): """ Args: """ super().__init__(**kwargs) def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature: """ Preprocess a batch of grayscale images. """ if not isinstance(images, list): images = [images] transform = v2.Compose([ v2.Resize(size=(28, 28), antialias=True), v2.ToDtype(torch.float32, scale=True), v2.Normalize( mean=[0.1307], std=[0.3081] ), ]) data = {"pixel_values": transform(images)} return BatchFeature(data=data, tensor_type=return_tensors)