Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import io | |
| import json | |
| import os | |
| import platform | |
| import random | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from unittest import mock | |
| import mmcv | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from mmocr.apis import init_detector | |
| from mmocr.datasets.kie_dataset import KIEDataset | |
| from mmocr.utils.ocr import MMOCR | |
| def test_ocr_init_errors(): | |
| # Test assertions | |
| with pytest.raises(ValueError): | |
| _ = MMOCR(det='test') | |
| with pytest.raises(ValueError): | |
| _ = MMOCR(recog='test') | |
| with pytest.raises(ValueError): | |
| _ = MMOCR(kie='test') | |
| with pytest.raises(NotImplementedError): | |
| _ = MMOCR(det=None, recog=None, kie='SDMGR') | |
| with pytest.raises(NotImplementedError): | |
| _ = MMOCR(det='DB_r18', recog=None, kie='SDMGR') | |
| cfg_default_prefix = os.path.join(str(Path.cwd()), 'configs/') | |
| def test_ocr_init(mock_loading, mock_config, mock_build_detector, | |
| mock_init_detector, det, recog, kie, config_dir, gt_cfg, | |
| gt_ckpt): | |
| def loadcheckpoint_assert(*args, **kwargs): | |
| assert args[1] == gt_ckpt[-1] | |
| assert kwargs['map_location'] == torch.device( | |
| 'cuda' if torch.cuda.is_available() else 'cpu') | |
| mock_loading.side_effect = loadcheckpoint_assert | |
| with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): | |
| if kie == '': | |
| if config_dir == '': | |
| _ = MMOCR(det=det, recog=recog) | |
| else: | |
| _ = MMOCR(det=det, recog=recog, config_dir=config_dir) | |
| else: | |
| if config_dir == '': | |
| _ = MMOCR(det=det, recog=recog, kie=kie) | |
| else: | |
| _ = MMOCR(det=det, recog=recog, kie=kie, config_dir=config_dir) | |
| if isinstance(gt_cfg, str): | |
| gt_cfg = [gt_cfg] | |
| if isinstance(gt_ckpt, str): | |
| gt_ckpt = [gt_ckpt] | |
| i_range = range(len(gt_cfg)) | |
| if kie: | |
| i_range = i_range[:-1] | |
| mock_config.assert_called_with(gt_cfg[-1]) | |
| mock_build_detector.assert_called_once() | |
| mock_loading.assert_called_once() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| calls = [ | |
| mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range | |
| ] | |
| mock_init_detector.assert_has_calls(calls) | |
| def test_ocr_init_customize_config(mock_loading, mock_config, | |
| mock_build_detector, mock_init_detector, | |
| det, det_config, det_ckpt, recog, | |
| recog_config, recog_ckpt, kie, kie_config, | |
| kie_ckpt, config_dir, gt_cfg, gt_ckpt): | |
| def loadcheckpoint_assert(*args, **kwargs): | |
| assert args[1] == gt_ckpt[-1] | |
| mock_loading.side_effect = loadcheckpoint_assert | |
| with mock.patch('mmocr.utils.ocr.revert_sync_batchnorm'): | |
| _ = MMOCR( | |
| det=det, | |
| det_config=det_config, | |
| det_ckpt=det_ckpt, | |
| recog=recog, | |
| recog_config=recog_config, | |
| recog_ckpt=recog_ckpt, | |
| kie=kie, | |
| kie_config=kie_config, | |
| kie_ckpt=kie_ckpt, | |
| config_dir=config_dir) | |
| i_range = range(len(gt_cfg)) | |
| if kie: | |
| i_range = i_range[:-1] | |
| mock_config.assert_called_with(gt_cfg[-1]) | |
| mock_build_detector.assert_called_once() | |
| mock_loading.assert_called_once() | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| calls = [ | |
| mock.call(gt_cfg[i], gt_ckpt[i], device=device) for i in i_range | |
| ] | |
| mock_init_detector.assert_has_calls(calls) | |
| def test_single_inference(mock_model_inference, mock_loading, mock_config, | |
| mock_build_detector, mock_init_detector): | |
| def dummy_inference(model, arr, batch_mode): | |
| return arr | |
| mock_model_inference.side_effect = dummy_inference | |
| mmocr = MMOCR() | |
| data = list(range(20)) | |
| model = 'dummy' | |
| res = mmocr.single_inference(model, data, batch_mode=False) | |
| assert (data == res) | |
| mock_model_inference.reset_mock() | |
| res = mmocr.single_inference(model, data, batch_mode=True) | |
| assert (data == res) | |
| mock_model_inference.assert_called_once() | |
| mock_model_inference.reset_mock() | |
| res = mmocr.single_inference(model, data, batch_mode=True, batch_size=100) | |
| assert (data == res) | |
| mock_model_inference.assert_called_once() | |
| mock_model_inference.reset_mock() | |
| res = mmocr.single_inference(model, data, batch_mode=True, batch_size=3) | |
| assert (data == res) | |
| def MMOCR_testobj(mock_loading, mock_init_detector, **kwargs): | |
| # returns an MMOCR object bypassing the | |
| # checkpoint initialization step | |
| def init_detector_skip_ckpt(config, ckpt, device): | |
| return init_detector(config, device=device) | |
| def modify_kie_class(model, ckpt, map_location): | |
| model.class_list = 'tests/data/kie_toy_dataset/class_list.txt' | |
| mock_init_detector.side_effect = init_detector_skip_ckpt | |
| mock_loading.side_effect = modify_kie_class | |
| kwargs['det'] = kwargs.get('det', 'DB_r18') | |
| kwargs['recog'] = kwargs.get('recog', 'CRNN') | |
| kwargs['kie'] = kwargs.get('kie', 'SDMGR') | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| return MMOCR(**kwargs, device=device) | |
| def test_readtext(mock_kiedataset): | |
| # Fixing the weights of models to prevent them from | |
| # generating invalid results and triggering other assertion errors | |
| torch.manual_seed(4) | |
| random.seed(4) | |
| mmocr = MMOCR_testobj() | |
| mmocr_det = MMOCR_testobj(kie='', recog='') | |
| mmocr_recog = MMOCR_testobj(kie='', det='', recog='CRNN_TPS') | |
| mmocr_det_recog = MMOCR_testobj(kie='') | |
| def readtext(imgs, ocr_obj=mmocr, **kwargs): | |
| # filename can be different depends on how | |
| # the the image was loaded | |
| e2e_res = ocr_obj.readtext(imgs, **kwargs) | |
| for res in e2e_res: | |
| res.pop('filename') | |
| return e2e_res | |
| def kiedataset_with_test_dict(**kwargs): | |
| kwargs['dict_file'] = 'tests/data/kie_toy_dataset/dict.txt' | |
| return KIEDataset(**kwargs) | |
| mock_kiedataset.side_effect = kiedataset_with_test_dict | |
| # Single image | |
| toy_dir = 'tests/data/toy_dataset/imgs/test/' | |
| toy_img1_path = toy_dir + 'img_1.jpg' | |
| str_e2e_res = readtext(toy_img1_path) | |
| toy_img1 = mmcv.imread(toy_img1_path) | |
| np_e2e_res = readtext(toy_img1) | |
| assert str_e2e_res == np_e2e_res | |
| # Multiple images | |
| toy_img2_path = toy_dir + 'img_2.jpg' | |
| toy_img2 = mmcv.imread(toy_img2_path) | |
| toy_imgs = [toy_img1, toy_img2] | |
| toy_img_paths = [toy_img1_path, toy_img2_path] | |
| np_e2e_results = readtext(toy_imgs) | |
| str_e2e_results = readtext(toy_img_paths) | |
| str_tuple_e2e_results = readtext(tuple(toy_img_paths)) | |
| assert np_e2e_results == str_e2e_results | |
| assert str_e2e_results == str_tuple_e2e_results | |
| # Batch mode test | |
| toy_imgs.append(toy_dir + 'img_3.jpg') | |
| e2e_res = readtext(toy_imgs) | |
| full_batch_e2e_res = readtext(toy_imgs, batch_mode=True) | |
| assert full_batch_e2e_res == e2e_res | |
| batch_e2e_res = readtext( | |
| toy_imgs, batch_mode=True, recog_batch_size=2, det_batch_size=2) | |
| assert batch_e2e_res == full_batch_e2e_res | |
| # Batch mode test with DBNet only | |
| full_batch_det_res = mmocr_det.readtext(toy_imgs, batch_mode=True) | |
| det_res = mmocr_det.readtext(toy_imgs) | |
| batch_det_res = mmocr_det.readtext( | |
| toy_imgs, batch_mode=True, single_batch_size=2) | |
| assert len(full_batch_det_res) == len(det_res) | |
| assert len(batch_det_res) == len(det_res) | |
| assert all([ | |
| np.allclose(full_batch_det_res[i]['boundary_result'], | |
| det_res[i]['boundary_result']) | |
| for i in range(len(full_batch_det_res)) | |
| ]) | |
| assert all([ | |
| np.allclose(batch_det_res[i]['boundary_result'], | |
| det_res[i]['boundary_result']) | |
| for i in range(len(batch_det_res)) | |
| ]) | |
| # Batch mode test with CRNN_TPS only (CRNN doesn't support batch inference) | |
| full_batch_recog_res = mmocr_recog.readtext(toy_imgs, batch_mode=True) | |
| recog_res = mmocr_recog.readtext(toy_imgs) | |
| batch_recog_res = mmocr_recog.readtext( | |
| toy_imgs, batch_mode=True, single_batch_size=2) | |
| full_batch_recog_res.sort(key=lambda x: x['text']) | |
| batch_recog_res.sort(key=lambda x: x['text']) | |
| recog_res.sort(key=lambda x: x['text']) | |
| assert np.all([ | |
| np.allclose(full_batch_recog_res[i]['score'], recog_res[i]['score']) | |
| for i in range(len(full_batch_recog_res)) | |
| ]) | |
| assert np.all([ | |
| np.allclose(batch_recog_res[i]['score'], recog_res[i]['score']) | |
| for i in range(len(full_batch_recog_res)) | |
| ]) | |
| # Test export | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| mmocr.readtext(toy_imgs, export=tmpdirname) | |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| mmocr_det.readtext(toy_imgs, export=tmpdirname) | |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| mmocr_recog.readtext(toy_imgs, export=tmpdirname) | |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) | |
| # Test output | |
| # Single image | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| tmp_output = os.path.join(tmpdirname, '1.jpg') | |
| mmocr.readtext(toy_imgs[0], output=tmp_output) | |
| assert os.path.exists(tmp_output) | |
| # Multiple images | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| mmocr.readtext(toy_imgs, output=tmpdirname) | |
| assert len(os.listdir(tmpdirname)) == len(toy_imgs) | |
| # Test imshow | |
| with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow: | |
| mmocr.readtext(toy_img1_path, imshow=True) | |
| mock_imshow.assert_called_once() | |
| mock_imshow.reset_mock() | |
| mmocr.readtext(toy_imgs, imshow=True) | |
| assert mock_imshow.call_count == len(toy_imgs) | |
| # Test print_result | |
| with io.StringIO() as capturedOutput: | |
| sys.stdout = capturedOutput | |
| res = mmocr.readtext(toy_imgs, print_result=True) | |
| assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( | |
| '\n\n', ',').replace("'", '"')) == res | |
| sys.stdout = sys.__stdout__ | |
| with io.StringIO() as capturedOutput: | |
| sys.stdout = capturedOutput | |
| res = mmocr.readtext(toy_imgs, details=True, print_result=True) | |
| assert json.loads('[%s]' % capturedOutput.getvalue().strip().replace( | |
| '\n\n', ',').replace("'", '"')) == res | |
| sys.stdout = sys.__stdout__ | |
| # Test merge | |
| with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge: | |
| mmocr_det_recog.readtext(toy_imgs, merge=True) | |
| assert mock_merge.call_count == len(toy_imgs) | |