efficientnet_b0 / image_processing_efficientnet.py
Thastp's picture
Upload processor
1523798 verified
raw
history blame
832 Bytes
from transformers.image_processing_utils import BaseImageProcessor
from configuration_efficientnet import MODEL_NAMES
from timm import create_model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
class EfficientNetImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(self,
model_name: str,
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
def preprocess(self, image):
transforms = create_transform(**self.config)
return transforms(image).unsqueeze(0)
__all__ = [
"EfficientNetImageProcessor"
]