Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| from argparse import ArgumentParser | |
| from functools import partial | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from mmcv.onnx import register_extra_symbolics | |
| from mmcv.parallel import collate | |
| from mmdet.datasets import replace_ImageToTensor | |
| from mmdet.datasets.pipelines import Compose | |
| from torch import nn | |
| from mmocr.apis import init_detector | |
| from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer | |
| from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 | |
| from mmocr.utils import is_2dlist | |
| def _convert_batchnorm(module): | |
| module_output = module | |
| if isinstance(module, torch.nn.SyncBatchNorm): | |
| module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, | |
| module.momentum, module.affine, | |
| module.track_running_stats) | |
| if module.affine: | |
| module_output.weight.data = module.weight.data.clone().detach() | |
| module_output.bias.data = module.bias.data.clone().detach() | |
| # keep requires_grad unchanged | |
| module_output.weight.requires_grad = module.weight.requires_grad | |
| module_output.bias.requires_grad = module.bias.requires_grad | |
| module_output.running_mean = module.running_mean | |
| module_output.running_var = module.running_var | |
| module_output.num_batches_tracked = module.num_batches_tracked | |
| for name, child in module.named_children(): | |
| module_output.add_module(name, _convert_batchnorm(child)) | |
| del module | |
| return module_output | |
| def _prepare_data(cfg, imgs): | |
| """Inference image(s) with the detector. | |
| Args: | |
| model (nn.Module): The loaded detector. | |
| imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): | |
| Either image files or loaded images. | |
| Returns: | |
| result (dict): Predicted results. | |
| """ | |
| if isinstance(imgs, (list, tuple)): | |
| if not isinstance(imgs[0], (np.ndarray, str)): | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| elif isinstance(imgs, (np.ndarray, str)): | |
| imgs = [imgs] | |
| else: | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| is_ndarray = isinstance(imgs[0], np.ndarray) | |
| if is_ndarray: | |
| cfg = cfg.copy() | |
| # set loading pipeline type | |
| cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray' | |
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |
| test_pipeline = Compose(cfg.data.test.pipeline) | |
| data = [] | |
| for img in imgs: | |
| # prepare data | |
| if is_ndarray: | |
| # directly add img | |
| datum = dict(img=img) | |
| else: | |
| # add information into dict | |
| datum = dict(img_info=dict(filename=img), img_prefix=None) | |
| # build the data pipeline | |
| datum = test_pipeline(datum) | |
| # get tensor from list to stack for batch mode (text detection) | |
| data.append(datum) | |
| if isinstance(data[0]['img'], list) and len(data) > 1: | |
| raise Exception('aug test does not support ' | |
| f'inference with batch size ' | |
| f'{len(data)}') | |
| data = collate(data, samples_per_gpu=len(imgs)) | |
| # process img_metas | |
| if isinstance(data['img_metas'], list): | |
| data['img_metas'] = [ | |
| img_metas.data[0] for img_metas in data['img_metas'] | |
| ] | |
| else: | |
| data['img_metas'] = data['img_metas'].data | |
| if isinstance(data['img'], list): | |
| data['img'] = [img.data for img in data['img']] | |
| if isinstance(data['img'][0], list): | |
| data['img'] = [img[0] for img in data['img']] | |
| else: | |
| data['img'] = data['img'].data | |
| return data | |
| def pytorch2onnx(model: nn.Module, | |
| model_type: str, | |
| img_path: str, | |
| verbose: bool = False, | |
| show: bool = False, | |
| opset_version: int = 11, | |
| output_file: str = 'tmp.onnx', | |
| verify: bool = False, | |
| dynamic_export: bool = False, | |
| device_id: int = 0): | |
| """Export PyTorch model to ONNX model and verify the outputs are same | |
| between PyTorch and ONNX. | |
| Args: | |
| model (nn.Module): PyTorch model we want to export. | |
| model_type (str): Model type, detection or recognition model. | |
| img_path (str): We need to use this input to execute the model. | |
| opset_version (int): The onnx op version. Default: 11. | |
| verbose (bool): Whether print the computation graph. Default: False. | |
| show (bool): Whether visialize final results. Default: False. | |
| output_file (string): The path to where we store the output ONNX model. | |
| Default: `tmp.onnx`. | |
| verify (bool): Whether compare the outputs between PyTorch and ONNX. | |
| Default: False. | |
| dynamic_export (bool): Whether apply dynamic export. | |
| Default: False. | |
| device_id (id): Device id to place model and data. | |
| Default: 0 | |
| """ | |
| device = torch.device(type='cuda', index=device_id) | |
| model.to(device).eval() | |
| _convert_batchnorm(model) | |
| # prepare inputs | |
| mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path) | |
| imgs = mm_inputs.pop('img') | |
| img_metas = mm_inputs.pop('img_metas') | |
| if isinstance(imgs, list): | |
| imgs = imgs[0] | |
| img_list = [img[None, :].to(device) for img in imgs] | |
| origin_forward = model.forward | |
| if (model_type == 'det'): | |
| model.forward = partial( | |
| model.simple_test, img_metas=img_metas, rescale=True) | |
| else: | |
| model.forward = partial( | |
| model.forward, | |
| img_metas=img_metas, | |
| return_loss=False, | |
| rescale=True) | |
| # pytorch has some bug in pytorch1.3, we have to fix it | |
| # by replacing these existing op | |
| register_extra_symbolics(opset_version) | |
| dynamic_axes = None | |
| if dynamic_export and model_type == 'det': | |
| dynamic_axes = { | |
| 'input': { | |
| 0: 'batch', | |
| 2: 'height', | |
| 3: 'width' | |
| }, | |
| 'output': { | |
| 0: 'batch', | |
| 2: 'height', | |
| 3: 'width' | |
| } | |
| } | |
| elif dynamic_export and model_type == 'recog': | |
| dynamic_axes = { | |
| 'input': { | |
| 0: 'batch', | |
| 3: 'width' | |
| }, | |
| 'output': { | |
| 0: 'batch', | |
| 1: 'seq_len', | |
| 2: 'num_classes' | |
| } | |
| } | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| model, (img_list[0], ), | |
| output_file, | |
| input_names=['input'], | |
| output_names=['output'], | |
| export_params=True, | |
| keep_initializers_as_inputs=False, | |
| verbose=verbose, | |
| opset_version=opset_version, | |
| dynamic_axes=dynamic_axes) | |
| print(f'Successfully exported ONNX model: {output_file}') | |
| if verify: | |
| # check by onnx | |
| import onnx | |
| onnx_model = onnx.load(output_file) | |
| onnx.checker.check_model(onnx_model) | |
| scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5) | |
| if dynamic_export: | |
| # scale image for dynamic shape test | |
| img_list = [ | |
| nn.functional.interpolate(_, scale_factor=scale_factor) | |
| for _ in img_list | |
| ] | |
| if model_type == 'det': | |
| img_metas[0][0][ | |
| 'scale_factor'] = img_metas[0][0]['scale_factor'] * ( | |
| scale_factor * 2) | |
| # check the numerical value | |
| # get pytorch output | |
| with torch.no_grad(): | |
| model.forward = origin_forward | |
| pytorch_out = model.simple_test( | |
| img_list[0], img_metas[0], rescale=True) | |
| # get onnx output | |
| if model_type == 'det': | |
| onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id) | |
| else: | |
| onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg, | |
| device_id) | |
| onnx_out = onnx_model.simple_test( | |
| img_list[0], img_metas[0], rescale=True) | |
| # compare results | |
| same_diff = 'same' | |
| if model_type == 'recog': | |
| for onnx_result, pytorch_result in zip(onnx_out, pytorch_out): | |
| if onnx_result['text'] != pytorch_result[ | |
| 'text'] or not np.allclose( | |
| np.array(onnx_result['score']), | |
| np.array(pytorch_result['score']), | |
| rtol=1e-4, | |
| atol=1e-4): | |
| same_diff = 'different' | |
| break | |
| else: | |
| for onnx_result, pytorch_result in zip( | |
| onnx_out[0]['boundary_result'], | |
| pytorch_out[0]['boundary_result']): | |
| if not np.allclose( | |
| np.array(onnx_result), | |
| np.array(pytorch_result), | |
| rtol=1e-4, | |
| atol=1e-4): | |
| same_diff = 'different' | |
| break | |
| print('The outputs are {} between PyTorch and ONNX'.format(same_diff)) | |
| if show: | |
| onnx_img = onnx_model.show_result( | |
| img_path, onnx_out[0], out_file='onnx.jpg', show=False) | |
| pytorch_img = model.show_result( | |
| img_path, pytorch_out[0], out_file='pytorch.jpg', show=False) | |
| if onnx_img is None: | |
| onnx_img = cv2.imread(img_path) | |
| if pytorch_img is None: | |
| pytorch_img = cv2.imread(img_path) | |
| cv2.imshow('PyTorch', pytorch_img) | |
| cv2.imshow('ONNXRuntime', onnx_img) | |
| cv2.waitKey() | |
| return | |
| def main(): | |
| parser = ArgumentParser( | |
| description='Convert MMOCR models from pytorch to ONNX') | |
| parser.add_argument('model_config', type=str, help='Config file.') | |
| parser.add_argument( | |
| 'model_ckpt', type=str, help='Checkpint file (local or url).') | |
| parser.add_argument( | |
| 'model_type', | |
| type=str, | |
| help='Detection or recognition model to deploy.', | |
| choices=['recog', 'det']) | |
| parser.add_argument('image_path', type=str, help='Input Image file.') | |
| parser.add_argument( | |
| '--output-file', | |
| type=str, | |
| help='Output file name of the onnx model.', | |
| default='tmp.onnx') | |
| parser.add_argument( | |
| '--device-id', default=0, help='Device used for inference.') | |
| parser.add_argument( | |
| '--opset-version', | |
| type=int, | |
| help='ONNX opset version, default to 11.', | |
| default=11) | |
| parser.add_argument( | |
| '--verify', | |
| action='store_true', | |
| help='Whether verify the outputs of onnx and pytorch are same.', | |
| default=False) | |
| parser.add_argument( | |
| '--verbose', | |
| action='store_true', | |
| help='Whether print the computation graph.', | |
| default=False) | |
| parser.add_argument( | |
| '--show', | |
| action='store_true', | |
| help='Whether visualize final output.', | |
| default=False) | |
| parser.add_argument( | |
| '--dynamic-export', | |
| action='store_true', | |
| help='Whether dynamically export onnx model.', | |
| default=False) | |
| args = parser.parse_args() | |
| # Following strings of text style are from colorama package | |
| bright_style, reset_style = '\x1b[1m', '\x1b[0m' | |
| red_text, blue_text = '\x1b[31m', '\x1b[34m' | |
| white_background = '\x1b[107m' | |
| msg = white_background + bright_style + red_text | |
| msg += 'DeprecationWarning: This tool will be deprecated in future. ' | |
| msg += blue_text + 'Welcome to use the unified model deployment toolbox ' | |
| msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' | |
| msg += reset_style | |
| warnings.warn(msg) | |
| device = torch.device(type='cuda', index=args.device_id) | |
| # build model | |
| model = init_detector(args.model_config, args.model_ckpt, device=device) | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| if model.cfg.data.test.get('pipeline', None) is None: | |
| if is_2dlist(model.cfg.data.test.datasets): | |
| model.cfg.data.test.pipeline = \ | |
| model.cfg.data.test.datasets[0][0].pipeline | |
| else: | |
| model.cfg.data.test.pipeline = \ | |
| model.cfg.data.test['datasets'][0].pipeline | |
| if is_2dlist(model.cfg.data.test.pipeline): | |
| model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0] | |
| pytorch2onnx( | |
| model, | |
| model_type=args.model_type, | |
| output_file=args.output_file, | |
| img_path=args.image_path, | |
| opset_version=args.opset_version, | |
| verify=args.verify, | |
| verbose=args.verbose, | |
| show=args.show, | |
| device_id=args.device_id, | |
| dynamic_export=args.dynamic_export) | |
| if __name__ == '__main__': | |
| main() | |