FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022 STMicroelectronics.
# * All rights reserved.
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import keras
from keras import layers
from keras.regularizers import l2
from typing import Tuple
def _resnet_layer(inputs: layers.Input, num_filters: int = 16, kernel_size: int = 3, strides: int = 1,
activation: str = 'relu', batch_normalization: bool = True,
conv_first: bool = True, **kwargs) -> layers.Activation:
"""
2D Convolution-Batch Normalization-Activation stack builder for ResNet models.
Args:
inputs: Input tensor from input image or previous layer.
num_filters: Conv2D number of filters.
kernel_size: Conv2D square kernel dimensions.
strides: Conv2D square stride dimensions.
activation: Activation name.
batch_normalization: Whether to include batch normalization.
conv_first: Conv-BN-Activation (True) or BN-Activation-Conv (False).
Returns:
A tensor as input to the next layer.
"""
conv = layers.Conv2D(num_filters,
kernel_size=kernel_size,
strides=strides,
padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(1e-4))
x = inputs
if conv_first:
x = conv(x)
if batch_normalization:
x = layers.BatchNormalization()(x)
if activation is not None:
x = layers.Activation(activation)(x)
else:
if batch_normalization:
x = layers.BatchNormalization()(x)
if activation is not None:
x = layers.Activation(activation)(x)
x = conv(x)
return x
def get_resnet(num_classes: int = None, input_shape: Tuple[int, int, int] = None,
depth: int = None, dropout: float = None, pretrained: bool = False, **kwargs) -> keras.Model:
"""
ResNet Version 1 Model builder.
Stacks of 2 x (3 x 3) Conv2D-BN-ReLU. Last ReLU is after the shortcut connection.
At the beginning of each stage, the feature map size is halved (down-sampled)
by a convolutional layer with strides=2, while the number of filters is doubled.
Within each stage, the layers have the same number filters and the same number of filters.
Args:
num_classes: Number of classes in the dataset.
input_shape: Shape of the input tensor.
depth: Depth of the ResNet model. Must be one of [8, 20, 32].
dropout: Dropout rate to be applied to the fully connected layer.
Returns:
A Keras model instance.
"""
if pretrained:
print("WARNING: No pretrained weights are found for 'resnet' model. Random weights are used instead.")
allowed_depths = [8, 20, 32]
if depth not in allowed_depths:
raise ValueError(f"depth must be one of {allowed_depths}, got {depth}")
# Start model definition.
num_filters = 16
num_res_blocks = int((depth - 2) / 6)
inputs = keras.Input(shape=input_shape)
x = _resnet_layer(inputs=inputs)
# Instantiate the stack of residual units
for stack in range(3):
for res_block in range(num_res_blocks):
strides = 1
if stack > 0 and res_block == 0: # first layer but not first stack
strides = 2 # down sample
y = _resnet_layer(inputs=x,
num_filters=num_filters,
strides=strides)
y = _resnet_layer(inputs=y,
num_filters=num_filters,
activation=None)
if stack > 0 and res_block == 0: # first layer but not first stack
# linear projection residual shortcut connection to match changed dims
x = _resnet_layer(inputs=x,
num_filters=num_filters,
kernel_size=1,
strides=strides,
activation=None,
batch_normalization=False)
x = layers.add([x, y])
x = layers.Activation('relu')(x)
num_filters *= 2
# Add classifier on top.
# v1 does not use BN after last shortcut connection-ReLU
x = layers.AveragePooling2D(pool_size=8)(x)
x = layers.Flatten()(x)
if dropout:
x = layers.Dropout(dropout)(x)
if num_classes > 2:
outputs = layers.Dense(num_classes, activation="softmax", kernel_initializer='he_normal')(x)
else:
outputs = layers.Dense(1, activation="sigmoid")(x)
# Instantiate model.
model = keras.Model(inputs=inputs, outputs=outputs, name=f"resnet{depth}")
return model