Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import tempfile | |
| import numpy as np | |
| import pytest | |
| from mmocr.datasets.pipelines.ocr_seg_targets import OCRSegTargets | |
| def _create_dummy_dict_file(dict_file): | |
| chars = list('0123456789') | |
| with open(dict_file, 'w') as fw: | |
| for char in chars: | |
| fw.write(char + '\n') | |
| def test_ocr_segm_targets(): | |
| tmp_dir = tempfile.TemporaryDirectory() | |
| # create dummy dict file | |
| dict_file = osp.join(tmp_dir.name, 'fake_chars.txt') | |
| _create_dummy_dict_file(dict_file) | |
| # dummy label convertor | |
| label_convertor = dict( | |
| type='SegConvertor', | |
| dict_file=dict_file, | |
| with_unknown=True, | |
| lower=True) | |
| # test init | |
| with pytest.raises(AssertionError): | |
| OCRSegTargets(None, 0.5, 0.5) | |
| with pytest.raises(AssertionError): | |
| OCRSegTargets(label_convertor, '1by2', 0.5) | |
| with pytest.raises(AssertionError): | |
| OCRSegTargets(label_convertor, 0.5, 2) | |
| ocr_seg_tgt = OCRSegTargets(label_convertor, 0.5, 0.5) | |
| # test generate kernels | |
| img_size = (8, 8) | |
| pad_size = (8, 10) | |
| char_boxes = [[2, 2, 6, 6]] | |
| char_idxs = [2] | |
| with pytest.raises(AssertionError): | |
| ocr_seg_tgt.generate_kernels(8, pad_size, char_boxes, char_idxs, 0.5, | |
| True) | |
| with pytest.raises(AssertionError): | |
| ocr_seg_tgt.generate_kernels(img_size, pad_size, [2, 2, 6, 6], | |
| char_idxs, 0.5, True) | |
| with pytest.raises(AssertionError): | |
| ocr_seg_tgt.generate_kernels(img_size, pad_size, char_boxes, 2, 0.5, | |
| True) | |
| attn_tgt = ocr_seg_tgt.generate_kernels( | |
| img_size, pad_size, char_boxes, char_idxs, 0.5, binary=True) | |
| expect_attn_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
| [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
| [0, 0, 0, 1, 1, 1, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] | |
| assert np.allclose(attn_tgt, np.array(expect_attn_tgt, dtype=np.int32)) | |
| segm_tgt = ocr_seg_tgt.generate_kernels( | |
| img_size, pad_size, char_boxes, char_idxs, 0.5, binary=False) | |
| expect_segm_tgt = [[0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
| [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
| [0, 0, 0, 2, 2, 2, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255], | |
| [0, 0, 0, 0, 0, 0, 0, 0, 255, 255]] | |
| assert np.allclose(segm_tgt, np.array(expect_segm_tgt, dtype=np.int32)) | |
| # test __call__ | |
| results = {} | |
| results['img_shape'] = (4, 4, 3) | |
| results['resize_shape'] = (8, 8, 3) | |
| results['pad_shape'] = (8, 10) | |
| results['ann_info'] = {} | |
| results['ann_info']['char_rects'] = [[1, 1, 3, 3]] | |
| results['ann_info']['chars'] = ['1'] | |
| results = ocr_seg_tgt(results) | |
| assert results['mask_fields'] == ['gt_kernels'] | |
| assert np.allclose(results['gt_kernels'].masks[0], | |
| np.array(expect_attn_tgt, dtype=np.int32)) | |
| assert np.allclose(results['gt_kernels'].masks[1], | |
| np.array(expect_segm_tgt, dtype=np.int32)) | |
| tmp_dir.cleanup() | |