| |
| import os.path as osp |
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
|
|
| from mmdet.datasets import DATASETS |
|
|
|
|
| @patch('mmdet.datasets.CocoDataset.load_annotations', MagicMock()) |
| @patch('mmdet.datasets.CustomDataset.load_annotations', MagicMock()) |
| @patch('mmdet.datasets.XMLDataset.load_annotations', MagicMock()) |
| @patch('mmdet.datasets.CityscapesDataset.load_annotations', MagicMock()) |
| @patch('mmdet.datasets.CocoDataset._filter_imgs', MagicMock) |
| @patch('mmdet.datasets.CustomDataset._filter_imgs', MagicMock) |
| @patch('mmdet.datasets.XMLDataset._filter_imgs', MagicMock) |
| @patch('mmdet.datasets.CityscapesDataset._filter_imgs', MagicMock) |
| @pytest.mark.parametrize('dataset', |
| ['CocoDataset', 'VOCDataset', 'CityscapesDataset']) |
| def test_custom_classes_override_default(dataset): |
| dataset_class = DATASETS.get(dataset) |
| if dataset in ['CocoDataset', 'CityscapesDataset']: |
| dataset_class.coco = MagicMock() |
| dataset_class.cat_ids = MagicMock() |
|
|
| original_classes = dataset_class.CLASSES |
|
|
| |
| custom_dataset = dataset_class( |
| ann_file=MagicMock(), |
| pipeline=[], |
| classes=('bus', 'car'), |
| test_mode=True, |
| img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
| assert custom_dataset.CLASSES != original_classes |
| assert custom_dataset.CLASSES == ('bus', 'car') |
| print(custom_dataset) |
|
|
| |
| custom_dataset = dataset_class( |
| ann_file=MagicMock(), |
| pipeline=[], |
| classes=['bus', 'car'], |
| test_mode=True, |
| img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
| assert custom_dataset.CLASSES != original_classes |
| assert custom_dataset.CLASSES == ['bus', 'car'] |
| print(custom_dataset) |
|
|
| |
| custom_dataset = dataset_class( |
| ann_file=MagicMock(), |
| pipeline=[], |
| classes=['foo'], |
| test_mode=True, |
| img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
| assert custom_dataset.CLASSES != original_classes |
| assert custom_dataset.CLASSES == ['foo'] |
| print(custom_dataset) |
|
|
| |
| custom_dataset = dataset_class( |
| ann_file=MagicMock(), |
| pipeline=[], |
| classes=None, |
| test_mode=True, |
| img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
| assert custom_dataset.CLASSES == original_classes |
| print(custom_dataset) |
|
|
| |
| import tempfile |
| with tempfile.TemporaryDirectory() as tmpdir: |
| path = tmpdir + 'classes.txt' |
| with open(path, 'w') as f: |
| f.write('bus\ncar\n') |
| custom_dataset = dataset_class( |
| ann_file=MagicMock(), |
| pipeline=[], |
| classes=path, |
| test_mode=True, |
| img_prefix='VOC2007' if dataset == 'VOCDataset' else '') |
|
|
| assert custom_dataset.CLASSES != original_classes |
| assert custom_dataset.CLASSES == ['bus', 'car'] |
| print(custom_dataset) |
|
|
|
|
| class CustomDatasetTests(unittest.TestCase): |
|
|
| def setUp(self): |
| super().setUp() |
| self.data_dir = osp.join( |
| osp.dirname(osp.dirname(osp.dirname(__file__))), 'data') |
| self.dataset_class = DATASETS.get('XMLDataset') |
|
|
| def test_data_infos__default_db_directories(self): |
| """Test correct data read having a Pacal-VOC directory structure.""" |
| test_dataset_root = osp.join(self.data_dir, 'VOCdevkit', 'VOC2007') |
| custom_ds = self.dataset_class( |
| data_root=test_dataset_root, |
| ann_file=osp.join(test_dataset_root, 'ImageSets', 'Main', |
| 'trainval.txt'), |
| pipeline=[], |
| classes=('person', 'dog'), |
| test_mode=True) |
|
|
| self.assertListEqual([{ |
| 'id': '000001', |
| 'filename': osp.join('JPEGImages', '000001.jpg'), |
| 'width': 353, |
| 'height': 500 |
| }], custom_ds.data_infos) |
|
|
| def test_data_infos__overridden_db_subdirectories(self): |
| """Test correct data read having a customized directory structure.""" |
| test_dataset_root = osp.join(self.data_dir, 'custom_dataset') |
| custom_ds = self.dataset_class( |
| data_root=test_dataset_root, |
| ann_file=osp.join(test_dataset_root, 'trainval.txt'), |
| pipeline=[], |
| classes=('person', 'dog'), |
| test_mode=True, |
| img_prefix='', |
| img_subdir='images', |
| ann_subdir='images') |
|
|
| self.assertListEqual([{ |
| 'id': '000001', |
| 'filename': osp.join('images', '000001.jpg'), |
| 'width': 353, |
| 'height': 500 |
| }], custom_ds.data_infos) |
|
|