Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch import Tensor, nn | |
| from mmpose.evaluation.functional import keypoint_pck_accuracy | |
| from mmpose.models.utils.tta import flip_coordinates | |
| from mmpose.registry import KEYPOINT_CODECS, MODELS | |
| from mmpose.utils.tensor_utils import to_numpy | |
| from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, | |
| Predictions) | |
| from ..base_head import BaseHead | |
| OptIntSeq = Optional[Sequence[int]] | |
| class RLEHead(BaseHead): | |
| """Top-down regression head introduced in `RLE`_ by Li et al(2021). The | |
| head is composed of fully-connected layers to predict the coordinates and | |
| sigma(the variance of the coordinates) together. | |
| Args: | |
| in_channels (int | sequence[int]): Number of input channels | |
| num_joints (int): Number of joints | |
| loss (Config): Config for keypoint loss. Defaults to use | |
| :class:`RLELoss` | |
| decoder (Config, optional): The decoder config that controls decoding | |
| keypoint coordinates from the network output. Defaults to ``None`` | |
| init_cfg (Config, optional): Config to control the initialization. See | |
| :attr:`default_init_cfg` for default settings | |
| .. _`RLE`: https://arxiv.org/abs/2107.11291 | |
| """ | |
| _version = 2 | |
| def __init__(self, | |
| in_channels: Union[int, Sequence[int]], | |
| num_joints: int, | |
| loss: ConfigType = dict( | |
| type='RLELoss', use_target_weight=True), | |
| decoder: OptConfigType = None, | |
| init_cfg: OptConfigType = None): | |
| if init_cfg is None: | |
| init_cfg = self.default_init_cfg | |
| super().__init__(init_cfg) | |
| self.in_channels = in_channels | |
| self.num_joints = num_joints | |
| self.loss_module = MODELS.build(loss) | |
| if decoder is not None: | |
| self.decoder = KEYPOINT_CODECS.build(decoder) | |
| else: | |
| self.decoder = None | |
| # Define fully-connected layers | |
| self.fc = nn.Linear(in_channels, self.num_joints * 4) | |
| # Register the hook to automatically convert old version state dicts | |
| self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) | |
| def forward(self, feats: Tuple[Tensor]) -> Tensor: | |
| """Forward the network. The input is multi scale feature maps and the | |
| output is the coordinates. | |
| Args: | |
| feats (Tuple[Tensor]): Multi scale feature maps. | |
| Returns: | |
| Tensor: output coordinates(and sigmas[optional]). | |
| """ | |
| x = feats[-1] | |
| x = torch.flatten(x, 1) | |
| x = self.fc(x) | |
| return x.reshape(-1, self.num_joints, 4) | |
| def predict(self, | |
| feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| test_cfg: ConfigType = {}) -> Predictions: | |
| """Predict results from outputs.""" | |
| if test_cfg.get('flip_test', False): | |
| # TTA: flip test -> feats = [orig, flipped] | |
| assert isinstance(feats, list) and len(feats) == 2 | |
| flip_indices = batch_data_samples[0].metainfo['flip_indices'] | |
| input_size = batch_data_samples[0].metainfo['input_size'] | |
| _feats, _feats_flip = feats | |
| _batch_coords = self.forward(_feats) | |
| _batch_coords[..., 2:] = _batch_coords[..., 2:].sigmoid() | |
| _batch_coords_flip = flip_coordinates( | |
| self.forward(_feats_flip), | |
| flip_indices=flip_indices, | |
| shift_coords=test_cfg.get('shift_coords', True), | |
| input_size=input_size) | |
| _batch_coords_flip[..., 2:] = _batch_coords_flip[..., 2:].sigmoid() | |
| batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 | |
| else: | |
| batch_coords = self.forward(feats) # (B, K, D) | |
| batch_coords[..., 2:] = batch_coords[..., 2:].sigmoid() | |
| batch_coords.unsqueeze_(dim=1) # (B, N, K, D) | |
| preds = self.decode(batch_coords) | |
| return preds | |
| def loss(self, | |
| inputs: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| train_cfg: ConfigType = {}) -> dict: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| pred_outputs = self.forward(inputs) | |
| keypoint_labels = torch.cat( | |
| [d.gt_instance_labels.keypoint_labels for d in batch_data_samples]) | |
| keypoint_weights = torch.cat([ | |
| d.gt_instance_labels.keypoint_weights for d in batch_data_samples | |
| ]) | |
| pred_coords = pred_outputs[:, :, :2] | |
| pred_sigma = pred_outputs[:, :, 2:4] | |
| # calculate losses | |
| losses = dict() | |
| loss = self.loss_module(pred_coords, pred_sigma, keypoint_labels, | |
| keypoint_weights.unsqueeze(-1)) | |
| losses.update(loss_kpt=loss) | |
| # calculate accuracy | |
| _, avg_acc, _ = keypoint_pck_accuracy( | |
| pred=to_numpy(pred_coords), | |
| gt=to_numpy(keypoint_labels), | |
| mask=to_numpy(keypoint_weights) > 0, | |
| thr=0.05, | |
| norm_factor=np.ones((pred_coords.size(0), 2), dtype=np.float32)) | |
| acc_pose = torch.tensor(avg_acc, device=keypoint_labels.device) | |
| losses.update(acc_pose=acc_pose) | |
| return losses | |
| def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, | |
| **kwargs): | |
| """A hook function to convert old-version state dict of | |
| :class:`DeepposeRegressionHead` (before MMPose v1.0.0) to a | |
| compatible format of :class:`RegressionHead`. | |
| The hook will be automatically registered during initialization. | |
| """ | |
| version = local_meta.get('version', None) | |
| if version and version >= self._version: | |
| return | |
| # convert old-version state dict | |
| keys = list(state_dict.keys()) | |
| for _k in keys: | |
| v = state_dict.pop(_k) | |
| k = _k.lstrip(prefix) | |
| # In old version, "loss" includes the instances of loss, | |
| # now it should be renamed "loss_module" | |
| k_parts = k.split('.') | |
| if k_parts[0] == 'loss': | |
| # loss.xxx -> loss_module.xxx | |
| k_new = prefix + 'loss_module.' + '.'.join(k_parts[1:]) | |
| else: | |
| k_new = _k | |
| state_dict[k_new] = v | |
| def default_init_cfg(self): | |
| init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] | |
| return init_cfg | |