File size: 1,796 Bytes
0efc562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase

import numpy as np

from mmpose.evaluation.functional import (transform_ann, transform_pred,
                                          transform_sigmas)


class TestKeypointEval(TestCase):

    def test_transform_sigmas(self):

        mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
        num_keypoints = 5
        sigmas = np.random.rand(17)
        new_sigmas = transform_sigmas(sigmas, num_keypoints, mapping)
        self.assertEqual(len(new_sigmas), 5)
        for i, j in mapping:
            self.assertEqual(sigmas[i], new_sigmas[j])

    def test_transform_ann(self):
        mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
        num_keypoints = 5

        kpt_info = dict(
            num_keypoints=17,
            keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
        kpt_info_copy = deepcopy(kpt_info)

        _ = transform_ann(kpt_info, num_keypoints, mapping)

        self.assertEqual(kpt_info['num_keypoints'], 5)
        self.assertEqual(len(kpt_info['keypoints']), 15)
        for i, j in mapping:
            self.assertListEqual(kpt_info_copy['keypoints'][i * 3:i * 3 + 3],
                                 kpt_info['keypoints'][j * 3:j * 3 + 3])

    def test_transform_pred(self):
        mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
        num_keypoints = 5

        kpt_info = dict(
            num_keypoints=17,
            keypoints=np.random.randint(3, size=(
                1,
                17,
                3,
            )),
            keypoint_scores=np.ones((1, 17)))

        _ = transform_pred(kpt_info, num_keypoints, mapping)

        self.assertEqual(kpt_info['num_keypoints'], 5)
        self.assertEqual(len(kpt_info['keypoints']), 1)