|
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from mmengine.model import BaseModule
|
|
|
|
|
|
from mmaction.evaluation import top_k_accuracy
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import ForwardResults, SampleList
|
|
|
|
|
|
|
|
|
class AvgConsensus(nn.Module):
|
|
|
"""Average consensus module.
|
|
|
|
|
|
Args:
|
|
|
dim (int): Decide which dim consensus function to apply.
|
|
|
Defaults to 1.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim: int = 1) -> None:
|
|
|
super().__init__()
|
|
|
self.dim = dim
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
return x.mean(dim=self.dim, keepdim=True)
|
|
|
|
|
|
|
|
|
class BaseHead(BaseModule, metaclass=ABCMeta):
|
|
|
"""Base class for head.
|
|
|
|
|
|
All Head should subclass it.
|
|
|
All subclass should overwrite:
|
|
|
- :meth:`forward`, supporting to forward both for training and testing.
|
|
|
|
|
|
Args:
|
|
|
num_classes (int): Number of classes to be classified.
|
|
|
in_channels (int): Number of channels in input feature.
|
|
|
loss_cls (dict): Config for building loss.
|
|
|
Defaults to ``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
|
|
|
multi_class (bool): Determines whether it is a multi-class
|
|
|
recognition task. Defaults to False.
|
|
|
label_smooth_eps (float): Epsilon used in label smooth.
|
|
|
Reference: arxiv.org/abs/1906.02629. Defaults to 0.
|
|
|
topk (int or tuple): Top-k accuracy. Defaults to ``(1, 5)``.
|
|
|
average_clips (dict, optional): Config for averaging class
|
|
|
scores over multiple clips. Defaults to None.
|
|
|
init_cfg (dict, optional): Config to control the initialization.
|
|
|
Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
num_classes: int,
|
|
|
in_channels: int,
|
|
|
loss_cls: Dict = dict(
|
|
|
type='CrossEntropyLoss', loss_weight=1.0),
|
|
|
multi_class: bool = False,
|
|
|
label_smooth_eps: float = 0.0,
|
|
|
topk: Union[int, Tuple[int]] = (1, 5),
|
|
|
average_clips: Optional[Dict] = None,
|
|
|
init_cfg: Optional[Dict] = None) -> None:
|
|
|
super(BaseHead, self).__init__(init_cfg=init_cfg)
|
|
|
self.num_classes = num_classes
|
|
|
self.in_channels = in_channels
|
|
|
self.loss_cls = MODELS.build(loss_cls)
|
|
|
self.multi_class = multi_class
|
|
|
self.label_smooth_eps = label_smooth_eps
|
|
|
self.average_clips = average_clips
|
|
|
assert isinstance(topk, (int, tuple))
|
|
|
if isinstance(topk, int):
|
|
|
topk = (topk, )
|
|
|
for _topk in topk:
|
|
|
assert _topk > 0, 'Top-k should be larger than 0'
|
|
|
self.topk = topk
|
|
|
|
|
|
@abstractmethod
|
|
|
def forward(self, x, **kwargs) -> ForwardResults:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def loss(self, feats: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
|
data_samples: SampleList, **kwargs) -> Dict:
|
|
|
"""Perform forward propagation of head and loss calculation on the
|
|
|
features of the upstream network.
|
|
|
|
|
|
Args:
|
|
|
feats (torch.Tensor | tuple[torch.Tensor]): Features from
|
|
|
upstream network.
|
|
|
data_samples (list[:obj:`ActionDataSample`]): The batch
|
|
|
data samples.
|
|
|
|
|
|
Returns:
|
|
|
dict: A dictionary of loss components.
|
|
|
"""
|
|
|
cls_scores = self(feats, **kwargs)
|
|
|
return self.loss_by_feat(cls_scores, data_samples)
|
|
|
|
|
|
def loss_by_feat(self, cls_scores: torch.Tensor,
|
|
|
data_samples: SampleList) -> Dict:
|
|
|
"""Calculate the loss based on the features extracted by the head.
|
|
|
|
|
|
Args:
|
|
|
cls_scores (torch.Tensor): Classification prediction results of
|
|
|
all class, has shape (batch_size, num_classes).
|
|
|
data_samples (list[:obj:`ActionDataSample`]): The batch
|
|
|
data samples.
|
|
|
|
|
|
Returns:
|
|
|
dict: A dictionary of loss components.
|
|
|
"""
|
|
|
labels = [x.gt_label for x in data_samples]
|
|
|
labels = torch.stack(labels).to(cls_scores.device)
|
|
|
labels = labels.squeeze()
|
|
|
|
|
|
losses = dict()
|
|
|
if labels.shape == torch.Size([]):
|
|
|
labels = labels.unsqueeze(0)
|
|
|
elif labels.dim() == 1 and labels.size()[0] == self.num_classes \
|
|
|
and cls_scores.size()[0] == 1:
|
|
|
|
|
|
|
|
|
|
|
|
labels = labels.unsqueeze(0)
|
|
|
|
|
|
if cls_scores.size() != labels.size():
|
|
|
top_k_acc = top_k_accuracy(cls_scores.detach().cpu().numpy(),
|
|
|
labels.detach().cpu().numpy(),
|
|
|
self.topk)
|
|
|
for k, a in zip(self.topk, top_k_acc):
|
|
|
losses[f'top{k}_acc'] = torch.tensor(
|
|
|
a, device=cls_scores.device)
|
|
|
if self.label_smooth_eps != 0:
|
|
|
if cls_scores.size() != labels.size():
|
|
|
labels = F.one_hot(labels, num_classes=self.num_classes)
|
|
|
labels = ((1 - self.label_smooth_eps) * labels +
|
|
|
self.label_smooth_eps / self.num_classes)
|
|
|
|
|
|
loss_cls = self.loss_cls(cls_scores, labels)
|
|
|
|
|
|
if isinstance(loss_cls, dict):
|
|
|
losses.update(loss_cls)
|
|
|
else:
|
|
|
losses['loss_cls'] = loss_cls
|
|
|
return losses
|
|
|
|
|
|
def predict(self, feats: Union[torch.Tensor, Tuple[torch.Tensor]],
|
|
|
data_samples: SampleList, **kwargs) -> SampleList:
|
|
|
"""Perform forward propagation of head and predict recognition results
|
|
|
on the features of the upstream network.
|
|
|
|
|
|
Args:
|
|
|
feats (torch.Tensor | tuple[torch.Tensor]): Features from
|
|
|
upstream network.
|
|
|
data_samples (list[:obj:`ActionDataSample`]): The batch
|
|
|
data samples.
|
|
|
|
|
|
Returns:
|
|
|
list[:obj:`ActionDataSample`]: Recognition results wrapped
|
|
|
by :obj:`ActionDataSample`.
|
|
|
"""
|
|
|
cls_scores = self(feats, **kwargs)
|
|
|
return self.predict_by_feat(cls_scores, data_samples)
|
|
|
|
|
|
def predict_by_feat(self, cls_scores: torch.Tensor,
|
|
|
data_samples: SampleList) -> SampleList:
|
|
|
"""Transform a batch of output features extracted from the head into
|
|
|
prediction results.
|
|
|
|
|
|
Args:
|
|
|
cls_scores (torch.Tensor): Classification scores, has a shape
|
|
|
(B*num_segs, num_classes)
|
|
|
data_samples (list[:obj:`ActionDataSample`]): The
|
|
|
annotation data of every samples. It usually includes
|
|
|
information such as `gt_label`.
|
|
|
|
|
|
Returns:
|
|
|
List[:obj:`ActionDataSample`]: Recognition results wrapped
|
|
|
by :obj:`ActionDataSample`.
|
|
|
"""
|
|
|
num_segs = cls_scores.shape[0] // len(data_samples)
|
|
|
cls_scores = self.average_clip(cls_scores, num_segs=num_segs)
|
|
|
pred_labels = cls_scores.argmax(dim=-1, keepdim=True).detach()
|
|
|
|
|
|
for data_sample, score, pred_label in zip(data_samples, cls_scores,
|
|
|
pred_labels):
|
|
|
data_sample.set_pred_score(score)
|
|
|
data_sample.set_pred_label(pred_label)
|
|
|
return data_samples
|
|
|
|
|
|
def average_clip(self,
|
|
|
cls_scores: torch.Tensor,
|
|
|
num_segs: int = 1) -> torch.Tensor:
|
|
|
"""Averaging class scores over multiple clips.
|
|
|
|
|
|
Using different averaging types ('score' or 'prob' or None,
|
|
|
which defined in test_cfg) to computed the final averaged
|
|
|
class score. Only called in test mode.
|
|
|
|
|
|
Args:
|
|
|
cls_scores (torch.Tensor): Class scores to be averaged.
|
|
|
num_segs (int): Number of clips for each input sample.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: Averaged class scores.
|
|
|
"""
|
|
|
|
|
|
if self.average_clips not in ['score', 'prob', None]:
|
|
|
raise ValueError(f'{self.average_clips} is not supported. '
|
|
|
f'Currently supported ones are '
|
|
|
f'["score", "prob", None]')
|
|
|
|
|
|
batch_size = cls_scores.shape[0]
|
|
|
cls_scores = cls_scores.view((batch_size // num_segs, num_segs) +
|
|
|
cls_scores.shape[1:])
|
|
|
|
|
|
if self.average_clips is None:
|
|
|
return cls_scores
|
|
|
elif self.average_clips == 'prob':
|
|
|
cls_scores = F.softmax(cls_scores, dim=2).mean(dim=1)
|
|
|
elif self.average_clips == 'score':
|
|
|
cls_scores = cls_scores.mean(dim=1)
|
|
|
|
|
|
return cls_scores
|
|
|
|