File size: 3,957 Bytes
152f0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
import os
from unittest import TestCase

import cv2
import numpy as np
import torch
from mmengine.structures import InstanceData, PixelData

from mmpose.structures import PoseDataSample
from mmpose.visualization import PoseLocalVisualizer


class TestPoseLocalVisualizer(TestCase):

    def setUp(self):
        self.visualizer = PoseLocalVisualizer(show_keypoint_weight=True)

    def _get_dataset_meta(self):
        # None: kpt or link is hidden
        pose_kpt_color = [None] + [(127, 127, 127)] * 2 + ['red']
        pose_link_color = [(127, 127, 127)] * 2 + [None]
        skeleton_links = [[0, 1], [1, 2], [2, 3]]
        return {
            'keypoint_colors': pose_kpt_color,
            'skeleton_link_colors': pose_link_color,
            'skeleton_links': skeleton_links
        }

    def test_set_dataset_meta(self):
        dataset_meta = self._get_dataset_meta()
        self.visualizer.set_dataset_meta(dataset_meta)
        self.assertEqual(len(self.visualizer.kpt_color), 4)
        self.assertEqual(self.visualizer.kpt_color[-1], 'red')
        self.assertListEqual(self.visualizer.skeleton[-1], [2, 3])

        self.visualizer.dataset_meta = None
        self.visualizer.set_dataset_meta(dataset_meta)
        self.assertIsNotNone(self.visualizer.dataset_meta)

    def test_add_datasample(self):
        h, w = 100, 100
        image = np.zeros((h, w, 3), dtype=np.uint8)
        out_file = 'out_file.jpg'

        dataset_meta = self._get_dataset_meta()
        self.visualizer.set_dataset_meta(dataset_meta)

        # setting keypoints
        gt_instances = InstanceData()
        gt_instances.keypoints = np.array([[[1, 1], [20, 20], [40, 40],
                                            [80, 80]]],
                                          dtype=np.float32)

        # setting bounding box
        gt_instances.bboxes = np.array([[20, 30, 50, 70]])

        # setting heatmap
        heatmap = torch.randn(10, 100, 100) * 0.05
        for i in range(10):
            heatmap[i][i * 10:(i + 1) * 10, i * 10:(i + 1) * 10] += 5
        gt_heatmap = PixelData()
        gt_heatmap.heatmaps = heatmap

        # test gt_sample
        pred_pose_data_sample = PoseDataSample()
        pred_pose_data_sample.gt_instances = gt_instances
        pred_pose_data_sample.gt_fields = gt_heatmap
        pred_instances = gt_instances.clone()
        pred_instances.scores = np.array([[0.9, 0.4, 1.7, -0.2]],
                                         dtype=np.float32)
        pred_pose_data_sample.pred_instances = pred_instances

        self.visualizer.add_datasample(
            'image',
            image,
            data_sample=pred_pose_data_sample,
            draw_bbox=True,
            out_file=out_file)
        self._assert_image_and_shape(out_file, (h, w * 2, 3))

        self.visualizer.show_keypoint_weight = False
        self.visualizer.add_datasample(
            'image',
            image,
            data_sample=pred_pose_data_sample,
            draw_pred=False,
            draw_heatmap=True,
            out_file=out_file)
        self._assert_image_and_shape(out_file, ((h * 2), w, 3))

        self.visualizer.add_datasample(
            'image',
            image,
            data_sample=pred_pose_data_sample,
            draw_heatmap=True,
            out_file=out_file)
        self._assert_image_and_shape(out_file, ((h * 2), (w * 2), 3))

    def test_simcc_visualization(self):
        img = np.zeros((512, 512, 3), dtype=np.uint8)
        heatmap = torch.randn([17, 512, 512])
        pixelData = PixelData()
        pixelData.heatmaps = heatmap
        self.visualizer._draw_instance_xy_heatmap(pixelData, img, 10)

    def _assert_image_and_shape(self, out_file, out_shape):
        self.assertTrue(os.path.exists(out_file))
        drawn_img = cv2.imread(out_file)
        self.assertTupleEqual(drawn_img.shape, out_shape)
        os.remove(out_file)