|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import get_str_type
|
|
|
from .resnet3d import ResNet3d
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class ResNet2Plus1d(ResNet3d):
|
|
|
"""ResNet (2+1)d backbone.
|
|
|
|
|
|
This model is proposed in `A Closer Look at Spatiotemporal Convolutions for
|
|
|
Action Recognition <https://arxiv.org/abs/1711.11248>`_
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
assert self.pretrained2d is False
|
|
|
assert get_str_type(self.conv_cfg['type']) == 'Conv2plus1d'
|
|
|
|
|
|
def _freeze_stages(self):
|
|
|
"""Prevent all the parameters from being optimized before
|
|
|
``self.frozen_stages``."""
|
|
|
if self.frozen_stages >= 0:
|
|
|
self.conv1.eval()
|
|
|
for param in self.conv1.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
for i in range(1, self.frozen_stages + 1):
|
|
|
m = getattr(self, f'layer{i}')
|
|
|
m.eval()
|
|
|
for param in m.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The feature of the input
|
|
|
samples extracted by the backbone.
|
|
|
"""
|
|
|
x = self.conv1(x)
|
|
|
x = self.maxpool(x)
|
|
|
for layer_name in self.res_layers:
|
|
|
res_layer = getattr(self, layer_name)
|
|
|
|
|
|
x = res_layer(x)
|
|
|
|
|
|
return x
|
|
|
|