File size: 1,736 Bytes
6febcc2 12a9466 6febcc2 12a9466 6febcc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | from transformers.configuration_utils import PretrainedConfig
from optimum.utils.normalized_config import NormalizedVisionConfig
from optimum.utils.input_generators import DummyVisionInputGenerator
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
from typing import OrderedDict, Dict
MODEL_NAMES = [
'efficientnet_b0',
'efficientnet_b1',
'efficientnet_b2',
'efficientnet_b3',
'efficientnet_b4',
'efficientnet_b5',
'efficientnet_b6',
'efficientnet_b7',
'efficientnet_b8',
'efficientnet_l2'
]
class EfficientNetConfig(PretrainedConfig):
model_type = 'efficientnet'
def __init__(
self,
model_name: str = 'efficientnet_b0',
pretrained: bool = False,
**kwargs
):
if model_name not in MODEL_NAMES:
raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
self.model_name = model_name
self.pretrained = pretrained
super().__init__(**kwargs)
class EfficientNetOnnxConfig(ViTOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self.task == "image-classification":
common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
elif self.task == "feature-extraction":
common_outputs["last_hidden_state"] = {0: "batch_size", 1: "num_features", 2: "height", 3: "width"}
return common_outputs
__all__ = [
'MODEL_NAMES',
'EfficientNetConfig',
'EfficientNetOnnxConfig'
] |