| | from transformers.configuration_utils import PretrainedConfig
|
| |
|
| | from optimum.utils.normalized_config import NormalizedVisionConfig
|
| | from optimum.utils.input_generators import DummyVisionInputGenerator
|
| | from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
| |
|
| | from typing import OrderedDict, Dict
|
| |
|
| | MODEL_NAMES = [
|
| | 'efficientnet_b0',
|
| | 'efficientnet_b1',
|
| | 'efficientnet_b2',
|
| | 'efficientnet_b3',
|
| | 'efficientnet_b4',
|
| | 'efficientnet_b5',
|
| | 'efficientnet_b6',
|
| | 'efficientnet_b7',
|
| | 'efficientnet_b8',
|
| | 'efficientnet_l2'
|
| | ]
|
| |
|
| | class EfficientNetConfig(PretrainedConfig):
|
| | model_type = 'efficientnet'
|
| |
|
| | def __init__(
|
| | self,
|
| | model_name: str = 'efficientnet_b0',
|
| | pretrained: bool = False,
|
| | **kwargs
|
| | ):
|
| |
|
| | if model_name not in MODEL_NAMES:
|
| | raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
|
| |
|
| | self.model_name = model_name
|
| | self.pretrained = pretrained
|
| |
|
| | super().__init__(**kwargs)
|
| |
|
| | class EfficientNetOnnxConfig(ViTOnnxConfig):
|
| |
|
| | @property
|
| | def outputs(self) -> Dict[str, Dict[int, str]]:
|
| | common_outputs = super().outputs
|
| |
|
| | if self.task == "image-classification":
|
| | common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
|
| | elif self.task == "feature-extraction":
|
| | common_outputs["last_hidden_state"] = {0: "batch_size", 1: "num_features", 2: "height", 3: "width"}
|
| |
|
| | return common_outputs
|
| |
|
| | __all__ = [
|
| | 'MODEL_NAMES',
|
| | 'EfficientNetConfig',
|
| | 'EfficientNetOnnxConfig'
|
| | ] |