Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from copy import deepcopy | |
| from typing import Dict | |
| import numpy as np | |
| from mmcv.transforms import BaseTransform | |
| from mmpose.registry import TRANSFORMS | |
| from mmpose.structures.keypoint import flip_keypoints_custom_center | |
| class RandomFlipAroundRoot(BaseTransform): | |
| """Data augmentation with random horizontal joint flip around a root joint. | |
| Args: | |
| keypoints_flip_cfg (dict): Configurations of the | |
| ``flip_keypoints_custom_center`` function for ``keypoints``. Please | |
| refer to the docstring of the ``flip_keypoints_custom_center`` | |
| function for more details. | |
| target_flip_cfg (dict): Configurations of the | |
| ``flip_keypoints_custom_center`` function for ``lifting_target``. | |
| Please refer to the docstring of the | |
| ``flip_keypoints_custom_center`` function for more details. | |
| flip_prob (float): Probability of flip. Default: 0.5. | |
| flip_camera (bool): Whether to flip horizontal distortion coefficients. | |
| Default: ``False``. | |
| flip_label (bool): Whether to flip labels instead of data. | |
| Default: ``False``. | |
| Required keys: | |
| - keypoints or keypoint_labels | |
| - lifting_target or lifting_target_label | |
| - keypoints_visible or keypoint_labels_visible (optional) | |
| - lifting_target_visible (optional) | |
| - flip_indices (optional) | |
| Modified keys: | |
| - keypoints or keypoint_labels (optional) | |
| - keypoints_visible or keypoint_labels_visible (optional) | |
| - lifting_target or lifting_target_label (optional) | |
| - lifting_target_visible (optional) | |
| - camera_param (optional) | |
| """ | |
| def __init__(self, | |
| keypoints_flip_cfg: dict, | |
| target_flip_cfg: dict, | |
| flip_prob: float = 0.5, | |
| flip_camera: bool = False, | |
| flip_label: bool = False): | |
| self.keypoints_flip_cfg = keypoints_flip_cfg | |
| self.target_flip_cfg = target_flip_cfg | |
| self.flip_prob = flip_prob | |
| self.flip_camera = flip_camera | |
| self.flip_label = flip_label | |
| def transform(self, results: Dict) -> dict: | |
| """The transform function of :class:`RandomFlipAroundRoot`. | |
| See ``transform()`` method of :class:`BaseTransform` for details. | |
| Args: | |
| results (dict): The result dict | |
| Returns: | |
| dict: The result dict. | |
| """ | |
| if np.random.rand() <= self.flip_prob: | |
| if self.flip_label: | |
| assert 'keypoint_labels' in results | |
| assert 'lifting_target_label' in results | |
| keypoints_key = 'keypoint_labels' | |
| keypoints_visible_key = 'keypoint_labels_visible' | |
| target_key = 'lifting_target_label' | |
| else: | |
| assert 'keypoints' in results | |
| assert 'lifting_target' in results | |
| keypoints_key = 'keypoints' | |
| keypoints_visible_key = 'keypoints_visible' | |
| target_key = 'lifting_target' | |
| keypoints = results[keypoints_key] | |
| if keypoints_visible_key in results: | |
| keypoints_visible = results[keypoints_visible_key] | |
| else: | |
| keypoints_visible = np.ones( | |
| keypoints.shape[:-1], dtype=np.float32) | |
| lifting_target = results[target_key] | |
| if 'lifting_target_visible' in results: | |
| lifting_target_visible = results['lifting_target_visible'] | |
| else: | |
| lifting_target_visible = np.ones( | |
| lifting_target.shape[:-1], dtype=np.float32) | |
| if 'flip_indices' not in results: | |
| flip_indices = list(range(self.num_keypoints)) | |
| else: | |
| flip_indices = results['flip_indices'] | |
| # flip joint coordinates | |
| _camera_param = deepcopy(results['camera_param']) | |
| keypoints, keypoints_visible = flip_keypoints_custom_center( | |
| keypoints, | |
| keypoints_visible, | |
| flip_indices, | |
| center_mode=self.keypoints_flip_cfg.get( | |
| 'center_mode', 'static'), | |
| center_x=self.keypoints_flip_cfg.get('center_x', 0.5), | |
| center_index=self.keypoints_flip_cfg.get('center_index', 0)) | |
| lifting_target, lifting_target_visible = flip_keypoints_custom_center( # noqa | |
| lifting_target, | |
| lifting_target_visible, | |
| flip_indices, | |
| center_mode=self.target_flip_cfg.get('center_mode', 'static'), | |
| center_x=self.target_flip_cfg.get('center_x', 0.5), | |
| center_index=self.target_flip_cfg.get('center_index', 0)) | |
| results[keypoints_key] = keypoints | |
| results[keypoints_visible_key] = keypoints_visible | |
| results[target_key] = lifting_target | |
| results['lifting_target_visible'] = lifting_target_visible | |
| # flip horizontal distortion coefficients | |
| if self.flip_camera: | |
| assert 'camera_param' in results, \ | |
| 'Camera parameters are missing.' | |
| assert 'c' in _camera_param | |
| _camera_param['c'][0] *= -1 | |
| if 'p' in _camera_param: | |
| _camera_param['p'][0] *= -1 | |
| results['camera_param'].update(_camera_param) | |
| return results | |