efficientnet_b0 / configuration_efficientnet.py
Thastp's picture
Upload model
44aa62f verified
raw
history blame
3.01 kB
from transformers.configuration_utils import PretrainedConfig
from optimum.utils.normalized_config import NormalizedVisionConfig
from optimum.utils.input_generators import DummyVisionInputGenerator
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
from typing import OrderedDict, Dict
MODEL_NAMES = [
'efficientnet_b0',
'efficientnet_b1',
'efficientnet_b2',
'efficientnet_b3',
'efficientnet_b4',
'efficientnet_b5',
'efficientnet_b6',
'efficientnet_b7',
'efficientnet_b8',
'efficientnet_l2'
]
class EfficientNetConfig(PretrainedConfig):
model_type = 'efficientnet'
def __init__(
self,
model_name: str = 'efficientnet_b0',
pretrained: bool = False,
**kwargs
):
if model_name not in MODEL_NAMES:
raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
self.model_name = model_name
self.pretrained = pretrained
# add attributes
# "batch_norm_eps": 0.001,
# "batch_norm_momentum": 0.99,
# "depth_coefficient": 3.1,
# "depth_divisor": 8,
# "depthwise_padding": [],
# "drop_connect_rate": 0.2,
# "dropout_rate": 0.5,
# "expand_ratios": [
# 1,
# 6,
# 6
# ],
# "hidden_act": "gelu",
# "hidden_dim": 2560,
# "id2label": { IMAGE NET DATASET
# "0": "LABEL_0",
# "1": "LABEL_1",
# "2": "LABEL_2",
# ...
# },
# "image_size": 600,
# "in_channels": [
# 32,
# 16,
# 24
# ],
# "initializer_range": 0.02,
# "kernel_sizes": [
# 3,
# 3,
# 5
# ],
# "label2id": {
# "LABEL_0": 0,
# "LABEL_1": 1,
# "LABEL_2": 2,
# ...
# },
# "model_type": "efficientnet",
# "num_block_repeats": [
# 1,
# 1,
# 2
# ],
# "num_channels": 3,
# "num_hidden_layers": 16,
# "out_channels": [
# 16,
# 24,
# 40
# ],
# "pooling_type": "mean",
# "squeeze_expansion_ratio": 0.25,
# "strides": [
# 1,
# 1,
# 2
# ],
# "width_coefficient": 2.0
# }
super().__init__(**kwargs)
class EfficientNetOnnxConfig(ViTOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self.task == "image-classification":
common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
return common_outputs
__all__ = [
'MODEL_NAMES',
'EfficientNetConfig',
'EfficientNetOnnxConfig'
]