File size: 4,836 Bytes
789eef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple

import numpy as np


def flip_keypoints(keypoints: np.ndarray,
                   keypoints_visible: Optional[np.ndarray],
                   image_size: Tuple[int, int],
                   flip_indices: List[int],
                   direction: str = 'horizontal'
                   ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """Flip keypoints in the given direction.

    Note:

        - keypoint number: K
        - keypoint dimension: D

    Args:
        keypoints (np.ndarray): Keypoints in shape (..., K, D)
        keypoints_visible (np.ndarray, optional): The visibility of keypoints
            in shape (..., K, 1). Set ``None`` if the keypoint visibility is
            unavailable
        image_size (tuple): The image shape in [w, h]
        flip_indices (List[int]): The indices of each keypoint's symmetric
            keypoint
        direction (str): The flip direction. Options are ``'horizontal'``,
            ``'vertical'`` and ``'diagonal'``. Defaults to ``'horizontal'``

    Returns:
        tuple:
        - keypoints_flipped (np.ndarray): Flipped keypoints in shape
            (..., K, D)
        - keypoints_visible_flipped (np.ndarray, optional): Flipped keypoints'
            visibility in shape (..., K, 1). Return ``None`` if the input
            ``keypoints_visible`` is ``None``
    """

    assert keypoints.shape[:-1] == keypoints_visible.shape, (
        f'Mismatched shapes of keypoints {keypoints.shape} and '
        f'keypoints_visible {keypoints_visible.shape}')

    direction_options = {'horizontal', 'vertical', 'diagonal'}
    assert direction in direction_options, (
        f'Invalid flipping direction "{direction}". '
        f'Options are {direction_options}')

    # swap the symmetric keypoint pairs
    if direction == 'horizontal' or direction == 'vertical':
        keypoints = keypoints[..., flip_indices, :]
        if keypoints_visible is not None:
            keypoints_visible = keypoints_visible[..., flip_indices]

    # flip the keypoints
    w, h = image_size
    if direction == 'horizontal':
        keypoints[..., 0] = w - 1 - keypoints[..., 0]
    elif direction == 'vertical':
        keypoints[..., 1] = h - 1 - keypoints[..., 1]
    else:
        keypoints = [w, h] - keypoints - 1

    return keypoints, keypoints_visible


def flip_keypoints_custom_center(keypoints: np.ndarray,
                                 keypoints_visible: np.ndarray,
                                 flip_indices: List[int],
                                 center_mode: str = 'static',
                                 center_x: float = 0.5,
                                 center_index: int = 0):
    """Flip human joints horizontally.

    Note:
        - num_keypoint: K
        - dimension: D

    Args:
        keypoints (np.ndarray([..., K, D])): Coordinates of keypoints.
        keypoints_visible (np.ndarray([..., K])): Visibility item of keypoints.
        flip_indices (list[int]): The indices to flip the keypoints.
        center_mode (str): The mode to set the center location on the x-axis
            to flip around. Options are:

            - static: use a static x value (see center_x also)
            - root: use a root joint (see center_index also)

            Defaults: ``'static'``.
        center_x (float): Set the x-axis location of the flip center. Only used
            when ``center_mode`` is ``'static'``. Defaults: 0.5.
        center_index (int): Set the index of the root joint, whose x location
            will be used as the flip center. Only used when ``center_mode`` is
            ``'root'``. Defaults: 0.

    Returns:
        np.ndarray([..., K, C]): Flipped joints.
    """

    assert keypoints.ndim >= 2, f'Invalid pose shape {keypoints.shape}'

    allowed_center_mode = {'static', 'root'}
    assert center_mode in allowed_center_mode, 'Get invalid center_mode ' \
        f'{center_mode}, allowed choices are {allowed_center_mode}'

    if center_mode == 'static':
        x_c = center_x
    elif center_mode == 'root':
        assert keypoints.shape[-2] > center_index
        x_c = keypoints[..., center_index, 0]

    keypoints_flipped = keypoints.copy()
    keypoints_visible_flipped = keypoints_visible.copy()
    # Swap left-right parts
    for left, right in enumerate(flip_indices):
        keypoints_flipped[..., left, :] = keypoints[..., right, :]
        keypoints_visible_flipped[..., left] = keypoints_visible[..., right]

    # Flip horizontally
    keypoints_flipped[..., 0] = x_c * 2 - keypoints_flipped[..., 0]
    return keypoints_flipped, keypoints_visible_flipped