Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from functools import partial | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmcv.runner import load_checkpoint | |
| def generate_inputs_and_wrap_model(config_path, | |
| checkpoint_path, | |
| input_config, | |
| cfg_options=None): | |
| """Prepare sample input and wrap model for ONNX export. | |
| The ONNX export API only accept args, and all inputs should be | |
| torch.Tensor or corresponding types (such as tuple of tensor). | |
| So we should call this function before exporting. This function will: | |
| 1. generate corresponding inputs which are used to execute the model. | |
| 2. Wrap the model's forward function. | |
| For example, the MMDet models' forward function has a parameter | |
| ``return_loss:bool``. As we want to set it as False while export API | |
| supports neither bool type or kwargs. So we have to replace the forward | |
| method like ``model.forward = partial(model.forward, return_loss=False)``. | |
| Args: | |
| config_path (str): the OpenMMLab config for the model we want to | |
| export to ONNX | |
| checkpoint_path (str): Path to the corresponding checkpoint | |
| input_config (dict): the exactly data in this dict depends on the | |
| framework. For MMSeg, we can just declare the input shape, | |
| and generate the dummy data accordingly. However, for MMDet, | |
| we may pass the real img path, or the NMS will return None | |
| as there is no legal bbox. | |
| Returns: | |
| tuple: (model, tensor_data) wrapped model which can be called by | |
| ``model(*tensor_data)`` and a list of inputs which are used to | |
| execute the model while exporting. | |
| """ | |
| model = build_model_from_cfg( | |
| config_path, checkpoint_path, cfg_options=cfg_options) | |
| one_img, one_meta = preprocess_example_input(input_config) | |
| tensor_data = [one_img] | |
| model.forward = partial( | |
| model.forward, img_metas=[[one_meta]], return_loss=False) | |
| # pytorch has some bug in pytorch1.3, we have to fix it | |
| # by replacing these existing op | |
| opset_version = 11 | |
| # put the import within the function thus it will not cause import error | |
| # when not using this function | |
| try: | |
| from mmcv.onnx.symbolic import register_extra_symbolics | |
| except ModuleNotFoundError: | |
| raise NotImplementedError('please update mmcv to version>=v1.0.4') | |
| register_extra_symbolics(opset_version) | |
| return model, tensor_data | |
| def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None): | |
| """Build a model from config and load the given checkpoint. | |
| Args: | |
| config_path (str): the OpenMMLab config for the model we want to | |
| export to ONNX | |
| checkpoint_path (str): Path to the corresponding checkpoint | |
| Returns: | |
| torch.nn.Module: the built model | |
| """ | |
| from mmdet.models import build_detector | |
| cfg = mmcv.Config.fromfile(config_path) | |
| if cfg_options is not None: | |
| cfg.merge_from_dict(cfg_options) | |
| # set cudnn_benchmark | |
| if cfg.get('cudnn_benchmark', False): | |
| torch.backends.cudnn.benchmark = True | |
| cfg.model.pretrained = None | |
| cfg.data.test.test_mode = True | |
| # build the model | |
| cfg.model.train_cfg = None | |
| model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |
| checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu') | |
| if 'CLASSES' in checkpoint.get('meta', {}): | |
| model.CLASSES = checkpoint['meta']['CLASSES'] | |
| else: | |
| from mmdet.datasets import DATASETS | |
| dataset = DATASETS.get(cfg.data.test['type']) | |
| assert (dataset is not None) | |
| model.CLASSES = dataset.CLASSES | |
| model.cpu().eval() | |
| return model | |
| def preprocess_example_input(input_config): | |
| """Prepare an example input image for ``generate_inputs_and_wrap_model``. | |
| Args: | |
| input_config (dict): customized config describing the example input. | |
| Returns: | |
| tuple: (one_img, one_meta), tensor of the example input image and \ | |
| meta information for the example input image. | |
| Examples: | |
| >>> from mmdet.core.export import preprocess_example_input | |
| >>> input_config = { | |
| >>> 'input_shape': (1,3,224,224), | |
| >>> 'input_path': 'demo/demo.jpg', | |
| >>> 'normalize_cfg': { | |
| >>> 'mean': (123.675, 116.28, 103.53), | |
| >>> 'std': (58.395, 57.12, 57.375) | |
| >>> } | |
| >>> } | |
| >>> one_img, one_meta = preprocess_example_input(input_config) | |
| >>> print(one_img.shape) | |
| torch.Size([1, 3, 224, 224]) | |
| >>> print(one_meta) | |
| {'img_shape': (224, 224, 3), | |
| 'ori_shape': (224, 224, 3), | |
| 'pad_shape': (224, 224, 3), | |
| 'filename': '<demo>.png', | |
| 'scale_factor': 1.0, | |
| 'flip': False} | |
| """ | |
| input_path = input_config['input_path'] | |
| input_shape = input_config['input_shape'] | |
| one_img = mmcv.imread(input_path) | |
| one_img = mmcv.imresize(one_img, input_shape[2:][::-1]) | |
| show_img = one_img.copy() | |
| if 'normalize_cfg' in input_config.keys(): | |
| normalize_cfg = input_config['normalize_cfg'] | |
| mean = np.array(normalize_cfg['mean'], dtype=np.float32) | |
| std = np.array(normalize_cfg['std'], dtype=np.float32) | |
| to_rgb = normalize_cfg.get('to_rgb', True) | |
| one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb) | |
| one_img = one_img.transpose(2, 0, 1) | |
| one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_( | |
| True) | |
| (_, C, H, W) = input_shape | |
| one_meta = { | |
| 'img_shape': (H, W, C), | |
| 'ori_shape': (H, W, C), | |
| 'pad_shape': (H, W, C), | |
| 'filename': '<demo>.png', | |
| 'scale_factor': np.ones(4, dtype=np.float32), | |
| 'flip': False, | |
| 'show_img': show_img, | |
| 'flip_direction': None | |
| } | |
| return one_img, one_meta | |