Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import shutil | |
| import time | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock | |
| import numpy as np | |
| from mmengine.structures import InstanceData | |
| from mmpose.engine.hooks import PoseVisualizationHook | |
| from mmpose.structures import PoseDataSample | |
| from mmpose.visualization import PoseLocalVisualizer | |
| def _rand_poses(num_boxes, h, w): | |
| center = np.random.rand(num_boxes, 2) | |
| offset = np.random.rand(num_boxes, 5, 2) / 2.0 | |
| pose = center[:, None, :] + offset.clip(0, 1) | |
| pose[:, :, 0] *= w | |
| pose[:, :, 1] *= h | |
| return pose | |
| class TestVisualizationHook(TestCase): | |
| def setUp(self) -> None: | |
| PoseLocalVisualizer.get_instance('test_visualization_hook') | |
| data_sample = PoseDataSample() | |
| data_sample.set_metainfo({ | |
| 'img_path': | |
| osp.join( | |
| osp.dirname(__file__), '../../data/coco/000000000785.jpg') | |
| }) | |
| self.data_batch = {'data_samples': [data_sample] * 2} | |
| pred_instances = InstanceData() | |
| pred_instances.keypoints = _rand_poses(5, 10, 12) | |
| pred_instances.score = np.random.rand(5, 5) | |
| pred_det_data_sample = data_sample.clone() | |
| pred_det_data_sample.pred_instances = pred_instances | |
| self.outputs = [pred_det_data_sample] * 2 | |
| def test_after_val_iter(self): | |
| runner = MagicMock() | |
| runner.iter = 1 | |
| runner.val_evaluator.dataset_meta = dict() | |
| hook = PoseVisualizationHook(interval=1, enable=True) | |
| hook.after_val_iter(runner, 1, self.data_batch, self.outputs) | |
| def test_after_test_iter(self): | |
| runner = MagicMock() | |
| runner.iter = 1 | |
| hook = PoseVisualizationHook(enable=True) | |
| hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | |
| self.assertEqual(hook._test_index, 2) | |
| # test | |
| timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
| out_dir = timestamp + '1' | |
| runner.work_dir = timestamp | |
| runner.timestamp = '1' | |
| hook = PoseVisualizationHook(enable=False, out_dir=out_dir) | |
| hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | |
| self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}')) | |
| hook = PoseVisualizationHook(enable=True, out_dir=out_dir) | |
| hook.after_test_iter(runner, 1, self.data_batch, self.outputs) | |
| self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) | |
| shutil.rmtree(f'{timestamp}') | |