Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Tuple, Union | |
| import numpy as np | |
| from mmcv.transforms import BaseTransform | |
| from mmpose.registry import TRANSFORMS | |
| class KeypointConverter(BaseTransform): | |
| """Change the order of keypoints according to the given mapping. | |
| Required Keys: | |
| - keypoints | |
| - keypoints_visible | |
| Modified Keys: | |
| - keypoints | |
| - keypoints_visible | |
| Args: | |
| num_keypoints (int): The number of keypoints in target dataset. | |
| mapping (list): A list containing mapping indexes. Each element has | |
| format (source_index, target_index) | |
| Example: | |
| >>> import numpy as np | |
| >>> # case 1: 1-to-1 mapping | |
| >>> # (0, 0) means target[0] = source[0] | |
| >>> self = KeypointConverter( | |
| >>> num_keypoints=3, | |
| >>> mapping=[ | |
| >>> (0, 0), (1, 1), (2, 2), (3, 3) | |
| >>> ]) | |
| >>> results = dict( | |
| >>> keypoints=np.arange(34).reshape(2, 3, 2), | |
| >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) | |
| >>> results = self(results) | |
| >>> assert np.equal(results['keypoints'], | |
| >>> np.arange(34).reshape(2, 3, 2)).all() | |
| >>> assert np.equal(results['keypoints_visible'], | |
| >>> np.arange(34).reshape(2, 3, 2) % 2).all() | |
| >>> | |
| >>> # case 2: 2-to-1 mapping | |
| >>> # ((1, 2), 0) means target[0] = (source[1] + source[2]) / 2 | |
| >>> self = KeypointConverter( | |
| >>> num_keypoints=3, | |
| >>> mapping=[ | |
| >>> ((1, 2), 0), (1, 1), (2, 2) | |
| >>> ]) | |
| >>> results = dict( | |
| >>> keypoints=np.arange(34).reshape(2, 3, 2), | |
| >>> keypoints_visible=np.arange(34).reshape(2, 3, 2) % 2) | |
| >>> results = self(results) | |
| """ | |
| def __init__(self, num_keypoints: int, | |
| mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple, | |
| int]]]): | |
| self.num_keypoints = num_keypoints | |
| self.mapping = mapping | |
| if len(mapping): | |
| source_index, target_index = zip(*mapping) | |
| else: | |
| source_index, target_index = [], [] | |
| src1, src2 = [], [] | |
| interpolation = False | |
| for x in source_index: | |
| if isinstance(x, (list, tuple)): | |
| assert len(x) == 2, 'source_index should be a list/tuple of ' \ | |
| 'length 2' | |
| src1.append(x[0]) | |
| src2.append(x[1]) | |
| interpolation = True | |
| else: | |
| src1.append(x) | |
| src2.append(x) | |
| # When paired source_indexes are input, | |
| # keep a self.source_index2 for interpolation | |
| if interpolation: | |
| self.source_index2 = src2 | |
| self.source_index = src1 | |
| self.target_index = list(target_index) | |
| self.interpolation = interpolation | |
| def transform(self, results: dict) -> dict: | |
| """Transforms the keypoint results to match the target keypoints.""" | |
| num_instances = results['keypoints'].shape[0] | |
| if 'keypoints_visible' not in results: | |
| results['keypoints_visible'] = np.ones( | |
| (num_instances, results['keypoints'].shape[1])) | |
| if len(results['keypoints_visible'].shape) > 2: | |
| results['keypoints_visible'] = results['keypoints_visible'][:, :, | |
| 0] | |
| # Initialize output arrays | |
| keypoints = np.zeros((num_instances, self.num_keypoints, 3)) | |
| keypoints_visible = np.zeros((num_instances, self.num_keypoints)) | |
| key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints' | |
| c = results[key].shape[-1] | |
| flip_indices = results.get('flip_indices', None) | |
| # Create a mask to weight visibility loss | |
| keypoints_visible_weights = keypoints_visible.copy() | |
| keypoints_visible_weights[:, self.target_index] = 1.0 | |
| # Interpolate keypoints if pairs of source indexes provided | |
| if self.interpolation: | |
| keypoints[:, self.target_index, :c] = 0.5 * ( | |
| results[key][:, self.source_index] + | |
| results[key][:, self.source_index2]) | |
| keypoints_visible[:, self.target_index] = results[ | |
| 'keypoints_visible'][:, self.source_index] * results[ | |
| 'keypoints_visible'][:, self.source_index2] | |
| # Flip keypoints if flip_indices provided | |
| if flip_indices is not None: | |
| for i, (x1, x2) in enumerate( | |
| zip(self.source_index, self.source_index2)): | |
| idx = flip_indices[x1] if x1 == x2 else i | |
| flip_indices[i] = idx if idx < self.num_keypoints else i | |
| flip_indices = flip_indices[:len(self.source_index)] | |
| # Otherwise just copy from the source index | |
| else: | |
| keypoints[:, | |
| self.target_index, :c] = results[key][:, | |
| self.source_index] | |
| keypoints_visible[:, self.target_index] = results[ | |
| 'keypoints_visible'][:, self.source_index] | |
| # Update the results dict | |
| results['keypoints'] = keypoints[..., :2] | |
| results['keypoints_visible'] = np.stack( | |
| [keypoints_visible, keypoints_visible_weights], axis=2) | |
| if 'keypoints_3d' in results: | |
| results['keypoints_3d'] = keypoints | |
| results['lifting_target'] = keypoints[results['target_idx']] | |
| results['lifting_target_visible'] = keypoints_visible[ | |
| results['target_idx']] | |
| results['flip_indices'] = flip_indices | |
| return results | |
| def __repr__(self) -> str: | |
| """print the basic information of the transform. | |
| Returns: | |
| str: Formatted string. | |
| """ | |
| repr_str = self.__class__.__name__ | |
| repr_str += f'(num_keypoints={self.num_keypoints}, '\ | |
| f'mapping={self.mapping})' | |
| return repr_str | |
| class SingleHandConverter(BaseTransform): | |
| """Mapping a single hand keypoints into double hands according to the given | |
| mapping and hand type. | |
| Required Keys: | |
| - keypoints | |
| - keypoints_visible | |
| - hand_type | |
| Modified Keys: | |
| - keypoints | |
| - keypoints_visible | |
| Args: | |
| num_keypoints (int): The number of keypoints in target dataset. | |
| left_hand_mapping (list): A list containing mapping indexes. Each | |
| element has format (source_index, target_index) | |
| right_hand_mapping (list): A list containing mapping indexes. Each | |
| element has format (source_index, target_index) | |
| Example: | |
| >>> import numpy as np | |
| >>> self = SingleHandConverter( | |
| >>> num_keypoints=42, | |
| >>> left_hand_mapping=[ | |
| >>> (0, 0), (1, 1), (2, 2), (3, 3) | |
| >>> ], | |
| >>> right_hand_mapping=[ | |
| >>> (0, 21), (1, 22), (2, 23), (3, 24) | |
| >>> ]) | |
| >>> results = dict( | |
| >>> keypoints=np.arange(84).reshape(2, 21, 2), | |
| >>> keypoints_visible=np.arange(84).reshape(2, 21, 2) % 2, | |
| >>> hand_type=np.array([[0, 1], [1, 0]])) | |
| >>> results = self(results) | |
| """ | |
| def __init__(self, num_keypoints: int, | |
| left_hand_mapping: Union[List[Tuple[int, int]], | |
| List[Tuple[Tuple, int]]], | |
| right_hand_mapping: Union[List[Tuple[int, int]], | |
| List[Tuple[Tuple, int]]]): | |
| self.num_keypoints = num_keypoints | |
| self.left_hand_converter = KeypointConverter(num_keypoints, | |
| left_hand_mapping) | |
| self.right_hand_converter = KeypointConverter(num_keypoints, | |
| right_hand_mapping) | |
| def transform(self, results: dict) -> dict: | |
| """Transforms the keypoint results to match the target keypoints.""" | |
| assert 'hand_type' in results, ( | |
| 'hand_type should be provided in results') | |
| hand_type = results['hand_type'] | |
| if np.sum(hand_type - [[0, 1]]) <= 1e-6: | |
| # left hand | |
| results = self.left_hand_converter(results) | |
| elif np.sum(hand_type - [[1, 0]]) <= 1e-6: | |
| results = self.right_hand_converter(results) | |
| else: | |
| raise ValueError('hand_type should be left or right') | |
| return results | |
| def __repr__(self) -> str: | |
| """print the basic information of the transform. | |
| Returns: | |
| str: Formatted string. | |
| """ | |
| repr_str = self.__class__.__name__ | |
| repr_str += f'(num_keypoints={self.num_keypoints}, '\ | |
| f'left_hand_converter={self.left_hand_converter}, '\ | |
| f'right_hand_converter={self.right_hand_converter})' | |
| return repr_str | |