|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch import Tensor |
|
|
|
|
|
from mmpose.registry import KEYPOINT_CODECS |
|
|
from .base import BaseKeypointCodec |
|
|
from .utils import (batch_heatmap_nms, generate_displacement_heatmap, |
|
|
generate_gaussian_heatmaps, get_diagonal_lengths, |
|
|
get_instance_root) |
|
|
|
|
|
|
|
|
@KEYPOINT_CODECS.register_module() |
|
|
class SPR(BaseKeypointCodec): |
|
|
"""Encode/decode keypoints with Structured Pose Representation (SPR). |
|
|
|
|
|
See the paper `Single-stage multi-person pose machines`_ |
|
|
by Nie et al (2017) for details |
|
|
|
|
|
Note: |
|
|
|
|
|
- instance number: N |
|
|
- keypoint number: K |
|
|
- keypoint dimension: D |
|
|
- image size: [w, h] |
|
|
- heatmap size: [W, H] |
|
|
|
|
|
Encoded: |
|
|
|
|
|
- heatmaps (np.ndarray): The generated heatmap in shape (1, H, W) |
|
|
where [W, H] is the `heatmap_size`. If the keypoint heatmap is |
|
|
generated together, the output heatmap shape is (K+1, H, W) |
|
|
- heatmap_weights (np.ndarray): The target weights for heatmaps which |
|
|
has same shape with heatmaps. |
|
|
- displacements (np.ndarray): The dense keypoint displacement in |
|
|
shape (K*2, H, W). |
|
|
- displacement_weights (np.ndarray): The target weights for heatmaps |
|
|
which has same shape with displacements. |
|
|
|
|
|
Args: |
|
|
input_size (tuple): Image size in [w, h] |
|
|
heatmap_size (tuple): Heatmap size in [W, H] |
|
|
sigma (float or tuple, optional): The sigma values of the Gaussian |
|
|
heatmaps. If sigma is a tuple, it includes both sigmas for root |
|
|
and keypoint heatmaps. ``None`` means the sigmas are computed |
|
|
automatically from the heatmap size. Defaults to ``None`` |
|
|
generate_keypoint_heatmaps (bool): Whether to generate Gaussian |
|
|
heatmaps for each keypoint. Defaults to ``False`` |
|
|
root_type (str): The method to generate the instance root. Options |
|
|
are: |
|
|
|
|
|
- ``'kpt_center'``: Average coordinate of all visible keypoints. |
|
|
- ``'bbox_center'``: Center point of bounding boxes outlined by |
|
|
all visible keypoints. |
|
|
|
|
|
Defaults to ``'kpt_center'`` |
|
|
|
|
|
minimal_diagonal_length (int or float): The threshold of diagonal |
|
|
length of instance bounding box. Small instances will not be |
|
|
used in training. Defaults to 32 |
|
|
background_weight (float): Loss weight of background pixels. |
|
|
Defaults to 0.1 |
|
|
decode_thr (float): The threshold of keypoint response value in |
|
|
heatmaps. Defaults to 0.01 |
|
|
decode_nms_kernel (int): The kernel size of the NMS during decoding, |
|
|
which should be an odd integer. Defaults to 5 |
|
|
decode_max_instances (int): The maximum number of instances |
|
|
to decode. Defaults to 30 |
|
|
|
|
|
.. _`Single-stage multi-person pose machines`: |
|
|
https://arxiv.org/abs/1908.09220 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_size: Tuple[int, int], |
|
|
heatmap_size: Tuple[int, int], |
|
|
sigma: Optional[Union[float, Tuple[float]]] = None, |
|
|
generate_keypoint_heatmaps: bool = False, |
|
|
root_type: str = 'kpt_center', |
|
|
minimal_diagonal_length: Union[int, float] = 5, |
|
|
background_weight: float = 0.1, |
|
|
decode_nms_kernel: int = 5, |
|
|
decode_max_instances: int = 30, |
|
|
decode_thr: float = 0.01, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.input_size = input_size |
|
|
self.heatmap_size = heatmap_size |
|
|
self.generate_keypoint_heatmaps = generate_keypoint_heatmaps |
|
|
self.root_type = root_type |
|
|
self.minimal_diagonal_length = minimal_diagonal_length |
|
|
self.background_weight = background_weight |
|
|
self.decode_nms_kernel = decode_nms_kernel |
|
|
self.decode_max_instances = decode_max_instances |
|
|
self.decode_thr = decode_thr |
|
|
|
|
|
self.scale_factor = (np.array(input_size) / |
|
|
heatmap_size).astype(np.float32) |
|
|
|
|
|
if sigma is None: |
|
|
sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 32 |
|
|
if generate_keypoint_heatmaps: |
|
|
|
|
|
self.sigma = (sigma, sigma // 2) |
|
|
else: |
|
|
self.sigma = (sigma, ) |
|
|
else: |
|
|
if not isinstance(sigma, (tuple, list)): |
|
|
sigma = (sigma, ) |
|
|
if generate_keypoint_heatmaps: |
|
|
assert len(sigma) == 2, 'sigma for keypoints must be given ' \ |
|
|
'if `generate_keypoint_heatmaps` ' \ |
|
|
'is True. e.g. sigma=(4, 2)' |
|
|
self.sigma = sigma |
|
|
|
|
|
def _get_heatmap_weights(self, |
|
|
heatmaps, |
|
|
fg_weight: float = 1, |
|
|
bg_weight: float = 0): |
|
|
"""Generate weight array for heatmaps. |
|
|
|
|
|
Args: |
|
|
heatmaps (np.ndarray): Root and keypoint (optional) heatmaps |
|
|
fg_weight (float): Weight for foreground pixels. Defaults to 1.0 |
|
|
bg_weight (float): Weight for background pixels. Defaults to 0.0 |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Heatmap weight array in the same shape with heatmaps |
|
|
""" |
|
|
heatmap_weights = np.ones(heatmaps.shape) * bg_weight |
|
|
heatmap_weights[heatmaps > 0] = fg_weight |
|
|
return heatmap_weights |
|
|
|
|
|
def encode(self, |
|
|
keypoints: np.ndarray, |
|
|
keypoints_visible: Optional[np.ndarray] = None) -> dict: |
|
|
"""Encode keypoints into root heatmaps and keypoint displacement |
|
|
fields. Note that the original keypoint coordinates should be in the |
|
|
input image space. |
|
|
|
|
|
Args: |
|
|
keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) |
|
|
keypoints_visible (np.ndarray): Keypoint visibilities in shape |
|
|
(N, K) |
|
|
|
|
|
Returns: |
|
|
dict: |
|
|
- heatmaps (np.ndarray): The generated heatmap in shape |
|
|
(1, H, W) where [W, H] is the `heatmap_size`. If keypoint |
|
|
heatmaps are generated together, the shape is (K+1, H, W) |
|
|
- heatmap_weights (np.ndarray): The pixel-wise weight for heatmaps |
|
|
which has same shape with `heatmaps` |
|
|
- displacements (np.ndarray): The generated displacement fields in |
|
|
shape (K*D, H, W). The vector on each pixels represents the |
|
|
displacement of keypoints belong to the associated instance |
|
|
from this pixel. |
|
|
- displacement_weights (np.ndarray): The pixel-wise weight for |
|
|
displacements which has same shape with `displacements` |
|
|
""" |
|
|
|
|
|
if keypoints_visible is None: |
|
|
keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) |
|
|
|
|
|
|
|
|
_keypoints = keypoints / self.scale_factor |
|
|
|
|
|
|
|
|
roots, roots_visible = get_instance_root(_keypoints, keypoints_visible, |
|
|
self.root_type) |
|
|
diagonal_lengths = get_diagonal_lengths(_keypoints, keypoints_visible) |
|
|
|
|
|
|
|
|
roots_visible[diagonal_lengths < self.minimal_diagonal_length] = 0 |
|
|
|
|
|
|
|
|
heatmaps, _ = generate_gaussian_heatmaps( |
|
|
heatmap_size=self.heatmap_size, |
|
|
keypoints=roots[:, None], |
|
|
keypoints_visible=roots_visible[:, None], |
|
|
sigma=self.sigma[0]) |
|
|
heatmap_weights = self._get_heatmap_weights( |
|
|
heatmaps, bg_weight=self.background_weight) |
|
|
|
|
|
if self.generate_keypoint_heatmaps: |
|
|
keypoint_heatmaps, _ = generate_gaussian_heatmaps( |
|
|
heatmap_size=self.heatmap_size, |
|
|
keypoints=_keypoints, |
|
|
keypoints_visible=keypoints_visible, |
|
|
sigma=self.sigma[1]) |
|
|
|
|
|
keypoint_heatmaps_weights = self._get_heatmap_weights( |
|
|
keypoint_heatmaps, bg_weight=self.background_weight) |
|
|
|
|
|
heatmaps = np.concatenate((keypoint_heatmaps, heatmaps), axis=0) |
|
|
heatmap_weights = np.concatenate( |
|
|
(keypoint_heatmaps_weights, heatmap_weights), axis=0) |
|
|
|
|
|
|
|
|
displacements, displacement_weights = \ |
|
|
generate_displacement_heatmap( |
|
|
self.heatmap_size, |
|
|
_keypoints, |
|
|
keypoints_visible, |
|
|
roots, |
|
|
roots_visible, |
|
|
diagonal_lengths, |
|
|
self.sigma[0], |
|
|
) |
|
|
|
|
|
encoded = dict( |
|
|
heatmaps=heatmaps, |
|
|
heatmap_weights=heatmap_weights, |
|
|
displacements=displacements, |
|
|
displacement_weights=displacement_weights) |
|
|
|
|
|
return encoded |
|
|
|
|
|
def decode(self, heatmaps: Tensor, |
|
|
displacements: Tensor) -> Tuple[np.ndarray, np.ndarray]: |
|
|
"""Decode the keypoint coordinates from heatmaps and displacements. The |
|
|
decoded keypoint coordinates are in the input image space. |
|
|
|
|
|
Args: |
|
|
heatmaps (Tensor): Encoded root and keypoints (optional) heatmaps |
|
|
in shape (1, H, W) or (K+1, H, W) |
|
|
displacements (Tensor): Encoded keypoints displacement fields |
|
|
in shape (K*D, H, W) |
|
|
|
|
|
Returns: |
|
|
tuple: |
|
|
- keypoints (Tensor): Decoded keypoint coordinates in shape |
|
|
(N, K, D) |
|
|
- scores (tuple): |
|
|
- root_scores (Tensor): The root scores in shape (N, ) |
|
|
- keypoint_scores (Tensor): The keypoint scores in |
|
|
shape (N, K). If keypoint heatmaps are not generated, |
|
|
`keypoint_scores` will be `None` |
|
|
""" |
|
|
|
|
|
_k, h, w = displacements.shape |
|
|
k = _k // 2 |
|
|
displacements = displacements.view(k, 2, h, w) |
|
|
|
|
|
|
|
|
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) |
|
|
regular_grid = torch.stack([x, y], dim=0).to(displacements) |
|
|
posemaps = (regular_grid[None] + displacements).flatten(2) |
|
|
|
|
|
|
|
|
root_heatmap_peaks = batch_heatmap_nms(heatmaps[None, -1:], |
|
|
self.decode_nms_kernel) |
|
|
root_scores, pos_idx = root_heatmap_peaks.flatten().topk( |
|
|
self.decode_max_instances) |
|
|
mask = root_scores > self.decode_thr |
|
|
root_scores, pos_idx = root_scores[mask], pos_idx[mask] |
|
|
|
|
|
keypoints = posemaps[:, :, pos_idx].permute(2, 0, 1).contiguous() |
|
|
|
|
|
if self.generate_keypoint_heatmaps and heatmaps.shape[0] == 1 + k: |
|
|
|
|
|
keypoint_scores = self.get_keypoint_scores(heatmaps[:k], keypoints) |
|
|
else: |
|
|
keypoint_scores = None |
|
|
|
|
|
keypoints = torch.cat([ |
|
|
kpt * self.scale_factor[i] |
|
|
for i, kpt in enumerate(keypoints.split(1, -1)) |
|
|
], |
|
|
dim=-1) |
|
|
return keypoints, (root_scores, keypoint_scores) |
|
|
|
|
|
def get_keypoint_scores(self, heatmaps: Tensor, keypoints: Tensor): |
|
|
"""Calculate the keypoint scores with keypoints heatmaps and |
|
|
coordinates. |
|
|
|
|
|
Args: |
|
|
heatmaps (Tensor): Keypoint heatmaps in shape (K, H, W) |
|
|
keypoints (Tensor): Keypoint coordinates in shape (N, K, D) |
|
|
|
|
|
Returns: |
|
|
Tensor: Keypoint scores in [N, K] |
|
|
""" |
|
|
k, h, w = heatmaps.shape |
|
|
keypoints = torch.stack(( |
|
|
keypoints[..., 0] / (w - 1) * 2 - 1, |
|
|
keypoints[..., 1] / (h - 1) * 2 - 1, |
|
|
), |
|
|
dim=-1) |
|
|
keypoints = keypoints.transpose(0, 1).unsqueeze(1).contiguous() |
|
|
|
|
|
keypoint_scores = torch.nn.functional.grid_sample( |
|
|
heatmaps.unsqueeze(1), keypoints, |
|
|
padding_mode='border').view(k, -1).transpose(0, 1).contiguous() |
|
|
|
|
|
return keypoint_scores |
|
|
|