Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmcv.image import tensor2imgs | |
| from mmcv.parallel import DataContainer | |
| from mmdet.core import encode_mask_results | |
| from .utils import tensor2grayimgs | |
| def retrieve_img_tensor_and_meta(data): | |
| """Retrieval img_tensor, img_metas and img_norm_cfg. | |
| Args: | |
| data (dict): One batch data from data_loader. | |
| Returns: | |
| tuple: Returns (img_tensor, img_metas, img_norm_cfg). | |
| - | img_tensor (Tensor): Input image tensor with shape | |
| :math:`(N, C, H, W)`. | |
| - | img_metas (list[dict]): The metadata of images. | |
| - | img_norm_cfg (dict): Config for image normalization. | |
| """ | |
| if isinstance(data['img'], torch.Tensor): | |
| # for textrecog with batch_size > 1 | |
| # and not use 'DefaultFormatBundle' in pipeline | |
| img_tensor = data['img'] | |
| img_metas = data['img_metas'].data[0] | |
| elif isinstance(data['img'], list): | |
| if isinstance(data['img'][0], torch.Tensor): | |
| # for textrecog with aug_test and batch_size = 1 | |
| img_tensor = data['img'][0] | |
| elif isinstance(data['img'][0], DataContainer): | |
| # for textdet with 'MultiScaleFlipAug' | |
| # and 'DefaultFormatBundle' in pipeline | |
| img_tensor = data['img'][0].data[0] | |
| img_metas = data['img_metas'][0].data[0] | |
| elif isinstance(data['img'], DataContainer): | |
| # for textrecog with 'DefaultFormatBundle' in pipeline | |
| img_tensor = data['img'].data[0] | |
| img_metas = data['img_metas'].data[0] | |
| must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape', 'ori_shape'] | |
| for key in must_keys: | |
| if key not in img_metas[0]: | |
| raise KeyError( | |
| f'Please add {key} to the "meta_keys" in the pipeline') | |
| img_norm_cfg = img_metas[0]['img_norm_cfg'] | |
| if max(img_norm_cfg['mean']) <= 1: | |
| img_norm_cfg['mean'] = [255 * x for x in img_norm_cfg['mean']] | |
| img_norm_cfg['std'] = [255 * x for x in img_norm_cfg['std']] | |
| return img_tensor, img_metas, img_norm_cfg | |
| def single_gpu_test(model, | |
| data_loader, | |
| show=False, | |
| out_dir=None, | |
| is_kie=False, | |
| show_score_thr=0.3): | |
| model.eval() | |
| results = [] | |
| dataset = data_loader.dataset | |
| prog_bar = mmcv.ProgressBar(len(dataset)) | |
| for data in data_loader: | |
| with torch.no_grad(): | |
| result = model(return_loss=False, rescale=True, **data) | |
| batch_size = len(result) | |
| if show or out_dir: | |
| if is_kie: | |
| img_tensor = data['img'].data[0] | |
| if img_tensor.shape[0] != 1: | |
| raise KeyError('Visualizing KIE outputs in batches is' | |
| 'currently not supported.') | |
| gt_bboxes = data['gt_bboxes'].data[0] | |
| img_metas = data['img_metas'].data[0] | |
| must_keys = ['img_norm_cfg', 'ori_filename', 'img_shape'] | |
| for key in must_keys: | |
| if key not in img_metas[0]: | |
| raise KeyError( | |
| f'Please add {key} to the "meta_keys" in config.') | |
| # for no visual model | |
| if np.prod(img_tensor.shape) == 0: | |
| imgs = [] | |
| for img_meta in img_metas: | |
| try: | |
| img = mmcv.imread(img_meta['filename']) | |
| except Exception as e: | |
| print(f'Load image with error: {e}, ' | |
| 'use empty image instead.') | |
| img = np.ones( | |
| img_meta['img_shape'], dtype=np.uint8) | |
| imgs.append(img) | |
| else: | |
| imgs = tensor2imgs(img_tensor, | |
| **img_metas[0]['img_norm_cfg']) | |
| for i, img in enumerate(imgs): | |
| h, w, _ = img_metas[i]['img_shape'] | |
| img_show = img[:h, :w, :] | |
| if out_dir: | |
| out_file = osp.join(out_dir, | |
| img_metas[i]['ori_filename']) | |
| else: | |
| out_file = None | |
| model.module.show_result( | |
| img_show, | |
| result[i], | |
| gt_bboxes[i], | |
| show=show, | |
| out_file=out_file) | |
| else: | |
| img_tensor, img_metas, img_norm_cfg = \ | |
| retrieve_img_tensor_and_meta(data) | |
| if img_tensor.size(1) == 1: | |
| imgs = tensor2grayimgs(img_tensor, **img_norm_cfg) | |
| else: | |
| imgs = tensor2imgs(img_tensor, **img_norm_cfg) | |
| assert len(imgs) == len(img_metas) | |
| for j, (img, img_meta) in enumerate(zip(imgs, img_metas)): | |
| img_shape, ori_shape = img_meta['img_shape'], img_meta[ | |
| 'ori_shape'] | |
| img_show = img[:img_shape[0], :img_shape[1]] | |
| img_show = mmcv.imresize(img_show, | |
| (ori_shape[1], ori_shape[0])) | |
| if out_dir: | |
| out_file = osp.join(out_dir, img_meta['ori_filename']) | |
| else: | |
| out_file = None | |
| model.module.show_result( | |
| img_show, | |
| result[j], | |
| show=show, | |
| out_file=out_file, | |
| score_thr=show_score_thr) | |
| # encode mask results | |
| if isinstance(result[0], tuple): | |
| result = [(bbox_results, encode_mask_results(mask_results)) | |
| for bbox_results, mask_results in result] | |
| results.extend(result) | |
| for _ in range(batch_size): | |
| prog_bar.update() | |
| return results | |