from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from configuration_efficientnet import MODEL_NAMES from timm import create_model from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform class EfficientNetImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__(self, model_name: str, **kwargs ): super().__init__(**kwargs) self.model_name = model_name self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False)) def preprocess(self, image): transforms = create_transform(**self.config) data = {'pixel_values': transforms(image).unsqueeze(0)} return BatchFeature(data=data) __all__ = [ "EfficientNetImageProcessor" ]