File size: 21,721 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
#  /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022-2023 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

import os
from tabulate import tabulate
from tensorflow.python.keras.utils.layer_utils import count_params
from typing import Dict, Optional, Tuple, List
import tensorflow as tf
from tensorflow.keras.models import Model
import numpy as np
import sklearn
from pathlib import Path
from onnx import ModelProto
import onnxruntime
import mlflow
from common.utils import log_to_file
import torch


def ai_interp_input_quant(ai_interp, data: np.array, file_extension: str):
    ai_runner_input_details = ai_interp.get_inputs()[0]  # input
    if ai_runner_input_details.dtype in [np.uint8, np.int8]:
        # rescale the data between [0,255]
        resc_data = (data / ai_runner_input_details.scale[0]) + ai_runner_input_details.zero_point[0]
        # change the dtype of the data
        if ai_runner_input_details.dtype==np.int8:
            out_data = (resc_data-128).astype(np.int8)
        elif ai_runner_input_details.dtype==np.uint8:
            out_data = resc_data.astype(np.uint8)
    else:
        out_data = data.astype(np.float32)
#    if ai_runner_input_details.shape[1:] != out_data.shape[1:]:
#        if file_extension == '.tflite':
#            out_data = np.transpose(out_data,[0,3,1,2])    # chlast -> chfirst
#        elif file_extension == '.onnx':
#            out_data = np.transpose(out_data,[0,2,3,1])    # chfirst -> chlast
    return out_data

def ai_interp_outputs_dequant(ai_interp, predictions: np.array):
    ai_runner_output_details = ai_interp.get_outputs()  # outputs
    out_predictions = []
    for i,ai_out in enumerate(ai_runner_output_details):
        if ai_out.scale[0]!=0 or ai_out.zero_point[0]!=0:
            out_predictions.append(ai_out.scale[0]*(predictions[i].astype(np.float32)-ai_out.zero_point[0]))
        else:
            out_predictions.append(predictions[i].astype(np.float32))
    return out_predictions


def ai_runner_interp(target: str, name_model: str):
    """Returns an interpreter for N6 board on-target inference



    Args:

        target (str): Target to run the model on (emulated N6 or real N6 board)

        name_model (str): Name of the model to print error message if cant connect to the board



    Returns:

        ai_runner_interpreter : The interpreter

        ai_runner_input_details : Dictionnary with details about the inputs of the model

        ai_runner_output_details : Dictionnary with details about the outputs of the model

    """
    from common.stm_ai_runner import AiRunner
    if target in ['stedgeai_host', 'stedgeai_n6', 'stedgeai_h7p'] :
        print(f"Loading {target} for ST Edge AI inference of {name_model}")
        from common.stm_ai_runner import AiRunner
        if target == 'stedgeai_host':
            ai_runner_desc = 'st_ai_ws'
        if target in ['stedgeai_n6', 'stedgeai_h7p']:
            ai_runner_desc = 'serial:921600'
        ai_runner_interpreter = AiRunner()
        if not ai_runner_interpreter.connect(ai_runner_desc):
            raise TypeError(f"model='{name_model}' unable to load the model")
        ai_runner_input_details = ai_runner_interpreter.get_inputs()  # inputs
        ai_runner_output_details = ai_runner_interpreter.get_outputs()  # outputs
        for detail in ai_runner_input_details:
            print(f" I: {detail.name} {detail.shape} {detail.dtype} {detail.scale} {detail.zero_point}")
        for detail in ai_runner_output_details:
            print(f" O: {detail.name} {detail.shape} {detail.dtype} {detail.scale} {detail.zero_point}")
    else:
        ai_runner_interpreter, ai_runner_input_details, ai_runner_output_details = None,None,None
    return ai_runner_interpreter


def get_model_name(model_type: str, 

                   input_shape: int, 

                   project_name: str) -> str:
    """Returns a string representation of the model name.



    Args:

        model_type (str): Type of the model.

        input_shape (int): Input shape of the model.

        project_name (str): Name of the project.



    Returns:

        str: String representation of the model name.

    """
    # Combine strings to form model name
    strings = [model_type, str(input_shape), project_name]
    name = '_'.join([str(i) for i in strings])

    return name


