| import argparse |
| import os |
| import sys |
| import warnings |
| from io import BytesIO |
| from pathlib import Path |
|
|
| import onnx |
| import torch |
| from mmdet.apis import init_detector |
| from mmengine.config import ConfigDict |
| from mmengine.logging import print_log |
| from mmengine.utils.path import mkdir_or_exist |
|
|
| |
| sys.path.append(str(Path(__file__).resolve().parents[3])) |
| from projects.easydeploy.model import DeployModel, MMYOLOBackend |
|
|
| warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) |
| warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) |
| warnings.filterwarnings(action='ignore', category=UserWarning) |
| warnings.filterwarnings(action='ignore', category=FutureWarning) |
| warnings.filterwarnings(action='ignore', category=ResourceWarning) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('config', help='Config file') |
| parser.add_argument('checkpoint', help='Checkpoint file') |
| parser.add_argument( |
| '--model-only', action='store_true', help='Export model only') |
| parser.add_argument( |
| '--work-dir', default='./work_dir', help='Path to save export model') |
| parser.add_argument( |
| '--img-size', |
| nargs='+', |
| type=int, |
| default=[640, 640], |
| help='Image size of height and width') |
| parser.add_argument('--batch-size', type=int, default=1, help='Batch size') |
| parser.add_argument( |
| '--device', default='cuda:0', help='Device used for inference') |
| parser.add_argument( |
| '--simplify', |
| action='store_true', |
| help='Simplify onnx model by onnx-sim') |
| parser.add_argument( |
| '--opset', type=int, default=11, help='ONNX opset version') |
| parser.add_argument( |
| '--backend', |
| type=str, |
| default='onnxruntime', |
| help='Backend for export onnx') |
| parser.add_argument( |
| '--pre-topk', |
| type=int, |
| default=1000, |
| help='Postprocess pre topk bboxes feed into NMS') |
| parser.add_argument( |
| '--keep-topk', |
| type=int, |
| default=100, |
| help='Postprocess keep topk bboxes out of NMS') |
| parser.add_argument( |
| '--iou-threshold', |
| type=float, |
| default=0.65, |
| help='IoU threshold for NMS') |
| parser.add_argument( |
| '--score-threshold', |
| type=float, |
| default=0.25, |
| help='Score threshold for NMS') |
| args = parser.parse_args() |
| args.img_size *= 2 if len(args.img_size) == 1 else 1 |
| return args |
|
|
|
|
| def build_model_from_cfg(config_path, checkpoint_path, device): |
| model = init_detector(config_path, checkpoint_path, device=device) |
| model.eval() |
| return model |
|
|
|
|
| def main(): |
| args = parse_args() |
| mkdir_or_exist(args.work_dir) |
| backend = MMYOLOBackend(args.backend.lower()) |
| if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO, |
| MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): |
| if not args.model_only: |
| print_log('Export ONNX with bbox decoder and NMS ...') |
| else: |
| args.model_only = True |
| print_log(f'Can not export postprocess for {args.backend.lower()}.\n' |
| f'Set "args.model_only=True" default.') |
| if args.model_only: |
| postprocess_cfg = None |
| output_names = None |
| else: |
| postprocess_cfg = ConfigDict( |
| pre_top_k=args.pre_topk, |
| keep_top_k=args.keep_topk, |
| iou_threshold=args.iou_threshold, |
| score_threshold=args.score_threshold) |
| output_names = ['num_dets', 'boxes', 'scores', 'labels'] |
| baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) |
|
|
| deploy_model = DeployModel( |
| baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg) |
| deploy_model.eval() |
|
|
| fake_input = torch.randn(args.batch_size, 3, |
| *args.img_size).to(args.device) |
| |
| deploy_model(fake_input) |
|
|
| save_onnx_path = os.path.join( |
| args.work_dir, |
| os.path.basename(args.checkpoint).replace('pth', 'onnx')) |
| |
| with BytesIO() as f: |
| torch.onnx.export( |
| deploy_model, |
| fake_input, |
| f, |
| input_names=['images'], |
| output_names=output_names, |
| opset_version=args.opset) |
| f.seek(0) |
| onnx_model = onnx.load(f) |
| onnx.checker.check_model(onnx_model) |
|
|
| |
| if not args.model_only and backend in (MMYOLOBackend.TENSORRT8, |
| MMYOLOBackend.TENSORRT7): |
| shapes = [ |
| args.batch_size, 1, args.batch_size, args.keep_topk, 4, |
| args.batch_size, args.keep_topk, args.batch_size, |
| args.keep_topk |
| ] |
| for i in onnx_model.graph.output: |
| for j in i.type.tensor_type.shape.dim: |
| j.dim_param = str(shapes.pop(0)) |
| if args.simplify: |
| try: |
| import onnxsim |
| onnx_model, check = onnxsim.simplify(onnx_model) |
| assert check, 'assert check failed' |
| except Exception as e: |
| print_log(f'Simplify failure: {e}') |
| onnx.save(onnx_model, save_onnx_path) |
| print_log(f'ONNX export success, save into {save_onnx_path}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|