| from transformers.image_utils import ImageInput | |
| from transformers import BaseImageProcessor, BatchFeature | |
| from torchvision.transforms import v2 | |
| import torch | |
| class ResNetProcessor(BaseImageProcessor): | |
| """ | |
| A custom processor for ResNet training | |
| """ | |
| model_input_names = ["pixel_values"] | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def preprocess(self, images: ImageInput, return_tensors="pt", **kwargs) -> BatchFeature: | |
| """ | |
| Preprocess a batch of grayscale images. | |
| """ | |
| if not isinstance(images, list): | |
| images = [images] | |
| transform = v2.Compose([ | |
| v2.RandomResizedCrop(size=(224, 224), antialias=True), | |
| v2.RandomHorizontalFlip(p=0.5), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| data = {"pixel_values": transform(images)} | |
| return BatchFeature(data=data, tensor_type="pt") | |