File size: 1,069 Bytes
44cadd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers.image_utils import ImageInput
from transformers import BaseImageProcessor, BatchFeature
from torchvision.transforms import v2
import torch


class ResNetProcessor(BaseImageProcessor):
    """
    A custom processor for ResNet training
    """

    model_input_names = ["pixel_values"]

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def preprocess(self, images: ImageInput, return_tensors="pt", **kwargs) -> BatchFeature:
        """
        Preprocess a batch of grayscale images.
        """
        if not isinstance(images, list):
            images = [images]
        
        transform = v2.Compose([
            v2.RandomResizedCrop(size=(224, 224), antialias=True),
            v2.RandomHorizontalFlip(p=0.5),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])
        
        data = {"pixel_values": transform(images)}
        return BatchFeature(data=data, tensor_type="pt")