File size: 15,738 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# /*---------------------------------------------------------------------------------------------
#  * Copyright 2015 The TensorFlow Authors. 
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import tensorflow as tf
import keras
from keras.activations import silu, relu, relu6
from keras.ops import clip
from keras import layers
from keras.src.applications import imagenet_utils
from keras import Model
from typing import List, Tuple
import math


repeats = [1, 2, 2, 3, 3, 4, 1]


def _round_expansion(expansion_factor: int, repeats: List[int]) -> List[int] :
    """
    Docstring for _round_expansion
    
    Args:
        expansion_factor (int): scaling coefficient for the input filters
        repeats (list): list managing block repetition in the network and therefore depth
    Returns:
        exp_ratio: list of scaling coefficient 
    """
    exp_ratio = []
    flag = 1
    for r in repeats:
        if (r != 0) and flag:
            exp_ratio.append(1)
            flag = 0
        else:
            exp_ratio.append(expansion_factor)

    return exp_ratio


def _num_blocks(repeats: List[int]) -> List[int]:
    """
    Docstring for _num_blocks
    
    Args: 
        repeats (List[int]): repetition pattern along the network

    Returns: 
        blocks (List[int]): auxiliary list for network construction
    """
    blocks = []
    for r in repeats:
        if (r != 0):
            blocks.append(1)
        else:
            blocks.append(0)
    return blocks


