File size: 1,110 Bytes
b3d7090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd9cdc
b3d7090
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
from PIL import Image
from transformers import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput
import torch
from torchvision.transforms import v2


class LeNetProcessor(BaseImageProcessor):
    """
    A custom processor that only normalizes a grayscale image
    and prepares it for a model.
    """

    model_input_names = ["pixel_values"]

    def __init__(self, **kwargs):
        """
        Args:
        """
        super().__init__(**kwargs)

    def preprocess(self, images: ImageInput, return_tensors=None, **kwargs) -> BatchFeature:
        """
        Preprocess a batch of grayscale images.
        """
        if not isinstance(images, list):
            images = [images]

        transform = v2.Compose([
            v2.Resize(size=(28, 28), antialias=True),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(
                mean=[0.1307],
                std=[0.3081]
            ),
        ])

        data = {"pixel_values": transform(images)}
        return BatchFeature(data=data, tensor_type=return_tensors)