|
|
|
|
|
import warnings
|
|
|
from collections import OrderedDict
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from mmcv.cnn import ConvModule
|
|
|
from mmengine.logging import MMLogger, print_log
|
|
|
from mmengine.model import BaseModule
|
|
|
from mmengine.model.weight_init import kaiming_init
|
|
|
from mmengine.runner.checkpoint import _load_checkpoint, load_checkpoint
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .resnet3d import ResNet3d
|
|
|
|
|
|
|
|
|
class DeConvModule(BaseModule):
|
|
|
"""A deconv module that bundles deconv/norm/activation layers.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Number of channels in the input feature map.
|
|
|
out_channels (int): Number of channels produced by the convolution.
|
|
|
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
|
|
stride (int | tuple[int]): Stride of the convolution.
|
|
|
padding (int | tuple[int]): Zero-padding added to both sides of
|
|
|
the input.
|
|
|
bias (bool): Whether to add a learnable bias to the output.
|
|
|
Defaults to False.
|
|
|
with_bn (bool): Whether to add a BN layer. Defaults to True.
|
|
|
with_relu (bool): Whether to add a ReLU layer. Defaults to True.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
in_channels: int,
|
|
|
out_channels: int,
|
|
|
kernel_size: int,
|
|
|
stride: Union[int, Tuple[int]] = (1, 1, 1),
|
|
|
padding: Union[int, Tuple[int]] = 0,
|
|
|
bias: bool = False,
|
|
|
with_bn: bool = True,
|
|
|
with_relu: bool = True) -> None:
|
|
|
super().__init__()
|
|
|
self.in_channels = in_channels
|
|
|
self.out_channels = out_channels
|
|
|
self.kernel_size = kernel_size
|
|
|
self.stride = stride
|
|
|
self.padding = padding
|
|
|
self.bias = bias
|
|
|
self.with_bn = with_bn
|
|
|
self.with_relu = with_relu
|
|
|
|
|
|
self.conv = nn.ConvTranspose3d(
|
|
|
in_channels,
|
|
|
out_channels,
|
|
|
kernel_size,
|
|
|
stride=stride,
|
|
|
padding=padding,
|
|
|
bias=bias)
|
|
|
self.bn = nn.BatchNorm3d(out_channels)
|
|
|
self.relu = nn.ReLU()
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call."""
|
|
|
|
|
|
assert len(x.shape) == 5
|
|
|
N, C, T, H, W = x.shape
|
|
|
out_shape = (N, self.out_channels, self.stride[0] * T,
|
|
|
self.stride[1] * H, self.stride[2] * W)
|
|
|
x = self.conv(x, output_size=out_shape)
|
|
|
if self.with_bn:
|
|
|
x = self.bn(x)
|
|
|
if self.with_relu:
|
|
|
x = self.relu(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class ResNet3dPathway(ResNet3d):
|
|
|
"""A pathway of Slowfast based on ResNet3d.
|
|
|
|
|
|
Args:
|
|
|
lateral (bool): Determines whether to enable the lateral connection
|
|
|
from another pathway. Defaults to False.
|
|
|
lateral_inv (bool): Whether to use deconv to upscale the time
|
|
|
dimension of features from another pathway. Defaults to False.
|
|
|
lateral_norm (bool): Determines whether to enable the lateral norm
|
|
|
in lateral layers. Defaults to False.
|
|
|
speed_ratio (int): Speed ratio indicating the ratio between time
|
|
|
dimension of the fast and slow pathway, corresponding to the
|
|
|
``alpha`` in the paper. Defaults to 8.
|
|
|
channel_ratio (int): Reduce the channel number of fast pathway
|
|
|
by ``channel_ratio``, corresponding to ``beta`` in the paper.
|
|
|
Defaults to 8.
|
|
|
fusion_kernel (int): The kernel size of lateral fusion.
|
|
|
Defaults to 5.
|
|
|
lateral_infl (int): The ratio of the inflated channels.
|
|
|
Defaults to 2.
|
|
|
lateral_activate (list[int]): Flags for activating the lateral
|
|
|
connection. Defaults to ``[1, 1, 1, 1]``.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
lateral: bool = False,
|
|
|
lateral_inv: bool = False,
|
|
|
lateral_norm: bool = False,
|
|
|
speed_ratio: int = 8,
|
|
|
channel_ratio: int = 8,
|
|
|
fusion_kernel: int = 5,
|
|
|
lateral_infl: int = 2,
|
|
|
lateral_activate: List[int] = [1, 1, 1, 1],
|
|
|
**kwargs) -> None:
|
|
|
self.lateral = lateral
|
|
|
self.lateral_inv = lateral_inv
|
|
|
self.lateral_norm = lateral_norm
|
|
|
self.speed_ratio = speed_ratio
|
|
|
self.channel_ratio = channel_ratio
|
|
|
self.fusion_kernel = fusion_kernel
|
|
|
self.lateral_infl = lateral_infl
|
|
|
self.lateral_activate = lateral_activate
|
|
|
self._calculate_lateral_inplanes(kwargs)
|
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
self.inplanes = self.base_channels
|
|
|
if self.lateral and self.lateral_activate[0] == 1:
|
|
|
if self.lateral_inv:
|
|
|
self.conv1_lateral = DeConvModule(
|
|
|
self.inplanes * self.channel_ratio,
|
|
|
self.inplanes * self.channel_ratio // lateral_infl,
|
|
|
kernel_size=(fusion_kernel, 1, 1),
|
|
|
stride=(self.speed_ratio, 1, 1),
|
|
|
padding=((fusion_kernel - 1) // 2, 0, 0),
|
|
|
with_bn=True,
|
|
|
with_relu=True)
|
|
|
else:
|
|
|
self.conv1_lateral = ConvModule(
|
|
|
self.inplanes // self.channel_ratio,
|
|
|
self.inplanes * lateral_infl // self.channel_ratio,
|
|
|
kernel_size=(fusion_kernel, 1, 1),
|
|
|
stride=(self.speed_ratio, 1, 1),
|
|
|
padding=((fusion_kernel - 1) // 2, 0, 0),
|
|
|
bias=False,
|
|
|
conv_cfg=self.conv_cfg,
|
|
|
norm_cfg=self.norm_cfg if self.lateral_norm else None,
|
|
|
act_cfg=self.act_cfg if self.lateral_norm else None)
|
|
|
|
|
|
self.lateral_connections = []
|
|
|
for i in range(len(self.stage_blocks)):
|
|
|
planes = self.base_channels * 2**i
|
|
|
self.inplanes = planes * self.block.expansion
|
|
|
|
|
|
if lateral and i != self.num_stages - 1 \
|
|
|
and self.lateral_activate[i + 1]:
|
|
|
|
|
|
lateral_name = f'layer{(i + 1)}_lateral'
|
|
|
if self.lateral_inv:
|
|
|
conv_module = DeConvModule(
|
|
|
self.inplanes * self.channel_ratio,
|
|
|
self.inplanes * self.channel_ratio // lateral_infl,
|
|
|
kernel_size=(fusion_kernel, 1, 1),
|
|
|
stride=(self.speed_ratio, 1, 1),
|
|
|
padding=((fusion_kernel - 1) // 2, 0, 0),
|
|
|
bias=False,
|
|
|
with_bn=True,
|
|
|
with_relu=True)
|
|
|
else:
|
|
|
conv_module = ConvModule(
|
|
|
self.inplanes // self.channel_ratio,
|
|
|
self.inplanes * lateral_infl // self.channel_ratio,
|
|
|
kernel_size=(fusion_kernel, 1, 1),
|
|
|
stride=(self.speed_ratio, 1, 1),
|
|
|
padding=((fusion_kernel - 1) // 2, 0, 0),
|
|
|
bias=False,
|
|
|
conv_cfg=self.conv_cfg,
|
|
|
norm_cfg=self.norm_cfg if self.lateral_norm else None,
|
|
|
act_cfg=self.act_cfg if self.lateral_norm else None)
|
|
|
setattr(self, lateral_name, conv_module)
|
|
|
self.lateral_connections.append(lateral_name)
|
|
|
|
|
|
def _calculate_lateral_inplanes(self, kwargs):
|
|
|
"""Calculate inplanes for lateral connection."""
|
|
|
depth = kwargs.get('depth', 50)
|
|
|
expansion = 1 if depth < 50 else 4
|
|
|
base_channels = kwargs.get('base_channels', 64)
|
|
|
lateral_inplanes = []
|
|
|
for i in range(kwargs.get('num_stages', 4)):
|
|
|
if expansion % 2 == 0:
|
|
|
planes = base_channels * (2 ** i) * \
|
|
|
((expansion // 2) ** (i > 0))
|
|
|
else:
|
|
|
planes = base_channels * (2**i) // (2**(i > 0))
|
|
|
if self.lateral and self.lateral_activate[i]:
|
|
|
if self.lateral_inv:
|
|
|
lateral_inplane = planes * \
|
|
|
self.channel_ratio // self.lateral_infl
|
|
|
else:
|
|
|
lateral_inplane = planes * \
|
|
|
self.lateral_infl // self.channel_ratio
|
|
|
else:
|
|
|
lateral_inplane = 0
|
|
|
lateral_inplanes.append(lateral_inplane)
|
|
|
self.lateral_inplanes = lateral_inplanes
|
|
|
|
|
|
def inflate_weights(self, logger: MMLogger) -> None:
|
|
|
"""Inflate the resnet2d parameters to resnet3d pathway.
|
|
|
|
|
|
The differences between resnet3d and resnet2d mainly lie in an extra
|
|
|
axis of conv kernel. To utilize the pretrained parameters in 2d model,
|
|
|
the weight of conv2d models should be inflated to fit in the shapes of
|
|
|
the 3d counterpart. For pathway the ``lateral_connection`` part should
|
|
|
not be inflated from 2d weights.
|
|
|
|
|
|
Args:
|
|
|
logger (MMLogger): The logger used to print
|
|
|
debugging information.
|
|
|
"""
|
|
|
|
|
|
state_dict_r2d = _load_checkpoint(self.pretrained, map_location='cpu')
|
|
|
if 'state_dict' in state_dict_r2d:
|
|
|
state_dict_r2d = state_dict_r2d['state_dict']
|
|
|
|
|
|
inflated_param_names = []
|
|
|
for name, module in self.named_modules():
|
|
|
if 'lateral' in name:
|
|
|
continue
|
|
|
if isinstance(module, ConvModule):
|
|
|
|
|
|
|
|
|
if 'downsample' in name:
|
|
|
|
|
|
original_conv_name = name + '.0'
|
|
|
|
|
|
original_bn_name = name + '.1'
|
|
|
else:
|
|
|
|
|
|
original_conv_name = name
|
|
|
|
|
|
original_bn_name = name.replace('conv', 'bn')
|
|
|
if original_conv_name + '.weight' not in state_dict_r2d:
|
|
|
logger.warning(f'Module not exist in the state_dict_r2d'
|
|
|
f': {original_conv_name}')
|
|
|
else:
|
|
|
self._inflate_conv_params(module.conv, state_dict_r2d,
|
|
|
original_conv_name,
|
|
|
inflated_param_names)
|
|
|
if original_bn_name + '.weight' not in state_dict_r2d:
|
|
|
logger.warning(f'Module not exist in the state_dict_r2d'
|
|
|
f': {original_bn_name}')
|
|
|
else:
|
|
|
self._inflate_bn_params(module.bn, state_dict_r2d,
|
|
|
original_bn_name,
|
|
|
inflated_param_names)
|
|
|
|
|
|
|
|
|
remaining_names = set(
|
|
|
state_dict_r2d.keys()) - set(inflated_param_names)
|
|
|
if remaining_names:
|
|
|
logger.info(f'These parameters in the 2d checkpoint are not loaded'
|
|
|
f': {remaining_names}')
|
|
|
|
|
|
def _inflate_conv_params(self, conv3d: nn.Module,
|
|
|
state_dict_2d: OrderedDict, module_name_2d: str,
|
|
|
inflated_param_names: List[str]) -> None:
|
|
|
"""Inflate a conv module from 2d to 3d.
|
|
|
|
|
|
The differences of conv modules betweene 2d and 3d in Pathway
|
|
|
mainly lie in the inplanes due to lateral connections. To fit the
|
|
|
shapes of the lateral connection counterpart, it will expand
|
|
|
parameters by concatting conv2d parameters and extra zero paddings.
|
|
|
|
|
|
Args:
|
|
|
conv3d (nn.Module): The destination conv3d module.
|
|
|
state_dict_2d (OrderedDict): The state dict of pretrained 2d model.
|
|
|
module_name_2d (str): The name of corresponding conv module in the
|
|
|
2d model.
|
|
|
inflated_param_names (list[str]): List of parameters that have been
|
|
|
inflated.
|
|
|
"""
|
|
|
weight_2d_name = module_name_2d + '.weight'
|
|
|
conv2d_weight = state_dict_2d[weight_2d_name]
|
|
|
old_shape = conv2d_weight.shape
|
|
|
new_shape = conv3d.weight.data.shape
|
|
|
kernel_t = new_shape[2]
|
|
|
|
|
|
if new_shape[1] != old_shape[1]:
|
|
|
if new_shape[1] < old_shape[1]:
|
|
|
warnings.warn(f'The parameter of {module_name_2d} is not'
|
|
|
'loaded due to incompatible shapes. ')
|
|
|
return
|
|
|
|
|
|
new_channels = new_shape[1] - old_shape[1]
|
|
|
pad_shape = old_shape
|
|
|
pad_shape = pad_shape[:1] + (new_channels, ) + pad_shape[2:]
|
|
|
|
|
|
conv2d_weight = torch.cat(
|
|
|
(conv2d_weight,
|
|
|
torch.zeros(pad_shape).type_as(conv2d_weight).to(
|
|
|
conv2d_weight.device)),
|
|
|
dim=1)
|
|
|
|
|
|
new_weight = conv2d_weight.data.unsqueeze(2).expand_as(
|
|
|
conv3d.weight) / kernel_t
|
|
|
conv3d.weight.data.copy_(new_weight)
|
|
|
inflated_param_names.append(weight_2d_name)
|
|
|
|
|
|
if getattr(conv3d, 'bias') is not None:
|
|
|
bias_2d_name = module_name_2d + '.bias'
|
|
|
conv3d.bias.data.copy_(state_dict_2d[bias_2d_name])
|
|
|
inflated_param_names.append(bias_2d_name)
|
|
|
|
|
|
def _freeze_stages(self) -> None:
|
|
|
"""Prevent all the parameters from being optimized before
|
|
|
`self.frozen_stages`."""
|
|
|
if self.frozen_stages >= 0:
|
|
|
self.conv1.eval()
|
|
|
for param in self.conv1.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
for i in range(1, self.frozen_stages + 1):
|
|
|
m = getattr(self, f'layer{i}')
|
|
|
m.eval()
|
|
|
for param in m.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
if i != len(self.res_layers) and self.lateral:
|
|
|
|
|
|
lateral_name = self.lateral_connections[i - 1]
|
|
|
conv_lateral = getattr(self, lateral_name)
|
|
|
conv_lateral.eval()
|
|
|
for param in conv_lateral.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
def init_weights(self, pretrained: Optional[str] = None) -> None:
|
|
|
"""Initiate the parameters either from existing checkpoint or from
|
|
|
scratch."""
|
|
|
if pretrained:
|
|
|
self.pretrained = pretrained
|
|
|
|
|
|
|
|
|
super().init_weights()
|
|
|
for module_name in self.lateral_connections:
|
|
|
layer = getattr(self, module_name)
|
|
|
for m in layer.modules():
|
|
|
if isinstance(m, (nn.Conv3d, nn.Conv2d)):
|
|
|
kaiming_init(m)
|
|
|
|
|
|
|
|
|
pathway_cfg = {
|
|
|
'resnet3d': ResNet3dPathway,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
def build_pathway(cfg: Dict, *args, **kwargs) -> nn.Module:
|
|
|
"""Build pathway.
|
|
|
|
|
|
Args:
|
|
|
cfg (dict): cfg should contain:
|
|
|
- type (str): identify backbone type.
|
|
|
|
|
|
Returns:
|
|
|
nn.Module: Created pathway.
|
|
|
"""
|
|
|
if not (isinstance(cfg, dict) and 'type' in cfg):
|
|
|
raise TypeError('cfg must be a dict containing the key "type"')
|
|
|
cfg_ = cfg.copy()
|
|
|
|
|
|
pathway_type = cfg_.pop('type')
|
|
|
if pathway_type not in pathway_cfg:
|
|
|
raise KeyError(f'Unrecognized pathway type {pathway_type}')
|
|
|
|
|
|
pathway_cls = pathway_cfg[pathway_type]
|
|
|
pathway = pathway_cls(*args, **kwargs, **cfg_)
|
|
|
|
|
|
return pathway
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class ResNet3dSlowFast(BaseModule):
|
|
|
"""Slowfast backbone.
|
|
|
|
|
|
This module is proposed in `SlowFast Networks for Video Recognition
|
|
|
<https://arxiv.org/abs/1812.03982>`_
|
|
|
|
|
|
Args:
|
|
|
pretrained (str): The file path to a pretrained model.
|
|
|
resample_rate (int): A large temporal stride ``resample_rate``
|
|
|
on input frames. The actual resample rate is calculated by
|
|
|
multipling the ``interval`` in ``SampleFrames`` in the
|
|
|
pipeline with ``resample_rate``, equivalent to the :math:`\\tau`
|
|
|
in the paper, i.e. it processes only one out of
|
|
|
``resample_rate * interval`` frames. Defaults to 8.
|
|
|
speed_ratio (int): Speed ratio indicating the ratio between time
|
|
|
dimension of the fast and slow pathway, corresponding to the
|
|
|
:math:`\\alpha` in the paper. Defaults to 8.
|
|
|
channel_ratio (int): Reduce the channel number of fast pathway
|
|
|
by ``channel_ratio``, corresponding to :math:`\\beta` in the paper.
|
|
|
Defaults to 8.
|
|
|
slow_pathway (dict): Configuration of slow branch. Defaults to
|
|
|
``dict(type='resnet3d', depth=50, pretrained=None, lateral=True,
|
|
|
conv1_kernel=(1, 7, 7), conv1_stride_t=1, pool1_stride_t=1,
|
|
|
inflate=(0, 0, 1, 1))``.
|
|
|
fast_pathway (dict): Configuration of fast branch. Defaults to
|
|
|
``dict(type='resnet3d', depth=50, pretrained=None, lateral=False,
|
|
|
base_channels=8, conv1_kernel=(5, 7, 7), conv1_stride_t=1,
|
|
|
pool1_stride_t=1)``.
|
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
|
Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
pretrained: Optional[str] = None,
|
|
|
resample_rate: int = 8,
|
|
|
speed_ratio: int = 8,
|
|
|
channel_ratio: int = 8,
|
|
|
slow_pathway: Dict = dict(
|
|
|
type='resnet3d',
|
|
|
depth=50,
|
|
|
pretrained=None,
|
|
|
lateral=True,
|
|
|
conv1_kernel=(1, 7, 7),
|
|
|
conv1_stride_t=1,
|
|
|
pool1_stride_t=1,
|
|
|
inflate=(0, 0, 1, 1)),
|
|
|
fast_pathway: Dict = dict(
|
|
|
type='resnet3d',
|
|
|
depth=50,
|
|
|
pretrained=None,
|
|
|
lateral=False,
|
|
|
base_channels=8,
|
|
|
conv1_kernel=(5, 7, 7),
|
|
|
conv1_stride_t=1,
|
|
|
pool1_stride_t=1),
|
|
|
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
self.pretrained = pretrained
|
|
|
self.resample_rate = resample_rate
|
|
|
self.speed_ratio = speed_ratio
|
|
|
self.channel_ratio = channel_ratio
|
|
|
|
|
|
if slow_pathway['lateral']:
|
|
|
slow_pathway['speed_ratio'] = speed_ratio
|
|
|
slow_pathway['channel_ratio'] = channel_ratio
|
|
|
|
|
|
self.slow_path = build_pathway(slow_pathway)
|
|
|
self.fast_path = build_pathway(fast_pathway)
|
|
|
|
|
|
def init_weights(self, pretrained: Optional[str] = None) -> None:
|
|
|
"""Initiate the parameters either from existing checkpoint or from
|
|
|
scratch."""
|
|
|
if pretrained:
|
|
|
self.pretrained = pretrained
|
|
|
|
|
|
if isinstance(self.pretrained, str):
|
|
|
logger = MMLogger.get_current_instance()
|
|
|
msg = f'load model from: {self.pretrained}'
|
|
|
print_log(msg, logger=logger)
|
|
|
|
|
|
load_checkpoint(self, self.pretrained, strict=True, logger=logger)
|
|
|
elif self.pretrained is None:
|
|
|
|
|
|
self.fast_path.init_weights()
|
|
|
self.slow_path.init_weights()
|
|
|
else:
|
|
|
raise TypeError('pretrained must be a str or None')
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> tuple:
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
tuple[torch.Tensor]: The feature of the input samples
|
|
|
extracted by the backbone.
|
|
|
"""
|
|
|
x_slow = nn.functional.interpolate(
|
|
|
x,
|
|
|
mode='nearest',
|
|
|
scale_factor=(1.0 / self.resample_rate, 1.0, 1.0))
|
|
|
x_slow = self.slow_path.conv1(x_slow)
|
|
|
x_slow = self.slow_path.maxpool(x_slow)
|
|
|
|
|
|
x_fast = nn.functional.interpolate(
|
|
|
x,
|
|
|
mode='nearest',
|
|
|
scale_factor=(1.0 / (self.resample_rate // self.speed_ratio), 1.0,
|
|
|
1.0))
|
|
|
x_fast = self.fast_path.conv1(x_fast)
|
|
|
x_fast = self.fast_path.maxpool(x_fast)
|
|
|
|
|
|
if self.slow_path.lateral:
|
|
|
x_fast_lateral = self.slow_path.conv1_lateral(x_fast)
|
|
|
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
|
|
|
|
|
|
for i, layer_name in enumerate(self.slow_path.res_layers):
|
|
|
res_layer = getattr(self.slow_path, layer_name)
|
|
|
x_slow = res_layer(x_slow)
|
|
|
res_layer_fast = getattr(self.fast_path, layer_name)
|
|
|
x_fast = res_layer_fast(x_fast)
|
|
|
if (i != len(self.slow_path.res_layers) - 1
|
|
|
and self.slow_path.lateral):
|
|
|
|
|
|
lateral_name = self.slow_path.lateral_connections[i]
|
|
|
conv_lateral = getattr(self.slow_path, lateral_name)
|
|
|
x_fast_lateral = conv_lateral(x_fast)
|
|
|
x_slow = torch.cat((x_slow, x_fast_lateral), dim=1)
|
|
|
|
|
|
out = (x_slow, x_fast)
|
|
|
|
|
|
return out
|
|
|
|