def get_model_name_and_its_input_shape(model_path: str = None, 

                                       custom_objects: Dict = None) -> Tuple:
    """

    Load a model from a given file path and return the model name and

    its input shape. Supported model formats are .h5, .keras, .tflite and .onnx.

    The basename of the model file is used as the model name. The input

    shape is extracted from the model.



    Args:

        model_path (str): A path to an .h5, .keras, .tflite or .onnx model file.

        custom_objects (Dict): a dictionnary containing custom object from the model



    Returns:

        Tuple: A tuple containing the loaded model name and its input shape.

               The input shape is a tuple of length 3.

    Raises:

        ValueError: If the model file extension is not '.h5' or '.tflite'.

        RuntimeError: If the input shape of the model cannot be found.

    """

    # We use the file basename as the model name.
    model_name = Path(model_path).stem

    file_extension = Path(model_path).suffix
    if file_extension in [".h5",".keras"]:
        # When we resume a training, the model includes the preprocessing layers
        # (augmented model). Therefore, we need to declare the custom data
        # augmentation layer as a custom object to be able to load the model.
        model = tf.keras.models.load_model(
                        model_path,
                        custom_objects = custom_objects,
                        compile=False)
        try :
            input_shape = tuple(model.input.shape[1:])
        except:
            input_shape = tuple(model.inputs[0].shape[1:])

    elif file_extension == ".tflite":
        try:
            # Load the tflite model
            interpreter = tf.lite.Interpreter(model_path=model_path)
            interpreter.allocate_tensors()
            # Get the input details
            input_details = interpreter.get_input_details()
            input_shape = input_details[0]['shape']
            input_shape = tuple(input_shape)[-3:]
        except RuntimeError as error:
            raise RuntimeError("\nUnable to extract input shape from .tflite model file\n"
                               f"Received path {model_path}") from error

    elif file_extension == ".onnx":
        try:
            # Load the model
            onx = ModelProto()
            with open(model_path, "rb") as f:
                content = f.read()
                onx.ParseFromString(content)
            sess = onnxruntime.InferenceSession(model_path)
            # Get the model input shape
            input_shape = sess.get_inputs()[0].shape
            input_shape = tuple(input_shape)[-3:]
        except RuntimeError as error:
            raise RuntimeError("\nUnable to extract input shape from .onnx model file\n"
                               f"Received path {model_path}") from error

    else:
        raise RuntimeError(f"\nUnknown/unsupported model file type.\nExpected `.tflite`, `.h5`, `.keras`, or `.onnx`."
                           f"\nReceived path {model_path.split('.')[-1]}")

    return model_name, input_shape


def check_model_support(model_name: str, version: Optional[str] = None,

                        supported_models: Dict = None,

                        message: Optional[str] = None) -> None:
    """

    Check if a model name and version are supported based on a dictionary of supported models and versions.



    Args:

        model_name(str): The name of the model to check.

        version(str): The version of the model to check. May be set to None by the caller.

        supported_models(Dict[str, List[str]]): A dictionary of supported models and their versions.

        message(str): An error message to print.



    Raises:

        NotImplementedError: If the model name or version is not in the list of supported models or versions.

        ValueError: If the version attribute is missing or not applicable for the given model.

    """
    if model_name not in supported_models:
        x = list(supported_models.keys())
        raise ValueError("\nSupported model names are {}. Received {}.{}".format(x, model_name, message))

    model_versions = supported_models[model_name]
    if model_versions:
        # There are different versions of the model.
        if not version:
            # The version is missing.
            raise ValueError("\nMissing `version` attribute for `{}` model.{}".format(model_name, message))
        if version not in model_versions:
            # The version is not a supported version.
            raise ValueError("\nSupported versions for `{}` model are {}. "
                             "Received {}.{}".format(model_name, model_versions, version, message))
    else:
        if version:
            # A version was given but there is no version for this model.
            raise ValueError("\nThe `version` attribute is not applicable "
                             "to '{}' model.{}".format(model_name, message))


def check_attribute_value(attribute_value: str, values: List[str] = None,

                          name: str = None, message: str = None) -> None:
    """

    Check if an attribute value is valid based on a list of supported values.

    Args:

        attribute_value(str): The value of the attribute to check.

        values(List[str]): A list of supported values.

        name(str): The name of the attribute being checked.

        message(str): A message to print if the attribute is not supported.

    Raises:

        ValueError: If the attribute value is not in the list of supported values.

    """
    if attribute_value not in values:
        raise ValueError(f"\nSupported values for `{name}` attribute are {values}. "
                         f"Received {attribute_value}.{message}")


