from transformers.configuration_utils import PretrainedConfig from optimum.utils.normalized_config import NormalizedVisionConfig from optimum.utils.input_generators import DummyVisionInputGenerator from optimum.exporters.onnx.model_configs import ViTOnnxConfig from typing import OrderedDict, Dict MODEL_NAMES = [ 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', 'efficientnet_l2' ] class EfficientNetConfig(PretrainedConfig): model_type = 'efficientnet' def __init__( self, model_name: str = 'efficientnet_b0', pretrained: bool = False, **kwargs ): if model_name not in MODEL_NAMES: raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}') self.model_name = model_name self.pretrained = pretrained # add attributes # "batch_norm_eps": 0.001, # "batch_norm_momentum": 0.99, # "depth_coefficient": 3.1, # "depth_divisor": 8, # "depthwise_padding": [], # "drop_connect_rate": 0.2, # "dropout_rate": 0.5, # "expand_ratios": [ # 1, # 6, # 6 # ], # "hidden_act": "gelu", # "hidden_dim": 2560, # "id2label": { IMAGE NET DATASET # "0": "LABEL_0", # "1": "LABEL_1", # "2": "LABEL_2", # ... # }, # "image_size": 600, # "in_channels": [ # 32, # 16, # 24 # ], # "initializer_range": 0.02, # "kernel_sizes": [ # 3, # 3, # 5 # ], # "label2id": { # "LABEL_0": 0, # "LABEL_1": 1, # "LABEL_2": 2, # ... # }, # "model_type": "efficientnet", # "num_block_repeats": [ # 1, # 1, # 2 # ], # "num_channels": 3, # "num_hidden_layers": 16, # "out_channels": [ # 16, # 24, # 40 # ], # "pooling_type": "mean", # "squeeze_expansion_ratio": 0.25, # "strides": [ # 1, # 1, # 2 # ], # "width_coefficient": 2.0 # } super().__init__(**kwargs) class EfficientNetOnnxConfig(ViTOnnxConfig): @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "image-classification": common_outputs["logits"] = {0: "batch_size", 1: "num_classes"} return common_outputs __all__ = [ 'MODEL_NAMES', 'EfficientNetConfig', 'EfficientNetOnnxConfig' ]