FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright 2018 The TensorFlow Authors.
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import keras
from keras.applications import imagenet_utils
from keras import layers
from keras.applications.efficientnet_v2 import (EfficientNetV2B0, EfficientNetV2B1, EfficientNetV2B2, EfficientNetV2B3,
EfficientNetV2S)
def get_efficientnetv2(input_shape: tuple, model_type: str = None, num_classes: int = None, dropout: float = None,
pretrained: bool = True, **kwargs) -> keras.Model:
"""
Returns a transfer learning model based on efficient net v2 architecture pre-trained on imagenet or random.
Args:
input_shape (tuple): Shape of the input tensor.
model_type (string): B0, B1, B2, B3, S. Default is None.
num_classes (int): Number of output classes of the target use-case. Default is None.
dropout (float, optional): The dropout rate for the custom classifier.
pretrained_weights (str, optional): The pre-trained weights to use. Either "imagenet" or None.
Returns:
keras.Model: Transfer learning model based on efficient net v2 architecture.
Raises:
"""
# commented because no longer used later-on
# if pretrained_weights:
# training = False
# # model is set in inference mode so that moving avg and var of any BN are kept untouched
# # should help the convergence according to Keras tutorial
# else:
# training = True
# fetch the backbone pre-trained on imagenet or random
if model_type == "B0":
backbone_func = EfficientNetV2B0
elif model_type == "B1":
backbone_func = EfficientNetV2B1
elif model_type == "B2":
backbone_func = EfficientNetV2B2
elif model_type == "B3":
backbone_func = EfficientNetV2B3
elif model_type == "S":
backbone_func = EfficientNetV2S
if dropout:
# Model loaded for training
base_model = backbone_func(
include_top=False,
weights="imagenet" if pretrained else None,
input_tensor=None,
input_shape=input_shape,
pooling="avg",
classes=num_classes,
include_preprocessing=False
)
# Create a new model on top
x = layers.Dropout(rate=dropout, name="dropout")(base_model.output)
if num_classes > 2:
outputs = layers.Dense(num_classes, activation="softmax")(x)
else:
outputs = layers.Dense(1, activation="sigmoid")(x)
else:
# Model entirely loaded for other services than training => no dropout
base_model = backbone_func(
include_top=True,
weights="imagenet" if pretrained else None,
input_tensor=None,
input_shape=input_shape,
pooling="avg",
classes=num_classes,
include_preprocessing=False,
classifier_activation="softmax"
)
outputs = base_model.output
# Create the Keras model
model = keras.Model(inputs=base_model.input, outputs=outputs, name="efficientnetv2"+model_type)
return model