def _round_filters(filters: int, width_coefficient: float, depth_divisor: int = 8, min_filters: int = None) -> int:
    """
    Round number of filters based on depth multiplier.
    
    Args: 
        filters (int): base filter number of a sub-block
        width_coefficient (int): scaling coefficient on filters
        depth_divisor (int): scaling coefficient on depth
        min_filters (int): minimum number of filters considered

    Returns: 
        Rounded number of filters
    """
    if not width_coefficient:
        return filters
    filters *= width_coefficient
    min_filters = min_filters or depth_divisor
    new_filters = max(int(filters + depth_divisor / 2) // depth_divisor * depth_divisor, min_filters)
    # Make sure that round down does not go down by more than 10%.
    if new_filters < 0.9 * filters:
        new_filters += depth_divisor
    return int(new_filters)


def _round_repeats(repeats: List[int], depth_coefficient: float, depth_trunc: str) -> List[int]:
    """ 
    Per-stage depth scaling. Scales the block repeats in each stage. This depth scaling maintains
    compatibility with the EfficientNet scaling method, while allowing sensible
    scaling for other models that may have multiple block arg definitions in each stage.

    Args:
        repeats (list): pattern for sub-block repetition
        depth_coefficient (float): scaling coefficient on depth
        depth_trunc (str): method for truncation, example 'round'
    Returns:
        repeats_scaled (list): scaled repeat per stage

    """

    # We scale the total repeat count for each stage, there may be multiple
    # block arg defs per stage so we need to sum.
    num_repeat = sum(repeats)
    if depth_trunc == 'round':
        # Truncating to int by rounding allows stages with few repeats to remain
        # proportionally smaller for longer. This is a good choice when stage definitions
        # include single repeat stages that we'd prefer to keep that way as long as possible
        num_repeat_scaled = round(num_repeat * depth_coefficient)
    else:
        # The default for EfficientNet truncates repeats to int via 'ceil'.
        # Any multiplier > 1.0 will result in an increased depth for every stage.
        num_repeat_scaled = int(math.ceil(num_repeat * depth_coefficient))
    # Proportionally distribute repeat count scaling to each block definition in the stage.
    # Allocation is done in reverse as it results in the first block being less likely to be scaled.
    # The first block makes less sense to repeat in most of the arch definitions.
    repeats_scaled = []
    for r in repeats[::-1]:
        if depth_trunc == 'round':
            rs = round((r / num_repeat * num_repeat_scaled))
        else:
            rs = max(1, round((r / num_repeat * num_repeat_scaled)))
        repeats_scaled.append(rs)
        num_repeat -= r
        num_repeat_scaled -= rs
    repeats_scaled = repeats_scaled[::-1]
    return repeats_scaled


def _swish(x):
    """
    Docstring for _swish
    
    Args:
        x (tf.Tensor): input tensor
    Returns:
        swish activation of x
    """
    return silu(x)
    

def _mb_conv_block(inputs: tf.Tensor, in_channels: int, out_channels: int, num_repeat: int, stride: int, expansion_factor: int, se_ratio: float, k: int, drop_rate: float, 
                  prev_block_num: int, activation) -> tf.Tensor:
    """
    Docstring for _mb_conv_block
    
    Args:
        inputs (tf.Tensor): block input tensor
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        num_repeat (int):
        stride (int): stride of the convolution
        expansion_factor (int): scaling coefficient for the input filters
        se_ratio (float): between 0 and 1, fraction to squeeze the input filters
        k (int): kernel size
        drop_rate (float): between 0 and 1, fraction of the input units to drop
        prev_block_num (int): dropout adjustement parameter
        activation (str): activation function
    Returns:
        tf.Tensor output of mb_conv block

    """

    x = inputs
    input_filters = in_channels
	
    for i in range(num_repeat):
        # Expansion phase: making the layer wide wide as mentioned in Inverted residual block
        input_tensor = x
        if i == 0:
            # The first block needs to take care of stride and filter size increase.
            stride = stride
        else:
            stride = 1

        expanded_filters = input_filters * expansion_factor
        if expansion_factor != 1:
            x = layers.Conv2D(filters=expanded_filters, kernel_size=(1, 1), strides=1, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.Activation(activation)(x)

        x = layers.DepthwiseConv2D(kernel_size=(k, k), strides=stride, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation(activation)(x)

        # Squeeze and excitation phase: extracting global features with global average pooling and squeeze numbers of channels using se_ratio
        squeezed_filters = max (1, int(input_filters * se_ratio))
        se_tensor = layers.GlobalAveragePooling2D()(x)
        se_tensor = layers.Reshape((1, 1, expanded_filters))(se_tensor)
        se_tensor = layers.Conv2D(filters=squeezed_filters , kernel_size=(1, 1), padding='same')(se_tensor)
        se_tensor = layers.BatchNormalization()(se_tensor)  
        se_tensor = layers.Activation(activation, name='act_{}'.format(i+prev_block_num))(se_tensor)
        se_tensor = layers.Conv2D(filters=expanded_filters , kernel_size=(1, 1), padding='same')(se_tensor)
        se_tensor = layers.BatchNormalization()(se_tensor)
        se_tensor = layers.Activation('sigmoid', name='act2_{}'.format(i+prev_block_num))(se_tensor)
        x = layers.multiply([x, se_tensor])
        
        # Output phase:
        x = layers.Conv2D(filters=out_channels, kernel_size=(1, 1), strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)

        if stride == 1 and input_filters == out_channels:
            num_blocks_total = 16
            dropout_rate = drop_rate * float(prev_block_num + i) / num_blocks_total
            if dropout_rate and (dropout_rate > 0):
                x = layers.Dropout(dropout_rate, noise_shape=(None, 1, 1, 1))(x)
            x = layers.add([x, input_tensor])
            
        input_filters = out_channels

    return x


def _EfficientNet(width_coefficient_list: float = 1.0,
                depth_coefficient: float = 1.0,
                input_resolution: int = 224,
                expansion_factor: int = 6,
                se_ratio: float = 0.25,
                input_channels: int = 3,
                dropout_rate: float = 0.2,
                drop_connect_rate: float = 0.2,
                depth_trunc: str = 'ceil',
                activation: str = 'relu',
                include_top=True,
                pooling: str = None,
                classes: int = 101) -> keras.Model:
    """
    Docstring for _EfficientNet
    
    Args: 
        width_coefficient_list (float): scaling coefficient for network width
        depth_coefficient (float): scaling coefficient for network depth
        input_resolution (int): szie of input in pixels
        expansion_factor (int): scaling coefficient for the input filters
        se_ratio (float): between 0 and 1, fraction to squeeze the input filters
        input_channels (int): number of input channels
        dropout_rate (float): dropout rate before final classifier layer
        drop_connect_rate (float): dropout rate at skip connections
        depth_trunc (str): method for truncation
        activation (str): the activation function to use
        include_top (boolean): whether to include the fully-connected layer at the top of the network
        pooling (str): pooling mode for feature extraction, 'None', 'avg' or 'max'
        classes (int): number of classes to classify images
    Returns: 
        keras.model with efficientnet topology
    """

    # Determine proper input shape
    input = keras.Input(shape=(input_resolution, input_resolution, input_channels))

    # Activation
    if activation == 'swish':
        activation = _swish()
    if activation == 'relu6':
        activation = relu6

    # Build stem
    x = layers.Conv2D(filters=_round_filters(32, width_coefficient_list[0]), kernel_size=(3, 3), strides=(2, 2), padding='same')(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation, name='stem_activation')(x)

    # Build blocks
    repeats_scaled = _round_repeats(repeats, depth_coefficient, depth_trunc)
    exp_ratio = _round_expansion(expansion_factor, repeats_scaled)
    
    block1 = _mb_conv_block(inputs=x, in_channels=_round_filters(32, width_coefficient_list[0]), out_channels=_round_filters(16, width_coefficient_list[1]), 
                           num_repeat=repeats_scaled[0],stride=1, expansion_factor=exp_ratio[0], se_ratio=se_ratio, k=3, drop_rate=drop_connect_rate, 
                           prev_block_num=0, activation=activation)

    block2 = _mb_conv_block(inputs=block1, in_channels=_round_filters(16, width_coefficient_list[1]), out_channels=_round_filters(24, width_coefficient_list[2]), 
                           num_repeat=repeats_scaled[1],stride=2, expansion_factor=exp_ratio[1], se_ratio=se_ratio, k=3, drop_rate=drop_connect_rate, 
                           prev_block_num=sum(repeats_scaled[0:1]), activation=activation)
    
    block3 = _mb_conv_block(inputs=block2, in_channels=_round_filters(24, width_coefficient_list[2]), out_channels=_round_filters(40, width_coefficient_list[3]), 
                           num_repeat=repeats_scaled[2],stride=2, expansion_factor=exp_ratio[2], se_ratio=se_ratio, k=5, drop_rate=drop_connect_rate, 
                           prev_block_num=sum(repeats_scaled[0:2]), activation=activation)
    
    block4 = _mb_conv_block(inputs=block3, in_channels=_round_filters(40, width_coefficient_list[3]), out_channels=_round_filters(80, width_coefficient_list[4]), 
                           num_repeat=repeats_scaled[3], stride=2, expansion_factor=exp_ratio[3], se_ratio=se_ratio, k=3, drop_rate=drop_connect_rate, 
                           prev_block_num=sum(repeats_scaled[0:3]), activation=activation)
    
    block5 = _mb_conv_block(inputs=block4, in_channels=_round_filters(80, width_coefficient_list[4]), out_channels=_round_filters(112, width_coefficient_list[5]), 
                           num_repeat=repeats_scaled[4], stride=1, expansion_factor=exp_ratio[4], se_ratio=se_ratio, k=5, drop_rate=drop_connect_rate, 
                           prev_block_num=sum(repeats_scaled[0:4]), activation=activation)
    
    block6 = _mb_conv_block(inputs=block5, in_channels=_round_filters(112, width_coefficient_list[5]), out_channels=_round_filters(192, width_coefficient_list[6]),
                            num_repeat=repeats_scaled[5], stride=2, expansion_factor=exp_ratio[5], se_ratio=se_ratio, k=5, drop_rate=drop_connect_rate, 
                            prev_block_num=sum(repeats_scaled[0:5]), activation=activation)
    
    block7 = _mb_conv_block(inputs=block6, in_channels=_round_filters(192, width_coefficient_list[6]), out_channels=_round_filters(320, width_coefficient_list[7]),
                            num_repeat=repeats_scaled[6],stride=1, expansion_factor=exp_ratio[6], se_ratio=se_ratio, k=3, drop_rate=drop_connect_rate, 
                            prev_block_num=sum(repeats_scaled[0:6]), activation=activation)

    # Build top
    x = layers.Conv2D(filters=_round_filters(1280, width_coefficient_list[8]), kernel_size=(1, 1), padding='same', name='top_conv')(block7)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation, name='top_activation')(x)
	
    if include_top:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        if dropout_rate and dropout_rate > 0:
            x = layers.Dropout(dropout_rate, name='top_dropout')(x)
        x = layers.Dense(classes, activation='softmax', name='output_probs')(x)
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D()(x)

    # Create model.
    model = Model(input, x, name="st_evvicientnetlcv1")

    return model


def get_st_efficientnetlcv1(input_shape: Tuple[int, int, int] = None, num_classes: int = None, dropout: float = None, pretrained: bool = False, **kwargs) -> keras.Model:
    """
    Creates a Keras model for fine-grained classification from scratch.

    Args:
        input_shape (Tuple[int, int, int]): The shape of the input tensor.
        num_classes (int): The number of classes for the classification task.
        dropout (float): The dropout rate.

    Returns:
        keras.Model: A Keras model for fine-grained classification.
    """
    
    if pretrained:
      print("WARNING: No pretrained weights are found for 'st_efficientnet_lv_v1' model. Random weights are used instead.")

    # Validate input_shape is square
    if input_shape[0] != input_shape[1]:
        raise ValueError(f"Expecting image width and height to be the same. Received image shape {input_shape}")
    # Validate input_shape is multiple of 32
    if (input_shape[0] % 32 > 0) or (input_shape[1] % 32 > 0):
        raise ValueError(f"Expecting image width and height to be multiples of 32. Received image shape {input_shape}")
    activation = 'relu6'
    d = 1.
    w = [0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45, 0.45]
    e = 3
    model = _EfficientNet(width_coefficient_list=w, depth_coefficient=d, input_resolution=input_shape[0], expansion_factor=e, depth_trunc='ceil', activation=activation, 
							input_channels=input_shape[2], dropout_rate=dropout, include_top=True, classes=num_classes)
    
    return model