|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from mmengine.model import BaseModel, BaseModule
|
|
|
from mmengine.runner import CheckpointLoader
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from mmaction.utils import OptConfigType
|
|
|
|
|
|
|
|
|
def batch_norm(inputs: torch.Tensor,
|
|
|
module: nn.modules.batchnorm,
|
|
|
training: Optional[bool] = None) -> torch.Tensor:
|
|
|
"""Applies Batch Normalization for each channel across a batch of data
|
|
|
using params from the given batch normalization module.
|
|
|
|
|
|
Args:
|
|
|
inputs (Tensor): The input data.
|
|
|
module (nn.modules.batchnorm): a batch normalization module. Will use
|
|
|
params from this batch normalization module to do the operation.
|
|
|
training (bool, optional): if true, apply the train mode batch
|
|
|
normalization. Defaults to None and will use the training mode of
|
|
|
the module.
|
|
|
"""
|
|
|
if training is None:
|
|
|
training = module.training
|
|
|
return F.batch_norm(
|
|
|
input=inputs,
|
|
|
running_mean=None if training else module.running_mean,
|
|
|
running_var=None if training else module.running_var,
|
|
|
weight=module.weight,
|
|
|
bias=module.bias,
|
|
|
training=training,
|
|
|
momentum=module.momentum,
|
|
|
eps=module.eps)
|
|
|
|
|
|
|
|
|
class BottleNeck(BaseModule):
|
|
|
"""Building block for Omni-ResNet.
|
|
|
|
|
|
Args:
|
|
|
inplanes (int): Number of channels for the input in first conv layer.
|
|
|
planes (int): Number of channels for the input in second conv layer.
|
|
|
temporal_kernel (int): Temporal kernel in the conv layer. Should be
|
|
|
either 1 or 3. Defaults to 1.
|
|
|
spatial_stride (int): Spatial stride in the conv layer. Defaults to 1.
|
|
|
init_cfg (dict or ConfigDict, optional): The Config for initialization.
|
|
|
Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
inplanes: int,
|
|
|
planes: int,
|
|
|
temporal_kernel: int = 3,
|
|
|
spatial_stride: int = 1,
|
|
|
init_cfg: OptConfigType = None,
|
|
|
**kwargs) -> None:
|
|
|
super(BottleNeck, self).__init__(init_cfg=init_cfg)
|
|
|
assert temporal_kernel in [1, 3]
|
|
|
|
|
|
self.conv1 = nn.Conv3d(
|
|
|
inplanes,
|
|
|
planes,
|
|
|
kernel_size=(temporal_kernel, 1, 1),
|
|
|
padding=(temporal_kernel // 2, 0, 0),
|
|
|
bias=False)
|
|
|
self.conv2 = nn.Conv3d(
|
|
|
planes,
|
|
|
planes,
|
|
|
stride=(1, spatial_stride, spatial_stride),
|
|
|
kernel_size=(1, 3, 3),
|
|
|
padding=(0, 1, 1),
|
|
|
bias=False)
|
|
|
|
|
|
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
|
|
|
|
|
|
self.bn1 = nn.BatchNorm3d(planes, momentum=0.01)
|
|
|
self.bn2 = nn.BatchNorm3d(planes, momentum=0.01)
|
|
|
self.bn3 = nn.BatchNorm3d(planes * 4, momentum=0.01)
|
|
|
|
|
|
if inplanes != planes * 4 or spatial_stride != 1:
|
|
|
downsample = [
|
|
|
nn.Conv3d(
|
|
|
inplanes,
|
|
|
planes * 4,
|
|
|
kernel_size=1,
|
|
|
stride=(1, spatial_stride, spatial_stride),
|
|
|
bias=False),
|
|
|
nn.BatchNorm3d(planes * 4, momentum=0.01)
|
|
|
]
|
|
|
self.downsample = nn.Sequential(*downsample)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors.
|
|
|
"""
|
|
|
if x.ndim == 4:
|
|
|
return self.forward_2d(x)
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
out = self.bn1(out).relu_()
|
|
|
|
|
|
out = self.conv2(out)
|
|
|
out = self.bn2(out).relu_()
|
|
|
|
|
|
out = self.conv3(out)
|
|
|
out = self.bn3(out)
|
|
|
|
|
|
if hasattr(self, 'downsample'):
|
|
|
x = self.downsample(x)
|
|
|
|
|
|
return out.add_(x).relu_()
|
|
|
|
|
|
def forward_2d(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward call for 2D tensors."""
|
|
|
out = F.conv2d(x, self.conv1.weight.sum(2))
|
|
|
out = batch_norm(out, self.bn1).relu_()
|
|
|
|
|
|
out = F.conv2d(
|
|
|
out,
|
|
|
self.conv2.weight.squeeze(2),
|
|
|
stride=self.conv2.stride[-1],
|
|
|
padding=1)
|
|
|
out = batch_norm(out, self.bn2).relu_()
|
|
|
|
|
|
out = F.conv2d(out, self.conv3.weight.squeeze(2))
|
|
|
out = batch_norm(out, self.bn3)
|
|
|
|
|
|
if hasattr(self, 'downsample'):
|
|
|
x = F.conv2d(
|
|
|
x,
|
|
|
self.downsample[0].weight.squeeze(2),
|
|
|
stride=self.downsample[0].stride[-1])
|
|
|
x = batch_norm(x, self.downsample[1])
|
|
|
|
|
|
return out.add_(x).relu_()
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class OmniResNet(BaseModel):
|
|
|
"""Omni-ResNet that accepts both image and video inputs.
|
|
|
|
|
|
Args:
|
|
|
layers (List[int]): number of layers in each residual stages. Defaults
|
|
|
to [3, 4, 6, 3].
|
|
|
pretrain_2d (str, optional): path to the 2D pretraining checkpoints.
|
|
|
Defaults to None.
|
|
|
init_cfg (dict or ConfigDict, optional): The Config for initialization.
|
|
|
Defaults to None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
layers: List[int] = [3, 4, 6, 3],
|
|
|
pretrain_2d: Optional[str] = None,
|
|
|
init_cfg: OptConfigType = None) -> None:
|
|
|
super(OmniResNet, self).__init__(init_cfg=init_cfg)
|
|
|
|
|
|
self.inplanes = 64
|
|
|
self.conv1 = nn.Conv3d(
|
|
|
3,
|
|
|
self.inplanes,
|
|
|
kernel_size=(1, 7, 7),
|
|
|
stride=(1, 2, 2),
|
|
|
padding=(0, 3, 3),
|
|
|
bias=False)
|
|
|
self.bn1 = nn.BatchNorm3d(self.inplanes, momentum=0.01)
|
|
|
|
|
|
self.pool3d = nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1))
|
|
|
self.pool2d = nn.MaxPool2d(3, 2, 1)
|
|
|
|
|
|
self.temporal_kernel = 1
|
|
|
self.layer1 = self._make_layer(64, layers[0])
|
|
|
self.layer2 = self._make_layer(128, layers[1], stride=2)
|
|
|
self.temporal_kernel = 3
|
|
|
self.layer3 = self._make_layer(256, layers[2], stride=2)
|
|
|
self.layer4 = self._make_layer(512, layers[3], stride=2)
|
|
|
|
|
|
if pretrain_2d is not None:
|
|
|
self.init_from_2d(pretrain_2d)
|
|
|
|
|
|
def _make_layer(self,
|
|
|
planes: int,
|
|
|
num_blocks: int,
|
|
|
stride: int = 1) -> nn.Module:
|
|
|
layers = [
|
|
|
BottleNeck(
|
|
|
self.inplanes,
|
|
|
planes,
|
|
|
spatial_stride=stride,
|
|
|
temporal_kernel=self.temporal_kernel)
|
|
|
]
|
|
|
self.inplanes = planes * 4
|
|
|
for _ in range(1, num_blocks):
|
|
|
layers.append(
|
|
|
BottleNeck(
|
|
|
self.inplanes,
|
|
|
planes,
|
|
|
temporal_kernel=self.temporal_kernel))
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def init_from_2d(self, pretrain: str) -> None:
|
|
|
param2d = CheckpointLoader.load_checkpoint(
|
|
|
pretrain, map_location='cpu')
|
|
|
param3d = self.state_dict()
|
|
|
for key in param3d:
|
|
|
if key in param2d:
|
|
|
weight = param2d[key]
|
|
|
if weight.ndim == 4:
|
|
|
t = param3d[key].shape[2]
|
|
|
weight = weight.unsqueeze(2)
|
|
|
weight = weight.expand(-1, -1, t, -1, -1)
|
|
|
weight = weight / t
|
|
|
param3d[key] = weight
|
|
|
self.load_state_dict(param3d)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors.
|
|
|
"""
|
|
|
if x.ndim == 4:
|
|
|
return self.forward_2d(x)
|
|
|
|
|
|
|
|
|
x = self.conv1(x)
|
|
|
x = self.bn1(x).relu_()
|
|
|
x = self.pool3d(x)
|
|
|
|
|
|
x = self.layer1(x)
|
|
|
x = self.layer2(x)
|
|
|
x = self.layer3(x)
|
|
|
x = self.layer4(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
def forward_2d(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward call for 2D tensors."""
|
|
|
x = F.conv2d(
|
|
|
x,
|
|
|
self.conv1.weight.squeeze(2),
|
|
|
stride=self.conv1.stride[-1],
|
|
|
padding=self.conv1.padding[-1])
|
|
|
x = batch_norm(x, self.bn1).relu_()
|
|
|
x = self.pool2d(x)
|
|
|
|
|
|
x = self.layer1(x)
|
|
|
x = self.layer2(x)
|
|
|
x = self.layer3(x)
|
|
|
x = self.layer4(x)
|
|
|
return x
|
|
|
|