# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from mmcv.cnn import ConvModule from mmengine.model.weight_init import constant_init, normal_init, xavier_init from mmaction.registry import MODELS from mmaction.utils import ConfigType, OptConfigType, SampleList class DownSample(nn.Module): """DownSample modules. It uses convolution and maxpooling to downsample the input feature, and specifies downsample position to determine `pool-conv` or `conv-pool`. Args: in_channels (int): Channel number of input features. out_channels (int): Channel number of output feature. kernel_size (int or Tuple[int]): Same as :class:`ConvModule`. Defaults to ``(3, 1, 1)``. stride (int or Tuple[int]): Same as :class:`ConvModule`. Defaults to ``(1, 1, 1)``. padding (int or Tuple[int]): Same as :class:`ConvModule`. Defaults to ``(1, 0, 0)``. groups (int): Same as :class:`ConvModule`. Defaults to 1. bias (bool or str): Same as :class:`ConvModule`. Defaults to False. conv_cfg (dict or ConfigDict): Same as :class:`ConvModule`. Defaults to ``dict(type='Conv3d')``. norm_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`. Defaults to None. act_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`. Defaults to None. downsample_position (str): Type of downsample position. Options are ``before`` and ``after``. Defaults to ``after``. downsample_scale (int or Tuple[int]): downsample scale for maxpooling. It will be used for kernel size and stride of maxpooling. Defaults to ``(1, 2, 2)``. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]] = (3, 1, 1), stride: Union[int, Tuple[int]] = (1, 1, 1), padding: Union[int, Tuple[int]] = (1, 0, 0), groups: int = 1, bias: Union[bool, str] = False, conv_cfg: ConfigType = dict(type='Conv3d'), norm_cfg: OptConfigType = None, act_cfg: OptConfigType = None, downsample_position: str = 'after', downsample_scale: Union[int, Tuple[int]] = (1, 2, 2) ) -> None: super().__init__() self.conv = ConvModule( in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=bias, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) assert downsample_position in ['before', 'after'] self.downsample_position = downsample_position self.pool = nn.MaxPool3d( downsample_scale, downsample_scale, (0, 0, 0), ceil_mode=True) def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" if self.downsample_position == 'before': x = self.pool(x) x = self.conv(x) else: x = self.conv(x) x = self.pool(x) return x class LevelFusion(nn.Module): """Level Fusion module. This module is used to aggregate the hierarchical features dynamic in visual tempos and consistent in spatial semantics. The top/bottom features for top-down/bottom-up flow would be combined to achieve two additional options, namely 'Cascade Flow' or 'Parallel Flow'. While applying a bottom-up flow after a top-down flow will lead to the cascade flow, applying them simultaneously will result in the parallel flow. Args: in_channels (Tuple[int]): Channel numbers of input features tuple. mid_channels (Tuple[int]): Channel numbers of middle features tuple. out_channels (int): Channel numbers of output features. downsample_scales (Tuple[int | Tuple[int]]): downsample scales for each :class:`DownSample` module. Defaults to ``((1, 1, 1), (1, 1, 1))``. """ def __init__( self, in_channels: Tuple[int], mid_channels: Tuple[int], out_channels: int, downsample_scales: Tuple[int, Tuple[int]] = ((1, 1, 1), (1, 1, 1)) ) -> None: super().__init__() num_stages = len(in_channels) self.downsamples = nn.ModuleList() for i in range(num_stages): downsample = DownSample( in_channels[i], mid_channels[i], kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False, padding=(0, 0, 0), groups=32, norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), downsample_position='before', downsample_scale=downsample_scales[i]) self.downsamples.append(downsample) self.fusion_conv = ConvModule( sum(mid_channels), out_channels, 1, stride=1, padding=0, bias=False, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True)) def forward(self, x: Tuple[torch.Tensor]) -> torch.Tensor: """Defines the computation performed at every call.""" out = [self.downsamples[i](feature) for i, feature in enumerate(x)] out = torch.cat(out, 1) out = self.fusion_conv(out) return out class SpatialModulation(nn.Module): """Spatial Semantic Modulation. This module is used to align spatial semantics of features in the multi-depth pyramid. For each but the top-level feature, a stack of convolutions with level-specific stride are applied to it, matching its spatial shape and receptive field with the top one. Args: in_channels (Tuple[int]): Channel numbers of input features tuple. out_channels (int): Channel numbers of output features tuple. """ def __init__(self, in_channels: Tuple[int], out_channels: int) -> None: super().__init__() self.spatial_modulation = nn.ModuleList() for channel in in_channels: downsample_scale = out_channels // channel downsample_factor = int(np.log2(downsample_scale)) op = nn.ModuleList() if downsample_factor < 1: op = nn.Identity() else: for factor in range(downsample_factor): in_factor = 2**factor out_factor = 2**(factor + 1) op.append( ConvModule( channel * in_factor, channel * out_factor, (1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True))) self.spatial_modulation.append(op) def forward(self, x: Tuple[torch.Tensor]) -> list: """Defines the computation performed at every call.""" out = [] for i, _ in enumerate(x): if isinstance(self.spatial_modulation[i], nn.ModuleList): out_ = x[i] for op in self.spatial_modulation[i]: out_ = op(out_) out.append(out_) else: out.append(self.spatial_modulation[i](x[i])) return out class AuxHead(nn.Module): """Auxiliary Head. This auxiliary head is appended to receive stronger supervision, leading to enhanced semantics. Args: in_channels (int): Channel number of input features. out_channels (int): Channel number of output features. loss_weight (float): weight of loss for the auxiliary head. Defaults to 0.5. loss_cls (dict or ConfigDict): Config for building loss. Defaults to ``dict(type='CrossEntropyLoss')``. """ def __init__( self, in_channels: int, out_channels: int, loss_weight: float = 0.5, loss_cls: ConfigType = dict(type='CrossEntropyLoss') ) -> None: super().__init__() self.conv = ConvModule( in_channels, in_channels * 2, (1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True)) self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.loss_weight = loss_weight self.dropout = nn.Dropout(p=0.5) self.fc = nn.Linear(in_channels * 2, out_channels) self.loss_cls = MODELS.build(loss_cls) def init_weights(self) -> None: """Initiate the parameters from scratch.""" for m in self.modules(): if isinstance(m, nn.Linear): normal_init(m, std=0.01) if isinstance(m, nn.Conv3d): xavier_init(m, distribution='uniform') if isinstance(m, nn.BatchNorm3d): constant_init(m, 1) def loss(self, x: torch.Tensor, data_samples: Optional[SampleList]) -> dict: """Calculate auxiliary loss.""" x = self(x) labels = [x.gt_label for x in data_samples] labels = torch.stack(labels).to(x.device) labels = labels.squeeze() if labels.shape == torch.Size([]): labels = labels.unsqueeze(0) losses = dict() losses['loss_aux'] = self.loss_weight * self.loss_cls(x, labels) return losses def forward(self, x: torch.Tensor) -> torch.Tensor: """Auxiliary head forward function.""" x = self.conv(x) x = self.avg_pool(x).squeeze(-1).squeeze(-1).squeeze(-1) x = self.dropout(x) x = self.fc(x) return x class TemporalModulation(nn.Module): """Temporal Rate Modulation. The module is used to equip TPN with a similar flexibility for temporal tempo modulation as in the input-level frame pyramid. Args: in_channels (int): Channel number of input features. out_channels (int): Channel number of output features. downsample_scale (int): Downsample scale for maxpooling. Defaults to 8. """ def __init__(self, in_channels: int, out_channels: int, downsample_scale: int = 8) -> None: super().__init__() self.conv = ConvModule( in_channels, out_channels, (3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False, groups=32, conv_cfg=dict(type='Conv3d'), act_cfg=None) self.pool = nn.MaxPool3d((downsample_scale, 1, 1), (downsample_scale, 1, 1), (0, 0, 0), ceil_mode=True) def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" x = self.conv(x) x = self.pool(x) return x @MODELS.register_module() class TPN(nn.Module): """TPN neck. This module is proposed in `Temporal Pyramid Network for Action Recognition `_ Args: in_channels (Tuple[int]): Channel numbers of input features tuple. out_channels (int): Channel number of output feature. spatial_modulation_cfg (dict or ConfigDict, optional): Config for spatial modulation layers. Required keys are ``in_channels`` and ``out_channels``. Defaults to None. temporal_modulation_cfg (dict or ConfigDict, optional): Config for temporal modulation layers. Defaults to None. upsample_cfg (dict or ConfigDict, optional): Config for upsample layers. The keys are same as that in :class:``nn.Upsample``. Defaults to None. downsample_cfg (dict or ConfigDict, optional): Config for downsample layers. Defaults to None. level_fusion_cfg (dict or ConfigDict, optional): Config for level fusion layers. Required keys are ``in_channels``, ``mid_channels``, ``out_channels``. Defaults to None. aux_head_cfg (dict or ConfigDict, optional): Config for aux head layers. Required keys are ``out_channels``. Defaults to None. flow_type (str): Flow type to combine the features. Options are ``cascade`` and ``parallel``. Defaults to ``cascade``. """ def __init__(self, in_channels: Tuple[int], out_channels: int, spatial_modulation_cfg: OptConfigType = None, temporal_modulation_cfg: OptConfigType = None, upsample_cfg: OptConfigType = None, downsample_cfg: OptConfigType = None, level_fusion_cfg: OptConfigType = None, aux_head_cfg: OptConfigType = None, flow_type: str = 'cascade') -> None: super().__init__() assert isinstance(in_channels, tuple) assert isinstance(out_channels, int) self.in_channels = in_channels self.out_channels = out_channels self.num_tpn_stages = len(in_channels) assert spatial_modulation_cfg is None or isinstance( spatial_modulation_cfg, dict) assert temporal_modulation_cfg is None or isinstance( temporal_modulation_cfg, dict) assert upsample_cfg is None or isinstance(upsample_cfg, dict) assert downsample_cfg is None or isinstance(downsample_cfg, dict) assert aux_head_cfg is None or isinstance(aux_head_cfg, dict) assert level_fusion_cfg is None or isinstance(level_fusion_cfg, dict) if flow_type not in ['cascade', 'parallel']: raise ValueError( f"flow type in TPN should be 'cascade' or 'parallel', " f'but got {flow_type} instead.') self.flow_type = flow_type self.temporal_modulation_ops = nn.ModuleList() self.upsample_ops = nn.ModuleList() self.downsample_ops = nn.ModuleList() self.level_fusion_1 = LevelFusion(**level_fusion_cfg) self.spatial_modulation = SpatialModulation(**spatial_modulation_cfg) for i in range(self.num_tpn_stages): if temporal_modulation_cfg is not None: downsample_scale = temporal_modulation_cfg[ 'downsample_scales'][i] temporal_modulation = TemporalModulation( in_channels[-1], out_channels, downsample_scale) self.temporal_modulation_ops.append(temporal_modulation) if i < self.num_tpn_stages - 1: if upsample_cfg is not None: upsample = nn.Upsample(**upsample_cfg) self.upsample_ops.append(upsample) if downsample_cfg is not None: downsample = DownSample(out_channels, out_channels, **downsample_cfg) self.downsample_ops.append(downsample) out_dims = level_fusion_cfg['out_channels'] # two pyramids self.level_fusion_2 = LevelFusion(**level_fusion_cfg) self.pyramid_fusion = ConvModule( out_dims * 2, 2048, 1, stride=1, padding=0, bias=False, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True)) if aux_head_cfg is not None: self.aux_head = AuxHead(self.in_channels[-2], **aux_head_cfg) else: self.aux_head = None def init_weights(self) -> None: """Default init_weights for conv(msra) and norm in ConvModule.""" for m in self.modules(): if isinstance(m, nn.Conv3d): xavier_init(m, distribution='uniform') if isinstance(m, nn.BatchNorm3d): constant_init(m, 1) if self.aux_head is not None: self.aux_head.init_weights() def forward(self, x: Tuple[torch.Tensor], data_samples: Optional[SampleList] = None) -> tuple: """Defines the computation performed at every call.""" loss_aux = dict() # Calculate auxiliary loss if `self.aux_head` # and `data_samples` are not None. if self.aux_head is not None and data_samples is not None: loss_aux = self.aux_head.loss(x[-2], data_samples) # Spatial Modulation spatial_modulation_outs = self.spatial_modulation(x) # Temporal Modulation temporal_modulation_outs = [] for i, temporal_modulation in enumerate(self.temporal_modulation_ops): temporal_modulation_outs.append( temporal_modulation(spatial_modulation_outs[i])) outs = [out.clone() for out in temporal_modulation_outs] if len(self.upsample_ops) != 0: for i in range(self.num_tpn_stages - 1, 0, -1): outs[i - 1] = outs[i - 1] + self.upsample_ops[i - 1](outs[i]) # Get top-down outs top_down_outs = self.level_fusion_1(outs) # Build bottom-up flow using downsample operation if self.flow_type == 'parallel': outs = [out.clone() for out in temporal_modulation_outs] if len(self.downsample_ops) != 0: for i in range(self.num_tpn_stages - 1): outs[i + 1] = outs[i + 1] + self.downsample_ops[i](outs[i]) # Get bottom-up outs botton_up_outs = self.level_fusion_2(outs) # fuse two pyramid outs outs = self.pyramid_fusion( torch.cat([top_down_outs, botton_up_outs], 1)) return outs, loss_aux