|
|
import logging
|
|
|
from pathlib import Path
|
|
|
|
|
|
import einops
|
|
|
import torch
|
|
|
from omegaconf import OmegaConf
|
|
|
from timm.layers import trunc_normal_
|
|
|
from torch import nn
|
|
|
|
|
|
from .utils import check_if_file_exists_else_download
|
|
|
from .video_model_builder import VisionTransformer
|
|
|
|
|
|
FILE2URL = {
|
|
|
|
|
|
'motionformer_224_16x4.yaml':
|
|
|
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
|
|
|
'joint_224_16x4.yaml':
|
|
|
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
|
|
|
'divided_224_16x4.yaml':
|
|
|
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
|
|
|
|
|
|
'ssv2_motionformer_224_16x4.pyth':
|
|
|
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
|
|
|
'ssv2_joint_224_16x4.pyth':
|
|
|
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
|
|
|
'ssv2_divided_224_16x4.pyth':
|
|
|
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
|
|
|
}
|
|
|
|
|
|
|
|
|
class MotionFormer(VisionTransformer):
|
|
|
''' This class serves three puposes:
|
|
|
1. Renames the class to MotionFormer.
|
|
|
2. Downloads the cfg from the original repo and patches it if needed.
|
|
|
3. Takes care of feature extraction by redefining .forward()
|
|
|
- if `extract_features=True` and `factorize_space_time=False`,
|
|
|
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
|
|
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
|
|
|
and spatial and temporal transformer encoder layers are used.
|
|
|
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
|
|
|
the output is of shape (B, D) and spatial and temporal transformer encoder layers
|
|
|
are used as well as the global representation is extracted from segments (extra pos emb
|
|
|
is added).
|
|
|
'''
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
extract_features: bool = False,
|
|
|
ckpt_path: str = None,
|
|
|
factorize_space_time: bool = None,
|
|
|
agg_space_module: str = None,
|
|
|
agg_time_module: str = None,
|
|
|
add_global_repr: bool = True,
|
|
|
agg_segments_module: str = None,
|
|
|
max_segments: int = None,
|
|
|
):
|
|
|
self.extract_features = extract_features
|
|
|
self.ckpt_path = ckpt_path
|
|
|
self.factorize_space_time = factorize_space_time
|
|
|
|
|
|
if self.ckpt_path is not None:
|
|
|
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
|
|
|
ckpt = torch.load(self.ckpt_path, map_location='cpu')
|
|
|
mformer_ckpt2cfg = {
|
|
|
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
|
|
|
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
|
|
|
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
was_pt_on_avclip = self.ckpt_path.endswith(
|
|
|
'.pt')
|
|
|
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
|
|
|
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
|
|
|
elif was_pt_on_avclip:
|
|
|
|
|
|
s1_cfg = ckpt.get('args', None)
|
|
|
if s1_cfg is not None:
|
|
|
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
|
|
|
|
|
|
if s1_vfeat_extractor_ckpt_path is not None:
|
|
|
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
|
|
|
else:
|
|
|
cfg_fname = 'divided_224_16x4.yaml'
|
|
|
else:
|
|
|
cfg_fname = 'divided_224_16x4.yaml'
|
|
|
else:
|
|
|
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
|
|
|
else:
|
|
|
was_pt_on_avclip = False
|
|
|
cfg_fname = 'divided_224_16x4.yaml'
|
|
|
|
|
|
|
|
|
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
|
|
|
pos_emb_type = 'separate'
|
|
|
elif cfg_fname == 'joint_224_16x4.yaml':
|
|
|
pos_emb_type = 'joint'
|
|
|
|
|
|
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
|
|
|
|
|
|
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
|
|
|
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
|
|
|
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
|
|
|
|
|
|
|
|
|
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
|
|
|
mformer_cfg.VIT.POS_EMBED = pos_emb_type
|
|
|
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
|
|
|
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none'
|
|
|
mformer_cfg.VIT.APPROX_ATTN_DIM = 64
|
|
|
|
|
|
|
|
|
super().__init__(mformer_cfg)
|
|
|
|
|
|
|
|
|
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
|
|
|
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
|
|
|
if len(_ckpt_load_status.missing_keys) > 0 or len(
|
|
|
_ckpt_load_status.unexpected_keys) > 0:
|
|
|
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
|
|
|
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
|
|
|
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
|
|
|
else:
|
|
|
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
|
|
|
|
|
if self.extract_features:
|
|
|
assert isinstance(self.norm,
|
|
|
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
|
|
|
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|
|
|
self.head = nn.Identity()
|
|
|
self.head_drop = nn.Identity()
|
|
|
|
|
|
transf_enc_layer_kwargs = dict(
|
|
|
d_model=self.embed_dim,
|
|
|
nhead=self.num_heads,
|
|
|
activation=nn.GELU(),
|
|
|
batch_first=True,
|
|
|
dim_feedforward=self.mlp_ratio * self.embed_dim,
|
|
|
dropout=self.drop_rate,
|
|
|
layer_norm_eps=1e-6,
|
|
|
norm_first=True,
|
|
|
)
|
|
|
|
|
|
if self.factorize_space_time:
|
|
|
if agg_space_module == 'TransformerEncoderLayer':
|
|
|
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
|
|
|
**transf_enc_layer_kwargs)
|
|
|
elif agg_space_module == 'AveragePooling':
|
|
|
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
|
|
|
then_permute_pattern='BS D t -> BS t D')
|
|
|
if agg_time_module == 'TransformerEncoderLayer':
|
|
|
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
|
|
elif agg_time_module == 'AveragePooling':
|
|
|
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
|
|
elif 'Identity' in agg_time_module:
|
|
|
self.temp_attn_agg = nn.Identity()
|
|
|
|
|
|
self.add_global_repr = add_global_repr
|
|
|
if add_global_repr:
|
|
|
if agg_segments_module == 'TransformerEncoderLayer':
|
|
|
|
|
|
|
|
|
pos_max_len = max_segments if max_segments is not None else 16
|
|
|
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
|
|
add_pos_emb=True,
|
|
|
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
|
|
|
pos_max_len=pos_max_len,
|
|
|
**transf_enc_layer_kwargs)
|
|
|
elif agg_segments_module == 'AveragePooling':
|
|
|
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
|
|
|
|
|
if was_pt_on_avclip:
|
|
|
|
|
|
|
|
|
ckpt_weights = dict()
|
|
|
for k, v in ckpt['state_dict'].items():
|
|
|
if k.startswith(('module.v_encoder.', 'v_encoder.')):
|
|
|
k = k.replace('module.', '').replace('v_encoder.', '')
|
|
|
ckpt_weights[k] = v
|
|
|
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
|
|
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
|
|
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
|
|
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
|
|
f'{_load_status.missing_keys}, \n' \
|
|
|
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
|
|
f'{_load_status.unexpected_keys} \n' \
|
|
|
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
|
|
else:
|
|
|
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed.requires_grad_(False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
'''
|
|
|
x is of shape (B, S, C, T, H, W) where S is the number of segments.
|
|
|
'''
|
|
|
|
|
|
B, S, C, T, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
orig_shape = (B, S, C, T, H, W)
|
|
|
x = x.view(B * S, C, T, H, W)
|
|
|
x = self.forward_segments(x, orig_shape=orig_shape)
|
|
|
|
|
|
x = x.view(B, S, *x.shape[1:])
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
|
|
|
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
|
|
|
x, x_mask = self.forward_features(x)
|
|
|
|
|
|
assert self.extract_features
|
|
|
|
|
|
|
|
|
x = x[:,
|
|
|
1:, :]
|
|
|
x = self.norm(x)
|
|
|
x = self.pre_logits(x)
|
|
|
if self.factorize_space_time:
|
|
|
x = self.restore_spatio_temp_dims(x, orig_shape)
|
|
|
|
|
|
x = self.spatial_attn_agg(x, x_mask)
|
|
|
x = self.temp_attn_agg(
|
|
|
x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
|
|
'''
|
|
|
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
|
|
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
|
|
|
From `self.patch_embed_3d`, it follows that we could reshape feats with:
|
|
|
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
|
|
|
'''
|
|
|
B, S, C, T, H, W = orig_shape
|
|
|
D = self.embed_dim
|
|
|
|
|
|
|
|
|
t = T // self.patch_embed_3d.z_block_size
|
|
|
h = self.patch_embed_3d.height
|
|
|
w = self.patch_embed_3d.width
|
|
|
|
|
|
feats = feats.permute(0, 2, 1)
|
|
|
feats = feats.view(B * S, D, t, h, w)
|
|
|
|
|
|
return feats
|
|
|
|
|
|
|
|
|
class BaseEncoderLayer(nn.TransformerEncoderLayer):
|
|
|
'''
|
|
|
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
|
|
|
to the sequence and outputs the CLS token's representation.
|
|
|
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
|
|
|
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
|
|
|
We also, optionally, add a positional embedding to the input sequence which
|
|
|
allows to reuse it for global aggregation (of segments) for both streams.
|
|
|
'''
|
|
|
|
|
|
def __init__(self,
|
|
|
add_pos_emb: bool = False,
|
|
|
pos_emb_drop: float = None,
|
|
|
pos_max_len: int = None,
|
|
|
*args_transformer_enc,
|
|
|
**kwargs_transformer_enc):
|
|
|
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
|
|
|
|
self.add_pos_emb = add_pos_emb
|
|
|
if add_pos_emb:
|
|
|
self.pos_max_len = 1 + pos_max_len
|
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
|
|
|
self.pos_drop = nn.Dropout(pos_emb_drop)
|
|
|
trunc_normal_(self.pos_emb, std=.02)
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
|
|
|
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
|
|
|
batch_dim = x.shape[0]
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_dim, -1, -1)
|
|
|
x = torch.cat((cls_tokens, x), dim=-2)
|
|
|
if x_mask is not None:
|
|
|
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
|
|
|
device=x_mask.device)
|
|
|
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1)
|
|
|
B, N = x_mask_w_cls.shape
|
|
|
|
|
|
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
|
|
|
.expand(-1, self.self_attn.num_heads, N, -1)\
|
|
|
.reshape(B * self.self_attn.num_heads, N, N)
|
|
|
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
|
|
|
x_mask_w_cls = ~x_mask_w_cls
|
|
|
else:
|
|
|
x_mask_w_cls = None
|
|
|
|
|
|
|
|
|
if self.add_pos_emb:
|
|
|
seq_len = x.shape[
|
|
|
1]
|
|
|
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
|
|
|
x = x + self.pos_emb[:, :seq_len, :]
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
x = super().forward(src=x, src_mask=x_mask_w_cls)
|
|
|
|
|
|
|
|
|
x = x[:, 0, :]
|
|
|
|
|
|
return x
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
if isinstance(m, nn.Linear):
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
def no_weight_decay(self):
|
|
|
return {'cls_token', 'pos_emb'}
|
|
|
|
|
|
|
|
|
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
|
|
|
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
|
|
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
|
|
|
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
|
|
|
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
|
|
|
BS, D, t, h, w = x.shape
|
|
|
|
|
|
|
|
|
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
|
|
|
|
|
|
if x_mask is not None:
|
|
|
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
|
|
|
|
|
|
|
|
|
x = super().forward(x=x, x_mask=x_mask)
|
|
|
|
|
|
|
|
|
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
|
|
|
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
|
|
|
in both streams. '''
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, x):
|
|
|
''' x is of shape (B*S, t, D) where S is the number of segments.
|
|
|
Returns a tensor of shape (B*S, D) pooling temporal information. '''
|
|
|
BS, t, D = x.shape
|
|
|
|
|
|
|
|
|
x = super().forward(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class AveragePooling(nn.Module):
|
|
|
|
|
|
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
|
|
|
''' patterns are e.g. "bs t d -> bs d" '''
|
|
|
super().__init__()
|
|
|
|
|
|
self.reduce_fn = 'mean'
|
|
|
self.avg_pattern = avg_pattern
|
|
|
self.then_permute_pattern = then_permute_pattern
|
|
|
|
|
|
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
|
|
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
|
|
|
if self.then_permute_pattern is not None:
|
|
|
x = einops.rearrange(x, self.then_permute_pattern)
|
|
|
return x
|
|
|
|