|
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from mmcv.cnn import ConvModule
|
|
|
from mmengine.model.weight_init import constant_init, normal_init, xavier_init
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import ConfigType, OptConfigType, SampleList
|
|
|
|
|
|
|
|
|
class DownSample(nn.Module):
|
|
|
"""DownSample modules.
|
|
|
|
|
|
It uses convolution and maxpooling to downsample the input feature,
|
|
|
and specifies downsample position to determine `pool-conv` or `conv-pool`.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel number of input features.
|
|
|
out_channels (int): Channel number of output feature.
|
|
|
kernel_size (int or Tuple[int]): Same as :class:`ConvModule`.
|
|
|
Defaults to ``(3, 1, 1)``.
|
|
|
stride (int or Tuple[int]): Same as :class:`ConvModule`.
|
|
|
Defaults to ``(1, 1, 1)``.
|
|
|
padding (int or Tuple[int]): Same as :class:`ConvModule`.
|
|
|
Defaults to ``(1, 0, 0)``.
|
|
|
groups (int): Same as :class:`ConvModule`. Defaults to 1.
|
|
|
bias (bool or str): Same as :class:`ConvModule`. Defaults to False.
|
|
|
conv_cfg (dict or ConfigDict): Same as :class:`ConvModule`.
|
|
|
Defaults to ``dict(type='Conv3d')``.
|
|
|
norm_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`.
|
|
|
Defaults to None.
|
|
|
act_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`.
|
|
|
Defaults to None.
|
|
|
downsample_position (str): Type of downsample position. Options are
|
|
|
``before`` and ``after``. Defaults to ``after``.
|
|
|
downsample_scale (int or Tuple[int]): downsample scale for maxpooling.
|
|
|
It will be used for kernel size and stride of maxpooling.
|
|
|
Defaults to ``(1, 2, 2)``.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
kernel_size: Union[int, Tuple[int]] = (3, 1, 1),
|
|
|
stride: Union[int, Tuple[int]] = (1, 1, 1),
|
|
|
padding: Union[int, Tuple[int]] = (1, 0, 0),
|
|
|
groups: int = 1,
|
|
|
bias: Union[bool, str] = False,
|
|
|
conv_cfg: ConfigType = dict(type='Conv3d'),
|
|
|
norm_cfg: OptConfigType = None,
|
|
|
act_cfg: OptConfigType = None,
|
|
|
downsample_position: str = 'after',
|
|
|
downsample_scale: Union[int, Tuple[int]] = (1, 2, 2)
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
self.conv = ConvModule(
|
|
|
in_channels,
|
|
|
out_channels,
|
|
|
kernel_size,
|
|
|
stride,
|
|
|
padding,
|
|
|
groups=groups,
|
|
|
bias=bias,
|
|
|
conv_cfg=conv_cfg,
|
|
|
norm_cfg=norm_cfg,
|
|
|
act_cfg=act_cfg)
|
|
|
assert downsample_position in ['before', 'after']
|
|
|
self.downsample_position = downsample_position
|
|
|
self.pool = nn.MaxPool3d(
|
|
|
downsample_scale, downsample_scale, (0, 0, 0), ceil_mode=True)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
if self.downsample_position == 'before':
|
|
|
x = self.pool(x)
|
|
|
x = self.conv(x)
|
|
|
else:
|
|
|
x = self.conv(x)
|
|
|
x = self.pool(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class LevelFusion(nn.Module):
|
|
|
"""Level Fusion module.
|
|
|
|
|
|
This module is used to aggregate the hierarchical features dynamic in
|
|
|
visual tempos and consistent in spatial semantics. The top/bottom features
|
|
|
for top-down/bottom-up flow would be combined to achieve two additional
|
|
|
options, namely 'Cascade Flow' or 'Parallel Flow'. While applying a
|
|
|
bottom-up flow after a top-down flow will lead to the cascade flow,
|
|
|
applying them simultaneously will result in the parallel flow.
|
|
|
|
|
|
Args:
|
|
|
in_channels (Tuple[int]): Channel numbers of input features tuple.
|
|
|
mid_channels (Tuple[int]): Channel numbers of middle features tuple.
|
|
|
out_channels (int): Channel numbers of output features.
|
|
|
downsample_scales (Tuple[int | Tuple[int]]): downsample scales for
|
|
|
each :class:`DownSample` module.
|
|
|
Defaults to ``((1, 1, 1), (1, 1, 1))``.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: Tuple[int],
|
|
|
mid_channels: Tuple[int],
|
|
|
out_channels: int,
|
|
|
downsample_scales: Tuple[int, Tuple[int]] = ((1, 1, 1), (1, 1, 1))
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
num_stages = len(in_channels)
|
|
|
|
|
|
self.downsamples = nn.ModuleList()
|
|
|
for i in range(num_stages):
|
|
|
downsample = DownSample(
|
|
|
in_channels[i],
|
|
|
mid_channels[i],
|
|
|
kernel_size=(1, 1, 1),
|
|
|
stride=(1, 1, 1),
|
|
|
bias=False,
|
|
|
padding=(0, 0, 0),
|
|
|
groups=32,
|
|
|
norm_cfg=dict(type='BN3d', requires_grad=True),
|
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
|
downsample_position='before',
|
|
|
downsample_scale=downsample_scales[i])
|
|
|
self.downsamples.append(downsample)
|
|
|
|
|
|
self.fusion_conv = ConvModule(
|
|
|
sum(mid_channels),
|
|
|
out_channels,
|
|
|
1,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
conv_cfg=dict(type='Conv3d'),
|
|
|
norm_cfg=dict(type='BN3d', requires_grad=True),
|
|
|
act_cfg=dict(type='ReLU', inplace=True))
|
|
|
|
|
|
def forward(self, x: Tuple[torch.Tensor]) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
out = [self.downsamples[i](feature) for i, feature in enumerate(x)]
|
|
|
out = torch.cat(out, 1)
|
|
|
out = self.fusion_conv(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
class SpatialModulation(nn.Module):
|
|
|
"""Spatial Semantic Modulation.
|
|
|
|
|
|
This module is used to align spatial semantics of features in the
|
|
|
multi-depth pyramid. For each but the top-level feature, a stack
|
|
|
of convolutions with level-specific stride are applied to it, matching
|
|
|
its spatial shape and receptive field with the top one.
|
|
|
|
|
|
Args:
|
|
|
in_channels (Tuple[int]): Channel numbers of input features tuple.
|
|
|
out_channels (int): Channel numbers of output features tuple.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels: Tuple[int], out_channels: int) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.spatial_modulation = nn.ModuleList()
|
|
|
for channel in in_channels:
|
|
|
downsample_scale = out_channels // channel
|
|
|
downsample_factor = int(np.log2(downsample_scale))
|
|
|
op = nn.ModuleList()
|
|
|
if downsample_factor < 1:
|
|
|
op = nn.Identity()
|
|
|
else:
|
|
|
for factor in range(downsample_factor):
|
|
|
in_factor = 2**factor
|
|
|
out_factor = 2**(factor + 1)
|
|
|
op.append(
|
|
|
ConvModule(
|
|
|
channel * in_factor,
|
|
|
channel * out_factor, (1, 3, 3),
|
|
|
stride=(1, 2, 2),
|
|
|
padding=(0, 1, 1),
|
|
|
bias=False,
|
|
|
conv_cfg=dict(type='Conv3d'),
|
|
|
norm_cfg=dict(type='BN3d', requires_grad=True),
|
|
|
act_cfg=dict(type='ReLU', inplace=True)))
|
|
|
self.spatial_modulation.append(op)
|
|
|
|
|
|
def forward(self, x: Tuple[torch.Tensor]) -> list:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
out = []
|
|
|
for i, _ in enumerate(x):
|
|
|
if isinstance(self.spatial_modulation[i], nn.ModuleList):
|
|
|
out_ = x[i]
|
|
|
for op in self.spatial_modulation[i]:
|
|
|
out_ = op(out_)
|
|
|
out.append(out_)
|
|
|
else:
|
|
|
out.append(self.spatial_modulation[i](x[i]))
|
|
|
return out
|
|
|
|
|
|
|
|
|
class AuxHead(nn.Module):
|
|
|
"""Auxiliary Head.
|
|
|
|
|
|
This auxiliary head is appended to receive stronger supervision,
|
|
|
leading to enhanced semantics.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel number of input features.
|
|
|
out_channels (int): Channel number of output features.
|
|
|
loss_weight (float): weight of loss for the auxiliary head.
|
|
|
Defaults to 0.5.
|
|
|
loss_cls (dict or ConfigDict): Config for building loss.
|
|
|
Defaults to ``dict(type='CrossEntropyLoss')``.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
loss_weight: float = 0.5,
|
|
|
loss_cls: ConfigType = dict(type='CrossEntropyLoss')
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.conv = ConvModule(
|
|
|
in_channels,
|
|
|
in_channels * 2, (1, 3, 3),
|
|
|
stride=(1, 2, 2),
|
|
|
padding=(0, 1, 1),
|
|
|
bias=False,
|
|
|
conv_cfg=dict(type='Conv3d'),
|
|
|
norm_cfg=dict(type='BN3d', requires_grad=True))
|
|
|
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
|
|
self.loss_weight = loss_weight
|
|
|
self.dropout = nn.Dropout(p=0.5)
|
|
|
self.fc = nn.Linear(in_channels * 2, out_channels)
|
|
|
self.loss_cls = MODELS.build(loss_cls)
|
|
|
|
|
|
def init_weights(self) -> None:
|
|
|
"""Initiate the parameters from scratch."""
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
normal_init(m, std=0.01)
|
|
|
if isinstance(m, nn.Conv3d):
|
|
|
xavier_init(m, distribution='uniform')
|
|
|
if isinstance(m, nn.BatchNorm3d):
|
|
|
constant_init(m, 1)
|
|
|
|
|
|
def loss(self, x: torch.Tensor,
|
|
|
data_samples: Optional[SampleList]) -> dict:
|
|
|
"""Calculate auxiliary loss."""
|
|
|
x = self(x)
|
|
|
labels = [x.gt_label for x in data_samples]
|
|
|
labels = torch.stack(labels).to(x.device)
|
|
|
labels = labels.squeeze()
|
|
|
if labels.shape == torch.Size([]):
|
|
|
labels = labels.unsqueeze(0)
|
|
|
|
|
|
losses = dict()
|
|
|
losses['loss_aux'] = self.loss_weight * self.loss_cls(x, labels)
|
|
|
return losses
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Auxiliary head forward function."""
|
|
|
x = self.conv(x)
|
|
|
x = self.avg_pool(x).squeeze(-1).squeeze(-1).squeeze(-1)
|
|
|
x = self.dropout(x)
|
|
|
x = self.fc(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TemporalModulation(nn.Module):
|
|
|
"""Temporal Rate Modulation.
|
|
|
|
|
|
The module is used to equip TPN with a similar flexibility for temporal
|
|
|
tempo modulation as in the input-level frame pyramid.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel number of input features.
|
|
|
out_channels (int): Channel number of output features.
|
|
|
downsample_scale (int): Downsample scale for maxpooling. Defaults to 8.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
downsample_scale: int = 8) -> None:
|
|
|
super().__init__()
|
|
|
|
|
|
self.conv = ConvModule(
|
|
|
in_channels,
|
|
|
out_channels, (3, 1, 1),
|
|
|
stride=(1, 1, 1),
|
|
|
padding=(1, 0, 0),
|
|
|
bias=False,
|
|
|
groups=32,
|
|
|
conv_cfg=dict(type='Conv3d'),
|
|
|
act_cfg=None)
|
|
|
self.pool = nn.MaxPool3d((downsample_scale, 1, 1),
|
|
|
(downsample_scale, 1, 1), (0, 0, 0),
|
|
|
ceil_mode=True)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
x = self.conv(x)
|
|
|
x = self.pool(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class TPN(nn.Module):
|
|
|
"""TPN neck.
|
|
|
|
|
|
This module is proposed in `Temporal Pyramid Network for Action Recognition
|
|
|
<https://arxiv.org/pdf/2004.03548.pdf>`_
|
|
|
|
|
|
Args:
|
|
|
in_channels (Tuple[int]): Channel numbers of input features tuple.
|
|
|
out_channels (int): Channel number of output feature.
|
|
|
spatial_modulation_cfg (dict or ConfigDict, optional): Config for
|
|
|
spatial modulation layers. Required keys are ``in_channels`` and
|
|
|
``out_channels``. Defaults to None.
|
|
|
temporal_modulation_cfg (dict or ConfigDict, optional): Config for
|
|
|
temporal modulation layers. Defaults to None.
|
|
|
upsample_cfg (dict or ConfigDict, optional): Config for upsample
|
|
|
layers. The keys are same as that in :class:``nn.Upsample``.
|
|
|
Defaults to None.
|
|
|
downsample_cfg (dict or ConfigDict, optional): Config for downsample
|
|
|
layers. Defaults to None.
|
|
|
level_fusion_cfg (dict or ConfigDict, optional): Config for level
|
|
|
fusion layers.
|
|
|
Required keys are ``in_channels``, ``mid_channels``,
|
|
|
``out_channels``. Defaults to None.
|
|
|
aux_head_cfg (dict or ConfigDict, optional): Config for aux head
|
|
|
layers. Required keys are ``out_channels``. Defaults to None.
|
|
|
flow_type (str): Flow type to combine the features. Options are
|
|
|
``cascade`` and ``parallel``. Defaults to ``cascade``.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
in_channels: Tuple[int],
|
|
|
out_channels: int,
|
|
|
spatial_modulation_cfg: OptConfigType = None,
|
|
|
temporal_modulation_cfg: OptConfigType = None,
|
|
|
upsample_cfg: OptConfigType = None,
|
|
|
downsample_cfg: OptConfigType = None,
|
|
|
level_fusion_cfg: OptConfigType = None,
|
|
|
aux_head_cfg: OptConfigType = None,
|
|
|
flow_type: str = 'cascade') -> None:
|
|
|
super().__init__()
|
|
|
assert isinstance(in_channels, tuple)
|
|
|
assert isinstance(out_channels, int)
|
|
|
self.in_channels = in_channels
|
|
|
self.out_channels = out_channels
|
|
|
self.num_tpn_stages = len(in_channels)
|
|
|
|
|
|
assert spatial_modulation_cfg is None or isinstance(
|
|
|
spatial_modulation_cfg, dict)
|
|
|
assert temporal_modulation_cfg is None or isinstance(
|
|
|
temporal_modulation_cfg, dict)
|
|
|
assert upsample_cfg is None or isinstance(upsample_cfg, dict)
|
|
|
assert downsample_cfg is None or isinstance(downsample_cfg, dict)
|
|
|
assert aux_head_cfg is None or isinstance(aux_head_cfg, dict)
|
|
|
assert level_fusion_cfg is None or isinstance(level_fusion_cfg, dict)
|
|
|
|
|
|
if flow_type not in ['cascade', 'parallel']:
|
|
|
raise ValueError(
|
|
|
f"flow type in TPN should be 'cascade' or 'parallel', "
|
|
|
f'but got {flow_type} instead.')
|
|
|
self.flow_type = flow_type
|
|
|
|
|
|
self.temporal_modulation_ops = nn.ModuleList()
|
|
|
self.upsample_ops = nn.ModuleList()
|
|
|
self.downsample_ops = nn.ModuleList()
|
|
|
|
|
|
self.level_fusion_1 = LevelFusion(**level_fusion_cfg)
|
|
|
self.spatial_modulation = SpatialModulation(**spatial_modulation_cfg)
|
|
|
|
|
|
for i in range(self.num_tpn_stages):
|
|
|
|
|
|
if temporal_modulation_cfg is not None:
|
|
|
downsample_scale = temporal_modulation_cfg[
|
|
|
'downsample_scales'][i]
|
|
|
temporal_modulation = TemporalModulation(
|
|
|
in_channels[-1], out_channels, downsample_scale)
|
|
|
self.temporal_modulation_ops.append(temporal_modulation)
|
|
|
|
|
|
if i < self.num_tpn_stages - 1:
|
|
|
if upsample_cfg is not None:
|
|
|
upsample = nn.Upsample(**upsample_cfg)
|
|
|
self.upsample_ops.append(upsample)
|
|
|
|
|
|
if downsample_cfg is not None:
|
|
|
downsample = DownSample(out_channels, out_channels,
|
|
|
**downsample_cfg)
|
|
|
self.downsample_ops.append(downsample)
|
|
|
|
|
|
out_dims = level_fusion_cfg['out_channels']
|
|
|
|
|
|
|
|
|
self.level_fusion_2 = LevelFusion(**level_fusion_cfg)
|
|
|
|
|
|
self.pyramid_fusion = ConvModule(
|
|
|
out_dims * 2,
|
|
|
2048,
|
|
|
1,
|
|
|
stride=1,
|
|
|
padding=0,
|
|
|
bias=False,
|
|
|
conv_cfg=dict(type='Conv3d'),
|
|
|
norm_cfg=dict(type='BN3d', requires_grad=True))
|
|
|
|
|
|
if aux_head_cfg is not None:
|
|
|
self.aux_head = AuxHead(self.in_channels[-2], **aux_head_cfg)
|
|
|
else:
|
|
|
self.aux_head = None
|
|
|
|
|
|
def init_weights(self) -> None:
|
|
|
"""Default init_weights for conv(msra) and norm in ConvModule."""
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Conv3d):
|
|
|
xavier_init(m, distribution='uniform')
|
|
|
if isinstance(m, nn.BatchNorm3d):
|
|
|
constant_init(m, 1)
|
|
|
|
|
|
if self.aux_head is not None:
|
|
|
self.aux_head.init_weights()
|
|
|
|
|
|
def forward(self,
|
|
|
x: Tuple[torch.Tensor],
|
|
|
data_samples: Optional[SampleList] = None) -> tuple:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
|
|
|
loss_aux = dict()
|
|
|
|
|
|
|
|
|
if self.aux_head is not None and data_samples is not None:
|
|
|
loss_aux = self.aux_head.loss(x[-2], data_samples)
|
|
|
|
|
|
|
|
|
spatial_modulation_outs = self.spatial_modulation(x)
|
|
|
|
|
|
|
|
|
temporal_modulation_outs = []
|
|
|
for i, temporal_modulation in enumerate(self.temporal_modulation_ops):
|
|
|
temporal_modulation_outs.append(
|
|
|
temporal_modulation(spatial_modulation_outs[i]))
|
|
|
|
|
|
outs = [out.clone() for out in temporal_modulation_outs]
|
|
|
if len(self.upsample_ops) != 0:
|
|
|
for i in range(self.num_tpn_stages - 1, 0, -1):
|
|
|
outs[i - 1] = outs[i - 1] + self.upsample_ops[i - 1](outs[i])
|
|
|
|
|
|
|
|
|
top_down_outs = self.level_fusion_1(outs)
|
|
|
|
|
|
|
|
|
if self.flow_type == 'parallel':
|
|
|
outs = [out.clone() for out in temporal_modulation_outs]
|
|
|
if len(self.downsample_ops) != 0:
|
|
|
for i in range(self.num_tpn_stages - 1):
|
|
|
outs[i + 1] = outs[i + 1] + self.downsample_ops[i](outs[i])
|
|
|
|
|
|
|
|
|
botton_up_outs = self.level_fusion_2(outs)
|
|
|
|
|
|
|
|
|
outs = self.pyramid_fusion(
|
|
|
torch.cat([top_down_outs, botton_up_outs], 1))
|
|
|
|
|
|
return outs, loss_aux
|
|
|
|