File size: 9,197 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2022 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

# Import necessary libraries
import os
import sys
from pathlib import Path
import warnings
import sklearn
import mlflow
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from typing import Tuple, Optional, List, Dict
import numpy as np

# Suppress warnings and TensorFlow logs
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import onnxruntime
import tensorflow as tf
import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix

# Import utility functions
from image_classification.tf.src.utils import ai_runner_invoke
from common.evaluation import model_is_quantized, predict_onnx_batch
from common.utils import (
    tf_dataset_to_np_array, display_figures, ai_runner_interp, ai_interp_input_quant,
    ai_interp_outputs_dequant, plot_confusion_matrix, torch_dataset_to_np_array
)  # Common utilities for evaluation and visualization


# Define a class for evaluating ONNX models
class ONNXModelEvaluator:
    """

    A class to evaluate ONNX models.



    Args:

        cfg (DictConfig): Configuration object for evaluation.

        model (object): The ONNX model to evaluate.

        dataloaders (dict): Dictionary containing datasets for testing and validation.

    """
    def __init__(self, cfg: DictConfig, model: object, 

                 dataloaders: dict = None):
        self.cfg = cfg
        self.input_model = model
        self.test_ds = dataloaders['test']
        self.valid_ds = dataloaders['valid']
        self.output_dir = HydraConfig.get().runtime.output_dir
        self.class_names = cfg.dataset.class_names
        self.display_figures = cfg.general.display_figures
        input_chpos = getattr(cfg.evaluation, 'input_chpos', 'chlast') if hasattr(cfg, 'evaluation') else 'chlast'
        if self.cfg.model.framework == "tf":
            # Dataloader is channel last with TF
            if input_chpos=="chfirst" or self._get_target() == 'host':
                self.nchw = True
            else:
                self.nchw = False
        else:
            # Dataloader is already channel first with Torch
            if input_chpos=="chfirst":
                self.nchw = False
            else:
                self.nchw = True
        
        self.eval_ds = None
        self.name_ds = None

    def _prepare_evaluation(self):
        """

        Prepares the evaluation process by selecting the appropriate dataset.

        """
        # Use the test dataset if available; otherwise, use the validation dataset
        if self.test_ds:
            self.eval_ds = self.test_ds
            self.name_ds = "test_set"
        else:
            self.eval_ds = self.valid_ds
            self.name_ds = "validation_set"

    def _ensure_output_dir(self):
        """

        Ensures that the output directory exists.

        """
        if not os.path.isdir(self.output_dir):
            os.makedirs(self.output_dir)

    def _get_target(self):
        """

        Retrieves the evaluation target from the configuration.

        Returns:

            str: the target on which evaluation will be done, by default 'host'

        """
        if self.cfg.evaluation and self.cfg.evaluation.target:
            return self.cfg.evaluation.target
        return "host"

    def _get_model_type(self):
        """

        Determines whether the model is quantized or float.

        Returns:

            str: 'quantized' or 'float' depending on model quantization status

        """
        return 'quantized' if model_is_quantized(self.input_model.model_path) else 'float'

    def _get_ai_runner_interpreter(self, target):
        """

        Retrieves the AI runner interpreter for the specified target.

        Args:

            target (str): target on which evaluation is to be performed

        Returns: ai_runner interpreter correctly parametrized

        """
        name_model = os.path.basename(self.input_model.model_path)
        return ai_runner_interp(target, name_model)

    def _predict(self, ai_runner_interpreter, model_type, target):
        """

        Runs predictions on the evaluation dataset.



        Args:

            ai_runner_interpreter: AI runner interpreter for inference.

            model_type (str): Type of the model ('float' or 'quantized').

            target (str): Evaluation target.



        Returns:

            Tuple[np.ndarray, np.ndarray]: Predicted labels and input labels.

        """
        if model_type == 'float' or target == 'host':
            # Use ONNX runtime for predictions
            prd_labels, input_labels = predict_onnx_batch(sess=self.input_model, 
                                                          data=self.eval_ds,
                                                          nchw=self.nchw)
            prd_labels = prd_labels.argmax(axis=1)
        elif target in ['stedgeai_host', 'stedgeai_n6', 'stedgeai_h7p'] and model_type == 'quantized':
            # Use AI runner for predictions
            prd_labels = []
            if self.cfg.model.framework == "tf":
                # Convert the tf dataset to NumPy array as dataloader was based on TF framework
                input_samples, input_labels = tf_dataset_to_np_array(input_ds=self.eval_ds, 
                                                                     nchw=self.nchw)
            else: #if self.cfg.model.framework == "torch":
                input_samples, input_labels = torch_dataset_to_np_array(input_loader=self.eval_ds,
                                                                        nchw=self.nchw)

            for i in tqdm.tqdm(range(input_samples.shape[0])):
                data = ai_interp_input_quant(ai_runner_interpreter, input_samples[i][None], '.onnx')
                prd_label = ai_runner_invoke(data, ai_runner_interpreter)
                prd_label = ai_interp_outputs_dequant(ai_runner_interpreter, [prd_label])[0]
                prd_label = prd_label.argmax(axis=1)
                prd_labels.append(prd_label)
            prd_labels = np.array(prd_labels, dtype=np.float32)
        else:
            raise TypeError("Only supported targets are \"host\", \"stedgeai_host\" or \"stedgeai_n6\". "
                            "Check the \"evaluation\" section of your configuration file.")
        return prd_labels, input_labels

    def _run_evaluate(self):
        """

        Runs the evaluation process and computes metrics.



        Returns:

            Tuple[float, np.ndarray]: Accuracy and confusion matrix.

        """
        self._ensure_output_dir()       # Ensure the output directory exists
        target = self._get_target()     # Get the evaluation target
        model_type = self._get_model_type() # Determine the model type
        ai_runner_interpreter = self._get_ai_runner_interpreter(target=target)   # Get the AI runner interpreter

        # Run predictions
        prd_labels, input_labels = self._predict(ai_runner_interpreter, model_type, target)
        accuracy = round(accuracy_score(input_labels, prd_labels) * 100, 2)
        print(f'[INFO] : Evaluation accuracy on {self.name_ds}: {accuracy} %')

        # Log evaluation results to a file
        log_file_name = f"{self.output_dir}/stm32ai_main.log"
        with open(log_file_name, 'a', encoding='utf-8') as f:
            f.write(f'{model_type} onnx model\nEvaluation accuracy: {accuracy} %\n')

        # Compute the confusion matrix
        cm = confusion_matrix(input_labels, prd_labels)
        acc_metric_name = f"int_acc_{self.name_ds}" if model_type == 'quantized' else f"float_acc_{self.name_ds}"
        mlflow.log_metric(acc_metric_name, accuracy)

        # Plot and display the confusion matrix if enabled
        if self.display_figures:
            model_name = f'{model_type}_onnx_model_{self.name_ds}'
            plot_confusion_matrix(
                cm=cm,
                class_names=self.class_names,
                model_name=model_name,
                title=f'{model_name}\naccuracy: {accuracy}',
                output_dir=self.output_dir
            )
            display_figures(self.cfg)
        return accuracy, cm

    def evaluate(self):
        """

        Executes the full evaluation process.



        Returns:

            float: Accuracy of the model on the evaluation dataset.

        """
        self._prepare_evaluation()      # Prepare the evaluation process
        acc, cm = self._run_evaluate()  # Run the evaluation
        print('[INFO] : Evaluation complete.')
        return acc   # Return the accuracy