File size: 1,110 Bytes
b3d7090 9bd9cdc b3d7090 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import numpy as np
from PIL import Image
from transformers import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput
import torch
from torchvision.transforms import v2
class LeNetProcessor(BaseImageProcessor):
"""
A custom processor that only normalizes a grayscale image
and prepares it for a model.
"""
model_input_names = ["pixel_values"]
def __init__(self, **kwargs):
"""
Args:
"""
super().__init__(**kwargs)
def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature:
"""
Preprocess a batch of grayscale images.
"""
if not isinstance(images, list):
images = [images]
transform = v2.Compose([
v2.Resize(size=(28, 28), antialias=True),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(
mean=[0.1307],
std=[0.3081]
),
])
data = {"pixel_values": transform(images)}
return BatchFeature(data=data, tensor_type=return_tensors)
|