|
|
|
|
|
from itertools import product |
|
|
from unittest import TestCase |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from munkres import Munkres |
|
|
|
|
|
from mmpose.codecs import AssociativeEmbedding |
|
|
from mmpose.registry import KEYPOINT_CODECS |
|
|
from mmpose.testing import get_coco_sample |
|
|
|
|
|
|
|
|
class TestAssociativeEmbedding(TestCase): |
|
|
|
|
|
def setUp(self) -> None: |
|
|
self.decode_keypoint_order = [ |
|
|
0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16 |
|
|
] |
|
|
|
|
|
def test_build(self): |
|
|
cfg = dict( |
|
|
type='AssociativeEmbedding', |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=False, |
|
|
decode_keypoint_order=self.decode_keypoint_order, |
|
|
) |
|
|
codec = KEYPOINT_CODECS.build(cfg) |
|
|
self.assertIsInstance(codec, AssociativeEmbedding) |
|
|
|
|
|
def test_encode(self): |
|
|
data = get_coco_sample(img_shape=(256, 256), num_instances=1) |
|
|
|
|
|
|
|
|
codec = AssociativeEmbedding( |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=False, |
|
|
decode_keypoint_order=self.decode_keypoint_order) |
|
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible']) |
|
|
|
|
|
heatmaps = encoded['heatmaps'] |
|
|
keypoint_indices = encoded['keypoint_indices'] |
|
|
keypoint_weights = encoded['keypoint_weights'] |
|
|
|
|
|
self.assertEqual(heatmaps.shape, (17, 64, 64)) |
|
|
self.assertEqual(keypoint_indices.shape, (1, 17, 2)) |
|
|
self.assertEqual(keypoint_weights.shape, (1, 17)) |
|
|
|
|
|
for k in range(heatmaps.shape[0]): |
|
|
index_expected = np.argmax(heatmaps[k]) |
|
|
index_encoded = keypoint_indices[0, k, 0] |
|
|
self.assertEqual(index_expected, index_encoded) |
|
|
|
|
|
|
|
|
codec = AssociativeEmbedding( |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=True, |
|
|
decode_keypoint_order=self.decode_keypoint_order) |
|
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible']) |
|
|
|
|
|
heatmaps = encoded['heatmaps'] |
|
|
keypoint_indices = encoded['keypoint_indices'] |
|
|
keypoint_weights = encoded['keypoint_weights'] |
|
|
|
|
|
self.assertEqual(heatmaps.shape, (17, 64, 64)) |
|
|
self.assertEqual(keypoint_indices.shape, (1, 17, 2)) |
|
|
self.assertEqual(keypoint_weights.shape, (1, 17)) |
|
|
|
|
|
for k in range(heatmaps.shape[0]): |
|
|
index_expected = np.argmax(heatmaps[k]) |
|
|
index_encoded = keypoint_indices[0, k, 0] |
|
|
self.assertEqual(index_expected, index_encoded) |
|
|
|
|
|
def _get_tags(self, |
|
|
heatmaps, |
|
|
keypoint_indices, |
|
|
tag_per_keypoint: bool, |
|
|
tag_dim: int = 1): |
|
|
|
|
|
K, H, W = heatmaps.shape |
|
|
N = keypoint_indices.shape[0] |
|
|
|
|
|
if tag_per_keypoint: |
|
|
tags = np.zeros((K * tag_dim, H, W), dtype=np.float32) |
|
|
else: |
|
|
tags = np.zeros((tag_dim, H, W), dtype=np.float32) |
|
|
|
|
|
for n, k in product(range(N), range(K)): |
|
|
y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W)) |
|
|
if tag_per_keypoint: |
|
|
tags[k::K, y, x] = n |
|
|
else: |
|
|
tags[:, y, x] = n |
|
|
|
|
|
return tags |
|
|
|
|
|
def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt): |
|
|
"""Sort multi-instance predictions to best match the ground-truth. |
|
|
|
|
|
Args: |
|
|
keypoints_pred (np.ndarray): predictions in shape (N, K, D) |
|
|
scores (np.ndarray): predictions in shape (N, K) |
|
|
keypoints_gt (np.ndarray): ground-truth in shape (N, K, D) |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Sorted predictions |
|
|
""" |
|
|
assert keypoints_gt.shape == keypoints_pred.shape |
|
|
costs = np.linalg.norm( |
|
|
keypoints_gt[None] - keypoints_pred[:, None], ord=2, |
|
|
axis=3).mean(axis=2) |
|
|
match = Munkres().compute(costs) |
|
|
keypoints_pred_sorted = np.zeros_like(keypoints_pred) |
|
|
scores_pred_sorted = np.zeros_like(scores_pred) |
|
|
for i, j in match: |
|
|
keypoints_pred_sorted[i] = keypoints_pred[j] |
|
|
scores_pred_sorted[i] = scores_pred[j] |
|
|
|
|
|
return keypoints_pred_sorted, scores_pred_sorted |
|
|
|
|
|
def test_decode(self): |
|
|
data = get_coco_sample( |
|
|
img_shape=(256, 256), num_instances=2, non_occlusion=True) |
|
|
|
|
|
|
|
|
codec = AssociativeEmbedding( |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=False, |
|
|
decode_keypoint_order=self.decode_keypoint_order) |
|
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible']) |
|
|
|
|
|
heatmaps = encoded['heatmaps'] |
|
|
keypoint_indices = encoded['keypoint_indices'] |
|
|
|
|
|
tags = self._get_tags( |
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True) |
|
|
|
|
|
|
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None]) |
|
|
batch_tags = torch.from_numpy(tags[None]) |
|
|
|
|
|
batch_keypoints, batch_keypoint_scores = codec.batch_decode( |
|
|
batch_heatmaps, batch_tags) |
|
|
|
|
|
self.assertIsInstance(batch_keypoints, list) |
|
|
self.assertIsInstance(batch_keypoint_scores, list) |
|
|
self.assertEqual(len(batch_keypoints), 1) |
|
|
self.assertEqual(len(batch_keypoint_scores), 1) |
|
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0], |
|
|
batch_keypoint_scores[0], |
|
|
data['keypoints']) |
|
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray) |
|
|
self.assertIsInstance(scores, np.ndarray) |
|
|
self.assertEqual(keypoints.shape, (2, 17, 2)) |
|
|
self.assertEqual(scores.shape, (2, 17)) |
|
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) |
|
|
|
|
|
|
|
|
codec = AssociativeEmbedding( |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=False, |
|
|
decode_keypoint_order=self.decode_keypoint_order) |
|
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible']) |
|
|
|
|
|
heatmaps = encoded['heatmaps'] |
|
|
keypoint_indices = encoded['keypoint_indices'] |
|
|
|
|
|
tags = self._get_tags( |
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True, tag_dim=2) |
|
|
|
|
|
|
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None]) |
|
|
batch_tags = torch.from_numpy(tags[None]) |
|
|
|
|
|
batch_keypoints, batch_keypoint_scores = codec.batch_decode( |
|
|
batch_heatmaps, batch_tags) |
|
|
|
|
|
self.assertIsInstance(batch_keypoints, list) |
|
|
self.assertIsInstance(batch_keypoint_scores, list) |
|
|
self.assertEqual(len(batch_keypoints), 1) |
|
|
self.assertEqual(len(batch_keypoint_scores), 1) |
|
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0], |
|
|
batch_keypoint_scores[0], |
|
|
data['keypoints']) |
|
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray) |
|
|
self.assertIsInstance(scores, np.ndarray) |
|
|
self.assertEqual(keypoints.shape, (2, 17, 2)) |
|
|
self.assertEqual(scores.shape, (2, 17)) |
|
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) |
|
|
|
|
|
|
|
|
codec = AssociativeEmbedding( |
|
|
input_size=(256, 256), |
|
|
heatmap_size=(64, 64), |
|
|
use_udp=True, |
|
|
decode_keypoint_order=self.decode_keypoint_order) |
|
|
|
|
|
encoded = codec.encode(data['keypoints'], data['keypoints_visible']) |
|
|
|
|
|
heatmaps = encoded['heatmaps'] |
|
|
keypoint_indices = encoded['keypoint_indices'] |
|
|
|
|
|
tags = self._get_tags( |
|
|
heatmaps, keypoint_indices, tag_per_keypoint=True) |
|
|
|
|
|
|
|
|
batch_heatmaps = torch.from_numpy(heatmaps[None]) |
|
|
batch_tags = torch.from_numpy(tags[None]) |
|
|
|
|
|
batch_keypoints, batch_keypoint_scores = codec.batch_decode( |
|
|
batch_heatmaps, batch_tags) |
|
|
|
|
|
self.assertIsInstance(batch_keypoints, list) |
|
|
self.assertIsInstance(batch_keypoint_scores, list) |
|
|
self.assertEqual(len(batch_keypoints), 1) |
|
|
self.assertEqual(len(batch_keypoint_scores), 1) |
|
|
|
|
|
keypoints, scores = self._sort_preds(batch_keypoints[0], |
|
|
batch_keypoint_scores[0], |
|
|
data['keypoints']) |
|
|
|
|
|
self.assertIsInstance(keypoints, np.ndarray) |
|
|
self.assertIsInstance(scores, np.ndarray) |
|
|
self.assertEqual(keypoints.shape, (2, 17, 2)) |
|
|
self.assertEqual(scores.shape, (2, 17)) |
|
|
|
|
|
self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) |
|
|
|