stm32-modelzoo-app / common /deployment /external_memory_mgt.py
FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import os
import json
import sys
from typing import Dict
from common.utils import get_model_name_and_its_input_shape
def update_activation_c_code(c_project_path: str, model_path: str=None, path_network_c_info: str=None, available_AXIRAM: int=0, aspect_ratio = None, custom_objects: Dict = None):
path_main_h=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Inc/CM7/main.h")
path_main_c=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Src/CM7/main.c")
path_ai_interface_h=os.path.join(c_project_path, "Application/STM32H747I-DISCO/Inc/CM7/ai_interface.h")
### Get NN preprocessing buffers size
aspect_ratio = aspect_ratio
_, input_shape = get_model_name_and_its_input_shape(model_path=model_path, custom_objects=custom_objects)
network_height = input_shape[0]
network_width = input_shape[1]
network_channel = input_shape[2]
# Grayscale
if network_channel == 1:
resize_buffer_size = network_height*network_width
# RGB565
if network_channel == 3:
resize_buffer_size = network_height*network_width*2
QVGA_width = 320
QVGA_height = 240
VGA_width = 640
VGA_height = 480
if aspect_ratio == "crop":
if network_width <= QVGA_height and network_height <= QVGA_height:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_HEIGHT"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_HEIGHT"
cam_buffer_height = "QVGA_RES_HEIGHT"
elif network_width <= VGA_height and network_height <= VGA_height:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_HEIGHT"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_HEIGHT"
cam_buffer_height = "VGA_RES_HEIGHT"
elif aspect_ratio == "padding":
if network_width <= QVGA_width and network_height <= QVGA_width:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_WIDTH"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_WIDTH"
cam_buffer_height = "QVGA_RES_WIDTH"
elif network_width <= VGA_width and network_height <= VGA_width:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_WIDTH"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_WIDTH"
cam_buffer_height = "VGA_RES_WIDTH"
else:
if network_width <= QVGA_width and network_height <= QVGA_height:
cam_res = "CAMERA_R320x240"
cam_res_width = "QVGA_RES_WIDTH"
cam_res_height = "QVGA_RES_HEIGHT"
cam_buffer_width = "QVGA_RES_WIDTH"
cam_buffer_height = "QVGA_RES_HEIGHT"
elif network_width <= VGA_width and network_height <= VGA_height:
cam_res = "CAMERA_R640x480"
cam_res_width = "VGA_RES_WIDTH"
cam_res_height = "VGA_RES_HEIGHT"
cam_buffer_width = "VGA_RES_WIDTH"
cam_buffer_height = "VGA_RES_HEIGHT"
if not 'cam_res' in locals():
ValueError("Needed camera resolution ({}x{}) exceeds VGA format. ".format(network_width,network_height))
### Generate main.h
with open(os.path.join(path_main_h), 'r') as f1, open(os.path.join(os.path.dirname(path_main_h), 'main_modify.h'),'w') as f2:
for lineNumber, line in enumerate(f1):
if "#define CAMERA_RESOLUTION" in line:
line = "#define CAMERA_RESOLUTION (" + cam_res + ")\n"
elif "#define CAM_RES_WIDTH" in line:
line = "#define CAM_RES_WIDTH (" + cam_res_width + ")\n"
elif "#define CAM_RES_HEIGHT" in line:
line = "#define CAM_RES_HEIGHT (" + cam_res_height + ")\n"
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_h), 'main_modify.h'), path_main_h)
if cam_buffer_width == "QVGA_RES_WIDTH":
cam_buffer_width = QVGA_width
elif cam_buffer_width == "QVGA_RES_HEIGHT":
cam_buffer_width = QVGA_height
elif cam_buffer_width == "VGA_RES_WIDTH":
cam_buffer_width = VGA_width
else:
cam_buffer_width = VGA_height
if cam_buffer_height == "QVGA_RES_WIDTH":
cam_buffer_height = QVGA_width
elif cam_buffer_height == "QVGA_RES_HEIGHT":
cam_buffer_height = QVGA_height
elif cam_buffer_height == "VGA_RES_WIDTH":
cam_buffer_height = VGA_width
else:
cam_buffer_height = VGA_height
# Grayscale
if network_channel == 1:
cam_buffer_size = cam_buffer_height*cam_buffer_width
# RGB565
if network_channel == 3:
cam_buffer_size = cam_buffer_height*cam_buffer_width*2
### Generate main.c
with open(os.path.join(path_network_c_info), 'r') as f:
graph = json.load(f)
# List activations
activations = []
for element in graph["memory_pools"]:
if element["rights"] == "ACC_WRITE" and element["used_size_bytes"] != 0:
activations.append(element)
# Sort activations by size_bytes
activations = sorted(activations, key=lambda x: x['used_size_bytes'])
writeLine = True
with open(os.path.join(path_main_c), 'r') as f1, open(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'),'w') as f2:
for lineNumber, line in enumerate(f1):
# re.findall(" uint8_t NN_Activation_Buffer[AI_NETWORK_DATA_ACTIVATIONS_COUNT];", line)
if line == " /*** @GENERATED CODE START - DO NOT TOUCH@ ***/\n":
# saveline = line
pool_list_str = []
for i, pool in enumerate(activations):
name_pool = "NN_Activation_Buffer_" + pool["name"] if pool["name"] != "heap_overlay_pool" else "NN_Activation_Buffer_AXIRAM"
line += """__attribute__((section(".""" + name_pool + """")))\n__attribute__ ((aligned (32)))\n"""
line += "static uint8_t " + name_pool + "[AI_ACTIVATION_" + str(i+1) + "_SIZE_BYTES + 32 - (AI_ACTIVATION_" + str(i+1) + "_SIZE_BYTES%32)];\n"
pool_list_str.append(name_pool)
if name_pool == "NN_Activation_Buffer_AXIRAM":
available_AXIRAM = available_AXIRAM - pool['used_size_bytes']
line += "uint8_t* NN_Activation_Buffer[AI_ACTIVATION_BUFFERS_COUNT] = { "
for pool in pool_list_str:
line += pool + ", "
line += "};\n\n"
f2.write(line)
writeLine = False
if line == " /*** @GENERATED CODE STOP - DO NOT TOUCH@ ***/\n":
writeLine = True
if writeLine == True:
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'), path_main_c)
writeLine = True
with open(os.path.join(path_main_c), 'r') as f1, open(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'),'w') as f2:
for lineNumber, line in enumerate(f1):
if """__attribute__((section(".CapturedImage_Buffer""" in line:
if cam_buffer_size < available_AXIRAM:
available_AXIRAM = available_AXIRAM - cam_buffer_size
line = """__attribute__((section(".CapturedImage_Buffer_AXIRAM")))\n"""
else:
line = """__attribute__((section(".CapturedImage_Buffer_SDRAM")))\n"""
if """__attribute__((section(".RescaledImage_Buffer""" in line:
if resize_buffer_size < available_AXIRAM:
available_AXIRAM = available_AXIRAM - resize_buffer_size
line = """__attribute__((section(".RescaledImage_Buffer_AXIRAM")))\n"""
else:
line = """__attribute__((section(".RescaledImage_Buffer_SDRAM")))\n"""
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_main_c), 'main_modify.c'), path_main_c)
### Generate ai_interface.h
input_buffer_name = "serving_default_image_input0_output_array"
input_buffer_activation_buffer_index = 0
with open(os.path.join(path_ai_interface_h), 'r') as f1, open(os.path.join(os.path.dirname(path_ai_interface_h), 'interface_modify.h'),'w') as f2:
for lineNumber, line in enumerate(f1):
# re.findall(" uint8_t NN_Activation_Buffer[AI_NETWORK_DATA_ACTIVATIONS_COUNT];", line)
if line == " /*** @GENERATED CODE START - DO NOT TOUCH@ ***/\n":
line += "#define AI_NETWORK_INPUTS_IN_ACTIVATIONS_INDEX " + str(input_buffer_activation_buffer_index)
line += "\n#define AI_NETWORK_INPUTS_IN_ACTIVATIONS_SIZE AI_ACTIVATION_"+ str(input_buffer_activation_buffer_index+1) + "_SIZE_BYTES\n\n"
f2.write(line)
writeLine = False
if line == " /*** @GENERATED CODE STOP - DO NOT TOUCH@ ***/\n":
writeLine = True
if writeLine == True:
f2.write(line)
os.replace(os.path.join(os.path.dirname(path_ai_interface_h), 'interface_modify.h'), path_ai_interface_h)