Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import warnings | |
| from typing import Optional, Sequence, Tuple, Union | |
| import torch | |
| from mmcv.cnn import build_conv_layer | |
| from mmengine.dist import get_dist_info | |
| from mmengine.structures import PixelData | |
| from torch import Tensor, nn | |
| from mmpose.codecs.utils import get_simcc_normalized | |
| from mmpose.evaluation.functional import simcc_pck_accuracy | |
| from mmpose.models.utils.tta import flip_vectors | |
| from mmpose.registry import KEYPOINT_CODECS, MODELS | |
| from mmpose.utils.tensor_utils import to_numpy | |
| from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, | |
| OptSampleList) | |
| from ..base_head import BaseHead | |
| OptIntSeq = Optional[Sequence[int]] | |
| class SimCCHead(BaseHead): | |
| """Top-down heatmap head introduced in `SimCC`_ by Li et al (2022). The | |
| head is composed of a few deconvolutional layers followed by a fully- | |
| connected layer to generate 1d representation from low-resolution feature | |
| maps. | |
| Args: | |
| in_channels (int | sequence[int]): Number of channels in the input | |
| feature map | |
| out_channels (int): Number of channels in the output heatmap | |
| input_size (tuple): Input image size in shape [w, h] | |
| in_featuremap_size (int | sequence[int]): Size of input feature map | |
| simcc_split_ratio (float): Split ratio of pixels | |
| deconv_type (str, optional): The type of deconv head which should | |
| be one of the following options: | |
| - ``'heatmap'``: make deconv layers in `HeatmapHead` | |
| - ``'vipnas'``: make deconv layers in `ViPNASHead` | |
| Defaults to ``'Heatmap'`` | |
| deconv_out_channels (sequence[int]): The output channel number of each | |
| deconv layer. Defaults to ``(256, 256, 256)`` | |
| deconv_kernel_sizes (sequence[int | tuple], optional): The kernel size | |
| of each deconv layer. Each element should be either an integer for | |
| both height and width dimensions, or a tuple of two integers for | |
| the height and the width dimension respectively.Defaults to | |
| ``(4, 4, 4)`` | |
| deconv_num_groups (Sequence[int], optional): The group number of each | |
| deconv layer. Defaults to ``(16, 16, 16)`` | |
| conv_out_channels (sequence[int], optional): The output channel number | |
| of each intermediate conv layer. ``None`` means no intermediate | |
| conv layer between deconv layers and the final conv layer. | |
| Defaults to ``None`` | |
| conv_kernel_sizes (sequence[int | tuple], optional): The kernel size | |
| of each intermediate conv layer. Defaults to ``None`` | |
| final_layer (dict): Arguments of the final Conv2d layer. | |
| Defaults to ``dict(kernel_size=1)`` | |
| loss (Config): Config of the keypoint loss. Defaults to use | |
| :class:`KLDiscretLoss` | |
| decoder (Config, optional): The decoder config that controls decoding | |
| keypoint coordinates from the network output. Defaults to ``None`` | |
| init_cfg (Config, optional): Config to control the initialization. See | |
| :attr:`default_init_cfg` for default settings | |
| .. _`SimCC`: https://arxiv.org/abs/2107.03332 | |
| """ | |
| _version = 2 | |
| def __init__( | |
| self, | |
| in_channels: Union[int, Sequence[int]], | |
| out_channels: int, | |
| input_size: Tuple[int, int], | |
| in_featuremap_size: Tuple[int, int], | |
| simcc_split_ratio: float = 2.0, | |
| deconv_type: str = 'heatmap', | |
| deconv_out_channels: OptIntSeq = (256, 256, 256), | |
| deconv_kernel_sizes: OptIntSeq = (4, 4, 4), | |
| deconv_num_groups: OptIntSeq = (16, 16, 16), | |
| conv_out_channels: OptIntSeq = None, | |
| conv_kernel_sizes: OptIntSeq = None, | |
| final_layer: dict = dict(kernel_size=1), | |
| loss: ConfigType = dict(type='KLDiscretLoss', use_target_weight=True), | |
| decoder: OptConfigType = None, | |
| init_cfg: OptConfigType = None, | |
| ): | |
| if init_cfg is None: | |
| init_cfg = self.default_init_cfg | |
| super().__init__(init_cfg) | |
| if deconv_type not in {'heatmap', 'vipnas'}: | |
| raise ValueError( | |
| f'{self.__class__.__name__} got invalid `deconv_type` value' | |
| f'{deconv_type}. Should be one of ' | |
| '{"heatmap", "vipnas"}') | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.input_size = input_size | |
| self.in_featuremap_size = in_featuremap_size | |
| self.simcc_split_ratio = simcc_split_ratio | |
| self.loss_module = MODELS.build(loss) | |
| if decoder is not None: | |
| self.decoder = KEYPOINT_CODECS.build(decoder) | |
| else: | |
| self.decoder = None | |
| num_deconv = len(deconv_out_channels) if deconv_out_channels else 0 | |
| if num_deconv != 0: | |
| self.heatmap_size = tuple( | |
| [s * (2**num_deconv) for s in in_featuremap_size]) | |
| # deconv layers + 1x1 conv | |
| self.deconv_head = self._make_deconv_head( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| deconv_type=deconv_type, | |
| deconv_out_channels=deconv_out_channels, | |
| deconv_kernel_sizes=deconv_kernel_sizes, | |
| deconv_num_groups=deconv_num_groups, | |
| conv_out_channels=conv_out_channels, | |
| conv_kernel_sizes=conv_kernel_sizes, | |
| final_layer=final_layer) | |
| if final_layer is not None: | |
| in_channels = out_channels | |
| else: | |
| in_channels = deconv_out_channels[-1] | |
| else: | |
| self.deconv_head = None | |
| if final_layer is not None: | |
| cfg = dict( | |
| type='Conv2d', | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1) | |
| cfg.update(final_layer) | |
| self.final_layer = build_conv_layer(cfg) | |
| else: | |
| self.final_layer = None | |
| self.heatmap_size = in_featuremap_size | |
| # Define SimCC layers | |
| flatten_dims = self.heatmap_size[0] * self.heatmap_size[1] | |
| W = int(self.input_size[0] * self.simcc_split_ratio) | |
| H = int(self.input_size[1] * self.simcc_split_ratio) | |
| self.mlp_head_x = nn.Linear(flatten_dims, W) | |
| self.mlp_head_y = nn.Linear(flatten_dims, H) | |
| def _make_deconv_head( | |
| self, | |
| in_channels: Union[int, Sequence[int]], | |
| out_channels: int, | |
| deconv_type: str = 'heatmap', | |
| deconv_out_channels: OptIntSeq = (256, 256, 256), | |
| deconv_kernel_sizes: OptIntSeq = (4, 4, 4), | |
| deconv_num_groups: OptIntSeq = (16, 16, 16), | |
| conv_out_channels: OptIntSeq = None, | |
| conv_kernel_sizes: OptIntSeq = None, | |
| final_layer: dict = dict(kernel_size=1) | |
| ) -> nn.Module: | |
| """Create deconvolutional layers by given parameters.""" | |
| if deconv_type == 'heatmap': | |
| deconv_head = MODELS.build( | |
| dict( | |
| type='HeatmapHead', | |
| in_channels=self.in_channels, | |
| out_channels=out_channels, | |
| deconv_out_channels=deconv_out_channels, | |
| deconv_kernel_sizes=deconv_kernel_sizes, | |
| conv_out_channels=conv_out_channels, | |
| conv_kernel_sizes=conv_kernel_sizes, | |
| final_layer=final_layer)) | |
| else: | |
| deconv_head = MODELS.build( | |
| dict( | |
| type='ViPNASHead', | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| deconv_out_channels=deconv_out_channels, | |
| deconv_num_groups=deconv_num_groups, | |
| conv_out_channels=conv_out_channels, | |
| conv_kernel_sizes=conv_kernel_sizes, | |
| final_layer=final_layer)) | |
| return deconv_head | |
| def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: | |
| """Forward the network. | |
| The input is the featuremap extracted by backbone and the | |
| output is the simcc representation. | |
| Args: | |
| feats (Tuple[Tensor]): Multi scale feature maps. | |
| Returns: | |
| pred_x (Tensor): 1d representation of x. | |
| pred_y (Tensor): 1d representation of y. | |
| """ | |
| if self.deconv_head is None: | |
| feats = feats[-1] | |
| if self.final_layer is not None: | |
| feats = self.final_layer(feats) | |
| else: | |
| feats = self.deconv_head(feats) | |
| # flatten the output heatmap | |
| x = torch.flatten(feats, 2) | |
| pred_x = self.mlp_head_x(x) | |
| pred_y = self.mlp_head_y(x) | |
| return pred_x, pred_y | |
| def predict( | |
| self, | |
| feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| test_cfg: OptConfigType = {}, | |
| ) -> InstanceList: | |
| """Predict results from features. | |
| Args: | |
| feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage | |
| features (or multiple multi-stage features in TTA) | |
| batch_data_samples (List[:obj:`PoseDataSample`]): The batch | |
| data samples | |
| test_cfg (dict): The runtime config for testing process. Defaults | |
| to {} | |
| Returns: | |
| List[InstanceData]: The pose predictions, each contains | |
| the following fields: | |
| - keypoints (np.ndarray): predicted keypoint coordinates in | |
| shape (num_instances, K, D) where K is the keypoint number | |
| and D is the keypoint dimension | |
| - keypoint_scores (np.ndarray): predicted keypoint scores in | |
| shape (num_instances, K) | |
| - keypoint_x_labels (np.ndarray, optional): The predicted 1-D | |
| intensity distribution in the x direction | |
| - keypoint_y_labels (np.ndarray, optional): The predicted 1-D | |
| intensity distribution in the y direction | |
| """ | |
| if test_cfg.get('flip_test', False): | |
| # TTA: flip test -> feats = [orig, flipped] | |
| assert isinstance(feats, list) and len(feats) == 2 | |
| flip_indices = batch_data_samples[0].metainfo['flip_indices'] | |
| _feats, _feats_flip = feats | |
| _batch_pred_x, _batch_pred_y = self.forward(_feats) | |
| _batch_pred_x_flip, _batch_pred_y_flip = self.forward(_feats_flip) | |
| _batch_pred_x_flip, _batch_pred_y_flip = flip_vectors( | |
| _batch_pred_x_flip, | |
| _batch_pred_y_flip, | |
| flip_indices=flip_indices) | |
| batch_pred_x = (_batch_pred_x + _batch_pred_x_flip) * 0.5 | |
| batch_pred_y = (_batch_pred_y + _batch_pred_y_flip) * 0.5 | |
| else: | |
| batch_pred_x, batch_pred_y = self.forward(feats) | |
| preds = self.decode((batch_pred_x, batch_pred_y)) | |
| if test_cfg.get('output_heatmaps', False): | |
| rank, _ = get_dist_info() | |
| if rank == 0: | |
| warnings.warn('The predicted simcc values are normalized for ' | |
| 'visualization. This may cause discrepancy ' | |
| 'between the keypoint scores and the 1D heatmaps' | |
| '.') | |
| # normalize the predicted 1d distribution | |
| sigma = self.decoder.sigma | |
| batch_pred_x = get_simcc_normalized(batch_pred_x, sigma[0]) | |
| batch_pred_y = get_simcc_normalized(batch_pred_y, sigma[1]) | |
| B, K, _ = batch_pred_x.shape | |
| # B, K, Wx -> B, K, Wx, 1 | |
| x = batch_pred_x.reshape(B, K, 1, -1) | |
| # B, K, Wy -> B, K, 1, Wy | |
| y = batch_pred_y.reshape(B, K, -1, 1) | |
| # B, K, Wx, Wy | |
| batch_heatmaps = torch.matmul(y, x) | |
| pred_fields = [ | |
| PixelData(heatmaps=hm) for hm in batch_heatmaps.detach() | |
| ] | |
| for pred_instances, pred_x, pred_y in zip(preds, | |
| to_numpy(batch_pred_x), | |
| to_numpy(batch_pred_y)): | |
| pred_instances.keypoint_x_labels = pred_x[None] | |
| pred_instances.keypoint_y_labels = pred_y[None] | |
| return preds, pred_fields | |
| else: | |
| return preds | |
| def loss( | |
| self, | |
| feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| train_cfg: OptConfigType = {}, | |
| ) -> dict: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| pred_x, pred_y = self.forward(feats) | |
| gt_x = torch.cat([ | |
| d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples | |
| ], | |
| dim=0) | |
| gt_y = torch.cat([ | |
| d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples | |
| ], | |
| dim=0) | |
| keypoint_weights = torch.cat( | |
| [ | |
| d.gt_instance_labels.keypoint_weights | |
| for d in batch_data_samples | |
| ], | |
| dim=0, | |
| ) | |
| pred_simcc = (pred_x, pred_y) | |
| gt_simcc = (gt_x, gt_y) | |
| # calculate losses | |
| losses = dict() | |
| loss = self.loss_module(pred_simcc, gt_simcc, keypoint_weights) | |
| losses.update(loss_kpt=loss) | |
| # calculate accuracy | |
| _, avg_acc, _ = simcc_pck_accuracy( | |
| output=to_numpy(pred_simcc), | |
| target=to_numpy(gt_simcc), | |
| simcc_split_ratio=self.simcc_split_ratio, | |
| mask=to_numpy(keypoint_weights) > 0, | |
| ) | |
| acc_pose = torch.tensor(avg_acc, device=gt_x.device) | |
| losses.update(acc_pose=acc_pose) | |
| return losses | |
| def default_init_cfg(self): | |
| init_cfg = [ | |
| dict( | |
| type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), | |
| dict(type='Constant', layer='BatchNorm2d', val=1), | |
| dict(type='Normal', layer=['Linear'], std=0.01, bias=0), | |
| ] | |
| return init_cfg | |