File size: 12,439 Bytes
d670799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath
from mmengine import to_2tuple
from mmaction.registry import MODELS
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the original BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(self.q_bias,
torch.zeros_like(self.v_bias,
requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(
init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
num_frames=16,
tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.tubelet_size = int(tubelet_size)
num_patches = (img_size[1] //
patch_size[1]) * (img_size[0] // patch_size[0]) * (
num_frames // self.tubelet_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]))
def forward(self, x):
B, C, T, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model " \
f'({self.img_size[0]}*{self.img_size[1]}).'
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
def get_sinusoid_encoding_table(n_position,
d_hid,
cur_frame=-1,
pre_n_position=1568):
"""Sinusoid position encoding table."""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = torch.tensor(
sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
print(f'n_position: {n_position}')
print(f'pre_n_position: {pre_n_position}')
if n_position // cur_frame * 8 != pre_n_position and cur_frame != -1:
T = 8 # checkpoint frame
P = 14 # checkpoint size
C = d_hid
new_P = int((n_position // cur_frame)**0.5) # testing size
print(
f'Pretraining uses 14x14, but current version is {new_P}x{new_P}')
print('Interpolate the position embedding')
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.reshape(-1, P, P,
C).permute(0, 3, 1, 2)
sinusoid_table = torch.nn.functional.interpolate(
sinusoid_table,
size=(new_P, new_P),
mode='bicubic',
align_corners=False)
# BT, C, H, W -> BT, H, W, C -> B, T, H, W, C
sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(
-1, T, new_P, new_P, C)
sinusoid_table = sinusoid_table.flatten(1, 3)
if cur_frame != -1 and cur_frame != 8:
print(f'Pretraining uses 8 frames, but current frame is {cur_frame}')
print('Interpolate the position embedding')
T = 8 # checkpoint frame
new_T = cur_frame # testing frame
# interpolate
P = int((n_position // cur_frame)**0.5) # testing size
C = d_hid
sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C)
sinusoid_table = sinusoid_table.permute(0, 2, 3, 4,
1).reshape(-1, C,
T) # BHW, C, T
sinusoid_table = torch.nn.functional.interpolate(
sinusoid_table, size=new_T, mode='linear')
sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(
0, 4, 1, 2, 3) # B, T, H, W, C
sinusoid_table = sinusoid_table.flatten(1, 3)
if n_position == pre_n_position:
return sinusoid_table
else:
print('Use learnable position embedding')
return nn.Parameter(sinusoid_table, requires_grad=True)
@MODELS.register_module()
class UMTViT(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.,
use_learnable_pos_emb=False,
all_frames=16,
tubelet_size=1,
use_checkpoint=False,
checkpoint_num=0,
use_mean_pooling=True):
super().__init__()
self.num_features = self.embed_dim = embed_dim
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=self.tubelet_size)
num_patches = self.patch_embed.num_patches
self.use_checkpoint = use_checkpoint
self.checkpoint_num = checkpoint_num
print(f'Use checkpoint: {use_checkpoint}')
print(f'Checkpoint number: {checkpoint_num}')
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
if patch_size == 14:
pre_n_position = 2048
else:
pre_n_position = 1568
self.pos_embed = get_sinusoid_encoding_table(
num_patches,
embed_dim,
all_frames // tubelet_size,
pre_n_position=pre_n_position)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values) for i in range(depth)
])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(
embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
def forward_features(self, x):
x = self.patch_embed(x)
B, _, _ = x.size()
if self.pos_embed is not None:
x = x + self.pos_embed.expand(B, -1, -1).type_as(x).to(
x.device).clone().detach()
x = self.pos_drop(x)
for idx, blk in enumerate(self.blocks):
if self.use_checkpoint and idx < self.checkpoint_num:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
return x
|