File size: 8,427 Bytes
b710267 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# This file is part of OpenCV Zoo project.
# It is subject to the license terms in the LICENSE file found in the same directory.
#
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
# Third party copyrights are property of their respective owners.
import os
import sys
import numpy as np
import cv2 as cv
import onnx
from onnx import version_converter
import onnxruntime
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat, quant_pre_process
from transform import Compose, Resize, CenterCrop, Normalize, ColorConvert, HandAlign
class DataReader(CalibrationDataReader):
def __init__(self, model_path, image_dir, transforms, data_dim):
model = onnx.load(model_path)
self.input_name = model.graph.input[0].name
self.transforms = transforms
self.data_dim = data_dim
self.data = self.get_calibration_data(image_dir)
self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
def get_next(self):
return next(self.enum_data_dicts, None)
def get_calibration_data(self, image_dir):
blobs = []
supported = ["jpg", "png"] # supported file suffix
for image_name in os.listdir(image_dir):
image_name_suffix = image_name.split('.')[-1].lower()
if image_name_suffix not in supported:
continue
img = cv.imread(os.path.join(image_dir, image_name))
img = self.transforms(img)
if img is None:
continue
blob = cv.dnn.blobFromImage(img)
if self.data_dim == 'hwc':
blob = cv.transposeND(blob, [0, 2, 3, 1])
blobs.append(blob)
return blobs
class Quantize:
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8', data_dim='chw', nodes_to_exclude=[]):
self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
self.model_path = model_path
self.calibration_image_dir = calibration_image_dir
self.transforms = transforms
self.per_channel = per_channel
self.act_type = act_type
self.wt_type = wt_type
self.nodes_to_exclude = nodes_to_exclude
# data reader
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
def check_opset(self):
model = onnx.load(self.model_path)
if model.opset_import[0].version != 13:
print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version))
# convert opset version to 13
model_opset13 = version_converter.convert_version(model, 13)
# save converted model
output_name = '{}-opset13.onnx'.format(self.model_path[:-5])
onnx.save_model(model_opset13, output_name)
# update model_path for quantization
return output_name
return self.model_path
def run(self):
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
new_model_path = self.check_opset()
quant_pre_process(new_model_path, new_model_path)
output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type)
quantize_static(new_model_path, output_name, self.dr,
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
per_channel=self.per_channel,
weight_type=self.type_dict[self.wt_type],
activation_type=self.type_dict[self.act_type],
nodes_to_exclude=self.nodes_to_exclude)
if new_model_path != self.model_path:
os.remove(new_model_path)
print('\tQuantized model saved to {}'.format(output_name))
models=dict(
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2023mar.onnx',
calibration_image_dir='../../benchmark/data/face_detection',
transforms=Compose([Resize(size=(160, 120))]),
nodes_to_exclude=['MaxPool_5', 'MaxPool_18', 'MaxPool_25', 'MaxPool_32'],
),
sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
calibration_image_dir='../../benchmark/data/face_recognition',
transforms=Compose([Resize(size=(112, 112))])),
pphumanseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2023mar.onnx',
calibration_image_dir='../../benchmark/data/human_segmentation',
transforms=Compose([Resize(size=(192, 192))])),
ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
calibration_image_dir='../../benchmark/data/image_classification',
transforms=Compose([Resize(size=(224, 224))])),
# TBD: VitTrack
youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
calibration_image_dir='../../benchmark/data/person_reid',
transforms=Compose([Resize(size=(128, 256))])),
ppocrv3det_en=Quantize(model_path='../../models/text_detection_ppocr/text_detection_en_ppocrv3_2023may.onnx',
calibration_image_dir='../../benchmark/data/text',
transforms=Compose([Resize(size=(736, 736)),
Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])])),
ppocrv3det_cn=Quantize(model_path='../../models/text_detection_ppocr/text_detection_cn_ppocrv3_2023may.onnx',
calibration_image_dir='../../benchmark/data/text',
transforms=Compose([Resize(size=(736, 736)),
Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375])])),
crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
calibration_image_dir='../../benchmark/data/text',
transforms=Compose([Resize(size=(100, 32)), Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
calibration_image_dir='../../benchmark/data/text',
transforms=Compose([Resize(size=(100, 32))])),
mp_palmdet=Quantize(model_path='../../models/palm_detection_mediapipe/palm_detection_mediapipe_2023feb.onnx',
calibration_image_dir='path/to/dataset',
transforms=Compose([Resize(size=(192, 192)), Normalize(std=[255, 255, 255]),
ColorConvert(ctype=cv.COLOR_BGR2RGB)]), data_dim='hwc'),
mp_handpose=Quantize(model_path='../../models/handpose_estimation_mediapipe/handpose_estimation_mediapipe_2023feb.onnx',
calibration_image_dir='path/to/dataset',
transforms=Compose([HandAlign("mp_handpose"), Resize(size=(224, 224)), Normalize(std=[255, 255, 255]),
ColorConvert(ctype=cv.COLOR_BGR2RGB)]), data_dim='hwc'),
lpd_yunet=Quantize(model_path='../../models/license_plate_detection_yunet/license_plate_detection_lpd_yunet_2023mar.onnx',
calibration_image_dir='../../benchmark/data/license_plate_detection',
transforms=Compose([Resize(size=(320, 240))]),
nodes_to_exclude=['MaxPool_5', 'MaxPool_18', 'MaxPool_25', 'MaxPool_32', 'MaxPool_39'],
),
)
if __name__ == '__main__':
selected_models = []
for i in range(1, len(sys.argv)):
selected_models.append(sys.argv[i])
if not selected_models:
selected_models = list(models.keys())
print('Models to be quantized: {}'.format(str(selected_models)))
for selected_model_name in selected_models:
q = models[selected_model_name]
q.run()
|