File size: 5,499 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from mmaction.registry import MODELS
from .base import BaseHead
@MODELS.register_module()
class FeatureHead(BaseHead):
"""General head for feature extraction.
Args:
spatial_type (str, optional): Pooling type in spatial dimension.
Default: 'avg'. If set to None, means keeping spatial dimension,
and for GCN backbone, keeping last two dimension(T, V).
temporal_type (str, optional): Pooling type in temporal dimension.
Default: 'avg'. If set to None, meanse keeping temporal dimnsion,
and for GCN backbone, keeping dimesion M. Please note that the
channel order would keep same with the output of backbone,
[N, T, C, H, W] for 2D recognizer, and [N, M, C, T, V] for GCN
recognizer.
backbone_name (str, optional): Backbone name to specifying special
operations.Currently supports: `'tsm'`, `'slowfast'`, and `'gcn'`.
Defaults to None, means take the input as normal feature.
num_segments (int, optional): Number of frame segments for TSM
backbone. Defaults to None.
kwargs (dict, optional): Any keyword argument to be used to initialize
the head.
"""
def __init__(self,
spatial_type: str = 'avg',
temporal_type: str = 'avg',
backbone_name: Optional[str] = None,
num_segments: Optional[str] = None,
**kwargs) -> None:
super().__init__(None, None, **kwargs)
self.temporal_type = temporal_type
self.backbone_name = backbone_name
self.num_segments = num_segments
if spatial_type == 'avg':
self.pool2d = torch.mean
elif spatial_type == 'max':
self.pool2d = torch.max
elif spatial_type is None:
self.pool2d = lambda x, dim: x
else:
raise NotImplementedError(
f'Unsupported spatial_type {spatial_type}')
if temporal_type == 'avg':
self.pool1d = torch.mean
elif temporal_type == 'max':
self.pool1d = torch.max
elif temporal_type is None:
self.pool1d = lambda x, dim: x
else:
raise NotImplementedError(
f'Unsupported temporal_type {temporal_type}')
def forward(self,
x: Tensor,
num_segs: Optional[int] = None,
**kwargs) -> Tensor:
"""Defines the computation performed at every call.
Args:
x (Tensor): The input data.
num_segs (int): For 2D backbone. Number of segments into which
a video is divided. Defaults to None.
Returns:
Tensor: The output features after pooling.
"""
if isinstance(x, Tensor):
n_dims = x.ndim
elif isinstance(x, tuple):
n_dims = x[0].ndim
assert self.backbone_name == 'slowfast', \
'Only support SlowFast backbone to input tuple'
else:
raise NotImplementedError(f'Unsupported feature type: {type(x)}')
# For 2D backbone with spatial dimension
if n_dims == 4:
assert num_segs is not None
if self.backbone_name == 'tsm':
assert self.num_segments is not None, \
'Please Specify num_segments for TSM'
num_segs = self.num_segments
# [N, T, channels, H, W]
x = x.view((-1, num_segs) + x.shape[1:])
feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1)
elif n_dims == 5:
if self.backbone_name == 'slowfast':
x_slow, x_fast = x
assert self.temporal_type is not None, \
'slowfast backbone has to pool temporal dimension'
x_fast = self.pool1d(self.pool2d(x_fast, dim=[-2, -1]), dim=2)
x_slow = self.pool1d(self.pool2d(x_slow, dim=[-2, -1]), dim=2)
feat = torch.cat((x_slow, x_fast), dim=1)
# For GCN-based backbone
elif self.backbone_name == 'gcn':
# N, M, C, T, V
feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1)
# For 3D backbone with spatial dimension
else:
# [N, channels, T, H, W]
feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=2)
# For backbone output feature without spatial and temporal dimension
elif n_dims == 2:
# [N, channels]
feat = x
return feat
def predict_by_feat(self, feats: Union[Tensor, Tuple[Tensor]],
data_samples) -> Tensor:
"""Integrate multi-view features into one tensor.
Args:
feats (torch.Tensor | tuple[torch.Tensor]): Features from
upstream network.
data_samples (list[:obj:`ActionDataSample`]): The batch
data samples.
Returns:
Tensor: The integrated multi-view features.
"""
num_segs = feats.shape[0] // len(data_samples)
feats = self.average_clip(feats, num_segs=num_segs)
return feats
|