File size: 5,942 Bytes
ae29340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import time
import sys
from data_gen import DataGenerator

from model.model import Thundernet as Thundernet_original
from model.model_ppm_factors import Thundernet as Thundernet_ppm

from collections import defaultdict

import thundernet_config as Thundernet_config
import numpy as np
import argparse
from glob import glob
from utils import resolution2framesize3cha, simple_iou_for_multiple_classes, image_test
import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from images_toolkit import show_two_images, overlap_image_with_label, show_x_images

# Example command: python inference_config.py --model_path C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5 --classes 2

baseline_duration = None

parser = argparse.ArgumentParser()

parser.add_argument(
    "--model_path",
    type=str,
    default=Thundernet_config.model_weights,
    help="Base directory for the hdf5 model, they are usually stored is /home/user/nas/deep_experiments/",
)

parser.add_argument(
    "--classes", type=int, default=Thundernet_config.classes, help="Number of classes. "
)

parser.add_argument(
    "--resolution",
    type=str,
    default=Thundernet_config.resolution,
    help="Input Resolution",
)


def main(
    args: list,
    model: str = "original",
    class_mappings: dict = None,
    transformations: tuple = tuple(),
    show: bool = False,
) -> None:
    """
    Perform inference in a set of images. If show=True, each prediction
    will be shown in the screen.
    Args:
     - args (list): list of parsed arguments
     - model (str): type of model. Default: "original"
     - class_mappings (dict): class mapper. Default: None
     - transformations (tuple): list of transformations to execute in the data. Default: tuple()
     - show (bool): display the predictions. Default: False
    Returns:
    - None
    """

    FLAGS: argparse.Namespace = parser.parse_args(args)

    # Get the model
    if model == "original":
        Thundernet = Thundernet_original
    elif model == "ppm":
        Thundernet = Thundernet_ppm
    else:
        raise ValueError(f"Unknown model: {model}")

    # Set class mapping
    if class_mappings is not None:
        FLAGS.classes = len(set(class_mappings.values())) + 1

    # Get the shape and the classes
    input_shape = resolution2framesize3cha(FLAGS.resolution)
    classes = FLAGS.classes

    # Initialize the model with loaded weights
    try:
        thundernet = Thundernet(
            input_shape=input_shape, resnet_trainable=False, n_classes=classes
        )
        model = thundernet.model
    except ValueError:
        if model == "ppm":
            Thundernet = Thundernet_original
        else:
            Thundernet = Thundernet_ppm

        thundernet = Thundernet(
            input_shape=input_shape, resnet_trainable=False, n_classes=classes
        )
        model = thundernet.model

    thundernet.model.load_weights(FLAGS.model_path)

    # Create dataloader for data
    dataset_dir: Path = Path(Thundernet_config.train_path).parent
    validation_generator: DataGenerator
    _, validation_generator = DataGenerator.create_generators(
        dataset_dir,
        FLAGS.classes,
        training_batch_size=1,
        validation_batch_size=1,
        to_stereo=False,
        transformations=transformations,
        class_mappings=class_mappings,
    )
    # Initilize lists to save data
    iou_aux: list = []
    iou_global: list = []
    durations: list = []

    # Iterate through the generator to get the iou metrics
    for i in tqdm.tqdm(range(len(validation_generator))):

        X, y = validation_generator[i]
        start_t = time.perf_counter()
        pred = model.predict(X)  # Shape: [1, 480, 640, 2]
        duration = time.perf_counter() - start_t
        durations.append(1000 * duration)

        pred = pred[0, :, :, :]  # Shape [480, 640, 2]

        prediction = np.argmax(pred, axis=2)  # Shape [480, 640]

        label = y[0].argmax(axis=-1) * 255

        if show:

            label_RGB = overlap_image_with_label(X, label)
            prediction_RGB = overlap_image_with_label(X, prediction)
            show_x_images(
                images=[label_RGB, prediction_RGB],
                titles=["Real", "Prediction"],
                horizontal=True,
            )

        iou_simple_iou = simple_iou_for_multiple_classes(
            y[0].argmax(axis=-1), prediction, classes
        )
        iou_global.append(iou_simple_iou)
        iou_aux = np.array(iou_global)

        name_image = validation_generator.get_item_name(i)

    for i in range(0, classes + 1):
        if classes <= 3 and i == classes:
            break
        values = iou_aux[:, i]
        values = values[~np.isnan(values)]
        print("IoU for class=", i, "is ", np.mean(values))

    durations = np.array(durations)

    print("")
    print("INFERENCE TIME")
    print(f" - Mean: {np.mean(durations)}")
    print(f" - Std: {np.std(durations)}")

    if baseline_duration:
        durations_baseline = np.load(Path(baseline_duration).open("rb"))
        diff_durations = durations - durations_baseline
        print("INFERENCE TIME WITH RESPECT TO BASELINE (ABSOLUTE)")
        print(f" - Mean: {np.mean(diff_durations)}")
        print(f" - Std: {np.std(diff_durations)}")
        increase_durations = (durations - durations_baseline) / durations_baseline
        print("INFERENCE TIME WITH RESPECT TO BASELINE (RELATIVE)")
        print(f" - Mean: {np.mean(increase_durations)}")
        print(f" - Std: {np.std(increase_durations)}")


if __name__ == "__main__":
    main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
    # main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}), show=False)