from PIL import Image from torch import Tensor, stack from typing import Union, List from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from timm import create_model from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform class EfficientNetImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, model_name: str, **kwargs, ): self.model_name = model_name self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False)) super().__init__(**kwargs) def preprocess( self, images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor], ) -> BatchFeature: """ Preprocesses input images by applying transformations and returning them as a BatchFeature. Parameters ---------- images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor] A single image or a list of images in one of the accepted formats. Returns ------- BatchFeature A batch of transformed images """ images = [images] if not isinstance(images, list) else images # TEST: empty list if len(images) == 0: raise ValueError("Received an empty list of images") # TEST: validate input type test_image = images[0] if not isinstance(images[0], (Image.Image, Tensor)): raise TypeError( f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, " f"but got {type(test_image).__name__} instead." ) # Apply transformations transforms = create_transform(**self.config) transformed_images = [transforms(image) for image in images] # Convert to batch tensor transformed_image_tensors = stack(transformed_images) data = {'pixel_values': transformed_image_tensors} return BatchFeature(data=data) __all__ = [ "EfficientNetImageProcessor" ]