File size: 4,815 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import sys
import os
from pathlib import Path
from typing import Tuple, Dict, Optional, List
import tensorflow as tf
from omegaconf import DictConfig
from common.utils import transfer_pretrained_weights, check_attribute_value, check_model_support, check_attributes
from image_classification.tf.src.models import get_mobilenetv1, get_mobilenetv2, get_fdmobilenet, get_resnet, \
get_resnet50v2, get_squeezenetv11, get_st_mnistv1, get_st_efficientnetlcv1, \
get_st_fdmobilenetv1, get_efficientnetv2, get_custom_model
def ai_runner_invoke(image_processed, ai_runner_interpreter):
"""
Docstring for ai_runner_invoke
Args:
image_processed (tf.Tensor): input images
ai_runner_interpreter: ai_runner object to be invoked on input images
Returns:
prediction outputs
"""
preds, _ = ai_runner_interpreter.invoke(image_processed)
nb_class = preds[0].shape[-1]
return preds[0].reshape([-1, nb_class])
def change_model_number_of_classes(model: tf.keras.Model, num_classes: int):
"""
Docstring for change_model_number_of_classes
Args:
model (tf.keras.Model): Keras model
num_classes (int): new number of classes as output
Returns:
(tf.keras.Model): a new model with updated number of classes
"""
output_shape = num_classes
# If the model already has the correct number of classes -> dont do anything
for outp in model.outputs:
if outp.shape[-1] == output_shape:
return model
l = -1
l_list = []
while True:
layer_type = type(model.layers[l])
layer_config = model.layers[l].get_config()
if layer_type in [tf.keras.layers.Conv2D,
tf.keras.layers.Conv2DTranspose,
tf.keras.layers.Conv1D,
tf.keras.layers.Conv1DTranspose,
tf.keras.layers.Dense]:
if layer_type in [tf.keras.layers.Conv2D,tf.keras.layers.Conv2DTranspose,tf.keras.layers.Conv1D,tf.keras.layers.Conv1DTranspose]:
layer_config['filters'] = output_shape
new_layer = layer_type(**layer_config)
outputs = new_layer(model.layers[l-1].output)
else:
layer_config['units'] = output_shape
new_layer = layer_type(**layer_config)
outputs = new_layer(model.layers[l-1].output)
for i, new_l in enumerate(l_list[::-1]):
outputs = new_l(outputs)
return tf.keras.Model(inputs=model.input, outputs=outputs, name=model.name)
else:
l_list.append(layer_type(**layer_config))
l-=1
return None
def change_model_input_shape(model: tf.keras.Model, new_inp_shape: Tuple):
"""
Change model input shape
Args
model (tf.keras.Model): keras model
new_inp_shape (Tuple): new input shape for model update
Returns:
(tf.keras.Model): updated model
"""
conf = model.get_config()
conf['layers'][0]['config']['batch_shape'] = new_inp_shape
new_model = model.__class__.from_config(conf, custom_objects={})
# iterate over all the layers that we want to get weights from
weights = [layer.get_weights() for layer in model.layers[1:]]
for layer, weight in zip(new_model.layers[1:], weights):
layer.set_weights(weight)
old_inp_shape = model.get_config()['layers'][0]['config']['batch_shape']
return new_model, old_inp_shape
def get_loss(num_classes: int) -> tf.keras.losses:
"""
Returns the appropriate loss function based on the number of classes in the dataset.
Args:
num_classes (int): The number of classes in the dataset.
Returns:
tf.keras.losses: The appropriate loss function based on the number of classes in the dataset.
"""
# We use the sparse version of the categorical crossentropy because
# this is what we use to load the dataset.
if num_classes > 2:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
else:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
return loss
|