Stylique's picture
Upload folder using huggingface_hub
789eef1 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from typing import Tuple, Union
from mmengine.model import BaseModule
from mmengine.structures import InstanceData
from torch import Tensor
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (Features, InstanceList, OptConfigType,
OptSampleList, Predictions)
class BaseHead(BaseModule, metaclass=ABCMeta):
"""Base head. A subclass should override :meth:`predict` and :meth:`loss`.
Args:
init_cfg (dict, optional): The extra init config of layers.
Defaults to None.
"""
@abstractmethod
def forward(self, feats: Tuple[Tensor]):
"""Forward the network."""
@abstractmethod
def predict(self,
feats: Features,
batch_data_samples: OptSampleList,
test_cfg: OptConfigType = {}) -> Predictions:
"""Predict results from features."""
@abstractmethod
def loss(self,
feats: Tuple[Tensor],
batch_data_samples: OptSampleList,
train_cfg: OptConfigType = {}) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
def decode(self, batch_outputs: Union[Tensor,
Tuple[Tensor]]) -> InstanceList:
"""Decode keypoints from outputs.
Args:
batch_outputs (Tensor | Tuple[Tensor]): The network outputs of
a data batch
Returns:
List[InstanceData]: A list of InstanceData, each contains the
decoded pose information of the instances of one data sample.
"""
def _pack_and_call(args, func):
if not isinstance(args, tuple):
args = (args, )
return func(*args)
if self.decoder is None:
raise RuntimeError(
f'The decoder has not been set in {self.__class__.__name__}. '
'Please set the decoder configs in the init parameters to '
'enable head methods `head.predict()` and `head.decode()`')
if self.decoder.support_batch_decoding:
batch_keypoints, batch_scores = _pack_and_call(
batch_outputs, self.decoder.batch_decode)
else:
batch_output_np = to_numpy(batch_outputs, unzip=True)
batch_keypoints = []
batch_scores = []
for outputs in batch_output_np:
keypoints, scores = _pack_and_call(outputs,
self.decoder.decode)
batch_keypoints.append(keypoints)
batch_scores.append(scores)
preds = [
InstanceData(keypoints=keypoints, keypoint_scores=scores)
for keypoints, scores in zip(batch_keypoints, batch_scores)
]
return preds