File size: 9,596 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import json
import sys
from typing import Dict
from common.utils import get_model_name_and_its_input_shape
def update_activation_c_code(c_project_path: str, model_path: str=None, path_network_c_info: str=None, available_AXIRAM: int=0, aspect_ratio = None, custom_objects: Dict = None):
path_main_h=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Inc/CM7/main.h")
path_main_c=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Src/CM7/main.c")
path_ai_interface_h=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Inc/CM7/ai_interface.h")
### Get NN preprocessing buffers size
aspect_ratio = aspect_ratio
_, input_shape = get_model_name_and_its_input_shape(model_path=model_path, custom_objects=custom_objects)
network_height = input_shape[0]
network_width = input_shape[1]
network_channel = input_shape[2]
# Grayscale
if network_channel == 1:
resize_buffer_size = network_height*network_width
# RGB565
if network_channel == 3:
resize_buffer_size = network_height*network_width*2
QVGA_width = 320
QVGA_height = 240
VGA_width = 640
VGA_height = 480
if aspect_ratio == "crop":
if network_width <= QVGA_height and network_height <= QVGA_height:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_HEIGHT"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_HEIGHT"
cam_buffer_height = "QVGA_RES_HEIGHT"
elif network_width <= VGA_height and network_height <= VGA_height:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_HEIGHT"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_HEIGHT"
cam_buffer_height = "VGA_RES_HEIGHT"
elif aspect_ratio == "padding":
if network_width <= QVGA_width and network_height <= QVGA_width:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_WIDTH"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_WIDTH"
cam_buffer_height = "QVGA_RES_WIDTH"
elif network_width <= VGA_width and network_height <= VGA_width:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_WIDTH"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_WIDTH"
cam_buffer_height = "VGA_RES_WIDTH"
else:
if network_width <= QVGA_width and network_height <= QVGA_height:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_WIDTH"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_WIDTH"
cam_buffer_height = "QVGA_RES_HEIGHT"
elif network_width <= VGA_width and network_height <= VGA_height:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_WIDTH"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_WIDTH"
cam_buffer_height = "VGA_RES_HEIGHT"
if not 'cam_res' in locals():
ValueError("Needed camera resolution ({}x{}) exceeds VGA format. ".format(network_width,network_height))
### Generate main.h
with open(os.path.join(path_main_h), 'r') as f1, open(os.path.join(os.path.dirname(path_main_h), 'main_modify.h'),'w') as f2:
for lineNumber, line in enumerate(f1):
if "#define CAMERA_RESOLUTION" in line:
line = "#define CAMERA_RESOLUTION (" + cam_res + ")\n"
elif "#define CAM_RES_WIDTH" in line:
line = "#define CAM_RES_WIDTH (" + cam_res_width + ")\n"
elif "#define CAM_RES_HEIGHT" in line:
line = "#define CAM_RES_HEIGHT (" + cam_res_height + ")\n"
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_h), 'main_modify.h'), path_main_h)
if cam_buffer_width == "QVGA_RES_WIDTH":
cam_buffer_width = QVGA_width
elif cam_buffer_width == "QVGA_RES_HEIGHT":
cam_buffer_width = QVGA_height
elif cam_buffer_width == "VGA_RES_WIDTH":
cam_buffer_width = VGA_width
else:
cam_buffer_width = VGA_height
if cam_buffer_height == "QVGA_RES_WIDTH":
cam_buffer_height = QVGA_width
elif cam_buffer_height == "QVGA_RES_HEIGHT":
cam_buffer_height = QVGA_height
elif cam_buffer_height == "VGA_RES_WIDTH":
cam_buffer_height = VGA_width
else:
cam_buffer_height = VGA_height
# Grayscale
if network_channel == 1:
cam_buffer_size = cam_buffer_height*cam_buffer_width
# RGB565
if network_channel == 3:
cam_buffer_size = cam_buffer_height*cam_buffer_width*2
### Generate main.c
with open(os.path.join(path_network_c_info), 'r') as f:
graph = json.load(f)
# List activations
activations = []
for element in graph["memory_pools"]:
if element["rights"] == "ACC_WRITE" and element["used_size_bytes"] != 0:
activations.append(element)
# Sort activations by size_bytes
activations = sorted(activations, key=lambda x: x['used_size_bytes'])
writeLine = True
with open(os.path.join(path_main_c), 'r') as f1, open(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'),'w') as f2:
for lineNumber, line in enumerate(f1):
# re.findall(" uint8_t NN_Activation_Buffer[AI_NETWORK_DATA_ACTIVATIONS_COUNT];", line)
if line == " /*** @GENERATED CODE START - DO NOT TOUCH@ ***/\n":
# saveline = line
pool_list_str = []
for i, pool in enumerate(activations):
name_pool = "NN_Activation_Buffer_" + pool["name"] if pool["name"] != "heap_overlay_pool" else "NN_Activation_Buffer_AXIRAM"
line += """__attribute__((section(".""" + name_pool + """")))\n__attribute__ ((aligned (32)))\n"""
line += "static uint8_t " + name_pool + "[AI_ACTIVATION_" + str(i+1) + "_SIZE_BYTES + 32 - (AI_ACTIVATION_" + str(i+1) + "_SIZE_BYTES%32)];\n"
pool_list_str.append(name_pool)
if name_pool == "NN_Activation_Buffer_AXIRAM":
available_AXIRAM = available_AXIRAM - pool['used_size_bytes']
line += "uint8_t* NN_Activation_Buffer[AI_ACTIVATION_BUFFERS_COUNT] = { "
for pool in pool_list_str:
line += pool + ", "
line += "};\n\n"
f2.write(line)
writeLine = False
if line == " /*** @GENERATED CODE STOP - DO NOT TOUCH@ ***/\n":
writeLine = True
if writeLine == True:
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'), path_main_c)
writeLine = True
with open(os.path.join(path_main_c), 'r') as f1, open(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'),'w') as f2:
for lineNumber, line in enumerate(f1):
if """__attribute__((section(".CapturedImage_Buffer""" in line:
if cam_buffer_size < available_AXIRAM:
available_AXIRAM = available_AXIRAM - cam_buffer_size
line = """__attribute__((section(".CapturedImage_Buffer_AXIRAM")))\n"""
else:
line = """__attribute__((section(".CapturedImage_Buffer_SDRAM")))\n"""
if """__attribute__((section(".RescaledImage_Buffer""" in line:
if resize_buffer_size < available_AXIRAM:
available_AXIRAM = available_AXIRAM - resize_buffer_size
line = """__attribute__((section(".RescaledImage_Buffer_AXIRAM")))\n"""
else:
line = """__attribute__((section(".RescaledImage_Buffer_SDRAM")))\n"""
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'), path_main_c)
### Generate ai_interface.h
input_buffer_name = "serving_default_image_input0_output_array"
input_buffer_activation_buffer_index = 0
with open(os.path.join(path_ai_interface_h), 'r') as f1, open(os.path.join(os.path.dirname(path_ai_interface_h), 'interface_modify.h'),'w') as f2:
for lineNumber, line in enumerate(f1):
# re.findall(" uint8_t NN_Activation_Buffer[AI_NETWORK_DATA_ACTIVATIONS_COUNT];", line)
if line == " /*** @GENERATED CODE START - DO NOT TOUCH@ ***/\n":
line += "#define AI_NETWORK_INPUTS_IN_ACTIVATIONS_INDEX " + str(input_buffer_activation_buffer_index)
line += "\n#define AI_NETWORK_INPUTS_IN_ACTIVATIONS_SIZE AI_ACTIVATION_"+ str(input_buffer_activation_buffer_index+1) + "_SIZE_BYTES\n\n"
f2.write(line)
writeLine = False
if line == " /*** @GENERATED CODE STOP - DO NOT TOUCH@ ***/\n":
writeLine = True
if writeLine == True:
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_ai_interface_h), 'interface_modify.h'), path_ai_interface_h) |