| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature | |
| from configuration_efficientnet import MODEL_NAMES | |
| from timm import create_model | |
| from timm.data import resolve_data_config | |
| from timm.data.transforms_factory import create_transform | |
| class EfficientNetImageProcessor(BaseImageProcessor): | |
| model_input_names = ["pixel_values"] | |
| def __init__(self, | |
| model_name: str, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.model_name = model_name | |
| self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False)) | |
| def preprocess(self, image): | |
| transforms = create_transform(**self.config) | |
| data = {'pixel_values': transforms(image).unsqueeze(0)} | |
| return BatchFeature(data=data) | |
| __all__ = [ | |
| "EfficientNetImageProcessor" | |
| ] |