from transformers.configuration_utils import PretrainedConfig from optimum.utils.normalized_config import NormalizedVisionConfig from optimum.utils.input_generators import DummyVisionInputGenerator from optimum.exporters.onnx.model_configs import ViTOnnxConfig from typing import OrderedDict, Dict MODEL_NAMES = [ 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_l2' ] class EfficientNetConfig(PretrainedConfig): model_type = 'efficientnet' def __init__( self, model_name: str = 'efficientnet_b0', pretrained: bool = False, **kwargs ): if model_name not in MODEL_NAMES: raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}') self.model_name = model_name self.pretrained = pretrained super().__init__(**kwargs) class EfficientNetOnnxConfig(ViTOnnxConfig): @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "image-classification": common_outputs["logits"] = {0: "batch_size", 1: "num_classes"} elif self.task == "feature-extraction": common_outputs["last_hidden_state"] = {0: "batch_size", 1: "num_features", 2: "height", 3: "width"} return common_outputs __all__ = [ 'MODEL_NAMES', 'EfficientNetConfig', 'EfficientNetOnnxConfig' ]