dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
from mmpose.evaluation.functional.nms import nearby_joints_nms
class TestNearbyJointsNMS(TestCase):
def test_nearby_joints_nms(self):
kpts_db = []
keep_pose_inds = nearby_joints_nms(
kpts_db, 0.05, score_per_joint=True, max_dets=1)
self.assertEqual(len(keep_pose_inds), 0)
kpts_db = []
for _ in range(5):
kpts_db.append(
dict(keypoints=np.random.rand(3, 2), score=np.random.rand(3)))
keep_pose_inds = nearby_joints_nms(
kpts_db, 0.05, score_per_joint=True, max_dets=1)
self.assertEqual(len(keep_pose_inds), 1)
self.assertLess(keep_pose_inds[0], 5)
kpts_db = []
for _ in range(5):
kpts_db.append(
dict(keypoints=np.random.rand(3, 2), score=np.random.rand()))
keep_pose_inds = nearby_joints_nms(
kpts_db, 0.05, num_nearby_joints_thr=2)
self.assertLessEqual(len(keep_pose_inds), 5)
self.assertGreater(len(keep_pose_inds), 0)
with self.assertRaises(AssertionError):
_ = nearby_joints_nms(kpts_db, 0, num_nearby_joints_thr=2)
with self.assertRaises(AssertionError):
_ = nearby_joints_nms(kpts_db, 0.05, num_nearby_joints_thr=3)