def transfer_pretrained_weights(target_model: tf.keras.Model, source_model_path: str = None,

                                end_layer_index: int = None, target_model_name: str = None) -> None:
    # NOTE : Unused in AED for now.
    # When it's ready to use, call it after loading model in get_model.
    """

    Copy the weights of a source model to a target model. Only the backbone weights

    are copied as the two models can have different classifiers.



    Args:

        target_model (tf.keras.Model): The target model.

        source_model_path (str): Path to the source model file (h5 or keras file).

        end_layer_index (int): Index of the last backbone layer (the first layer of the model has index 0).

        target_model_name (str): The name of the target model.



    Raises:

        ValueError: The source model file cannot be found.

        ValueError: The two models are incompatible because they have different backbones.

    """

    if source_model_path:
        if not os.path.isfile(source_model_path):
            raise ValueError("Unable to find pretrained model file.\nReceived "
                             f"model path {source_model_path}")
        source_model = tf.keras.models.load_model(source_model_path, compile=False)

    message = f"\nUnable to transfer to model `{target_model_name}`"
    message += f"the weights from model {source_model_path}\n"
    message += "Models are incompatible (backbones are different)."
    if len(source_model.layers) < end_layer_index + 1:
        raise ValueError(message)
    for i in range(end_layer_index + 1):
        weights = source_model.layers[i].get_weights()
        try:
            target_model.layers[i].set_weights(weights)
        except ValueError as error:
            raise message from error


def model_summary(model):
    """

    This function displays a model summary. It is similar to a Keras

    model summary with the additional information:

    - Indices of layers

    - Trainable/non-trainable status of layers

    - Total number of layers

    - Number of trainable layers

    - Number of non-trainable layers

    """
    # Create the summary table
    num_layers = len(model.layers)
    trainable_layers = 0
    table = []
    for i, layer in enumerate(model.layers):
        layer_type = layer.__class__.__name__
        if layer_type == "InputLayer":
            layer_shape = model.input.shape
        else:
            layer_shape = layer.output.shape
        is_trainable = True if layer.trainable else False
        num_params = layer.count_params()
        if layer.trainable:
            trainable_layers += 1
        table.append([i, is_trainable, layer.name, layer_type, num_params, layer_shape])

    # Display the table
    print(108 * '=')
    print("  Model:", model.name)
    print(108 * '=')
    print(tabulate(table, headers=["Layer index", "Trainable", "Name", "Type", "Params#", "Output shape"]))
    print(108 * '-')
    print("Total params:", model.count_params())
    print("Trainable params: ", count_params(model.trainable_weights))
    print("Non-trainable params: ", count_params(model.non_trainable_weights))
    print(108 * '-')
    print("Total layers:", num_layers)
    print("Trainable layers:", trainable_layers)
    print("Non-trainable layers:", num_layers - trainable_layers)
    print(108 * '=')


def count_h5_parameters(output_dir: str = None, model: tf.keras.Model = None):
    total_params = model.count_params()
    mlflow.log_metric(f"nb_params", total_params)
    log_to_file(output_dir, f"Nb params of float model : {total_params}")
    print(f"[INFO] : Nb params of float model : {total_params}")


def count_tflite_parameters(output_dir: str = None,

                            model_path: str = None,

                            num_threads: Optional[int] = 1):
    interpreter = tf.lite.Interpreter(model_path=model_path, num_threads=num_threads)
    # Get all tensor details
    tensor_details = interpreter.get_tensor_details()

    total_params = 0
    for tensor in tensor_details:
        shape = tensor['shape']
        if shape is not None:
            num_params = 1
            for dim in shape:
                num_params *= dim
            total_params += num_params
    mlflow.log_metric(f"quantized_nb_params", total_params)
    log_to_file(output_dir, f"Nb params of quantized model : {total_params}")
    print(f"[INFO] : Nb params of quantized model : {total_params}")


