| from functools import partial |
|
|
| import mmcv |
| import numpy as np |
| import torch |
| from mmcv.runner import load_checkpoint |
|
|
|
|
| def generate_inputs_and_wrap_model(config_path, |
| checkpoint_path, |
| input_config, |
| cfg_options=None): |
| """Prepare sample input and wrap model for ONNX export. |
| |
| The ONNX export API only accept args, and all inputs should be |
| torch.Tensor or corresponding types (such as tuple of tensor). |
| So we should call this function before exporting. This function will: |
| |
| 1. generate corresponding inputs which are used to execute the model. |
| 2. Wrap the model's forward function. |
| |
| For example, the MMDet models' forward function has a parameter |
| ``return_loss:bool``. As we want to set it as False while export API |
| supports neither bool type or kwargs. So we have to replace the forward |
| like: ``model.forward = partial(model.forward, return_loss=False)`` |
| |
| Args: |
| config_path (str): the OpenMMLab config for the model we want to |
| export to ONNX |
| checkpoint_path (str): Path to the corresponding checkpoint |
| input_config (dict): the exactly data in this dict depends on the |
| framework. For MMSeg, we can just declare the input shape, |
| and generate the dummy data accordingly. However, for MMDet, |
| we may pass the real img path, or the NMS will return None |
| as there is no legal bbox. |
| |
| Returns: |
| tuple: (model, tensor_data) wrapped model which can be called by \ |
| model(*tensor_data) and a list of inputs which are used to execute \ |
| the model while exporting. |
| """ |
|
|
| model = build_model_from_cfg( |
| config_path, checkpoint_path, cfg_options=cfg_options) |
| one_img, one_meta = preprocess_example_input(input_config) |
| tensor_data = [one_img] |
| model.forward = partial( |
| model.forward, img_metas=[[one_meta]], return_loss=False) |
|
|
| |
| |
| opset_version = 11 |
| |
| |
| try: |
| from mmcv.onnx.symbolic import register_extra_symbolics |
| except ModuleNotFoundError: |
| raise NotImplementedError('please update mmcv to version>=v1.0.4') |
| register_extra_symbolics(opset_version) |
|
|
| return model, tensor_data |
|
|
|
|
| def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None): |
| """Build a model from config and load the given checkpoint. |
| |
| Args: |
| config_path (str): the OpenMMLab config for the model we want to |
| export to ONNX |
| checkpoint_path (str): Path to the corresponding checkpoint |
| |
| Returns: |
| torch.nn.Module: the built model |
| """ |
| from mmdet.models import build_detector |
|
|
| cfg = mmcv.Config.fromfile(config_path) |
| if cfg_options is not None: |
| cfg.merge_from_dict(cfg_options) |
| |
| if cfg.get('custom_imports', None): |
| from mmcv.utils import import_modules_from_strings |
| import_modules_from_strings(**cfg['custom_imports']) |
| |
| if cfg.get('cudnn_benchmark', False): |
| torch.backends.cudnn.benchmark = True |
| cfg.model.pretrained = None |
| cfg.data.test.test_mode = True |
|
|
| |
| cfg.model.train_cfg = None |
| model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) |
| load_checkpoint(model, checkpoint_path, map_location='cpu') |
| model.cpu().eval() |
| return model |
|
|
|
|
| def preprocess_example_input(input_config): |
| """Prepare an example input image for ``generate_inputs_and_wrap_model``. |
| |
| Args: |
| input_config (dict): customized config describing the example input. |
| |
| Returns: |
| tuple: (one_img, one_meta), tensor of the example input image and \ |
| meta information for the example input image. |
| |
| Examples: |
| >>> from mmdet.core.export import preprocess_example_input |
| >>> input_config = { |
| >>> 'input_shape': (1,3,224,224), |
| >>> 'input_path': 'demo/demo.jpg', |
| >>> 'normalize_cfg': { |
| >>> 'mean': (123.675, 116.28, 103.53), |
| >>> 'std': (58.395, 57.12, 57.375) |
| >>> } |
| >>> } |
| >>> one_img, one_meta = preprocess_example_input(input_config) |
| >>> print(one_img.shape) |
| torch.Size([1, 3, 224, 224]) |
| >>> print(one_meta) |
| {'img_shape': (224, 224, 3), |
| 'ori_shape': (224, 224, 3), |
| 'pad_shape': (224, 224, 3), |
| 'filename': '<demo>.png', |
| 'scale_factor': 1.0, |
| 'flip': False} |
| """ |
| input_path = input_config['input_path'] |
| input_shape = input_config['input_shape'] |
| one_img = mmcv.imread(input_path) |
| one_img = mmcv.imresize(one_img, input_shape[2:][::-1]) |
| show_img = one_img.copy() |
| if 'normalize_cfg' in input_config.keys(): |
| normalize_cfg = input_config['normalize_cfg'] |
| mean = np.array(normalize_cfg['mean'], dtype=np.float32) |
| std = np.array(normalize_cfg['std'], dtype=np.float32) |
| to_rgb = normalize_cfg.get('to_rgb', True) |
| one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb) |
| one_img = one_img.transpose(2, 0, 1) |
| one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_( |
| True) |
| (_, C, H, W) = input_shape |
| one_meta = { |
| 'img_shape': (H, W, C), |
| 'ori_shape': (H, W, C), |
| 'pad_shape': (H, W, C), |
| 'filename': '<demo>.png', |
| 'scale_factor': 1.0, |
| 'flip': False, |
| 'show_img': show_img, |
| } |
|
|
| return one_img, one_meta |
|
|