File size: 4,114 Bytes
152f0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData

from mmpose.structures import MultilevelPixelData, PoseDataSample


class TestPoseDataSample(TestCase):

    def get_pose_data_sample(self, multilevel: bool = False):
        # meta
        pose_meta = dict(
            img_shape=(600, 900),  # [h, w, c]
            crop_size=(256, 192),  # [h, w]
            heatmap_size=(64, 48),  # [h, w]
        )
        # gt_instances
        gt_instances = InstanceData()
        gt_instances.bboxes = torch.rand(1, 4)
        gt_instances.keypoints = torch.rand(1, 17, 2)
        gt_instances.keypoints_visible = torch.rand(1, 17)

        # pred_instances
        pred_instances = InstanceData()
        pred_instances.keypoints = torch.rand(1, 17, 2)
        pred_instances.keypoint_scores = torch.rand(1, 17)

        # gt_fields
        if multilevel:
            # generate multilevel gt_fields
            metainfo = dict(num_keypoints=17)
            sizes = [(64, 48), (32, 24), (16, 12)]
            heatmaps = [np.random.rand(17, h, w) for h, w in sizes]
            masks = [torch.rand(1, h, w) for h, w in sizes]
            gt_fields = MultilevelPixelData(
                metainfo=metainfo, heatmaps=heatmaps, masks=masks)
        else:
            gt_fields = PixelData()
            gt_fields.heatmaps = torch.rand(17, 64, 48)

        # pred_fields
        pred_fields = PixelData()
        pred_fields.heatmaps = torch.rand(17, 64, 48)

        data_sample = PoseDataSample(
            gt_instances=gt_instances,
            pred_instances=pred_instances,
            gt_fields=gt_fields,
            pred_fields=pred_fields,
            metainfo=pose_meta)

        return data_sample

    @staticmethod
    def _equal(x, y):
        if type(x) != type(y):
            return False
        if isinstance(x, torch.Tensor):
            return torch.allclose(x, y)
        elif isinstance(x, np.ndarray):
            return np.allclose(x, y)
        else:
            return x == y

    def test_init(self):

        data_sample = self.get_pose_data_sample()
        self.assertIn('img_shape', data_sample)
        self.assertTrue(len(data_sample.gt_instances) == 1)

    def test_setter(self):

        data_sample = self.get_pose_data_sample()

        # test gt_instances
        data_sample.gt_instances = InstanceData()

        # test gt_fields
        data_sample.gt_fields = PixelData()

        # test multilevel gt_fields
        data_sample = self.get_pose_data_sample(multilevel=True)
        data_sample.gt_fields = MultilevelPixelData()

        # test pred_instances as pytorch tensor
        pred_instances_data = dict(
            keypoints=torch.rand(1, 17, 2), scores=torch.rand(1, 17, 1))
        data_sample.pred_instances = InstanceData(**pred_instances_data)

        self.assertTrue(
            self._equal(data_sample.pred_instances.keypoints,
                        pred_instances_data['keypoints']))
        self.assertTrue(
            self._equal(data_sample.pred_instances.scores,
                        pred_instances_data['scores']))

        # test pred_fields as numpy array
        pred_fields_data = dict(heatmaps=np.random.rand(17, 64, 48))
        data_sample.pred_fields = PixelData(**pred_fields_data)

        self.assertTrue(
            self._equal(data_sample.pred_fields.heatmaps,
                        pred_fields_data['heatmaps']))

        # test to_tensor
        data_sample = data_sample.to_tensor()
        self.assertTrue(
            self._equal(data_sample.pred_fields.heatmaps,
                        torch.from_numpy(pred_fields_data['heatmaps'])))

    def test_deleter(self):

        data_sample = self.get_pose_data_sample()

        for key in [
                'gt_instances',
                'pred_instances',
                'gt_fields',
                'pred_fields',
        ]:
            self.assertIn(key, data_sample)
            exec(f'del data_sample.{key}')
            self.assertNotIn(key, data_sample)