| |
| from unittest import TestCase |
|
|
| import numpy as np |
|
|
| from mmpose.evaluation.functional.nms import nearby_joints_nms |
|
|
|
|
| class TestNearbyJointsNMS(TestCase): |
|
|
| def test_nearby_joints_nms(self): |
|
|
| kpts_db = [] |
| keep_pose_inds = nearby_joints_nms( |
| kpts_db, 0.05, score_per_joint=True, max_dets=1) |
| self.assertEqual(len(keep_pose_inds), 0) |
|
|
| kpts_db = [] |
| for _ in range(5): |
| kpts_db.append( |
| dict(keypoints=np.random.rand(3, 2), score=np.random.rand(3))) |
| keep_pose_inds = nearby_joints_nms( |
| kpts_db, 0.05, score_per_joint=True, max_dets=1) |
| self.assertEqual(len(keep_pose_inds), 1) |
| self.assertLess(keep_pose_inds[0], 5) |
|
|
| kpts_db = [] |
| for _ in range(5): |
| kpts_db.append( |
| dict(keypoints=np.random.rand(3, 2), score=np.random.rand())) |
| keep_pose_inds = nearby_joints_nms( |
| kpts_db, 0.05, num_nearby_joints_thr=2) |
| self.assertLessEqual(len(keep_pose_inds), 5) |
| self.assertGreater(len(keep_pose_inds), 0) |
|
|
| with self.assertRaises(AssertionError): |
| _ = nearby_joints_nms(kpts_db, 0, num_nearby_joints_thr=2) |
|
|
| with self.assertRaises(AssertionError): |
| _ = nearby_joints_nms(kpts_db, 0.05, num_nearby_joints_thr=3) |
|
|