|
|
| import os
|
| import json
|
| import warnings
|
| import argparse
|
| from io import BytesIO
|
|
|
| import onnx
|
| import torch
|
| from mmdet.apis import init_detector
|
| from mmengine.config import ConfigDict
|
| from mmengine.logging import print_log
|
| from mmengine.utils.path import mkdir_or_exist
|
|
|
| from easydeploy.model import DeployModel, MMYOLOBackend
|
|
|
| warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
|
| warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
|
| warnings.filterwarnings(action='ignore', category=UserWarning)
|
| warnings.filterwarnings(action='ignore', category=FutureWarning)
|
| warnings.filterwarnings(action='ignore', category=ResourceWarning)
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('config', help='Config file')
|
| parser.add_argument('checkpoint', help='Checkpoint file')
|
| parser.add_argument('--custom-text',
|
| type=str,
|
| help='custom text inputs (text json) for YOLO-World.')
|
| parser.add_argument('--add-padding',
|
| action="store_true",
|
| help="add an empty padding to texts.")
|
| parser.add_argument('--model-only',
|
| action='store_true',
|
| help='Export model only')
|
| parser.add_argument('--without-nms',
|
| action='store_true',
|
| help='Export model without NMS')
|
| parser.add_argument('--without-bbox-decoder',
|
| action='store_true',
|
| help='Export model without Bbox Decoder (for INT8 Quantization)')
|
| parser.add_argument('--work-dir',
|
| default='./work_dirs',
|
| help='Path to save export model')
|
| parser.add_argument('--img-size',
|
| nargs='+',
|
| type=int,
|
| default=[640, 640],
|
| help='Image size of height and width')
|
| parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
|
| parser.add_argument('--device',
|
| default='cuda:0',
|
| help='Device used for inference')
|
| parser.add_argument('--simplify',
|
| action='store_true',
|
| help='Simplify onnx model by onnx-sim')
|
| parser.add_argument('--opset',
|
| type=int,
|
| default=11,
|
| help='ONNX opset version')
|
| parser.add_argument('--backend',
|
| type=str,
|
| default='onnxruntime',
|
| help='Backend for export onnx')
|
| parser.add_argument('--pre-topk',
|
| type=int,
|
| default=1000,
|
| help='Postprocess pre topk bboxes feed into NMS')
|
| parser.add_argument('--keep-topk',
|
| type=int,
|
| default=100,
|
| help='Postprocess keep topk bboxes out of NMS')
|
| parser.add_argument('--iou-threshold',
|
| type=float,
|
| default=0.65,
|
| help='IoU threshold for NMS')
|
| parser.add_argument('--score-threshold',
|
| type=float,
|
| default=0.25,
|
| help='Score threshold for NMS')
|
| args = parser.parse_args()
|
| args.img_size *= 2 if len(args.img_size) == 1 else 1
|
| return args
|
|
|
|
|
| def build_model_from_cfg(config_path, checkpoint_path, device):
|
| model = init_detector(config_path, checkpoint_path, device=device)
|
| model.eval()
|
| return model
|
|
|
|
|
| def main():
|
| args = parse_args()
|
| mkdir_or_exist(args.work_dir)
|
| backend = MMYOLOBackend(args.backend.lower())
|
| if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
|
| MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
|
| if not args.model_only:
|
| print_log('Export ONNX with bbox decoder and NMS ...')
|
| else:
|
| args.model_only = True
|
| print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
|
| f'Set "args.model_only=True" default.')
|
| if args.model_only:
|
| postprocess_cfg = None
|
| output_names = None
|
| else:
|
| postprocess_cfg = ConfigDict(pre_top_k=args.pre_topk,
|
| keep_top_k=args.keep_topk,
|
| iou_threshold=args.iou_threshold,
|
| score_threshold=args.score_threshold)
|
|
|
| output_names = ['num_dets', 'boxes', 'scores', 'labels']
|
| if args.without_bbox_decoder or args.without_nms:
|
| output_names = ['scores', 'boxes']
|
|
|
| if args.custom_text is not None and len(args.custom_text) > 0:
|
| with open(args.custom_text) as f:
|
| texts = json.load(f)
|
| texts = [x[0] for x in texts]
|
| else:
|
| from mmdet.datasets import CocoDataset
|
| texts = CocoDataset.METAINFO['classes']
|
| if args.add_padding:
|
| texts = texts + [' ']
|
|
|
| baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
|
| if hasattr(baseModel, 'reparameterize'):
|
|
|
| baseModel.reparameterize([texts])
|
| deploy_model = DeployModel(baseModel=baseModel,
|
| backend=backend,
|
| postprocess_cfg=postprocess_cfg,
|
| with_nms=not args.without_nms,
|
| without_bbox_decoder=args.without_bbox_decoder)
|
| deploy_model.eval()
|
|
|
| fake_input = torch.randn(args.batch_size, 3,
|
| *args.img_size).to(args.device)
|
|
|
| deploy_model(fake_input)
|
|
|
| save_onnx_path = os.path.join(
|
| args.work_dir,
|
| os.path.basename(args.checkpoint).replace('pth', 'onnx'))
|
|
|
| with BytesIO() as f:
|
| torch.onnx.export(deploy_model,
|
| fake_input,
|
| f,
|
| input_names=['images'],
|
| output_names=output_names,
|
| opset_version=args.opset)
|
| f.seek(0)
|
| onnx_model = onnx.load(f)
|
| onnx.checker.check_model(onnx_model)
|
|
|
|
|
| if not args.model_only and not args.without_nms and backend in (
|
| MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
|
| shapes = [
|
| args.batch_size, 1, args.batch_size, args.keep_topk, 4,
|
| args.batch_size, args.keep_topk, args.batch_size,
|
| args.keep_topk
|
| ]
|
| for i in onnx_model.graph.output:
|
| for j in i.type.tensor_type.shape.dim:
|
| j.dim_param = str(shapes.pop(0))
|
| if args.simplify:
|
| try:
|
| import onnxsim
|
| onnx_model, check = onnxsim.simplify(onnx_model)
|
| assert check, 'assert check failed'
|
| except Exception as e:
|
| print_log(f'Simplify failure: {e}')
|
| onnx.save(onnx_model, save_onnx_path)
|
| print_log(f'ONNX export success, save into {save_onnx_path}')
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|