Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from itertools import product | |
| from unittest import TestCase | |
| import numpy as np | |
| import torch | |
| from munkres import Munkres | |
| from mmpose.codecs import AssociativeEmbedding | |
| from mmpose.registry import KEYPOINT_CODECS | |
| from mmpose.testing import get_coco_sample | |
| class TestAssociativeEmbedding(TestCase): | |
| def setUp(self) -> None: | |
| self.decode_keypoint_order = [ | |
| 0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16 | |
| ] | |
| def test_build(self): | |
| cfg = dict( | |
| type='AssociativeEmbedding', | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=False, | |
| decode_keypoint_order=self.decode_keypoint_order, | |
| ) | |
| codec = KEYPOINT_CODECS.build(cfg) | |
| self.assertIsInstance(codec, AssociativeEmbedding) | |
| def test_encode(self): | |
| data = get_coco_sample(img_shape=(256, 256), num_instances=1) | |
| # w/o UDP | |
| codec = AssociativeEmbedding( | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=False, | |
| decode_keypoint_order=self.decode_keypoint_order) | |
| encoded = codec.encode(data['keypoints'], data['keypoints_visible']) | |
| heatmaps = encoded['heatmaps'] | |
| keypoint_indices = encoded['keypoint_indices'] | |
| keypoint_weights = encoded['keypoint_weights'] | |
| self.assertEqual(heatmaps.shape, (17, 64, 64)) | |
| self.assertEqual(keypoint_indices.shape, (1, 17, 2)) | |
| self.assertEqual(keypoint_weights.shape, (1, 17)) | |
| for k in range(heatmaps.shape[0]): | |
| index_expected = np.argmax(heatmaps[k]) | |
| index_encoded = keypoint_indices[0, k, 0] | |
| self.assertEqual(index_expected, index_encoded) | |
| # w/ UDP | |
| codec = AssociativeEmbedding( | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=True, | |
| decode_keypoint_order=self.decode_keypoint_order) | |
| encoded = codec.encode(data['keypoints'], data['keypoints_visible']) | |
| heatmaps = encoded['heatmaps'] | |
| keypoint_indices = encoded['keypoint_indices'] | |
| keypoint_weights = encoded['keypoint_weights'] | |
| self.assertEqual(heatmaps.shape, (17, 64, 64)) | |
| self.assertEqual(keypoint_indices.shape, (1, 17, 2)) | |
| self.assertEqual(keypoint_weights.shape, (1, 17)) | |
| for k in range(heatmaps.shape[0]): | |
| index_expected = np.argmax(heatmaps[k]) | |
| index_encoded = keypoint_indices[0, k, 0] | |
| self.assertEqual(index_expected, index_encoded) | |
| def _get_tags(self, | |
| heatmaps, | |
| keypoint_indices, | |
| tag_per_keypoint: bool, | |
| tag_dim: int = 1): | |
| K, H, W = heatmaps.shape | |
| N = keypoint_indices.shape[0] | |
| if tag_per_keypoint: | |
| tags = np.zeros((K * tag_dim, H, W), dtype=np.float32) | |
| else: | |
| tags = np.zeros((tag_dim, H, W), dtype=np.float32) | |
| for n, k in product(range(N), range(K)): | |
| y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W)) | |
| if tag_per_keypoint: | |
| tags[k::K, y, x] = n | |
| else: | |
| tags[:, y, x] = n | |
| return tags | |
| def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt): | |
| """Sort multi-instance predictions to best match the ground-truth. | |
| Args: | |
| keypoints_pred (np.ndarray): predictions in shape (N, K, D) | |
| scores (np.ndarray): predictions in shape (N, K) | |
| keypoints_gt (np.ndarray): ground-truth in shape (N, K, D) | |
| Returns: | |
| np.ndarray: Sorted predictions | |
| """ | |
| assert keypoints_gt.shape == keypoints_pred.shape | |
| costs = np.linalg.norm( | |
| keypoints_gt[None] - keypoints_pred[:, None], ord=2, | |
| axis=3).mean(axis=2) | |
| match = Munkres().compute(costs) | |
| keypoints_pred_sorted = np.zeros_like(keypoints_pred) | |
| scores_pred_sorted = np.zeros_like(scores_pred) | |
| for i, j in match: | |
| keypoints_pred_sorted[i] = keypoints_pred[j] | |
| scores_pred_sorted[i] = scores_pred[j] | |
| return keypoints_pred_sorted, scores_pred_sorted | |
| def test_decode(self): | |
| data = get_coco_sample( | |
| img_shape=(256, 256), num_instances=2, non_occlusion=True) | |
| # w/o UDP | |
| codec = AssociativeEmbedding( | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=False, | |
| decode_keypoint_order=self.decode_keypoint_order) | |
| encoded = codec.encode(data['keypoints'], data['keypoints_visible']) | |
| heatmaps = encoded['heatmaps'] | |
| keypoint_indices = encoded['keypoint_indices'] | |
| tags = self._get_tags( | |
| heatmaps, keypoint_indices, tag_per_keypoint=True) | |
| # to Tensor | |
| batch_heatmaps = torch.from_numpy(heatmaps[None]) | |
| batch_tags = torch.from_numpy(tags[None]) | |
| batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ | |
| codec.batch_decode(batch_heatmaps, batch_tags) | |
| self.assertIsInstance(batch_keypoints, list) | |
| self.assertIsInstance(batch_keypoint_scores, list) | |
| self.assertEqual(len(batch_keypoints), 1) | |
| self.assertEqual(len(batch_keypoint_scores), 1) | |
| keypoints, scores = self._sort_preds(batch_keypoints[0], | |
| batch_keypoint_scores[0], | |
| data['keypoints']) | |
| self.assertIsInstance(keypoints, np.ndarray) | |
| self.assertIsInstance(scores, np.ndarray) | |
| self.assertEqual(keypoints.shape, (2, 17, 2)) | |
| self.assertEqual(scores.shape, (2, 17)) | |
| self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) | |
| # w/o UDP, tag_imd=2 | |
| codec = AssociativeEmbedding( | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=False, | |
| decode_keypoint_order=self.decode_keypoint_order) | |
| encoded = codec.encode(data['keypoints'], data['keypoints_visible']) | |
| heatmaps = encoded['heatmaps'] | |
| keypoint_indices = encoded['keypoint_indices'] | |
| tags = self._get_tags( | |
| heatmaps, keypoint_indices, tag_per_keypoint=True, tag_dim=2) | |
| # to Tensor | |
| batch_heatmaps = torch.from_numpy(heatmaps[None]) | |
| batch_tags = torch.from_numpy(tags[None]) | |
| batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ | |
| codec.batch_decode(batch_heatmaps, batch_tags) | |
| self.assertIsInstance(batch_keypoints, list) | |
| self.assertIsInstance(batch_keypoint_scores, list) | |
| self.assertEqual(len(batch_keypoints), 1) | |
| self.assertEqual(len(batch_keypoint_scores), 1) | |
| keypoints, scores = self._sort_preds(batch_keypoints[0], | |
| batch_keypoint_scores[0], | |
| data['keypoints']) | |
| self.assertIsInstance(keypoints, np.ndarray) | |
| self.assertIsInstance(scores, np.ndarray) | |
| self.assertEqual(keypoints.shape, (2, 17, 2)) | |
| self.assertEqual(scores.shape, (2, 17)) | |
| self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) | |
| # w/ UDP | |
| codec = AssociativeEmbedding( | |
| input_size=(256, 256), | |
| heatmap_size=(64, 64), | |
| use_udp=True, | |
| decode_keypoint_order=self.decode_keypoint_order) | |
| encoded = codec.encode(data['keypoints'], data['keypoints_visible']) | |
| heatmaps = encoded['heatmaps'] | |
| keypoint_indices = encoded['keypoint_indices'] | |
| tags = self._get_tags( | |
| heatmaps, keypoint_indices, tag_per_keypoint=True) | |
| # to Tensor | |
| batch_heatmaps = torch.from_numpy(heatmaps[None]) | |
| batch_tags = torch.from_numpy(tags[None]) | |
| batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ | |
| codec.batch_decode(batch_heatmaps, batch_tags) | |
| self.assertIsInstance(batch_keypoints, list) | |
| self.assertIsInstance(batch_keypoint_scores, list) | |
| self.assertEqual(len(batch_keypoints), 1) | |
| self.assertEqual(len(batch_keypoint_scores), 1) | |
| keypoints, scores = self._sort_preds(batch_keypoints[0], | |
| batch_keypoint_scores[0], | |
| data['keypoints']) | |
| self.assertIsInstance(keypoints, np.ndarray) | |
| self.assertIsInstance(scores, np.ndarray) | |
| self.assertEqual(keypoints.shape, (2, 17, 2)) | |
| self.assertEqual(scores.shape, (2, 17)) | |
| self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) | |