| |
| import os.path as osp |
| from unittest import TestCase |
|
|
| from mmdet.datasets import ReIDDataset |
|
|
| PREFIX = osp.join(osp.dirname(__file__), '../data') |
| |
| REID_ANN_FILE = f'{PREFIX}/demo_reid_data/mot17_reid/ann.txt' |
|
|
|
|
| class TestReIDDataset(TestCase): |
|
|
| @classmethod |
| def setUpClass(cls): |
| cls.num_ids = 8 |
| cls.ins_per_id = 4 |
| cls.dataset = ReIDDataset( |
| pipeline=[], ann_file=REID_ANN_FILE, data_prefix=dict(img_path='')) |
| cls.dataset_triplet = ReIDDataset( |
| pipeline=[], |
| triplet_sampler=dict( |
| num_ids=cls.num_ids, ins_per_id=cls.ins_per_id), |
| ann_file=REID_ANN_FILE, |
| data_prefix=dict(img_path='')) |
|
|
| def test_get_data_info(self): |
| |
| img_id = 0 |
| data_list = [ |
| self.dataset.get_data_info(i) for i in range(len(self.dataset)) |
| ] |
| assert len([ |
| data_info for data_info in data_list |
| if data_info['gt_label'] == img_id |
| ]) == 21 |
| |
| img_id = 11 |
| assert len([ |
| data_info for data_info in data_list |
| if data_info['gt_label'] == img_id |
| ]) == 0 |
|
|
| def test_len(self): |
| assert len(self.dataset) == 704 |
| assert len(self.dataset_triplet) == 704 |
|
|
| def test_getitem(self): |
| for i in range(len(self.dataset)): |
| results = self.dataset[i] |
| assert isinstance(results, dict) |
| assert 'img_path' in results |
| assert 'gt_label' in results |
| for i in range(len(self.dataset_triplet)): |
| num = self.num_ids * self.ins_per_id |
| results = self.dataset_triplet[i] |
| assert isinstance(results, dict) |
| assert len(results['img_path']) == num |
| assert 'img_path' in results |
| assert 'gt_label' in results |
| for idx in range(num - 1): |
| if (idx + 1) % self.ins_per_id != 0: |
| assert results['gt_label'][idx] == \ |
| results['gt_label'][idx + 1] |
|
|