Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import copy | |
| from mmocr.datasets import UniformConcatDataset | |
| from mmocr.utils import list_from_file | |
| def test_dataset_warpper(): | |
| pipeline1 = [dict(type='LoadImageFromFile')] | |
| pipeline2 = [dict(type='LoadImageFromFile'), dict(type='ColorJitter')] | |
| img_prefix = 'tests/data/ocr_toy_dataset/imgs' | |
| ann_file = 'tests/data/ocr_toy_dataset/label.txt' | |
| train1 = dict( | |
| type='OCRDataset', | |
| img_prefix=img_prefix, | |
| ann_file=ann_file, | |
| loader=dict( | |
| type='HardDiskLoader', | |
| repeat=1, | |
| parser=dict( | |
| type='LineStrParser', | |
| keys=['filename', 'text'], | |
| keys_idx=[0, 1], | |
| separator=' ')), | |
| pipeline=None, | |
| test_mode=False) | |
| train2 = {key: value for key, value in train1.items()} | |
| train2['pipeline'] = pipeline2 | |
| # pipeline is 1d list | |
| copy_train1 = copy.deepcopy(train1) | |
| copy_train2 = copy.deepcopy(train2) | |
| tmp_dataset = UniformConcatDataset( | |
| datasets=[copy_train1, copy_train2], | |
| pipeline=pipeline1, | |
| force_apply=True) | |
| assert len(tmp_dataset) == 2 * len(list_from_file(ann_file)) | |
| assert len(tmp_dataset.datasets[0].pipeline.transforms) == len( | |
| tmp_dataset.datasets[1].pipeline.transforms) | |
| # pipeline is None | |
| copy_train2 = copy.deepcopy(train2) | |
| tmp_dataset = UniformConcatDataset(datasets=[copy_train2], pipeline=None) | |
| assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2) | |
| copy_train2 = copy.deepcopy(train2) | |
| tmp_dataset = UniformConcatDataset( | |
| datasets=[[copy_train2], [copy_train2]], pipeline=None) | |
| assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline2) | |
| # pipeline is 2d list | |
| copy_train1 = copy.deepcopy(train1) | |
| copy_train2 = copy.deepcopy(train2) | |
| tmp_dataset = UniformConcatDataset( | |
| datasets=[[copy_train1], [copy_train2]], | |
| pipeline=[pipeline1, pipeline2]) | |
| assert len(tmp_dataset.datasets[0].pipeline.transforms) == len(pipeline1) | |