File size: 16,525 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend
from typing import Tuple, Dict, Optional
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf, DictConfig, open_dict
from common.utils import aspect_ratio_dict, color_mode_n6_dict
import onnxruntime
def gen_h_user_file_h7(config: DictConfig = None, quantized_model_path: str = None, board: str = None) -> None:
"""
Generates a C header file containing user configuration for the AI model.
Args:
config: A configuration object containing user configuration for the AI model.
quantized_model_path: The path to the quantized model file.
board: the name of the board
"""
class Flags:
def __init__(self, **entries):
self.__dict__.update(entries)
params = Flags(**config)
class_names = params.dataset.class_names
# input_shape = params.deployment.model.input_shape
classes = '{\\\n'
for i, x in enumerate(params.dataset.class_names):
if i == (len(class_names) - 1):
classes = classes + ' "' + str(x) + '"' + '}\\'
else:
classes = classes + ' "' + str(x) + '"' + ' ,' + ('\\\n' if (i % 5 == 0 and i != 0) else '')
# Quantization params
interpreter_quant = tf.lite.Interpreter(model_path=quantized_model_path)
input_details = interpreter_quant.get_input_details()[0]
output_details = interpreter_quant.get_output_details()[0]
input_shape = input_details['shape']
path = os.path.join(HydraConfig.get().runtime.output_dir, "C_header/")
try:
os.mkdir(path)
except OSError as error:
print(error)
with open(os.path.join(path, "ai_model_config.h"), "wt") as f:
f.write("/**\n")
f.write(" ******************************************************************************\n")
f.write(" * @file ai_model_config.h\n")
f.write(" * @author Artificial Intelligence Solutions group (AIS)\n")
f.write(" * @brief User header file for Preprocessing configuration\n")
f.write(" ******************************************************************************\n")
f.write(" * @attention\n")
f.write(" *\n")
f.write(" * Copyright (c) 2024 STMicroelectronics.\n")
f.write(" * All rights reserved.\n")
f.write(" *\n")
f.write(" * This software is licensed under terms that can be found in the LICENSE file in\n")
f.write(" * the root directory of this software component.\n")
f.write(" * If no LICENSE file comes with this software, it is provided AS-IS.\n")
f.write(" *\n")
f.write(" ******************************************************************************\n")
f.write(" */\n\n")
f.write("/* --------------- Generated code ----------------- */\n")
f.write("#ifndef __AI_MODEL_CONFIG_H__\n")
f.write("#define __AI_MODEL_CONFIG_H__\n\n\n")
f.write("/* I/O configuration */\n")
f.write("#define NB_CLASSES ({})\n".format(len(class_names)))
f.write("#define INPUT_HEIGHT ({})\n".format(int(input_shape[1])))
f.write("#define INPUT_WIDTH ({})\n".format(int(input_shape[2])))
f.write("#define INPUT_CHANNELS ({})\n".format(int(input_shape[3])))
f.write("\n")
f.write("/* Classes */\n")
f.write("#define CLASSES_TABLE const char* classes_table[NB_CLASSES] = {}\n".format(classes))
f.write("\n\n")
f.write("/***** Preprocessing configuration *****/\n\n")
f.write("/* Aspect Ratio configuration */\n")
f.write("#define ASPECT_RATIO_FIT (1)\n")
f.write("#define ASPECT_RATIO_CROP (2)\n")
f.write("#define ASPECT_RATIO_PADDING (3)\n\n")
f.write("#define ASPECT_RATIO_MODE {}\n".format(
aspect_ratio_dict[params.preprocessing.resizing.aspect_ratio]))
f.write("\n")
f.write("/* Input color format configuration */\n")
yaml_opt = ["rgb", "bgr", "grayscale"]
opt = ["RGB_FORMAT", "BGR_FORMAT", "GRAYSCALE_FORMAT"]
f.write("#define RGB_FORMAT (1)\n")
f.write("#define BGR_FORMAT (2)\n")
f.write("#define GRAYSCALE_FORMAT (3)\n")
f.write("\n")
f.write("#define PP_COLOR_MODE {}\n".format(opt[yaml_opt.index(params.preprocessing.color_mode)]))
f.write("\n")
f.write("/* Input/Output quantization configuration */\n")
opt = ["UINT8_FORMAT", "INT8_FORMAT", "FLOAT32_FORMAT"]
f.write("#define UINT8_FORMAT (1)\n")
f.write("#define INT8_FORMAT (2)\n")
f.write("#define FLOAT32_FORMAT (3)\n")
f.write("\n")
f.write("#define QUANT_INPUT_TYPE {}\n".format(
opt[[np.uint8, np.int8, np.float32].index(input_details['dtype'])]))
f.write("#define QUANT_OUTPUT_TYPE {}\n".format(
opt[[np.uint8, np.int8, np.float32].index(output_details['dtype'])]))
f.write("\n")
if str(board).split(",")[0] == "NUCLEO-H743ZI2":
f.write("/* Display configuration */\n")
f.write("#define DISPLAY_INTERFACE_USB (1)\n")
f.write("#define DISPLAY_INTERFACE_SPI (2)\n")
f.write("\n")
f.write("#define DISPLAY_INTERFACE {}\n".format(
params.deployment.hardware_setup.output))
f.write("\n")
f.write("/* Camera configuration */\n")
f.write("#define CAMERA_INTERFACE_DCMI (1)\n")
f.write("#define CAMERA_INTERFACE_USB (2)\n")
f.write("#define CAMERA_INTERFACE_SPI (3)\n")
f.write("\n")
f.write("#define CAMERA_INTERFACE {}\n".format(
str(params.deployment.hardware_setup.input)))
f.write("\n")
f.write("/* Camera Sensor configuration */\n")
f.write("#define CAMERA_SENSOR_OV5640 (1)\n")
f.write("\n")
f.write("#define CAMERA_SENSOR CAMERA_SENSOR_OV5640\n")
f.write("\n")
f.write("#endif /* __AI_MODEL_CONFIG_H__ */\n")
# Code do not compile when the USB display files and USB camera files are included at the same time: this code removes the unecessary files
if str(board).split(",")[0] == "NUCLEO-H743ZI2":
if params.deployment.hardware_setup.output == "DISPLAY_INTERFACE_USB" and params.deployment.hardware_setup.input == "CAMERA_INTERFACE_USB":
raise ValueError("\033[31mThere is only one USB port on the Nucleo-H743ZI2 board. You can't select CAMERA_INTERFACE_USB as input and DISPLAY_INTERFACE_USB as output at the same time. \033[39m")
# .project lines for USB display
USB_display_str_usb_disp = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp.c</locationURI>\n"]
USB_display_str_usb_disp_desc = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp_desc.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp_desc.c</locationURI>\n"]
USB_display_str_usb_disp_format = ["\t\t\t<name>Middlewares/STM32_USB_Display/usb_disp_format.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usb_disp_format.c</locationURI>\n"]
USB_display_str_usbd_conf = ["\t\t\t<name>Middlewares/STM32_USB_Display/usbd_conf.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Middlewares/ST/STM32_USB_Display/Src/usbd_conf.c</locationURI>\n"]
# .project lines for USB camera
USB_camera_str_nucleo_h743zi2_camera_usb = ["\t\t\t<name>Drivers/BSP/NUCLEO_H743ZI2/nucleo_h743zi2_camera_usb.c</name>\n", "\t\t\t<locationURI>PARENT-3-PROJECT_LOC/Drivers/BSP/NUCLEO-H743ZI2/nucleo_h743zi2_camera_usb.c</locationURI>\n"]
# .project link
project_file_link = "\t\t<link>\n"
# .project type
project_file_type = "\t\t\t<type>1</type>\n"
# .project delink
project_file_delink = "\t\t</link>\n"
# .project last line
project_file_last_lines = "\t</linkedResources>\n</projectDescription>"
# Update .project file to avoid USB conflict
with open(os.path.join(params.deployment.c_project_path,'Application/NUCLEO-H743ZI2/STM32CubeIDE/.project'), 'r') as project_file:
project_file_data = project_file.read()
# Remove all configuration lines
project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp[0] + project_file_type + USB_display_str_usb_disp[1] + project_file_delink, '')
project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp_desc[0] + project_file_type + USB_display_str_usb_disp_desc[1] + project_file_delink, '')
project_file_data = project_file_data.replace(project_file_link + USB_display_str_usb_disp_format[0] + project_file_type + USB_display_str_usb_disp_format[1] + project_file_delink, '')
project_file_data = project_file_data.replace(project_file_link + USB_display_str_usbd_conf[0] + project_file_type + USB_display_str_usbd_conf[1] + project_file_delink, '')
project_file_data = project_file_data.replace(project_file_link + USB_camera_str_nucleo_h743zi2_camera_usb[0] + project_file_type + USB_camera_str_nucleo_h743zi2_camera_usb[1] + project_file_delink, '')
if params.deployment.hardware_setup.output == "DISPLAY_INTERFACE_USB":
# Write USB display lines
project_file_data = project_file_data.replace(project_file_last_lines, \
project_file_link + USB_display_str_usb_disp[0] + project_file_type + USB_display_str_usb_disp[1] + project_file_delink \
+ project_file_link + USB_display_str_usb_disp_desc[0] + project_file_type + USB_display_str_usb_disp_desc[1] + project_file_delink \
+ project_file_link + USB_display_str_usb_disp_format[0] + project_file_type + USB_display_str_usb_disp_format[1] + project_file_delink \
+ project_file_link + USB_display_str_usbd_conf[0] + project_file_type + USB_display_str_usbd_conf[1] + project_file_delink \
+ project_file_last_lines)
elif params.deployment.hardware_setup.input == "CAMERA_INTERFACE_USB":
# Write USB camera lines
project_file_data = project_file_data.replace(project_file_last_lines, \
project_file_link + USB_camera_str_nucleo_h743zi2_camera_usb[0] + project_file_type + USB_camera_str_nucleo_h743zi2_camera_usb[1] + project_file_delink \
+ project_file_last_lines)
with open(os.path.join(params.deployment.c_project_path,'Application/NUCLEO-H743ZI2/STM32CubeIDE/.project'), 'w') as project_file:
project_file.write(project_file_data)
def gen_h_user_file_n6(config: DictConfig = None, quantized_model_path: str = None) -> None:
"""
Generates a C header file containing user configuration for the AI model.
Args:
config: A configuration object containing user configuration for the AI model.
quantized_model_path: The path to the quantized model file.
"""
class Flags:
def __init__(self, **entries):
self.__dict__.update(entries)
params = Flags(**config)
if os.path.basename(quantized_model_path).endswith(".tflite"):
interpreter_quant = tf.lite.Interpreter(model_path=quantized_model_path)
input_details = interpreter_quant.get_input_details()[0]
output_details = interpreter_quant.get_output_details()[0]
input_shape = input_details['shape']
elif os.path.basename(quantized_model_path).endswith(".onnx"):
session = onnxruntime.InferenceSession(quantized_model_path, None)
model_input_shape = session.get_inputs()[0].shape
input_shape = model_input_shape[-3:]
path = os.path.join(HydraConfig.get().runtime.output_dir, "C_header/")
try:
os.mkdir(path)
except OSError as error:
print(error)
TFLite_Detection_PostProcess_id = False
class_names = params.dataset.class_names
classes = '{\\\n'
for i, x in enumerate(params.dataset.class_names):
if i == (len(class_names) - 1):
classes = classes + ' "' + str(x) + '"' + '}\\'
else:
classes = classes + ' "' + str(x) + '"' + ' ,' + ('\\\n' if (i % 5 == 0 and i != 0) else '')
with open(os.path.join(path, "app_config.h"), "wt") as f:
f.write("/**\n")
f.write("******************************************************************************\n")
f.write("* @file app_config.h\n")
f.write("* @author GPM Application Team\n")
f.write("*\n")
f.write("******************************************************************************\n")
f.write("* @attention\n")
f.write("*\n")
f.write("* Copyright (c) 2023 STMicroelectronics.\n")
f.write("* All rights reserved.\n")
f.write("*\n")
f.write("* This software is licensed under terms that can be found in the LICENSE file\n")
f.write("* in the root directory of this software component.\n")
f.write("* If no LICENSE file comes with this software, it is provided AS-IS.\n")
f.write("*\n")
f.write("******************************************************************************\n")
f.write("*/\n\n")
f.write("/* --------------- Generated code ----------------- */\n")
f.write("#ifndef APP_CONFIG\n")
f.write("#define APP_CONFIG\n\n")
f.write("#define USE_DCACHE\n\n")
f.write('#include "arm_math.h"\n\n')
f.write("/*Defines: CMW_MIRRORFLIP_NONE; CMW_MIRRORFLIP_FLIP; CMW_MIRRORFLIP_MIRROR; CMW_MIRRORFLIP_FLIP_MIRROR;*/\n")
f.write("#define CAMERA_FLIP CMW_MIRRORFLIP_NONE\n\n")
f.write("#define ASPECT_RATIO_CROP (1) /* Crop both pipes to nn input aspect ratio; Original aspect ratio kept */\n")
f.write("#define ASPECT_RATIO_FIT (2) /* Resize both pipe to NN input aspect ratio; Original aspect ratio not kept */\n")
f.write("#define ASPECT_RATIO_FULLSCREEN (3) /* Resize camera image to NN input size and display a fullscreen image */\n")
f.write("#define ASPECT_RATIO_MODE {}\n".format(aspect_ratio_dict[params.preprocessing.resizing.aspect_ratio]))
f.write("\n\n")
f.write("#define COLOR_BGR (0)\n")
f.write("#define COLOR_RGB (1)\n")
f.write("#define COLOR_MODE {}\n".format(color_mode_n6_dict[params.preprocessing.color_mode]))
f.write("/* Classes */\n")
f.write("#define NB_CLASSES ({})\n".format(len(class_names)))
f.write("#define CLASSES_TABLE const char* classes_table[NB_CLASSES] = {}\n\n".format(classes))
f.write('#define WELCOME_MSG_1 "{}"\n'.format(os.path.basename(params.model.model_path)))
if config.deployment.hardware_setup.board == 'NUCLEO-N657X0-Q':
f.write('#define WELCOME_MSG_2 ((char *[2]) {"Model Running in STM32 MCU", "internal memory"})')
else:
f.write('#define WELCOME_MSG_2 "{}"\n'.format("Model Running in STM32 MCU internal memory"))
f.write("\n")
f.write("#endif /* APP_CONFIG */\n")
return TFLite_Detection_PostProcess_id, quantized_model_path |