File size: 1,513 Bytes
9c69ae8
 
6febcc2
12a9466
6febcc2
 
 
 
 
 
 
 
 
 
 
 
 
 
9c69ae8
6febcc2
 
 
 
9c69ae8
 
 
 
 
 
 
6febcc2
 
 
 
 
f41fda1
 
6febcc2
 
9c69ae8
12a9466
6febcc2
 
 
 
 
 
 
 
 
9c69ae8
6febcc2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Dict

from transformers.configuration_utils import PretrainedConfig
from optimum.exporters.onnx.model_configs import ViTOnnxConfig

MODEL_NAMES = [
        'efficientnet_b0',
        'efficientnet_b1',
        'efficientnet_b2',
        'efficientnet_b3',
        'efficientnet_b4',
        'efficientnet_b5',
        'efficientnet_b6',
        'efficientnet_b7',
        'efficientnet_b8',
        'efficientnet_l2'
    ]


class EfficientNetConfig(PretrainedConfig):
    model_type = 'efficientnet'

    def __init__(

        self,

        model_name: str = 'efficientnet_b0',

        pretrained: bool = False,

        num_classes: int = 1000,

        global_pool: str = 'avg',

        **kwargs,

    ): 
        if model_name not in MODEL_NAMES:
            raise ValueError(f'`model_name` must be one of these: {MODEL_NAMES}, but got {model_name}')
        
        self.model_name = model_name
        self.pretrained = pretrained
        self.num_classes = num_classes
        self.global_pool = global_pool
        super().__init__(**kwargs)


class EfficientNetOnnxConfig(ViTOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        common_outputs = super().outputs 

        if self.task == "image-classification":
            common_outputs["logits"] = {0: "batch_size", 1: "num_classes"}
        
        return common_outputs
    

__all__ = [
    'EfficientNetConfig',
    'EfficientNetOnnxConfig'
]