Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmcv.ops import RoIPool | |
| from mmcv.parallel import collate, scatter | |
| from mmcv.runner import load_checkpoint | |
| from mmdet.core import get_classes | |
| from mmdet.datasets import replace_ImageToTensor | |
| from mmdet.datasets.pipelines import Compose | |
| from mmocr.models import build_detector | |
| from mmocr.utils import is_2dlist | |
| from .utils import disable_text_recog_aug_test | |
| def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None): | |
| """Initialize a detector from config file. | |
| Args: | |
| config (str or :obj:`mmcv.Config`): Config file path or the config | |
| object. | |
| checkpoint (str, optional): Checkpoint path. If left as None, the model | |
| will not load any weights. | |
| cfg_options (dict): Options to override some settings in the used | |
| config. | |
| Returns: | |
| nn.Module: The constructed detector. | |
| """ | |
| if isinstance(config, str): | |
| config = mmcv.Config.fromfile(config) | |
| elif not isinstance(config, mmcv.Config): | |
| raise TypeError('config must be a filename or Config object, ' | |
| f'but got {type(config)}') | |
| if cfg_options is not None: | |
| config.merge_from_dict(cfg_options) | |
| if config.model.get('pretrained'): | |
| config.model.pretrained = None | |
| config.model.train_cfg = None | |
| model = build_detector(config.model, test_cfg=config.get('test_cfg')) | |
| if checkpoint is not None: | |
| checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') | |
| if 'CLASSES' in checkpoint.get('meta', {}): | |
| model.CLASSES = checkpoint['meta']['CLASSES'] | |
| else: | |
| warnings.simplefilter('once') | |
| warnings.warn('Class names are not saved in the checkpoint\'s ' | |
| 'meta data, use COCO classes by default.') | |
| model.CLASSES = get_classes('coco') | |
| model.cfg = config # save the config in the model for convenience | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def model_inference(model, | |
| imgs, | |
| ann=None, | |
| batch_mode=False, | |
| return_data=False): | |
| """Inference image(s) with the detector. | |
| Args: | |
| model (nn.Module): The loaded detector. | |
| imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): | |
| Either image files or loaded images. | |
| batch_mode (bool): If True, use batch mode for inference. | |
| ann (dict): Annotation info for key information extraction. | |
| return_data: Return postprocessed data. | |
| Returns: | |
| result (dict): Predicted results. | |
| """ | |
| if isinstance(imgs, (list, tuple)): | |
| is_batch = True | |
| if len(imgs) == 0: | |
| raise Exception('empty imgs provided, please check and try again') | |
| if not isinstance(imgs[0], (np.ndarray, str)): | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| elif isinstance(imgs, (np.ndarray, str)): | |
| imgs = [imgs] | |
| is_batch = False | |
| else: | |
| raise AssertionError('imgs must be strings or numpy arrays') | |
| is_ndarray = isinstance(imgs[0], np.ndarray) | |
| cfg = model.cfg | |
| if batch_mode: | |
| cfg = disable_text_recog_aug_test(cfg, set_types=['test']) | |
| device = next(model.parameters()).device # model device | |
| if cfg.data.test.get('pipeline', None) is None: | |
| if is_2dlist(cfg.data.test.datasets): | |
| cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline | |
| else: | |
| cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline | |
| if is_2dlist(cfg.data.test.pipeline): | |
| cfg.data.test.pipeline = cfg.data.test.pipeline[0] | |
| if is_ndarray: | |
| cfg = cfg.copy() | |
| # set loading pipeline type | |
| cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray' | |
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |
| test_pipeline = Compose(cfg.data.test.pipeline) | |
| datas = [] | |
| for img in imgs: | |
| # prepare data | |
| if is_ndarray: | |
| # directly add img | |
| data = dict( | |
| img=img, | |
| ann_info=ann, | |
| img_info=dict(width=img.shape[1], height=img.shape[0]), | |
| bbox_fields=[]) | |
| else: | |
| # add information into dict | |
| data = dict( | |
| img_info=dict(filename=img), | |
| img_prefix=None, | |
| ann_info=ann, | |
| bbox_fields=[]) | |
| if ann is not None: | |
| data.update(dict(**ann)) | |
| # build the data pipeline | |
| data = test_pipeline(data) | |
| # get tensor from list to stack for batch mode (text detection) | |
| if batch_mode: | |
| if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug': | |
| for key, value in data.items(): | |
| data[key] = value[0] | |
| datas.append(data) | |
| if isinstance(datas[0]['img'], list) and len(datas) > 1: | |
| raise Exception('aug test does not support ' | |
| f'inference with batch size ' | |
| f'{len(datas)}') | |
| data = collate(datas, samples_per_gpu=len(imgs)) | |
| # process img_metas | |
| if isinstance(data['img_metas'], list): | |
| data['img_metas'] = [ | |
| img_metas.data[0] for img_metas in data['img_metas'] | |
| ] | |
| else: | |
| data['img_metas'] = data['img_metas'].data | |
| if isinstance(data['img'], list): | |
| data['img'] = [img.data for img in data['img']] | |
| if isinstance(data['img'][0], list): | |
| data['img'] = [img[0] for img in data['img']] | |
| else: | |
| data['img'] = data['img'].data | |
| # for KIE models | |
| if ann is not None: | |
| data['relations'] = data['relations'].data[0] | |
| data['gt_bboxes'] = data['gt_bboxes'].data[0] | |
| data['texts'] = data['texts'].data[0] | |
| data['img'] = data['img'][0] | |
| data['img_metas'] = data['img_metas'][0] | |
| if next(model.parameters()).is_cuda: | |
| # scatter to specified GPU | |
| data = scatter(data, [device])[0] | |
| else: | |
| for m in model.modules(): | |
| assert not isinstance( | |
| m, RoIPool | |
| ), 'CPU inference with RoIPool is not supported currently.' | |
| # forward the model | |
| with torch.no_grad(): | |
| results = model(return_loss=False, rescale=True, **data) | |
| if not is_batch: | |
| if not return_data: | |
| return results[0] | |
| return results[0], datas[0] | |
| else: | |
| if not return_data: | |
| return results | |
| return results, datas | |
| def text_model_inference(model, input_sentence): | |
| """Inference text(s) with the entity recognizer. | |
| Args: | |
| model (nn.Module): The loaded recognizer. | |
| input_sentence (str): A text entered by the user. | |
| Returns: | |
| result (dict): Predicted results. | |
| """ | |
| assert isinstance(input_sentence, str) | |
| cfg = model.cfg | |
| if cfg.data.test.get('pipeline', None) is None: | |
| if is_2dlist(cfg.data.test.datasets): | |
| cfg.data.test.pipeline = cfg.data.test.datasets[0][0].pipeline | |
| else: | |
| cfg.data.test.pipeline = cfg.data.test.datasets[0].pipeline | |
| if is_2dlist(cfg.data.test.pipeline): | |
| cfg.data.test.pipeline = cfg.data.test.pipeline[0] | |
| test_pipeline = Compose(cfg.data.test.pipeline) | |
| data = {'text': input_sentence, 'label': {}} | |
| # build the data pipeline | |
| data = test_pipeline(data) | |
| if isinstance(data['img_metas'], dict): | |
| img_metas = data['img_metas'] | |
| else: | |
| img_metas = data['img_metas'].data | |
| assert isinstance(img_metas, dict) | |
| img_metas = { | |
| 'input_ids': img_metas['input_ids'].unsqueeze(0), | |
| 'attention_masks': img_metas['attention_masks'].unsqueeze(0), | |
| 'token_type_ids': img_metas['token_type_ids'].unsqueeze(0), | |
| 'labels': img_metas['labels'].unsqueeze(0) | |
| } | |
| # forward the model | |
| with torch.no_grad(): | |
| result = model(None, img_metas, return_loss=False) | |
| return result | |