stm32-modelzoo-app / common /quantization /onnx_quantizer.py
FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022 STMicroelectronics.
# * All rights reserved.
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
from datetime import datetime
import glob
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
import logging
import numpy as np
import onnx
import onnxruntime
from onnx import version_converter, checker
from onnxruntime import quantization
from onnxruntime.quantization import (CalibrationDataReader, CalibrationMethod,
QuantFormat, QuantType, quantize_static)
import os
from typing import List
from .quant_utils import get_weights_activations_quant_type, get_calibration_method
def _update_opset(input_model, target_opset, export_dir):
"""
updates the opset of an onnx model
inputs:
input_model: path of the input model.
target_opset: the target opset model is to be updated to.
"""
# ir_version in function of opset
ir_version_dict = {21: 10,
20: 9,
19: 9,
18: 8,
17: 8,
16: 8,
15: 8
}
if not str(input_model).endswith('.onnx'):
raise TypeError("Error! The model file must be of onnx format!")
model = onnx.load(input_model)
# Check the current opset version
current_opset = model.opset_import[0].version
if current_opset >= target_opset:
print(f"[INFO] : The model is already using opset {current_opset} >= {target_opset}")
return input_model
# Modify the opset version in the model
converted_model = version_converter.convert_version(model, target_opset)
# Potentially change ir_version
print(f"[INFO] : Model current IR version: {converted_model.ir_version}")
if target_opset >= 15 and converted_model.ir_version != ir_version_dict[target_opset]:
converted_model.ir_version = ir_version_dict[target_opset]
print(f"[INFO] : Update model IR version to {converted_model.ir_version} for compatibility with target opset "
f"{target_opset}")
# check if the obtained model is valid
try:
checker.check_model(converted_model)
except checker.ValidationError as e:
print(f"[ERROR] : The model is invalid. {e}")
opset_model = f'{export_dir}/{os.path.basename(input_model)}'[:-5] + f'_opset{target_opset}.onnx'
onnx.save(converted_model, opset_model)
# Load the modified model using ONNX Runtime Check if the model is valid
session = onnxruntime.InferenceSession(opset_model)
try:
session.get_inputs()
except Exception as e:
print(f"[ERROR] : An error occurred while loading the modified model: {e}")
return
# Replace the original model file with the modified model
print(f"[INFO] : The model has been converted to opset {target_opset}, IR {converted_model.ir_version} and saved "
f"at the same location.")
return opset_model
def _preprocess_random_images(height: int, width: int, channel: int, size_limit=10):
"""
Loads a batch of images and preprocess them
parameter height: image height in pixels
parameter width: image width in pixels
parameter size_limit: number of images to load. Default is 100
return: list of matrices characterizing multiple images
"""
unconcatenated_batch_data = []
for i in range(size_limit):
random_vals = np.random.uniform(0, 1, channel*height*width).astype('float32')
random_image = random_vals.reshape(1, channel, height, width)
unconcatenated_batch_data.append(random_image)
batch_data = np.concatenate(np.expand_dims(unconcatenated_batch_data, axis=0), axis=0)
print(f'[INFO] : random dataset with {size_limit} random images is prepared!')
return batch_data
class ImageDataReader(CalibrationDataReader):
'''
ImageDataReader for the calibration during onnx quantization.
The initiation takes as input:
quantization_samples: an np array containing the calibration samples dataset
model_path: path of the model to be quantized
'''
def __init__(self,
quantization_samples,
model_path: str):
# Use inference session to get input shape
session = onnxruntime.InferenceSession(model_path, None)
(_, channel, height, width) = session.get_inputs()[0].shape
# Convert image to input data
if quantization_samples is not None:
self.nhwc_data_list = np.expand_dims(quantization_samples, axis=1)
else:
self.nhwc_data_list = _preprocess_random_images(height,
width,
channel)
self.input_name = session.get_inputs()[0].name
self.datasize = len(self.nhwc_data_list)
self.enum_data = None # Enumerator for calibration data
def get_next(self):
if self.enum_data is None:
# Create an iterator that generates input dictionaries
# with input name and corresponding data
self.enum_data = iter(
[{self.input_name: nhwc_data} for nhwc_data in self.nhwc_data_list]
)
return next(self.enum_data, None) # Return next item from enumerator
def rewind(self):
self.enum_data = None # Reset the enumeration of calibration dataclass ImageDataReader
def quantize_onnx(configs: DictConfig, model_path=None, quantization_samples=None, model: object = None, extra_options: dict = None):
"""
Quantizes an ONNX model using onnx-runtime.
Args:
configs (DictConfig): Configuration dictionary containing quantization and model settings.
quantization_samples: Calibration/representative dataset as a numpy array (optional).
model (object, optional): Model object with a model_path attribute, or None to use configs.model.model_path.
extra_options (dict, optional): Extra options for ONNX quantizer.
Returns:
onnxruntime.InferenceSession: Quantized model session with model_path attribute.
"""
# model_path = configs.model.model_path
# model_path = model.model_path
granularity = configs.quantization.granularity.lower()
target_opset = configs.quantization.target_opset
output_dir = HydraConfig.get().runtime.output_dir
export_dir = configs.quantization.export_dir
export_dir = output_dir + (f'/{export_dir}' if export_dir else '')
if not os.path.isdir(export_dir):
os.makedirs(export_dir)
if granularity == 'per_tensor':
quant_tag = f'quant_qdq_pt'
elif granularity == 'per_channel':
quant_tag = f'quant_qdq_pc'
else:
raise TypeError('Not a valid quantization_type!\n',
'Only supported options for quantization_type are per_channel, or per_tensor!')
print(f'[INFO] : Quantizing model : {model_path}')
opset_model = _update_opset(input_model=model_path, target_opset=target_opset, export_dir=output_dir)
# set the data reader pointing to the representative dataset
print('[INFO] : Prepare the data reader for the representative dataset...')
dr = ImageDataReader(quantization_samples=quantization_samples, model_path=opset_model)
print('[INFO] : the data reader is ready')
# preprocess the model to infer shapes of each tensor
infer_model = os.path.join(export_dir, f'{os.path.basename(opset_model)[:-5]}_infer.onnx')
quantization.quant_pre_process(input_model_path=opset_model,
output_model_path=infer_model,
skip_optimization=False,
skip_symbolic_shape=True)
# settings for quantization
weight_type, activ_type = get_weights_activations_quant_type(cfg=configs)
calibration_method = get_calibration_method(cfg=configs)
if configs.quantization.onnx_quant_parameters:
op_types_to_quantize = configs.quantization.onnx_quant_parameters.op_types_to_quantize
nodes_to_quantize = configs.quantization.onnx_quant_parameters.nodes_to_quantize
nodes_to_exclude = configs.quantization.onnx_quant_parameters.nodes_to_exclude
else:
op_types_to_quantize = None
nodes_to_quantize = None
nodes_to_exclude = None
# prepare quantized onnx model filename
quant_model = os.path.join(export_dir,f'{os.path.basename(opset_model)[:-5]}_{quant_tag}.onnx')
print(f'[INFO] : Quantizing the model {os.path.basename(model_path)}, please wait...')
quantize_static(
infer_model,
quant_model,
dr,
op_types_to_quantize=op_types_to_quantize,
calibrate_method=calibration_method,
quant_format=QuantFormat.QDQ,
per_channel= granularity == 'per_channel',
weight_type=weight_type,
activation_type=activ_type,
nodes_to_quantize=nodes_to_quantize,
nodes_to_exclude=nodes_to_exclude,
#optimize_model=False,
extra_options=extra_options)
# Load the modified model using ONNX Runtime Check if the model is valid
model = onnxruntime.InferenceSession(quant_model)
try:
model.get_inputs()
except Exception as e:
print(f"[ERROR] : An error occurred while quantizing the model: {e}")
return
print("Quantized model path:", quant_model)
setattr(model, 'model_path', quant_model)
return model