File size: 904 Bytes
b31a063 9d4e1df 1523798 9d4e1df b31a063 9d4e1df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from configuration_efficientnet import MODEL_NAMES
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(self,
model_name: str,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
def preprocess(self, image):
transforms = create_transform(**self.config)
data = {'pixel_values': transforms(image).unsqueeze(0)}
return BatchFeature(data=data)
__all__ = [
"EfficientNetImageProcessor"
] |