| | |
| |
|
| | import bisect |
| | from unittest import TestCase |
| | from unittest.mock import patch |
| |
|
| | import numpy as np |
| | from torch.utils.data import ConcatDataset, Dataset |
| |
|
| | from mmdet.datasets.samplers import GroupMultiSourceSampler, MultiSourceSampler |
| |
|
| |
|
| | class DummyDataset(Dataset): |
| |
|
| | def __init__(self, length, flag): |
| | self.length = length |
| | self.flag = flag |
| | self.shapes = np.random.random((length, 2)) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, idx): |
| | return self.shapes[idx] |
| |
|
| | def get_data_info(self, idx): |
| | return dict( |
| | width=self.shapes[idx][0], |
| | height=self.shapes[idx][1], |
| | flag=self.flag) |
| |
|
| |
|
| | class DummyConcatDataset(ConcatDataset): |
| |
|
| | def _get_ori_dataset_idx(self, idx): |
| | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| | sample_idx = idx if dataset_idx == 0 else idx - self.cumulative_sizes[ |
| | dataset_idx - 1] |
| | return dataset_idx, sample_idx |
| |
|
| | def get_data_info(self, idx: int): |
| | dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) |
| | return self.datasets[dataset_idx].get_data_info(sample_idx) |
| |
|
| |
|
| | class TestMultiSourceSampler(TestCase): |
| |
|
| | @patch('mmengine.dist.get_dist_info', return_value=(7, 8)) |
| | def setUp(self, mock): |
| | self.length_a = 100 |
| | self.dataset_a = DummyDataset(self.length_a, flag='a') |
| | self.length_b = 1000 |
| | self.dataset_b = DummyDataset(self.length_b, flag='b') |
| | self.dataset = DummyConcatDataset([self.dataset_a, self.dataset_b]) |
| |
|
| | def test_multi_source_sampler(self): |
| | |
| | with self.assertRaises(AssertionError): |
| | MultiSourceSampler( |
| | self.dataset_a, batch_size=5, source_ratio=[1, 4]) |
| | |
| | with self.assertRaises(AssertionError): |
| | MultiSourceSampler( |
| | self.dataset_a, batch_size=-5, source_ratio=[1, 4]) |
| | |
| | with self.assertRaises(AssertionError): |
| | MultiSourceSampler( |
| | self.dataset, batch_size=5, source_ratio=[1, 2, 4]) |
| | sampler = MultiSourceSampler( |
| | self.dataset, batch_size=5, source_ratio=[1, 4]) |
| | sampler = iter(sampler) |
| | flags = [] |
| | for i in range(100): |
| | idx = next(sampler) |
| | flags.append(self.dataset.get_data_info(idx)['flag']) |
| | flags_gt = ['a', 'b', 'b', 'b', 'b'] * 20 |
| | self.assertEqual(flags, flags_gt) |
| |
|
| |
|
| | class TestGroupMultiSourceSampler(TestCase): |
| |
|
| | @patch('mmengine.dist.get_dist_info', return_value=(7, 8)) |
| | def setUp(self, mock): |
| | self.length_a = 100 |
| | self.dataset_a = DummyDataset(self.length_a, flag='a') |
| | self.length_b = 1000 |
| | self.dataset_b = DummyDataset(self.length_b, flag='b') |
| | self.dataset = DummyConcatDataset([self.dataset_a, self.dataset_b]) |
| |
|
| | def test_group_multi_source_sampler(self): |
| | sampler = GroupMultiSourceSampler( |
| | self.dataset, batch_size=5, source_ratio=[1, 4]) |
| | sampler = iter(sampler) |
| | flags = [] |
| | groups = [] |
| | for i in range(100): |
| | idx = next(sampler) |
| | data_info = self.dataset.get_data_info(idx) |
| | flags.append(data_info['flag']) |
| | group = 0 if data_info['width'] < data_info['height'] else 1 |
| | groups.append(group) |
| | flags_gt = ['a', 'b', 'b', 'b', 'b'] * 20 |
| | self.assertEqual(flags, flags_gt) |
| | groups = set( |
| | [sum(x) for x in (groups[k:k + 5] for k in range(0, 100, 5))]) |
| | groups_gt = set([0, 5]) |
| | self.assertEqual(groups, groups_gt) |
| |
|