lenet / preprocessor_lenet.py
l45k's picture
Upload processor
9bd9cdc verified
import numpy as np
from PIL import Image
from transformers import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput
import torch
from torchvision.transforms import v2
class LeNetProcessor(BaseImageProcessor):
"""
A custom processor that only normalizes a grayscale image
and prepares it for a model.
"""
model_input_names = ["pixel_values"]
def __init__(self, **kwargs):
"""
Args:
"""
super().__init__(**kwargs)
def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature:
"""
Preprocess a batch of grayscale images.
"""
if not isinstance(images, list):
images = [images]
transform = v2.Compose([
v2.Resize(size=(28, 28), antialias=True),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(
mean=[0.1307],
std=[0.3081]
),
])
data = {"pixel_values": transform(images)}
return BatchFeature(data=data, tensor_type=return_tensors)