|
|
|
|
|
import os |
|
|
from unittest import TestCase |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from mmengine.structures import InstanceData, PixelData |
|
|
|
|
|
from mmpose.structures import PoseDataSample |
|
|
from mmpose.visualization import PoseLocalVisualizer |
|
|
|
|
|
|
|
|
class TestPoseLocalVisualizer(TestCase): |
|
|
|
|
|
def setUp(self): |
|
|
self.visualizer = PoseLocalVisualizer(show_keypoint_weight=True) |
|
|
|
|
|
def _get_dataset_meta(self): |
|
|
|
|
|
pose_kpt_color = [None] + [(127, 127, 127)] * 2 + ['red'] |
|
|
pose_link_color = [(127, 127, 127)] * 2 + [None] |
|
|
skeleton_links = [[0, 1], [1, 2], [2, 3]] |
|
|
return { |
|
|
'keypoint_colors': pose_kpt_color, |
|
|
'skeleton_link_colors': pose_link_color, |
|
|
'skeleton_links': skeleton_links |
|
|
} |
|
|
|
|
|
def test_set_dataset_meta(self): |
|
|
dataset_meta = self._get_dataset_meta() |
|
|
self.visualizer.set_dataset_meta(dataset_meta) |
|
|
self.assertEqual(len(self.visualizer.kpt_color), 4) |
|
|
self.assertEqual(self.visualizer.kpt_color[-1], 'red') |
|
|
self.assertListEqual(self.visualizer.skeleton[-1], [2, 3]) |
|
|
|
|
|
self.visualizer.dataset_meta = None |
|
|
self.visualizer.set_dataset_meta(dataset_meta) |
|
|
self.assertIsNotNone(self.visualizer.dataset_meta) |
|
|
|
|
|
def test_add_datasample(self): |
|
|
h, w = 100, 100 |
|
|
image = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
out_file = 'out_file.jpg' |
|
|
|
|
|
dataset_meta = self._get_dataset_meta() |
|
|
self.visualizer.set_dataset_meta(dataset_meta) |
|
|
|
|
|
|
|
|
gt_instances = InstanceData() |
|
|
gt_instances.keypoints = np.array([[[1, 1], [20, 20], [40, 40], |
|
|
[80, 80]]], |
|
|
dtype=np.float32) |
|
|
|
|
|
|
|
|
gt_instances.bboxes = np.array([[20, 30, 50, 70]]) |
|
|
|
|
|
|
|
|
heatmap = torch.randn(10, 100, 100) * 0.05 |
|
|
for i in range(10): |
|
|
heatmap[i][i * 10:(i + 1) * 10, i * 10:(i + 1) * 10] += 5 |
|
|
gt_heatmap = PixelData() |
|
|
gt_heatmap.heatmaps = heatmap |
|
|
|
|
|
|
|
|
pred_pose_data_sample = PoseDataSample() |
|
|
pred_pose_data_sample.gt_instances = gt_instances |
|
|
pred_pose_data_sample.gt_fields = gt_heatmap |
|
|
pred_instances = gt_instances.clone() |
|
|
pred_instances.scores = np.array([[0.9, 0.4, 1.7, -0.2]], |
|
|
dtype=np.float32) |
|
|
pred_pose_data_sample.pred_instances = pred_instances |
|
|
|
|
|
self.visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
data_sample=pred_pose_data_sample, |
|
|
draw_bbox=True, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, (h, w * 2, 3)) |
|
|
|
|
|
self.visualizer.show_keypoint_weight = False |
|
|
self.visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
data_sample=pred_pose_data_sample, |
|
|
draw_pred=False, |
|
|
draw_heatmap=True, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, ((h * 2), w, 3)) |
|
|
|
|
|
self.visualizer.add_datasample( |
|
|
'image', |
|
|
image, |
|
|
data_sample=pred_pose_data_sample, |
|
|
draw_heatmap=True, |
|
|
out_file=out_file) |
|
|
self._assert_image_and_shape(out_file, ((h * 2), (w * 2), 3)) |
|
|
|
|
|
def test_simcc_visualization(self): |
|
|
img = np.zeros((512, 512, 3), dtype=np.uint8) |
|
|
heatmap = torch.randn([17, 512, 512]) |
|
|
pixelData = PixelData() |
|
|
pixelData.heatmaps = heatmap |
|
|
self.visualizer._draw_instance_xy_heatmap(pixelData, img, 10) |
|
|
|
|
|
def _assert_image_and_shape(self, out_file, out_shape): |
|
|
self.assertTrue(os.path.exists(out_file)) |
|
|
drawn_img = cv2.imread(out_file) |
|
|
self.assertTupleEqual(drawn_img.shape, out_shape) |
|
|
os.remove(out_file) |
|
|
|