| | |
| | |
| |
|
| | import datetime |
| | import argparse |
| | from rknn.api import RKNN |
| | from sys import exit |
| | import os |
| | import onnxslim |
| |
|
| | num_pointss = [1] |
| | num_labelss = [1] |
| |
|
| | def convert_to_rknn(onnx_model, model_part, dataset="/home/zt/rk3588-nn/rknn_model_zoo/datasets/COCO/coco_subset_20.txt", quantize=False): |
| | """转换单个ONNX模型到RKNN格式""" |
| | rknn_model = onnx_model.replace(".onnx",".rknn") |
| | timedate_iso = datetime.datetime.now().isoformat() |
| | |
| | print(f"\n开始转换 {onnx_model} 到 {rknn_model}") |
| |
|
| | input_shapes = None |
| |
|
| | if model_part == "encoder": |
| | input_shapes = None |
| | elif model_part == "decoder": |
| | input_shapes = [ |
| | [ |
| | [1, 256, 64, 64], |
| | [1, 32, 256, 256], |
| | [1, 64, 128, 128], |
| | [num_labels, num_points, 2], |
| | [num_labels, num_points], |
| | [num_labels, 1, 256, 256], |
| | [num_labels], |
| | ] |
| | for num_labels in num_labelss |
| | for num_points in num_pointss |
| | ] |
| | |
| | rknn = RKNN(verbose=True) |
| | rknn.config( |
| | dynamic_input=input_shapes, |
| | std_values=[[255,255,255]] if model_part == "encoder" else None, |
| | quantized_dtype='w8a8', |
| | quantized_algorithm='normal', |
| | quantized_method='channel', |
| | quantized_hybrid_level=0, |
| | target_platform='rk3588', |
| | quant_img_RGB2BGR = False, |
| | float_dtype='float16', |
| | optimization_level=3, |
| | custom_string=f"converted at {timedate_iso}", |
| | remove_weight=False, |
| | compress_weight=False, |
| | inputs_yuv_fmt=None, |
| | single_core_mode=False, |
| | model_pruning=False, |
| | op_target=None, |
| | quantize_weight=False, |
| | remove_reshape=False, |
| | sparse_infer=False, |
| | enable_flash_attention=False, |
| | ) |
| |
|
| | ret = rknn.load_onnx(model=onnx_model) |
| | ret = rknn.build(do_quantization=quantize, dataset=dataset, rknn_batch_size=None) |
| | ret = rknn.export_rknn(rknn_model) |
| | print(f"完成转换 {rknn_model}\n") |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='转换SAM模型从ONNX到RKNN格式') |
| | parser.add_argument('model_name', type=str, help='模型名称,例如: sam2.1_hiera_tiny') |
| | args = parser.parse_args() |
| | |
| | |
| | encoder_onnx = f"{args.model_name}_encoder.onnx" |
| | decoder_onnx = f"{args.model_name}_decoder.onnx" |
| | |
| | |
| | for model in [encoder_onnx, decoder_onnx]: |
| | if not os.path.exists(model): |
| | print(f"错误: 找不到文件 {model}") |
| | exit(1) |
| | |
| | |
| | |
| | print("开始转换encoder...") |
| | onnxslim.slim(encoder_onnx, output_model="encoder_slim.onnx", skip_fusion_patterns=["EliminationSlice"]) |
| | convert_to_rknn("encoder_slim.onnx", model_part="encoder") |
| | os.rename("encoder_slim.rknn", encoder_onnx.replace(".onnx", ".rknn")) |
| | os.remove("encoder_slim.onnx") |
| |
|
| | |
| | |
| | print("所有模型转换完成!") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|