| import numpy as np | |
| from PIL import Image | |
| from transformers import BaseImageProcessor, BatchFeature | |
| from transformers.image_utils import ImageInput | |
| import torch | |
| from torchvision.transforms import v2 | |
| class LeNetProcessor(BaseImageProcessor): | |
| """ | |
| A custom processor that only normalizes a grayscale image | |
| and prepares it for a model. | |
| """ | |
| model_input_names = ["pixel_values"] | |
| def __init__(self, **kwargs): | |
| """ | |
| Args: | |
| """ | |
| super().__init__(**kwargs) | |
| def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature: | |
| """ | |
| Preprocess a batch of grayscale images. | |
| """ | |
| if not isinstance(images, list): | |
| images = [images] | |
| transform = v2.Compose([ | |
| v2.Resize(size=(28, 28), antialias=True), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize( | |
| mean=[0.1307], | |
| std=[0.3081] | |
| ), | |
| ]) | |
| data = {"pixel_values": transform(images)} | |
| return BatchFeature(data=data, tensor_type=return_tensors) | |