|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from einops import rearrange
|
|
|
from mmcv.cnn import build_norm_layer
|
|
|
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
|
|
from mmengine.model import BaseModule
|
|
|
from mmengine.model.weight_init import constant_init
|
|
|
from mmengine.utils import digit_version
|
|
|
|
|
|
from mmaction.registry import MODELS
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class DividedTemporalAttentionWithNorm(BaseModule):
|
|
|
"""Temporal Attention in Divided Space Time Attention.
|
|
|
|
|
|
Args:
|
|
|
embed_dims (int): Dimensions of embedding.
|
|
|
num_heads (int): Number of parallel attention heads in
|
|
|
TransformerCoder.
|
|
|
num_frames (int): Number of frames in the video.
|
|
|
attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
|
|
|
0..
|
|
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
|
|
Defaults to 0..
|
|
|
dropout_layer (dict): The dropout_layer used when adding the shortcut.
|
|
|
Defaults to `dict(type='DropPath', drop_prob=0.1)`.
|
|
|
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
|
|
`dict(type='LN')`.
|
|
|
init_cfg (dict | None): The Config for initialization. Defaults to
|
|
|
None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
embed_dims,
|
|
|
num_heads,
|
|
|
num_frames,
|
|
|
attn_drop=0.,
|
|
|
proj_drop=0.,
|
|
|
dropout_layer=dict(type='DropPath', drop_prob=0.1),
|
|
|
norm_cfg=dict(type='LN'),
|
|
|
init_cfg=None,
|
|
|
**kwargs):
|
|
|
super().__init__(init_cfg)
|
|
|
self.embed_dims = embed_dims
|
|
|
self.num_heads = num_heads
|
|
|
self.num_frames = num_frames
|
|
|
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
|
|
|
|
|
|
if digit_version(torch.__version__) < digit_version('1.9.0'):
|
|
|
kwargs.pop('batch_first', None)
|
|
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
|
|
**kwargs)
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
self.dropout_layer = build_dropout(
|
|
|
dropout_layer) if dropout_layer else nn.Identity()
|
|
|
self.temporal_fc = nn.Linear(self.embed_dims, self.embed_dims)
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
"""Initialize weights."""
|
|
|
constant_init(self.temporal_fc, val=0, bias=0)
|
|
|
|
|
|
def forward(self, query, key=None, value=None, residual=None, **kwargs):
|
|
|
"""Defines the computation performed at every call."""
|
|
|
assert residual is None, (
|
|
|
'Always adding the shortcut in the forward function')
|
|
|
|
|
|
init_cls_token = query[:, 0, :].unsqueeze(1)
|
|
|
identity = query_t = query[:, 1:, :]
|
|
|
|
|
|
|
|
|
b, pt, m = query_t.size()
|
|
|
p, t = pt // self.num_frames, self.num_frames
|
|
|
|
|
|
|
|
|
query_t = self.norm(query_t.reshape(b * p, t, m)).permute(1, 0, 2)
|
|
|
res_temporal = self.attn(query_t, query_t, query_t)[0].permute(1, 0, 2)
|
|
|
res_temporal = self.dropout_layer(
|
|
|
self.proj_drop(res_temporal.contiguous()))
|
|
|
res_temporal = self.temporal_fc(res_temporal)
|
|
|
|
|
|
|
|
|
res_temporal = res_temporal.reshape(b, p * t, m)
|
|
|
|
|
|
|
|
|
new_query_t = identity + res_temporal
|
|
|
new_query = torch.cat((init_cls_token, new_query_t), 1)
|
|
|
return new_query
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class DividedSpatialAttentionWithNorm(BaseModule):
|
|
|
"""Spatial Attention in Divided Space Time Attention.
|
|
|
|
|
|
Args:
|
|
|
embed_dims (int): Dimensions of embedding.
|
|
|
num_heads (int): Number of parallel attention heads in
|
|
|
TransformerCoder.
|
|
|
num_frames (int): Number of frames in the video.
|
|
|
attn_drop (float): A Dropout layer on attn_output_weights. Defaults to
|
|
|
0..
|
|
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
|
|
Defaults to 0..
|
|
|
dropout_layer (dict): The dropout_layer used when adding the shortcut.
|
|
|
Defaults to `dict(type='DropPath', drop_prob=0.1)`.
|
|
|
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
|
|
`dict(type='LN')`.
|
|
|
init_cfg (dict | None): The Config for initialization. Defaults to
|
|
|
None.
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
embed_dims,
|
|
|
num_heads,
|
|
|
num_frames,
|
|
|
attn_drop=0.,
|
|
|
proj_drop=0.,
|
|
|
dropout_layer=dict(type='DropPath', drop_prob=0.1),
|
|
|
norm_cfg=dict(type='LN'),
|
|
|
init_cfg=None,
|
|
|
**kwargs):
|
|
|
super().__init__(init_cfg)
|
|
|
self.embed_dims = embed_dims
|
|
|
self.num_heads = num_heads
|
|
|
self.num_frames = num_frames
|
|
|
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
|
|
|
if digit_version(torch.__version__) < digit_version('1.9.0'):
|
|
|
kwargs.pop('batch_first', None)
|
|
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
|
|
**kwargs)
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
self.dropout_layer = build_dropout(
|
|
|
dropout_layer) if dropout_layer else nn.Identity()
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def init_weights(self):
|
|
|
"""init DividedSpatialAttentionWithNorm by default."""
|
|
|
pass
|
|
|
|
|
|
def forward(self, query, key=None, value=None, residual=None, **kwargs):
|
|
|
"""Defines the computation performed at every call."""
|
|
|
assert residual is None, (
|
|
|
'Always adding the shortcut in the forward function')
|
|
|
|
|
|
identity = query
|
|
|
init_cls_token = query[:, 0, :].unsqueeze(1)
|
|
|
query_s = query[:, 1:, :]
|
|
|
|
|
|
|
|
|
b, pt, m = query_s.size()
|
|
|
p, t = pt // self.num_frames, self.num_frames
|
|
|
|
|
|
|
|
|
cls_token = init_cls_token.repeat(1, t, 1).reshape(b * t,
|
|
|
m).unsqueeze(1)
|
|
|
|
|
|
|
|
|
query_s = rearrange(query_s, 'b (p t) m -> (b t) p m', p=p, t=t)
|
|
|
query_s = torch.cat((cls_token, query_s), 1)
|
|
|
|
|
|
|
|
|
query_s = self.norm(query_s).permute(1, 0, 2)
|
|
|
res_spatial = self.attn(query_s, query_s, query_s)[0].permute(1, 0, 2)
|
|
|
res_spatial = self.dropout_layer(
|
|
|
self.proj_drop(res_spatial.contiguous()))
|
|
|
|
|
|
|
|
|
cls_token = res_spatial[:, 0, :].reshape(b, t, m)
|
|
|
cls_token = torch.mean(cls_token, 1, True)
|
|
|
|
|
|
|
|
|
res_spatial = rearrange(
|
|
|
res_spatial[:, 1:, :], '(b t) p m -> b (p t) m', p=p, t=t)
|
|
|
res_spatial = torch.cat((cls_token, res_spatial), 1)
|
|
|
|
|
|
new_query = identity + res_spatial
|
|
|
return new_query
|
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
|
class FFNWithNorm(FFN):
|
|
|
"""FFN with pre normalization layer.
|
|
|
|
|
|
FFNWithNorm is implemented to be compatible with `BaseTransformerLayer`
|
|
|
when using `DividedTemporalAttentionWithNorm` and
|
|
|
`DividedSpatialAttentionWithNorm`.
|
|
|
|
|
|
FFNWithNorm has one main difference with FFN:
|
|
|
|
|
|
- It apply one normalization layer before forwarding the input data to
|
|
|
feed-forward networks.
|
|
|
|
|
|
Args:
|
|
|
embed_dims (int): Dimensions of embedding. Defaults to 256.
|
|
|
feedforward_channels (int): Hidden dimension of FFNs. Defaults to 1024.
|
|
|
num_fcs (int, optional): Number of fully-connected layers in FFNs.
|
|
|
Defaults to 2.
|
|
|
act_cfg (dict): Config for activate layers.
|
|
|
Defaults to `dict(type='ReLU')`
|
|
|
ffn_drop (float, optional): Probability of an element to be
|
|
|
zeroed in FFN. Defaults to 0..
|
|
|
add_residual (bool, optional): Whether to add the
|
|
|
residual connection. Defaults to `True`.
|
|
|
dropout_layer (dict | None): The dropout_layer used when adding the
|
|
|
shortcut. Defaults to None.
|
|
|
init_cfg (dict): The Config for initialization. Defaults to None.
|
|
|
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
|
|
`dict(type='LN')`.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *args, norm_cfg=dict(type='LN'), **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
|
|
|
|
|
|
def forward(self, x, residual=None):
|
|
|
"""Defines the computation performed at every call."""
|
|
|
assert residual is None, ('Cannot apply pre-norm with FFNWithNorm')
|
|
|
return super().forward(self.norm(x), x)
|
|
|
|