|
|
| from copy import deepcopy
|
| from typing import Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils import checkpoint as cp
|
|
|
| from mmaction.registry import MODELS
|
| from ..common import TAM
|
| from .resnet import Bottleneck, ResNet
|
|
|
|
|
| class TABlock(nn.Module):
|
| """Temporal Adaptive Block (TA-Block) for TANet.
|
|
|
| This block is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
|
| RECOGNITION <https://arxiv.org/pdf/2005.06803>`_
|
|
|
| The temporal adaptive module (TAM) is embedded into ResNet-Block
|
| after the first Conv2D, which turns the vanilla ResNet-Block
|
| into TA-Block.
|
|
|
| Args:
|
| block (nn.Module): Residual blocks to be substituted.
|
| num_segments (int): Number of frame segments.
|
| tam_cfg (dict): Config for temporal adaptive module (TAM).
|
| """
|
|
|
| def __init__(self, block: nn.Module, num_segments: int,
|
| tam_cfg: dict) -> None:
|
| super().__init__()
|
| self.tam_cfg = deepcopy(tam_cfg)
|
| self.block = block
|
| self.num_segments = num_segments
|
| self.tam = TAM(
|
| in_channels=block.conv1.out_channels,
|
| num_segments=num_segments,
|
| **self.tam_cfg)
|
|
|
| if not isinstance(self.block, Bottleneck):
|
| raise NotImplementedError('TA-Blocks have not been fully '
|
| 'implemented except the pattern based '
|
| 'on Bottleneck block.')
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """Defines the computation performed at every call."""
|
| assert isinstance(self.block, Bottleneck)
|
|
|
| def _inner_forward(x):
|
| """Forward wrapper for utilizing checkpoint."""
|
| identity = x
|
|
|
| out = self.block.conv1(x)
|
| out = self.tam(out)
|
| out = self.block.conv2(out)
|
| out = self.block.conv3(out)
|
|
|
| if self.block.downsample is not None:
|
| identity = self.block.downsample(x)
|
|
|
| out = out + identity
|
|
|
| return out
|
|
|
| if self.block.with_cp and x.requires_grad:
|
| out = cp.checkpoint(_inner_forward, x)
|
| else:
|
| out = _inner_forward(x)
|
|
|
| out = self.block.relu(out)
|
|
|
| return out
|
|
|
|
|
| @MODELS.register_module()
|
| class TANet(ResNet):
|
| """Temporal Adaptive Network (TANet) backbone.
|
|
|
| This backbone is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
|
| RECOGNITION <https://arxiv.org/pdf/2005.06803>`_
|
|
|
| Embedding the temporal adaptive module (TAM) into ResNet to
|
| instantiate TANet.
|
|
|
| Args:
|
| depth (int): Depth of resnet, from ``{18, 34, 50, 101, 152}``.
|
| num_segments (int): Number of frame segments.
|
| tam_cfg (dict, optional): Config for temporal adaptive module (TAM).
|
| Defaults to None.
|
| """
|
|
|
| def __init__(self,
|
| depth: int,
|
| num_segments: int,
|
| tam_cfg: Optional[dict] = None,
|
| **kwargs) -> None:
|
| super().__init__(depth, **kwargs)
|
| assert num_segments >= 3
|
| self.num_segments = num_segments
|
| tam_cfg = dict() if tam_cfg is None else tam_cfg
|
| self.tam_cfg = deepcopy(tam_cfg)
|
| super().init_weights()
|
| self.make_tam_modeling()
|
|
|
| def init_weights(self):
|
| """Initialize weights."""
|
| pass
|
|
|
| def make_tam_modeling(self):
|
| """Replace ResNet-Block with TA-Block."""
|
|
|
| def make_tam_block(stage, num_segments, tam_cfg=dict()):
|
| blocks = list(stage.children())
|
| for i, block in enumerate(blocks):
|
| blocks[i] = TABlock(block, num_segments, deepcopy(tam_cfg))
|
| return nn.Sequential(*blocks)
|
|
|
| for i in range(self.num_stages):
|
| layer_name = f'layer{i + 1}'
|
| res_layer = getattr(self, layer_name)
|
| setattr(self, layer_name,
|
| make_tam_block(res_layer, self.num_segments, self.tam_cfg))
|
|
|