| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import ssl |
|
|
| ssl._create_default_https_context = ssl._create_unverified_context |
|
|
| import os |
| import numpy as np |
| import tensorflow as tf |
| from tensorflow.keras import backend |
| from typing import Tuple, Dict, Optional |
| from hydra.core.hydra_config import HydraConfig |
| from omegaconf import OmegaConf, DictConfig, open_dict |
| from common.utils import aspect_ratio_dict, color_mode_n6_dict |
| import onnxruntime |
|
|
| def gen_h_user_file_h7(config: DictConfig = None, quantized_model_path: str = None, board: str = None) -> None: |
| """ |
| Generates a C header file containing user configuration for the AI model. |
| |
| Args: |
| config: A configuration object containing user configuration for the AI model. |
| quantized_model_path: The path to the quantized model file. |
| board: the name of the board |
| """ |
|
|
| class Flags: |
| def __init__(self, **entries): |
| self.__dict__.update(entries) |
|
|
| params = Flags(**config) |
| class_names = params.dataset.class_names |
| |
|
|
| classes = '{\\\n' |
| for i, x in enumerate(params.dataset.class_names): |
| if i == (len(class_names) - 1): |
| classes = classes + ' "' + str(x) + '"' + '}\\' |
| else: |
| classes = classes + ' "' + str(x) + '"' + ' ,' + ('\\\n' if (i % 5 == 0 and i != 0) else '') |
|
|
| |
| interpreter_quant = tf.lite.Interpreter(model_path=quantized_model_path) |
| input_details = interpreter_quant.get_input_details()[0] |
| output_details = interpreter_quant.get_output_details()[0] |
| input_shape = input_details['shape'] |
| path = os.path.join(HydraConfig.get().runtime.output_dir, "C_header/") |
| try: |
| os.mkdir(path) |
| except OSError as error: |
| print(error) |
|
|
| with open(os.path.join(path, "ai_model_config.h"), "wt") as f: |
| f.write("/**\n") |
| f.write(" ******************************************************************************\n") |
| f.write(" * @file ai_model_config.h\n") |
| f.write(" * @author Artificial Intelligence Solutions group (AIS)\n") |
| f.write(" * @brief User header file for Preprocessing configuration\n") |
| f.write(" ******************************************************************************\n") |
| f.write(" * @attention\n") |
| f.write(" *\n") |
| f.write(" * Copyright (c) 2024 STMicroelectronics.\n") |
| f.write(" * All rights reserved.\n") |
| f.write(" *\n") |
| f.write(" * This software is licensed under terms that can be found in the LICENSE file in\n") |
| f.write(" * the root directory of this software component.\n") |
| f.write(" * If no LICENSE file comes with this software, it is provided AS-IS.\n") |
| f.write(" *\n") |
| f.write(" ******************************************************************************\n") |
| f.write(" */\n\n") |
| f.write("/* --------------- Generated code ----------------- */\n") |
| f.write("#ifndef __AI_MODEL_CONFIG_H__\n") |
| f.write("#define __AI_MODEL_CONFIG_H__\n\n\n") |
| f.write("/* I/O configuration */\n") |
| f.write("#define NB_CLASSES ({})\n".format(len(class_names))) |
| f.write("#define INPUT_HEIGHT ({})\n".format(int(input_shape[1]))) |
| f.write("#define INPUT_WIDTH ({})\n".format(int(input_shape[2]))) |
| f.write("#define INPUT_CHANNELS ({})\n".format(int(input_shape[3]))) |
| f.write("\n") |
| f.write("/* Classes */\n") |
| f.write("#define CLASSES_TABLE const char* classes_table[NB_CLASSES] = {}\n".format(classes)) |
| f.write("\n\n") |
| f.write("/***** Preprocessing configuration *****/\n\n") |
| f.write("/* Aspect Ratio configuration */\n") |
| f.write("#define ASPECT_RATIO_FIT (1)\n") |
| f.write("#define ASPECT_RATIO_CROP (2)\n") |
| f.write("#define ASPECT_RATIO_PADDING (3)\n\n") |
| f.write("#define ASPECT_RATIO_MODE {}\n".format( |
| aspect_ratio_dict[params.preprocessing.resizing.aspect_ratio])) |
| f.write("\n") |
| f.write("/* Input color format configuration */\n") |
| yaml_opt = ["rgb", "bgr", "grayscale"] |
| opt = ["RGB_FORMAT", "BGR_FORMAT", "GRAYSCALE_FORMAT"] |
| f.write("#define RGB_FORMAT (1)\n") |
| f.write("#define BGR_FORMAT (2)\n") |
| f.write("#define GRAYSCALE_FORMAT (3)\n") |
| f.write("\n") |
| f.write("#define PP_COLOR_MODE {}\n".format(opt[yaml_opt.index(params.preprocessing.color_mode)])) |
| f.write("\n") |
| f.write("/* Input/Output quantization configuration */\n") |
| opt = ["UINT8_FORMAT", "INT8_FORMAT", "FLOAT32_FORMAT"] |
| f.write("#define UINT8_FORMAT (1)\n") |
| f.write("#define INT8_FORMAT (2)\n") |
| f.write("#define FLOAT32_FORMAT (3)\n") |
| f.write("\n") |
| f.write("#define QUANT_INPUT_TYPE {}\n".format( |
| opt[[np.uint8, np.int8, np.float32].index(input_details['dtype'])])) |
| f.write("#define QUANT_OUTPUT_TYPE {}\n".format( |
| opt[[np.uint8, np.int8, np.float32].index(output_details['dtype'])])) |
| f.write("\n") |
| if str(board).split(",")[0] == "NUCLEO-H743ZI2": |
| f.write("/* Display configuration */\n") |
| f.write("#define DISPLAY_INTERFACE_USB (1)\n") |
| f.write("#define DISPLAY_INTERFACE_SPI (2)\n") |
| f.write("\n") |
| f.write("#define DISPLAY_INTERFACE {}\n".format( |
| params.deployment.hardware_setup.output)) |
| f.write("\n") |
| f.write("/* Camera configuration */\n") |
| f.write("#define CAMERA_INTERFACE_DCMI (1)\n") |
| f.write("#define CAMERA_INTERFACE_USB (2)\n") |
| f.write("#define CAMERA_INTERFACE_SPI (3)\n") |
| f.write("\n") |
| f.write("#define CAMERA_INTERFACE {}\n".format( |
| str(params.deployment.hardware_setup.input))) |
| f.write("\n") |
| f.write("/* Camera Sensor configuration */\n") |
| f.write("#define CAMERA_SENSOR_OV5640 (1)\n") |
| f.write("\n") |
| f.write("#define CAMERA_SENSOR CAMERA_SENSOR_OV5640\n") |
| f.write("\n") |
|
|
| f.write("#endif /* __AI_MODEL_CONFIG_H__ */\n") |
|
|
| |
| if str(board).split(",")[0] == "NUCLEO-H743ZI2": |
| if params.deployment.hardware_setup.output == "DISPLAY_INTERFACE_USB" and params.deployment.hardware_setup.input == "CAMERA_INTERFACE_USB": |
| raise ValueError("\033[31mThere is only one USB port on the Nucleo-H743ZI2 board. You can't select CAMERA_INTERFACE_USB as input and DISPLAY_INTERFACE_USB as output at the same time. \033[39m") |
|
|
| |
| USB_display_str_usb_disp = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp.c</locationURI>\n"] |
| USB_display_str_usb_disp_desc = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp_desc.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp_desc.c</locationURI>\n"] |
| USB_display_str_usb_disp_format = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp_format.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp_format.c</locationURI>\n"] |
| USB_display_str_usbd_conf = ["\t\t\t<name>Middlewares/STM32_USB_Display/usbd_conf.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usbd_conf.c</locationURI>\n"] |
|
|
| |
| USB_camera_str_nucleo_h743zi2_camera_usb = ["\t\t\t<name>Drivers/BSP/NUCLEO_H743ZI2/nucleo_h743zi2_camera_usb.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Drivers/BSP/NUCLEO-H743ZI2/nucleo_h743zi2_camera_usb.c</locationURI>\n"] |
|
|
| |
| project_file_link = "\t\t<link>\n" |
| |
| project_file_type = "\t\t\t<type>1</type>\n" |
| |
| project_file_delink = "\t\t</link>\n" |
| |
| project_file_last_lines = "\t</linkedResources>\n</projectDescription>" |
|
|
| |
| with open(os.path.join(params.deployment.c_project_path,'Application/NUCLEO-H743ZI2/STM32CubeIDE/.project'), 'r') as project_file: |
| project_file_data = project_file.read() |
|
|
| |
| project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp[0] + project_file_type + USB_display_str_usb_disp[1] + project_file_delink, '') |
| project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp_desc[0] + project_file_type + USB_display_str_usb_disp_desc[1] + project_file_delink, '') |
| project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp_format[0] + project_file_type + USB_display_str_usb_disp_format[1] + project_file_delink, '') |
| project_file_data = project_file_data.replace(project_file_link + USB_display_str_usbd_conf[0] + project_file_type + USB_display_str_usbd_conf[1] + project_file_delink, '') |
| project_file_data = project_file_data.replace(project_file_link + USB_camera_str_nucleo_h743zi2_camera_usb[0] + project_file_type + USB_camera_str_nucleo_h743zi2_camera_usb[1] + project_file_delink, '') |
|
|
| if params.deployment.hardware_setup.output == "DISPLAY_INTERFACE_USB": |
| |
| project_file_data = project_file_data.replace(project_file_last_lines, \ |
| project_file_link + USB_display_str_usb_disp[0] + project_file_type + USB_display_str_usb_disp[1] + project_file_delink \ |
| + project_file_link + USB_display_str_usb_disp_desc[0] + project_file_type + USB_display_str_usb_disp_desc[1] + project_file_delink \ |
| + project_file_link + USB_display_str_usb_disp_format[0] + project_file_type + USB_display_str_usb_disp_format[1] + project_file_delink \ |
| + project_file_link + USB_display_str_usbd_conf[0] + project_file_type + USB_display_str_usbd_conf[1] + project_file_delink \ |
| + project_file_last_lines) |
|
|
| elif params.deployment.hardware_setup.input == "CAMERA_INTERFACE_USB": |
| |
| project_file_data = project_file_data.replace(project_file_last_lines, \ |
| project_file_link + USB_camera_str_nucleo_h743zi2_camera_usb[0] + project_file_type + USB_camera_str_nucleo_h743zi2_camera_usb[1] + project_file_delink \ |
| + project_file_last_lines) |
|
|
| with open(os.path.join(params.deployment.c_project_path,'Application/NUCLEO-H743ZI2/STM32CubeIDE/.project'), 'w') as project_file: |
| project_file.write(project_file_data) |
|
|
|
|
| def gen_h_user_file_n6(config: DictConfig = None, quantized_model_path: str = None) -> None: |
| """ |
| Generates a C header file containing user configuration for the AI model. |
| |
| Args: |
| config: A configuration object containing user configuration for the AI model. |
| quantized_model_path: The path to the quantized model file. |
| |
| """ |
| class Flags: |
| def __init__(self, **entries): |
| self.__dict__.update(entries) |
|
|
| params = Flags(**config) |
|
|
| if os.path.basename(quantized_model_path).endswith(".tflite"): |
| interpreter_quant = tf.lite.Interpreter(model_path=quantized_model_path) |
| input_details = interpreter_quant.get_input_details()[0] |
| output_details = interpreter_quant.get_output_details()[0] |
| input_shape = input_details['shape'] |
| elif os.path.basename(quantized_model_path).endswith(".onnx"): |
| session = onnxruntime.InferenceSession(quantized_model_path, None) |
| model_input_shape = session.get_inputs()[0].shape |
| input_shape = model_input_shape[-3:] |
|
|
| path = os.path.join(HydraConfig.get().runtime.output_dir, "C_header/") |
|
|
| try: |
| os.mkdir(path) |
| except OSError as error: |
| print(error) |
|
|
| TFLite_Detection_PostProcess_id = False |
|
|
| class_names = params.dataset.class_names |
|
|
| classes = '{\\\n' |
| for i, x in enumerate(params.dataset.class_names): |
| if i == (len(class_names) - 1): |
| classes = classes + ' "' + str(x) + '"' + '}\\' |
| else: |
| classes = classes + ' "' + str(x) + '"' + ' ,' + ('\\\n' if (i % 5 == 0 and i != 0) else '') |
|
|
| with open(os.path.join(path, "app_config.h"), "wt") as f: |
| f.write("/**\n") |
| f.write("******************************************************************************\n") |
| f.write("* @file app_config.h\n") |
| f.write("* @author GPM Application Team\n") |
| f.write("*\n") |
| f.write("******************************************************************************\n") |
| f.write("* @attention\n") |
| f.write("*\n") |
| f.write("* Copyright (c) 2023 STMicroelectronics.\n") |
| f.write("* All rights reserved.\n") |
| f.write("*\n") |
| f.write("* This software is licensed under terms that can be found in the LICENSE file\n") |
| f.write("* in the root directory of this software component.\n") |
| f.write("* If no LICENSE file comes with this software, it is provided AS-IS.\n") |
| f.write("*\n") |
| f.write("******************************************************************************\n") |
| f.write("*/\n\n") |
| f.write("/* --------------- Generated code ----------------- */\n") |
| f.write("#ifndef APP_CONFIG\n") |
| f.write("#define APP_CONFIG\n\n") |
| f.write("#define USE_DCACHE\n\n") |
| f.write('#include "arm_math.h"\n\n') |
| f.write("/*Defines: CMW_MIRRORFLIP_NONE; CMW_MIRRORFLIP_FLIP; CMW_MIRRORFLIP_MIRROR; CMW_MIRRORFLIP_FLIP_MIRROR;*/\n") |
| f.write("#define CAMERA_FLIP CMW_MIRRORFLIP_NONE\n\n") |
| f.write("#define ASPECT_RATIO_CROP (1) /* Crop both pipes to nn input aspect ratio; Original aspect ratio kept */\n") |
| f.write("#define ASPECT_RATIO_FIT (2) /* Resize both pipe to NN input aspect ratio; Original aspect ratio not kept */\n") |
| f.write("#define ASPECT_RATIO_FULLSCREEN (3) /* Resize camera image to NN input size and display a fullscreen image */\n") |
| f.write("#define ASPECT_RATIO_MODE {}\n".format(aspect_ratio_dict[params.preprocessing.resizing.aspect_ratio])) |
| f.write("\n\n") |
| f.write("#define COLOR_BGR (0)\n") |
| f.write("#define COLOR_RGB (1)\n") |
| f.write("#define COLOR_MODE {}\n".format(color_mode_n6_dict[params.preprocessing.color_mode])) |
|
|
| f.write("/* Classes */\n") |
| f.write("#define NB_CLASSES ({})\n".format(len(class_names))) |
| f.write("#define CLASSES_TABLE const char* classes_table[NB_CLASSES] = {}\n\n".format(classes)) |
|
|
| f.write('#define WELCOME_MSG_1 "{}"\n'.format(os.path.basename(params.model.model_path))) |
| if config.deployment.hardware_setup.board == 'NUCLEO-N657X0-Q': |
| f.write('#define WELCOME_MSG_2 ((char *[2]) {"Model Running in STM32 MCU", "internal memory"})') |
| else: |
| f.write('#define WELCOME_MSG_2 "{}"\n'.format("Model Running in STM32 MCU internal memory")) |
| f.write("\n") |
| f.write("#endif /* APP_CONFIG */\n") |
|
|
| return TFLite_Detection_PostProcess_id, quantized_model_path |