|
|
|
|
|
from unittest import TestCase |
|
|
|
|
|
from mmpose.datasets.transforms import KeypointConverter |
|
|
from mmpose.testing import get_coco_sample |
|
|
|
|
|
|
|
|
class TestKeypointConverter(TestCase): |
|
|
|
|
|
def setUp(self): |
|
|
|
|
|
self.data_info = get_coco_sample( |
|
|
img_shape=(240, 320), num_instances=4, with_bbox_cs=True) |
|
|
|
|
|
def test_transform(self): |
|
|
|
|
|
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)] |
|
|
transform = KeypointConverter(num_keypoints=5, mapping=mapping) |
|
|
results = transform(self.data_info.copy()) |
|
|
|
|
|
|
|
|
self.assertEqual(results['keypoints'].shape[0], |
|
|
self.data_info['keypoints'].shape[0]) |
|
|
self.assertEqual(results['keypoints'].shape[1], 5) |
|
|
self.assertEqual(results['keypoints'].shape[2], 2) |
|
|
self.assertEqual(results['keypoints_visible'].shape[0], |
|
|
self.data_info['keypoints_visible'].shape[0]) |
|
|
self.assertEqual(results['keypoints_visible'].shape[1], 5) |
|
|
|
|
|
|
|
|
for source_index, target_index in mapping: |
|
|
self.assertTrue((results['keypoints'][:, target_index] == |
|
|
self.data_info['keypoints'][:, |
|
|
source_index]).all()) |
|
|
self.assertTrue( |
|
|
(results['keypoints_visible'][:, target_index] == |
|
|
self.data_info['keypoints_visible'][:, source_index]).all()) |
|
|
|
|
|
|
|
|
mapping = [((3, 5), 0), (6, 1), (16, 2), (5, 3)] |
|
|
transform = KeypointConverter(num_keypoints=5, mapping=mapping) |
|
|
results = transform(self.data_info.copy()) |
|
|
|
|
|
|
|
|
self.assertEqual(results['keypoints'].shape[0], |
|
|
self.data_info['keypoints'].shape[0]) |
|
|
self.assertEqual(results['keypoints'].shape[1], 5) |
|
|
self.assertEqual(results['keypoints'].shape[2], 2) |
|
|
self.assertEqual(results['keypoints_visible'].shape[0], |
|
|
self.data_info['keypoints_visible'].shape[0]) |
|
|
self.assertEqual(results['keypoints_visible'].shape[1], 5) |
|
|
|
|
|
|
|
|
for source_index, target_index in mapping: |
|
|
if isinstance(source_index, tuple): |
|
|
source_index, source_index2 = source_index |
|
|
self.assertTrue( |
|
|
(results['keypoints'][:, target_index] == 0.5 * |
|
|
(self.data_info['keypoints'][:, source_index] + |
|
|
self.data_info['keypoints'][:, source_index2])).all()) |
|
|
self.assertTrue( |
|
|
(results['keypoints_visible'][:, target_index] == |
|
|
self.data_info['keypoints_visible'][:, source_index] * |
|
|
self.data_info['keypoints_visible'][:, |
|
|
source_index2]).all()) |
|
|
else: |
|
|
self.assertTrue( |
|
|
(results['keypoints'][:, target_index] == |
|
|
self.data_info['keypoints'][:, source_index]).all()) |
|
|
self.assertTrue( |
|
|
(results['keypoints_visible'][:, target_index] == |
|
|
self.data_info['keypoints_visible'][:, |
|
|
source_index]).all()) |
|
|
|