Spaces:
Sleeping
Sleeping
| from mmpose.models import HeatmapHead | |
| from mmpose.registry import MODELS | |
| # Register your head to the `MODELS`. | |
| class ExampleHead(HeatmapHead): | |
| """Implements an example head. | |
| Implement the model head just like a normal pytorch module. | |
| """ | |
| def __init__(self, **kwargs) -> None: | |
| print('Initializing ExampleHead...') | |
| super().__init__(**kwargs) | |
| def forward(self, feats): | |
| """Forward the network. The input is multi scale feature maps and the | |
| output is the coordinates. | |
| Args: | |
| feats (Tuple[Tensor]): Multi scale feature maps. | |
| Returns: | |
| Tensor: output coordinates or heatmaps. | |
| """ | |
| return super().forward(feats) | |
| def predict(self, feats, batch_data_samples, test_cfg={}): | |
| """Predict results from outputs. The behaviour of head during testing | |
| should be defined in this function. | |
| Args: | |
| feats (Tuple[Tensor] | List[Tuple[Tensor]]): The multi-stage | |
| features (or multiple multi-stage features in TTA) | |
| batch_data_samples (List[:obj:`PoseDataSample`]): A list of | |
| data samples for instances in a batch | |
| test_cfg (dict): The runtime config for testing process. Defaults | |
| to {} | |
| Returns: | |
| Union[InstanceList | Tuple[InstanceList | PixelDataList]]: If | |
| ``test_cfg['output_heatmap']==True``, return both pose and heatmap | |
| prediction; otherwise only return the pose prediction. | |
| The pose prediction is a list of ``InstanceData``, each contains | |
| the following fields: | |
| - keypoints (np.ndarray): predicted keypoint coordinates in | |
| shape (num_instances, K, D) where K is the keypoint number | |
| and D is the keypoint dimension | |
| - keypoint_scores (np.ndarray): predicted keypoint scores in | |
| shape (num_instances, K) | |
| The heatmap prediction is a list of ``PixelData``, each contains | |
| the following fields: | |
| - heatmaps (Tensor): The predicted heatmaps in shape (K, h, w) | |
| """ | |
| return super().predict(feats, batch_data_samples, test_cfg) | |
| def loss(self, feats, batch_data_samples, train_cfg={}) -> dict: | |
| """Calculate losses from a batch of inputs and data samples. The | |
| behaviour of head during training should be defined in this function. | |
| Args: | |
| feats (Tuple[Tensor]): The multi-stage features | |
| batch_data_samples (List[:obj:`PoseDataSample`]): A list of | |
| data samples for instances in a batch | |
| train_cfg (dict): The runtime config for training process. | |
| Defaults to {} | |
| Returns: | |
| dict: A dictionary of losses. | |
| """ | |
| return super().loss(feats, batch_data_samples, train_cfg) | |