efficientnet_b0 / image_processing_efficientnet.py
Thastp's picture
Upload processor
b31a063 verified
raw
history blame
904 Bytes
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from configuration_efficientnet import MODEL_NAMES
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(self,
model_name: str,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
def preprocess(self, image):
transforms = create_transform(**self.config)
data = {'pixel_values': transforms(image).unsqueeze(0)}
return BatchFeature(data=data)
__all__ = [
"EfficientNetImageProcessor"
]