Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from itertools import zip_longest | |
| from typing import Optional | |
| from torch import Tensor | |
| from mmpose.registry import MODELS | |
| from mmpose.utils.typing import (ConfigType, InstanceList, OptConfigType, | |
| OptMultiConfig, PixelDataList, SampleList) | |
| from .base import BasePoseEstimator | |
| class TopdownPoseEstimator(BasePoseEstimator): | |
| """Base class for top-down pose estimators. | |
| Args: | |
| backbone (dict): The backbone config | |
| neck (dict, optional): The neck config. Defaults to ``None`` | |
| head (dict, optional): The head config. Defaults to ``None`` | |
| train_cfg (dict, optional): The runtime config for training process. | |
| Defaults to ``None`` | |
| test_cfg (dict, optional): The runtime config for testing process. | |
| Defaults to ``None`` | |
| data_preprocessor (dict, optional): The data preprocessing config to | |
| build the instance of :class:`BaseDataPreprocessor`. Defaults to | |
| ``None`` | |
| init_cfg (dict, optional): The config to control the initialization. | |
| Defaults to ``None`` | |
| metainfo (dict): Meta information for dataset, such as keypoints | |
| definition and properties. If set, the metainfo of the input data | |
| batch will be overridden. For more details, please refer to | |
| https://mmpose.readthedocs.io/en/latest/user_guides/ | |
| prepare_datasets.html#create-a-custom-dataset-info- | |
| config-file-for-the-dataset. Defaults to ``None`` | |
| """ | |
| def __init__(self, | |
| backbone: ConfigType, | |
| neck: OptConfigType = None, | |
| head: OptConfigType = None, | |
| train_cfg: OptConfigType = None, | |
| test_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None, | |
| metainfo: Optional[dict] = None, | |
| freeze_backbone: bool = False): | |
| super().__init__( | |
| backbone=backbone, | |
| neck=neck, | |
| head=head, | |
| train_cfg=train_cfg, | |
| test_cfg=test_cfg, | |
| data_preprocessor=data_preprocessor, | |
| init_cfg=init_cfg, | |
| metainfo=metainfo) | |
| # Freeze all params of the backbone | |
| if freeze_backbone: | |
| print("Freezing backbone!") | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
| """Calculate losses from a batch of inputs and data samples. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W). | |
| data_samples (List[:obj:`PoseDataSample`]): The batch | |
| data samples. | |
| Returns: | |
| dict: A dictionary of losses. | |
| """ | |
| feats = self.extract_feat(inputs) | |
| losses = dict() | |
| if self.with_head: | |
| losses.update( | |
| self.head.loss(feats, data_samples, train_cfg=self.train_cfg)) | |
| return losses | |
| def predict(self, inputs: Tensor, data_samples: SampleList) -> SampleList: | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W) | |
| data_samples (List[:obj:`PoseDataSample`]): The batch | |
| data samples | |
| Returns: | |
| list[:obj:`PoseDataSample`]: The pose estimation results of the | |
| input images. The return value is `PoseDataSample` instances with | |
| ``pred_instances`` and ``pred_fields``(optional) field , and | |
| ``pred_instances`` usually contains the following keys: | |
| - keypoints (Tensor): predicted keypoint coordinates in shape | |
| (num_instances, K, D) where K is the keypoint number and D | |
| is the keypoint dimension | |
| - keypoint_scores (Tensor): predicted keypoint scores in shape | |
| (num_instances, K) | |
| """ | |
| assert self.with_head, ( | |
| 'The model must have head to perform prediction.') | |
| if self.test_cfg.get('flip_test', False): | |
| _feats = self.extract_feat(inputs) | |
| _feats_flip = self.extract_feat(inputs.flip(-1)) | |
| feats = [_feats, _feats_flip] | |
| else: | |
| feats = self.extract_feat(inputs) | |
| preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg) | |
| if isinstance(preds, tuple): | |
| batch_pred_instances, batch_pred_fields = preds | |
| else: | |
| batch_pred_instances = preds | |
| batch_pred_fields = None | |
| results = self.add_pred_to_datasample(batch_pred_instances, | |
| batch_pred_fields, data_samples) | |
| return results | |
| def add_pred_to_datasample(self, batch_pred_instances: InstanceList, | |
| batch_pred_fields: Optional[PixelDataList], | |
| batch_data_samples: SampleList) -> SampleList: | |
| """Add predictions into data samples. | |
| Args: | |
| batch_pred_instances (List[InstanceData]): The predicted instances | |
| of the input data batch | |
| batch_pred_fields (List[PixelData], optional): The predicted | |
| fields (e.g. heatmaps) of the input batch | |
| batch_data_samples (List[PoseDataSample]): The input data batch | |
| Returns: | |
| List[PoseDataSample]: A list of data samples where the predictions | |
| are stored in the ``pred_instances`` field of each data sample. | |
| """ | |
| assert len(batch_pred_instances) == len(batch_data_samples) | |
| if batch_pred_fields is None: | |
| batch_pred_fields = [] | |
| output_keypoint_indices = self.test_cfg.get('output_keypoint_indices', | |
| None) | |
| for pred_instances, pred_fields, data_sample in zip_longest( | |
| batch_pred_instances, batch_pred_fields, batch_data_samples): | |
| if pred_instances is None: | |
| continue | |
| gt_instances = data_sample.gt_instances | |
| # convert keypoint coordinates from input space to image space | |
| input_center = data_sample.metainfo['input_center'] | |
| input_scale = data_sample.metainfo['input_scale'] | |
| input_size = data_sample.metainfo['input_size'] | |
| pred_instances.keypoints[..., :2] = \ | |
| pred_instances.keypoints[..., :2] / input_size * input_scale \ | |
| + input_center - 0.5 * input_scale | |
| if 'keypoints_visible' not in pred_instances: | |
| pred_instances.keypoints_visible = \ | |
| pred_instances.keypoint_scores | |
| if output_keypoint_indices is not None: | |
| # select output keypoints with given indices | |
| num_keypoints = pred_instances.keypoints.shape[1] | |
| for key, value in pred_instances.all_items(): | |
| if key.startswith('keypoint'): | |
| pred_instances.set_field( | |
| value[:, output_keypoint_indices], key) | |
| # add bbox information into pred_instances | |
| pred_instances.bboxes = gt_instances.bboxes | |
| pred_instances.bbox_scores = gt_instances.bbox_scores | |
| data_sample.pred_instances = pred_instances | |
| if pred_fields is not None: | |
| if output_keypoint_indices is not None: | |
| # select output heatmap channels with keypoint indices | |
| # when the number of heatmap channel matches num_keypoints | |
| for key, value in pred_fields.all_items(): | |
| if value.shape[0] != num_keypoints: | |
| continue | |
| pred_fields.set_field(value[output_keypoint_indices], | |
| key) | |
| data_sample.pred_fields = pred_fields | |
| return batch_data_samples | |