File size: 12,553 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmengine import ConfigDict
from mmengine.logging import MMLogger
from mmengine.model.weight_init import kaiming_init, trunc_normal_
from mmengine.runner.checkpoint import _load_checkpoint, load_state_dict
from torch.nn.modules.utils import _pair
from mmaction.registry import MODELS
class PatchEmbed(nn.Module):
"""Image to Patch Embedding.
Args:
img_size (int | tuple): Size of input image.
patch_size (int): Size of one patch.
in_channels (int): Channel num of input features. Defaults to 3.
embed_dims (int): Dimensions of embedding. Defaults to 768.
conv_cfg (dict | None): Config dict for convolution layer. Defaults to
`dict(type='Conv2d')`.
"""
def __init__(self,
img_size,
patch_size,
in_channels=3,
embed_dims=768,
conv_cfg=dict(type='Conv2d')):
super().__init__()
self.img_size = _pair(img_size)
self.patch_size = _pair(patch_size)
num_patches = (self.img_size[1] // self.patch_size[1]) * (
self.img_size[0] // self.patch_size[0])
assert num_patches * self.patch_size[0] * self.patch_size[1] == \
self.img_size[0] * self.img_size[1], \
'The image size H*W must be divisible by patch size'
self.num_patches = num_patches
# Use conv layer to embed
self.projection = build_conv_layer(
conv_cfg,
in_channels,
embed_dims,
kernel_size=patch_size,
stride=patch_size)
self.init_weights()
def init_weights(self):
"""Initialize weights."""
# Lecun norm from ClassyVision
kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')
def forward(self, x):
"""Defines the computation performed at every call.
Args:
x (Tensor): The input data.
Returns:
Tensor: The output of the module.
"""
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.projection(x).flatten(2).transpose(1, 2)
return x
@MODELS.register_module()
class TimeSformer(nn.Module):
"""TimeSformer. A PyTorch impl of `Is Space-Time Attention All You Need for
Video Understanding? <https://arxiv.org/abs/2102.05095>`_
Args:
num_frames (int): Number of frames in the video.
img_size (int | tuple): Size of input image.
patch_size (int): Size of one patch.
pretrained (str | None): Name of pretrained model. Default: None.
embed_dims (int): Dimensions of embedding. Defaults to 768.
num_heads (int): Number of parallel attention heads in
TransformerCoder. Defaults to 12.
num_transformer_layers (int): Number of transformer layers. Defaults to
12.
in_channels (int): Channel num of input features. Defaults to 3.
dropout_ratio (float): Probability of dropout layer. Defaults to 0..
transformer_layers (list[obj:`mmcv.ConfigDict`] |
obj:`mmcv.ConfigDict` | None): Config of transformerlayer in
TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be
repeated `num_transformer_layers` times to a
list[obj:`mmcv.ConfigDict`]. Defaults to None.
attention_type (str): Type of attentions in TransformerCoder. Choices
are 'divided_space_time', 'space_only' and 'joint_space_time'.
Defaults to 'divided_space_time'.
norm_cfg (dict): Config for norm layers. Defaults to
`dict(type='LN', eps=1e-6)`.
"""
supported_attention_types = [
'divided_space_time', 'space_only', 'joint_space_time'
]
def __init__(self,
num_frames,
img_size,
patch_size,
pretrained=None,
embed_dims=768,
num_heads=12,
num_transformer_layers=12,
in_channels=3,
dropout_ratio=0.,
transformer_layers=None,
attention_type='divided_space_time',
norm_cfg=dict(type='LN', eps=1e-6),
**kwargs):
super().__init__(**kwargs)
assert attention_type in self.supported_attention_types, (
f'Unsupported Attention Type {attention_type}!')
assert transformer_layers is None or isinstance(
transformer_layers, (dict, list))
self.num_frames = num_frames
self.pretrained = pretrained
self.embed_dims = embed_dims
self.num_transformer_layers = num_transformer_layers
self.attention_type = attention_type
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
self.drop_after_pos = nn.Dropout(p=dropout_ratio)
if self.attention_type != 'space_only':
self.time_embed = nn.Parameter(
torch.zeros(1, num_frames, embed_dims))
self.drop_after_time = nn.Dropout(p=dropout_ratio)
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
if transformer_layers is None:
# stochastic depth decay rule
dpr = np.linspace(0, 0.1, num_transformer_layers)
if self.attention_type == 'divided_space_time':
_transformerlayers_cfg = [
dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='DividedTemporalAttentionWithNorm',
embed_dims=embed_dims,
num_heads=num_heads,
num_frames=num_frames,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]),
norm_cfg=dict(type='LN', eps=1e-6)),
dict(
type='DividedSpatialAttentionWithNorm',
embed_dims=embed_dims,
num_heads=num_heads,
num_frames=num_frames,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]),
norm_cfg=dict(type='LN', eps=1e-6))
],
ffn_cfgs=dict(
type='FFNWithNorm',
embed_dims=embed_dims,
feedforward_channels=embed_dims * 4,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]),
norm_cfg=dict(type='LN', eps=1e-6)),
operation_order=('self_attn', 'self_attn', 'ffn'))
for i in range(num_transformer_layers)
]
else:
# Sapce Only & Joint Space Time
_transformerlayers_cfg = [
dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
batch_first=True,
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i]))
],
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=embed_dims * 4,
num_fcs=2,
act_cfg=dict(type='GELU'),
dropout_layer=dict(
type='DropPath', drop_prob=dpr[i])),
operation_order=('norm', 'self_attn', 'norm', 'ffn'),
norm_cfg=dict(type='LN', eps=1e-6),
batch_first=True)
for i in range(num_transformer_layers)
]
transformer_layers = ConfigDict(
dict(
type='TransformerLayerSequence',
transformerlayers=_transformerlayers_cfg,
num_layers=num_transformer_layers))
self.transformer_layers = build_transformer_layer_sequence(
transformer_layers)
def init_weights(self, pretrained=None):
"""Initiate the parameters either from existing checkpoint or from
scratch."""
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
if pretrained:
self.pretrained = pretrained
if isinstance(self.pretrained, str):
logger = MMLogger.get_current_instance()
logger.info(f'load model from: {self.pretrained}')
state_dict = _load_checkpoint(self.pretrained, map_location='cpu')
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
if self.attention_type == 'divided_space_time':
# modify the key names of norm layers
old_state_dict_keys = list(state_dict.keys())
for old_key in old_state_dict_keys:
if 'norms' in old_key:
new_key = old_key.replace('norms.0',
'attentions.0.norm')
new_key = new_key.replace('norms.1', 'ffns.0.norm')
state_dict[new_key] = state_dict.pop(old_key)
# copy the parameters of space attention to time attention
old_state_dict_keys = list(state_dict.keys())
for old_key in old_state_dict_keys:
if 'attentions.0' in old_key:
new_key = old_key.replace('attentions.0',
'attentions.1')
state_dict[new_key] = state_dict[old_key].clone()
load_state_dict(self, state_dict, strict=False, logger=logger)
def forward(self, x):
"""Defines the computation performed at every call."""
# x [batch_size * num_frames, num_patches, embed_dims]
batches = x.shape[0]
x = self.patch_embed(x)
# x [batch_size * num_frames, num_patches + 1, embed_dims]
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
# Add Time Embedding
if self.attention_type != 'space_only':
# x [batch_size, num_patches * num_frames + 1, embed_dims]
cls_tokens = x[:batches, 0, :].unsqueeze(1)
x = rearrange(x[:, 1:, :], '(b t) p m -> (b p) t m', b=batches)
x = x + self.time_embed
x = self.drop_after_time(x)
x = rearrange(x, '(b p) t m -> b (p t) m', b=batches)
x = torch.cat((cls_tokens, x), dim=1)
x = self.transformer_layers(x, None, None)
if self.attention_type == 'space_only':
# x [batch_size, num_patches + 1, embed_dims]
x = x.view(-1, self.num_frames, *x.size()[-2:])
x = torch.mean(x, 1)
x = self.norm(x)
# Return Class Token
return x[:, 0]
|