|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
from functools import partial
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from timm.layers import trunc_normal_
|
|
|
|
|
|
from . import vit_helper
|
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
""" Vision Transformer with support for patch or hybrid CNN input stage """
|
|
|
|
|
|
def __init__(self, cfg):
|
|
|
super().__init__()
|
|
|
self.img_size = cfg.DATA.TRAIN_CROP_SIZE
|
|
|
self.patch_size = cfg.VIT.PATCH_SIZE
|
|
|
self.in_chans = cfg.VIT.CHANNELS
|
|
|
if cfg.TRAIN.DATASET == "Epickitchens":
|
|
|
self.num_classes = [97, 300]
|
|
|
else:
|
|
|
self.num_classes = cfg.MODEL.NUM_CLASSES
|
|
|
self.embed_dim = cfg.VIT.EMBED_DIM
|
|
|
self.depth = cfg.VIT.DEPTH
|
|
|
self.num_heads = cfg.VIT.NUM_HEADS
|
|
|
self.mlp_ratio = cfg.VIT.MLP_RATIO
|
|
|
self.qkv_bias = cfg.VIT.QKV_BIAS
|
|
|
self.drop_rate = cfg.VIT.DROP
|
|
|
self.drop_path_rate = cfg.VIT.DROP_PATH
|
|
|
self.head_dropout = cfg.VIT.HEAD_DROPOUT
|
|
|
self.video_input = cfg.VIT.VIDEO_INPUT
|
|
|
self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
|
|
|
self.use_mlp = cfg.VIT.USE_MLP
|
|
|
self.num_features = self.embed_dim
|
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
|
|
self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
|
|
|
self.head_act = cfg.VIT.HEAD_ACT
|
|
|
self.cfg = cfg
|
|
|
|
|
|
|
|
|
self.patch_embed = vit_helper.PatchEmbed(img_size=224,
|
|
|
patch_size=self.patch_size,
|
|
|
in_chans=self.in_chans,
|
|
|
embed_dim=self.embed_dim)
|
|
|
|
|
|
|
|
|
self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
|
|
|
temporal_resolution=self.temporal_resolution,
|
|
|
patch_size=self.patch_size,
|
|
|
in_chans=self.in_chans,
|
|
|
embed_dim=self.embed_dim,
|
|
|
z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
|
|
|
self.patch_embed_3d.proj.weight.data = torch.zeros_like(
|
|
|
self.patch_embed_3d.proj.weight.data)
|
|
|
|
|
|
|
|
|
if self.video_input:
|
|
|
num_patches = self.patch_embed.num_patches * self.temporal_resolution
|
|
|
else:
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
self.num_patches = num_patches
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(
|
|
|
torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
|
|
|
self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
|
if self.cfg.VIT.POS_EMBED == "joint":
|
|
|
self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
|
|
trunc_normal_(self.st_embed, std=.02)
|
|
|
elif self.cfg.VIT.POS_EMBED == "separate":
|
|
|
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
|
|
|
if self.cfg.VIT.ATTN_LAYER == "divided":
|
|
|
self.blocks = nn.ModuleList([
|
|
|
vit_helper.DividedSpaceTimeBlock(
|
|
|
attn_type=cfg.VIT.ATTN_LAYER,
|
|
|
dim=self.embed_dim,
|
|
|
num_heads=self.num_heads,
|
|
|
mlp_ratio=self.mlp_ratio,
|
|
|
qkv_bias=self.qkv_bias,
|
|
|
drop=self.drop_rate,
|
|
|
attn_drop=self.attn_drop_rate,
|
|
|
drop_path=dpr[i],
|
|
|
norm_layer=norm_layer,
|
|
|
) for i in range(self.depth)
|
|
|
])
|
|
|
else:
|
|
|
self.blocks = nn.ModuleList([
|
|
|
vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
|
|
|
dim=self.embed_dim,
|
|
|
num_heads=self.num_heads,
|
|
|
mlp_ratio=self.mlp_ratio,
|
|
|
qkv_bias=self.qkv_bias,
|
|
|
drop=self.drop_rate,
|
|
|
attn_drop=self.attn_drop_rate,
|
|
|
drop_path=dpr[i],
|
|
|
norm_layer=norm_layer,
|
|
|
use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
|
|
|
for i in range(self.depth)
|
|
|
])
|
|
|
self.norm = norm_layer(self.embed_dim)
|
|
|
|
|
|
|
|
|
if self.use_mlp:
|
|
|
hidden_dim = self.embed_dim
|
|
|
if self.head_act == 'tanh':
|
|
|
|
|
|
act = nn.Tanh()
|
|
|
elif self.head_act == 'gelu':
|
|
|
|
|
|
act = nn.GELU()
|
|
|
else:
|
|
|
|
|
|
act = nn.ReLU()
|
|
|
self.pre_logits = nn.Sequential(
|
|
|
OrderedDict([
|
|
|
('fc', nn.Linear(self.embed_dim, hidden_dim)),
|
|
|
('act', act),
|
|
|
]))
|
|
|
else:
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|
|
|
|
|
|
self.head_drop = nn.Dropout(p=self.head_dropout)
|
|
|
if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
|
|
for a, i in enumerate(range(len(self.num_classes))):
|
|
|
setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
|
|
|
else:
|
|
|
self.head = nn.Linear(self.embed_dim,
|
|
|
self.num_classes) if self.num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
if isinstance(m, nn.Linear):
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
def no_weight_decay(self):
|
|
|
if self.cfg.VIT.POS_EMBED == "joint":
|
|
|
return {'pos_embed', 'cls_token', 'st_embed'}
|
|
|
else:
|
|
|
return {'pos_embed', 'cls_token', 'temp_embed'}
|
|
|
|
|
|
def get_classifier(self):
|
|
|
return self.head
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
self.num_classes = num_classes
|
|
|
self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.patch_embed_3d(x)
|
|
|
tok_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_pos_embed = self.pos_embed
|
|
|
npatch = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
|
if self.video_input:
|
|
|
if self.cfg.VIT.POS_EMBED == "separate":
|
|
|
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
|
|
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
|
|
|
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
|
|
|
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
|
|
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
|
|
x = x + total_pos_embed
|
|
|
elif self.cfg.VIT.POS_EMBED == "joint":
|
|
|
x = x + self.st_embed
|
|
|
else:
|
|
|
|
|
|
x = x + new_pos_embed
|
|
|
|
|
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
for i, blk in enumerate(self.blocks):
|
|
|
x = blk(x,
|
|
|
seq_len=npatch,
|
|
|
num_frames=self.temporal_resolution,
|
|
|
approx=self.cfg.VIT.APPROX_ATTN_TYPE,
|
|
|
num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
|
|
|
tok_mask=tok_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x, tok_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|