Thundernet / inference_config.py
ExtendedRealityLab's picture
Add files using upload-large-folder tool
ae29340 verified
import time
import sys
from data_gen import DataGenerator
from model.model import Thundernet as Thundernet_original
from model.model_ppm_factors import Thundernet as Thundernet_ppm
from collections import defaultdict
import thundernet_config as Thundernet_config
import numpy as np
import argparse
from glob import glob
from utils import resolution2framesize3cha, simple_iou_for_multiple_classes, image_test
import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from images_toolkit import show_two_images, overlap_image_with_label, show_x_images
# Example command: python inference_config.py --model_path C:/Users/user/Documents/Thundernet/pruebas_modelos/32_ppm/BS4_lossBCE_weights_lr_0.00013713842558297858_reg-1.1743577101671763e-05-ep-13-val_loss0.11463435739278793-train_loss0.053004469722509384-val_iou0.8959722518920898-train_iou0.9606077075004578.hdf5 --classes 2
baseline_duration = None
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
default=Thundernet_config.model_weights,
help="Base directory for the hdf5 model, they are usually stored is /home/user/nas/deep_experiments/",
)
parser.add_argument(
"--classes", type=int, default=Thundernet_config.classes, help="Number of classes. "
)
parser.add_argument(
"--resolution",
type=str,
default=Thundernet_config.resolution,
help="Input Resolution",
)
def main(
args: list,
model: str = "original",
class_mappings: dict = None,
transformations: tuple = tuple(),
show: bool = False,
) -> None:
"""
Perform inference in a set of images. If show=True, each prediction
will be shown in the screen.
Args:
- args (list): list of parsed arguments
- model (str): type of model. Default: "original"
- class_mappings (dict): class mapper. Default: None
- transformations (tuple): list of transformations to execute in the data. Default: tuple()
- show (bool): display the predictions. Default: False
Returns:
- None
"""
FLAGS: argparse.Namespace = parser.parse_args(args)
# Get the model
if model == "original":
Thundernet = Thundernet_original
elif model == "ppm":
Thundernet = Thundernet_ppm
else:
raise ValueError(f"Unknown model: {model}")
# Set class mapping
if class_mappings is not None:
FLAGS.classes = len(set(class_mappings.values())) + 1
# Get the shape and the classes
input_shape = resolution2framesize3cha(FLAGS.resolution)
classes = FLAGS.classes
# Initialize the model with loaded weights
try:
thundernet = Thundernet(
input_shape=input_shape, resnet_trainable=False, n_classes=classes
)
model = thundernet.model
except ValueError:
if model == "ppm":
Thundernet = Thundernet_original
else:
Thundernet = Thundernet_ppm
thundernet = Thundernet(
input_shape=input_shape, resnet_trainable=False, n_classes=classes
)
model = thundernet.model
thundernet.model.load_weights(FLAGS.model_path)
# Create dataloader for data
dataset_dir: Path = Path(Thundernet_config.train_path).parent
validation_generator: DataGenerator
_, validation_generator = DataGenerator.create_generators(
dataset_dir,
FLAGS.classes,
training_batch_size=1,
validation_batch_size=1,
to_stereo=False,
transformations=transformations,
class_mappings=class_mappings,
)
# Initilize lists to save data
iou_aux: list = []
iou_global: list = []
durations: list = []
# Iterate through the generator to get the iou metrics
for i in tqdm.tqdm(range(len(validation_generator))):
X, y = validation_generator[i]
start_t = time.perf_counter()
pred = model.predict(X) # Shape: [1, 480, 640, 2]
duration = time.perf_counter() - start_t
durations.append(1000 * duration)
pred = pred[0, :, :, :] # Shape [480, 640, 2]
prediction = np.argmax(pred, axis=2) # Shape [480, 640]
label = y[0].argmax(axis=-1) * 255
if show:
label_RGB = overlap_image_with_label(X, label)
prediction_RGB = overlap_image_with_label(X, prediction)
show_x_images(
images=[label_RGB, prediction_RGB],
titles=["Real", "Prediction"],
horizontal=True,
)
iou_simple_iou = simple_iou_for_multiple_classes(
y[0].argmax(axis=-1), prediction, classes
)
iou_global.append(iou_simple_iou)
iou_aux = np.array(iou_global)
name_image = validation_generator.get_item_name(i)
for i in range(0, classes + 1):
if classes <= 3 and i == classes:
break
values = iou_aux[:, i]
values = values[~np.isnan(values)]
print("IoU for class=", i, "is ", np.mean(values))
durations = np.array(durations)
print("")
print("INFERENCE TIME")
print(f" - Mean: {np.mean(durations)}")
print(f" - Std: {np.std(durations)}")
if baseline_duration:
durations_baseline = np.load(Path(baseline_duration).open("rb"))
diff_durations = durations - durations_baseline
print("INFERENCE TIME WITH RESPECT TO BASELINE (ABSOLUTE)")
print(f" - Mean: {np.mean(diff_durations)}")
print(f" - Std: {np.std(diff_durations)}")
increase_durations = (durations - durations_baseline) / durations_baseline
print("INFERENCE TIME WITH RESPECT TO BASELINE (RELATIVE)")
print(f" - Mean: {np.mean(increase_durations)}")
print(f" - Std: {np.std(increase_durations)}")
if __name__ == "__main__":
main(sys.argv[1:], model="ppm", class_mappings=defaultdict(int, {1: 1}))
# main(sys.argv[1:], model="original", class_mappings=defaultdict(int, {1: 1}), show=False)