def tf_dataset_to_np_array(input_ds, nchw=True, labels_included=True):
    """

    Converts a TensorFlow dataset into two NumPy arrays containing the data and labels.



    This function iterates over the provided TensorFlow dataset, casts the image data to

    float32, and then converts the images and labels into NumPy arrays. The images and

    labels from all batches are concatenated along the first axis (batch dimension) to

    form two unified arrays.



    Parameters:

    - input_ds (tf.data.Dataset): A TensorFlow dataset object that yields tuples of

      (images, labels) when iterated over.



    - labels_included (bool): A boolean that represent whether or not the dataset 

      contains the labels of the images (True) or just the images (False)



    Returns:

    - tuple: A tuple containing two NumPy arrays:

        - The first array contains the image data from the dataset.

        - The second array contains the corresponding labels.



    Example:

    ```python

    import tensorflow as tf

    import numpy as np



    # Assuming `dataset` is a pre-defined TensorFlow dataset with image-label pairs

    data, labels = tf_dataset_to_np_array(dataset)



    print(data.shape)   # Prints the shape of the image data array

    print(labels.shape) # Prints the shape of the labels array



    # Assuming `dataset` is a pre-defined TensorFlow dataset with image only

    data, _ = tf_dataset_to_np_array(dataset,labels_included=False)



    print(data.shape)   # Prints the shape of the image data array

    ```



    Note:

    - The input TensorFlow dataset is expected to yield batches of data.

    - The function assumes that the dataset yields data in the form of (images, labels),

      where `images` are the features and `labels` are the corresponding targets

      or the data is of the form (images).

    - The function will fail if the input dataset does not yield data in the expected format.

    """
    batch_data = []
    batch_labels = []
    if labels_included:
        for images, labels in input_ds:
            images = tf.cast(images, dtype=tf.float32).numpy()
            batch_data.append(images)
            batch_labels.append(labels)
        batch_labels = np.concatenate(batch_labels, axis=0)
    else:
        for images in input_ds:
            images = tf.cast(images, dtype=tf.float32).numpy()
            batch_data.append(images)
    batch_data = np.concatenate(batch_data, axis=0)
    
    # Convert image to input data
    if nchw and batch_data is not None:
        if batch_data.ndim == 4:
            # For a 4D array with shape [n, h, w, c], the new order will be [n, c, h, w]
            axes_order = (0, 3, 1, 2)
        elif batch_data.ndim == 3:
            # For a 3D array with shape [n, h, c], the new order will be [n, c, h]
            axes_order = (0, 2, 1)
        else:
            raise ValueError("The input array must have either 3 or 4 dimensions.")
        batch_data = np.transpose(batch_data, axes_order)

    return batch_data, batch_labels


def torch_dataset_to_np_array(input_loader, nchw=True, labels_included=True, device="cpu"):
    """

    Converts a PyTorch DataLoader into two NumPy arrays containing the data and labels.



    Parameters:

    - input_loader (torch.utils.data.DataLoader): A PyTorch DataLoader yielding batches of

      (images, labels) or (images,) tensors.

    - nchw (bool): Whether to ensure the output images follow NCHW format (channels first).

    - labels_included (bool): Whether the dataset includes labels.

    - device (str): Device to move tensors to before converting (default: 'cpu').



    Returns:

    - tuple: (images_np, labels_np)

        - images_np: numpy.ndarray containing all image tensors concatenated along batch dim.

        - labels_np: numpy.ndarray containing all labels (if labels_included=True), else None.

    

    Example:

    ```python

    data, labels = torch_dataset_to_np_array(train_loader)

    print(data.shape, labels.shape)

    

    images_only, _ = torch_dataset_to_np_array(pred_loader, labels_included=False)

    print(images_only.shape)

    ```

    """
    all_images = []
    all_labels = []

    # Ensure no gradients interfere
    with torch.no_grad():
        for batch in input_loader:
            if labels_included:
                images, labels = batch
                images = images.to(device)
                all_labels.append(labels.cpu().numpy())
            else:
                images = batch[0] if isinstance(batch, (list, tuple)) else batch
                images = images.to(device)

            all_images.append(images.cpu().numpy())

    # Concatenate all batches
    images_np = np.concatenate(all_images, axis=0)
    labels_np = np.concatenate(all_labels, axis=0) if labels_included else None

    # Reorder channels if needed
    if nchw and images_np.ndim == 4:
        # Convert from NCHW -> NHWC
        images_np = np.transpose(images_np, (0, 2, 3, 1))

    return images_np, labels_np

def compute_confusion_matrix(test_set: tf.data.Dataset = None, model: Model = None) -> Tuple[np.ndarray, np.float32]:
    """

    Computes the confusion matrix and logs it as an image summary.



    Args:

        test_set (tf.data.Dataset): The test dataset to evaluate the model on.

        model (tf.keras.models.Model): The trained model to evaluate.

    Returns:

        confusion_matrix and accuracy

    """
    test_pred = []
    test_labels = []
    for data in test_set:
        test_pred_score = model.predict_on_batch(data[0])
        if test_pred_score.shape[1] > 1:
            # Multi-label classification
            test_pred.append(np.argmax(test_pred_score, axis=1))
        else:
            # Binary classification
            test_pred_score = np.where(test_pred_score < 0.5, 0, 1)
            test_pred.append(np.squeeze(test_pred_score))
        # handle both types of the ground truth labels (one-hotcoded or integer)
        batch_labels = np.argmax(data[1], axis=1) if len(data[1].shape)>1 else data[1]
        test_labels.append(batch_labels)

    labels = np.concatenate(test_labels, axis=0)
    logits = np.concatenate(test_pred, axis=0)
    test_accuracy = round((np.sum(labels == logits) * 100) / len(labels), 2)

    # Calculate the confusion matrix.
    cm = sklearn.metrics.confusion_matrix(labels, logits)
    return cm, test_accuracy