Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| import os | |
| import pytest | |
| from mmcv import Config | |
| from mmocr.apis.utils import (disable_text_recog_aug_test, | |
| replace_image_to_tensor) | |
| def test_disable_text_recog_aug_test(cfg_file): | |
| tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) | |
| config_file = os.path.join(tmp_dir, cfg_file) | |
| cfg = Config.fromfile(config_file) | |
| test = cfg.data.test.datasets[0] | |
| # cfg.data.test.type is 'OCRDataset' | |
| cfg1 = copy.deepcopy(cfg) | |
| test1 = copy.deepcopy(test) | |
| test1.pipeline = cfg1.data.test.pipeline | |
| cfg1.data.test = test1 | |
| cfg1 = disable_text_recog_aug_test(cfg1, set_types=['test']) | |
| assert cfg1.data.test.pipeline[1].type != 'MultiRotateAugOCR' | |
| # cfg.data.test.type is 'UniformConcatDataset' | |
| # and cfg.data.test.pipeline is list[dict] | |
| cfg2 = copy.deepcopy(cfg) | |
| test2 = copy.deepcopy(test) | |
| test2.pipeline = cfg2.data.test.pipeline | |
| cfg2.data.test.datasets = [test2] | |
| cfg2 = disable_text_recog_aug_test(cfg2, set_types=['test']) | |
| assert cfg2.data.test.pipeline[1].type != 'MultiRotateAugOCR' | |
| assert cfg2.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR' | |
| # cfg.data.test.type is 'ConcatDataset' | |
| cfg3 = copy.deepcopy(cfg) | |
| test3 = copy.deepcopy(test) | |
| test3.pipeline = cfg3.data.test.pipeline | |
| cfg3.data.test = Config(dict(type='ConcatDataset', datasets=[test3])) | |
| cfg3 = disable_text_recog_aug_test(cfg3, set_types=['test']) | |
| assert cfg3.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR' | |
| # cfg.data.test.type is 'UniformConcatDataset' | |
| # and cfg.data.test.pipeline is list[list[dict]] | |
| cfg4 = copy.deepcopy(cfg) | |
| test4 = copy.deepcopy(test) | |
| test4.pipeline = cfg4.data.test.pipeline | |
| cfg4.data.test.datasets = [[test4], [test]] | |
| cfg4.data.test.pipeline = [ | |
| cfg4.data.test.pipeline, cfg4.data.test.pipeline | |
| ] | |
| cfg4 = disable_text_recog_aug_test(cfg4, set_types=['test']) | |
| assert cfg4.data.test.datasets[0][0].pipeline[1].type != \ | |
| 'MultiRotateAugOCR' | |
| # cfg.data.test.type is 'UniformConcatDataset' | |
| # and cfg.data.test.pipeline is None | |
| cfg5 = copy.deepcopy(cfg) | |
| test5 = copy.deepcopy(test) | |
| test5.pipeline = copy.deepcopy(cfg5.data.test.pipeline) | |
| cfg5.data.test.datasets = [test5] | |
| cfg5.data.test.pipeline = None | |
| cfg5 = disable_text_recog_aug_test(cfg5, set_types=['test']) | |
| assert cfg5.data.test.datasets[0].pipeline[1].type != 'MultiRotateAugOCR' | |
| def test_replace_image_to_tensor(cfg_file): | |
| tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) | |
| config_file = os.path.join(tmp_dir, cfg_file) | |
| cfg = Config.fromfile(config_file) | |
| test = cfg.data.test.datasets[0] | |
| # cfg.data.test.pipeline is list[dict] | |
| # and cfg.data.test.datasets is list[dict] | |
| cfg1 = copy.deepcopy(cfg) | |
| test1 = copy.deepcopy(test) | |
| test1.pipeline = copy.deepcopy(cfg.data.test.pipeline) | |
| cfg1.data.test.datasets = [test1] | |
| cfg1 = replace_image_to_tensor(cfg1, set_types=['test']) | |
| assert cfg1.data.test.pipeline[1]['transforms'][3][ | |
| 'type'] == 'DefaultFormatBundle' | |
| assert cfg1.data.test.datasets[0].pipeline[1]['transforms'][3][ | |
| 'type'] == 'DefaultFormatBundle' | |
| # cfg.data.test.pipeline is list[list[dict]] | |
| # and cfg.data.test.datasets is list[list[dict]] | |
| cfg2 = copy.deepcopy(cfg) | |
| test2 = copy.deepcopy(test) | |
| test2.pipeline = copy.deepcopy(cfg.data.test.pipeline) | |
| cfg2.data.test.datasets = [[test2], [test2]] | |
| cfg2.data.test.pipeline = [ | |
| cfg2.data.test.pipeline, cfg2.data.test.pipeline | |
| ] | |
| cfg2 = replace_image_to_tensor(cfg2, set_types=['test']) | |
| assert cfg2.data.test.pipeline[0][1]['transforms'][3][ | |
| 'type'] == 'DefaultFormatBundle' | |
| assert cfg2.data.test.datasets[0][0].pipeline[1]['transforms'][3][ | |
| 'type'] == 'DefaultFormatBundle' | |