File size: 2,507 Bytes
152f0f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import time
from unittest import TestCase
from unittest.mock import MagicMock

import numpy as np
from mmengine.structures import InstanceData

from mmpose.engine.hooks import PoseVisualizationHook
from mmpose.structures import PoseDataSample
from mmpose.visualization import PoseLocalVisualizer


def _rand_poses(num_boxes, h, w):
    center = np.random.rand(num_boxes, 2)
    offset = np.random.rand(num_boxes, 5, 2) / 2.0

    pose = center[:, None, :] + offset.clip(0, 1)
    pose[:, :, 0] *= w
    pose[:, :, 1] *= h

    return pose


class TestVisualizationHook(TestCase):

    def setUp(self) -> None:
        PoseLocalVisualizer.get_instance('test_visualization_hook')

        data_sample = PoseDataSample()
        data_sample.set_metainfo({
            'img_path':
            osp.join(
                osp.dirname(__file__), '../../data/coco/000000000785.jpg')
        })
        self.data_batch = {'data_samples': [data_sample] * 2}

        pred_instances = InstanceData()
        pred_instances.keypoints = _rand_poses(5, 10, 12)
        pred_instances.score = np.random.rand(5, 5)
        pred_det_data_sample = data_sample.clone()
        pred_det_data_sample.pred_instances = pred_instances
        self.outputs = [pred_det_data_sample] * 2

    def test_after_val_iter(self):
        runner = MagicMock()
        runner.iter = 1
        runner.val_evaluator.dataset_meta = dict()
        hook = PoseVisualizationHook(interval=1, enable=True)
        hook.after_val_iter(runner, 1, self.data_batch, self.outputs)

    def test_after_test_iter(self):
        runner = MagicMock()
        runner.iter = 1
        hook = PoseVisualizationHook(enable=True)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
        self.assertEqual(hook._test_index, 2)

        # test
        timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
        out_dir = timestamp + '1'
        runner.work_dir = timestamp
        runner.timestamp = '1'
        hook = PoseVisualizationHook(enable=False, out_dir=out_dir)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
        self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}'))

        hook = PoseVisualizationHook(enable=True, out_dir=out_dir)
        hook.after_test_iter(runner, 1, self.data_batch, self.outputs)
        self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}'))
        shutil.rmtree(f'{timestamp}')