|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
from .resnet_tsm import ResNetTSM
|
|
|
|
|
|
|
|
|
def linear_sampler(data, offset):
|
|
|
"""Differentiable Temporal-wise Frame Sampling, which is essentially a
|
|
|
linear interpolation process.
|
|
|
|
|
|
It gets the feature map which has been split into several groups
|
|
|
and shift them by different offsets according to their groups.
|
|
|
Then compute the weighted sum along with the temporal dimension.
|
|
|
|
|
|
Args:
|
|
|
data (torch.Tensor): Split data for certain group in shape
|
|
|
[N, num_segments, C, H, W].
|
|
|
offset (torch.Tensor): Data offsets for this group data in shape
|
|
|
[N, num_segments].
|
|
|
"""
|
|
|
|
|
|
n, t, c, h, w = data.shape
|
|
|
|
|
|
|
|
|
offset0 = torch.floor(offset).int()
|
|
|
offset1 = offset0 + 1
|
|
|
|
|
|
|
|
|
data = data.view(n, t, c, h * w).contiguous()
|
|
|
|
|
|
try:
|
|
|
from mmcv.ops import tin_shift
|
|
|
except (ImportError, ModuleNotFoundError):
|
|
|
raise ImportError('Failed to import `tin_shift` from `mmcv.ops`. You '
|
|
|
'will be unable to use TIN. ')
|
|
|
|
|
|
data0 = tin_shift(data, offset0)
|
|
|
data1 = tin_shift(data, offset1)
|
|
|
|
|
|
|
|
|
weight0 = 1 - (offset - offset0.float())
|
|
|
weight1 = 1 - weight0
|
|
|
|
|
|
|
|
|
|
|
|
group_size = offset.shape[1]
|
|
|
weight0 = weight0[:, :, None].repeat(1, 1, c // group_size)
|
|
|
weight0 = weight0.view(weight0.size(0), -1)
|
|
|
weight1 = weight1[:, :, None].repeat(1, 1, c // group_size)
|
|
|
weight1 = weight1.view(weight1.size(0), -1)
|
|
|
|
|
|
|
|
|
weight0 = weight0[:, None, :, None]
|
|
|
weight1 = weight1[:, None, :, None]
|
|
|
|
|
|
|
|
|
output = weight0 * data0 + weight1 * data1
|
|
|
output = output.view(n, t, c, h, w)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
class CombineNet(nn.Module):
|
|
|
"""Combine Net.
|
|
|
|
|
|
It combines Temporal interlace module with some part of ResNet layer.
|
|
|
|
|
|
Args:
|
|
|
net1 (nn.module): Temporal interlace module.
|
|
|
net2 (nn.module): Some part of ResNet layer.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, net1, net2):
|
|
|
super().__init__()
|
|
|
self.net1 = net1
|
|
|
self.net2 = net2
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The output of the module.
|
|
|
"""
|
|
|
|
|
|
|
|
|
x = self.net1(x)
|
|
|
|
|
|
x = self.net2(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class WeightNet(nn.Module):
|
|
|
"""WeightNet in Temporal interlace module.
|
|
|
|
|
|
The WeightNet consists of two parts: one convolution layer
|
|
|
and a sigmoid function. Following the convolution layer, the sigmoid
|
|
|
function and rescale module can scale our output to the range (0, 2).
|
|
|
Here we set the initial bias of the convolution layer to 0, and the
|
|
|
final initial output will be 1.0.
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel num of input features.
|
|
|
groups (int): Number of groups for fc layer outputs.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, groups):
|
|
|
super().__init__()
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
self.groups = groups
|
|
|
|
|
|
self.conv = nn.Conv1d(in_channels, groups, 3, padding=1)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
"""Initiate the parameters either from existing checkpoint or from
|
|
|
scratch."""
|
|
|
|
|
|
|
|
|
self.conv.bias.data[...] = 0
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The output of the module.
|
|
|
"""
|
|
|
|
|
|
|
|
|
n, _, t = x.shape
|
|
|
|
|
|
x = self.conv(x)
|
|
|
x = x.view(n, self.groups, t)
|
|
|
|
|
|
x = x.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
x = 2 * self.sigmoid(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class OffsetNet(nn.Module):
|
|
|
"""OffsetNet in Temporal interlace module.
|
|
|
|
|
|
The OffsetNet consists of one convolution layer and two fc layers
|
|
|
with a relu activation following with a sigmoid function. Following
|
|
|
the convolution layer, two fc layers and relu are applied to the output.
|
|
|
Then, apply the sigmoid function with a multiply factor and a minus 0.5
|
|
|
to transform the output to (-4, 4).
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel num of input features.
|
|
|
groups (int): Number of groups for fc layer outputs.
|
|
|
num_segments (int): Number of frame segments.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, groups, num_segments):
|
|
|
super().__init__()
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
kernel_size = 3
|
|
|
padding = 1
|
|
|
|
|
|
self.conv = nn.Conv1d(in_channels, 1, kernel_size, padding=padding)
|
|
|
self.fc1 = nn.Linear(num_segments, num_segments)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.fc2 = nn.Linear(num_segments, groups)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
"""Initiate the parameters either from existing checkpoint or from
|
|
|
scratch."""
|
|
|
|
|
|
|
|
|
self.fc2.bias.data[...] = 0.5108
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The output of the module.
|
|
|
"""
|
|
|
|
|
|
|
|
|
n, _, t = x.shape
|
|
|
|
|
|
x = self.conv(x)
|
|
|
|
|
|
x = x.view(n, t)
|
|
|
|
|
|
x = self.relu(self.fc1(x))
|
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
|
x = x.view(n, 1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
x = 4 * (self.sigmoid(x) - 0.5)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TemporalInterlace(nn.Module):
|
|
|
"""Temporal interlace module.
|
|
|
|
|
|
This module is proposed in `Temporal Interlacing Network
|
|
|
<https://arxiv.org/abs/2001.06499>`_
|
|
|
|
|
|
Args:
|
|
|
in_channels (int): Channel num of input features.
|
|
|
num_segments (int): Number of frame segments. Default: 3.
|
|
|
shift_div (int): Number of division parts for shift. Default: 1.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_channels, num_segments=3, shift_div=1):
|
|
|
super().__init__()
|
|
|
self.num_segments = num_segments
|
|
|
self.shift_div = shift_div
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
self.deform_groups = 2
|
|
|
|
|
|
self.offset_net = OffsetNet(in_channels // shift_div,
|
|
|
self.deform_groups, num_segments)
|
|
|
self.weight_net = WeightNet(in_channels // shift_div,
|
|
|
self.deform_groups)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""Defines the computation performed at every call.
|
|
|
|
|
|
Args:
|
|
|
x (torch.Tensor): The input data.
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: The output of the module.
|
|
|
"""
|
|
|
|
|
|
|
|
|
n, c, h, w = x.size()
|
|
|
num_batches = n // self.num_segments
|
|
|
num_folds = c // self.shift_div
|
|
|
|
|
|
|
|
|
x_out = torch.zeros((n, c, h, w), device=x.device)
|
|
|
|
|
|
x_descriptor = x[:, :num_folds, :, :].view(num_batches,
|
|
|
self.num_segments,
|
|
|
num_folds, h, w)
|
|
|
|
|
|
|
|
|
|
|
|
x_pooled = torch.mean(x_descriptor, 3)
|
|
|
|
|
|
x_pooled = torch.mean(x_pooled, 3)
|
|
|
|
|
|
x_pooled = x_pooled.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
x_offset = self.offset_net(x_pooled).view(num_batches, -1)
|
|
|
|
|
|
x_weight = self.weight_net(x_pooled)
|
|
|
|
|
|
|
|
|
x_offset = torch.cat([x_offset, -x_offset], 1)
|
|
|
|
|
|
x_shift = linear_sampler(x_descriptor, x_offset)
|
|
|
|
|
|
|
|
|
x_weight = x_weight[:, :, :, None]
|
|
|
|
|
|
|
|
|
x_weight = x_weight.repeat(1, 1, 2, num_folds // 2 // 2)
|
|
|
|
|
|
|
|
|
x_weight = x_weight.view(x_weight.size(0), x_weight.size(1), -1)
|
|
|
|
|
|
|
|
|
x_weight = x_weight[:, :, :, None, None]
|
|
|
|
|
|
x_shift = x_shift * x_weight
|
|
|
|
|
|
x_shift = x_shift.contiguous().view(n, num_folds, h, w)
|
|
|
|
|
|
|
|
|
x_out[:, :num_folds, :] = x_shift
|
|
|
x_out[:, num_folds:, :] = x[:, num_folds:, :]
|
|
|
|
|
|
return x_out
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class ResNetTIN(ResNetTSM):
|
|
|
"""ResNet backbone for TIN.
|
|
|
|
|
|
Args:
|
|
|
depth (int): Depth of ResNet, from {18, 34, 50, 101, 152}.
|
|
|
num_segments (int): Number of frame segments. Default: 8.
|
|
|
is_tin (bool): Whether to apply temporal interlace. Default: True.
|
|
|
shift_div (int): Number of division parts for shift. Default: 4.
|
|
|
kwargs (dict, optional): Arguments for ResNet.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, depth, is_tin=True, **kwargs):
|
|
|
self.is_tin = is_tin
|
|
|
super().__init__(depth, **kwargs)
|
|
|
|
|
|
def init_structure(self):
|
|
|
if self.is_tin:
|
|
|
self.make_temporal_interlace()
|
|
|
if len(self.non_local_cfg) != 0:
|
|
|
self.make_non_local()
|
|
|
|
|
|
def _get_wrap_prefix(self):
|
|
|
return ['.net2']
|
|
|
|
|
|
def make_temporal_interlace(self):
|
|
|
"""Make temporal interlace for some layers."""
|
|
|
num_segment_list = [self.num_segments] * 4
|
|
|
assert num_segment_list[-1] > 0
|
|
|
|
|
|
n_round = 1
|
|
|
if len(list(self.layer3.children())) >= 23:
|
|
|
print(f'=> Using n_round {n_round} to insert temporal shift.')
|
|
|
|
|
|
def make_block_interlace(stage, num_segments, shift_div):
|
|
|
"""Apply Deformable shift for a ResNet layer module.
|
|
|
|
|
|
Args:
|
|
|
stage (nn.module): A ResNet layer to be deformed.
|
|
|
num_segments (int): Number of frame segments.
|
|
|
shift_div (int): Number of division parts for shift.
|
|
|
|
|
|
Returns:
|
|
|
nn.Sequential: A Sequential container consisted of
|
|
|
deformed Interlace blocks.
|
|
|
"""
|
|
|
blocks = list(stage.children())
|
|
|
for i, b in enumerate(blocks):
|
|
|
if i % n_round == 0:
|
|
|
tds = TemporalInterlace(
|
|
|
b.conv1.in_channels,
|
|
|
num_segments=num_segments,
|
|
|
shift_div=shift_div)
|
|
|
blocks[i].conv1.conv = CombineNet(tds,
|
|
|
blocks[i].conv1.conv)
|
|
|
return nn.Sequential(*blocks)
|
|
|
|
|
|
self.layer1 = make_block_interlace(self.layer1, num_segment_list[0],
|
|
|
self.shift_div)
|
|
|
self.layer2 = make_block_interlace(self.layer2, num_segment_list[1],
|
|
|
self.shift_div)
|
|
|
self.layer3 = make_block_interlace(self.layer3, num_segment_list[2],
|
|
|
self.shift_div)
|
|
|
self.layer4 = make_block_interlace(self.layer4, num_segment_list[3],
|
|
|
self.shift_div)
|
|
|
|