File size: 832 Bytes
9d4e1df 1523798 9d4e1df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | from transformers.image_processing_utils import BaseImageProcessor
from configuration_efficientnet import MODEL_NAMES
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(self,
model_name: str,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
def preprocess(self, image):
transforms = create_transform(**self.config)
return transforms(image).unsqueeze(0)
__all__ = [
"EfficientNetImageProcessor"
] |