|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from math import ceil |
|
|
import fvcore.nn.weight_init as weight_init |
|
|
from typing import Optional |
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from detectron2.config import configurable |
|
|
from detectron2.layers import Conv2d |
|
|
from .fuse_modules import BiAttentionBlock |
|
|
from scipy.optimize import linear_sum_assignment |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class SelfAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
|
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
|
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, tgt_mask, |
|
|
tgt_key_padding_mask, query_pos) |
|
|
return self.forward_post(tgt, tgt_mask, |
|
|
tgt_key_padding_mask, query_pos) |
|
|
|
|
|
|
|
|
class CrossAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt, memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
return self.forward_post(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
|
|
|
class FFNLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt): |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt) |
|
|
return self.forward_post(tgt) |
|
|
|
|
|
|
|
|
def _get_activation_fn(activation): |
|
|
"""Return an activation function given a string""" |
|
|
if activation == "relu": |
|
|
return F.relu |
|
|
if activation == "gelu": |
|
|
return F.gelu |
|
|
if activation == "glu": |
|
|
return F.glu |
|
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
|
|
def forward(self, x): |
|
|
for i, layer in enumerate(self.layers): |
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DsHmpHierarchical(nn.Module): |
|
|
|
|
|
@configurable |
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
aux_loss, |
|
|
*, |
|
|
hidden_dim: int, |
|
|
num_frame_queries: int, |
|
|
num_queries: int, |
|
|
nheads: int, |
|
|
dim_feedforward: int, |
|
|
enc_layers: int, |
|
|
dec_layers: int, |
|
|
enc_window_size: int, |
|
|
pre_norm: bool, |
|
|
enforce_input_project: bool, |
|
|
num_frames: int, |
|
|
num_classes: int, |
|
|
clip_last_layer_num: bool, |
|
|
conv_dim: int, |
|
|
mask_dim: int, |
|
|
sim_use_clip: list, |
|
|
use_sim: bool, |
|
|
): |
|
|
""" |
|
|
NOTE: this interface is experimental. |
|
|
Args: |
|
|
in_channels: channels of the input features |
|
|
hidden_dim: Transformer feature dimension |
|
|
num_queries: number of queries |
|
|
nheads: number of heads |
|
|
dim_feedforward: feature dimension in feedforward network |
|
|
enc_layers: number of Transformer encoder layers |
|
|
dec_layers: number of Transformer decoder layers |
|
|
pre_norm: whether to use pre-LayerNorm or not |
|
|
enforce_input_project: add input project 1x1 conv even if input |
|
|
channels and hidden dim is identical |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.num_heads = nheads |
|
|
self.num_layers = dec_layers |
|
|
self.transformer_self_attention_layers = nn.ModuleList() |
|
|
self.transformer_cross_attention_layers = nn.ModuleList() |
|
|
self.transformer_ffn_layers = nn.ModuleList() |
|
|
self.fusion_layers = nn.ModuleList() |
|
|
self.cross_test = nn.ModuleList() |
|
|
self.num_frames = num_frames |
|
|
self.num_classes = num_classes |
|
|
self.clip_last_layer_num = clip_last_layer_num |
|
|
|
|
|
self.enc_layers = enc_layers |
|
|
self.window_size = enc_window_size |
|
|
self.sim_use_clip = sim_use_clip |
|
|
self.use_sim = use_sim |
|
|
self.aux_loss = aux_loss |
|
|
|
|
|
self.enc_layers = enc_layers |
|
|
if enc_layers > 0: |
|
|
self.enc_self_attn = nn.ModuleList() |
|
|
self.hierarchical_cross = nn.ModuleList() |
|
|
self.enc_ffn = nn.ModuleList() |
|
|
for _ in range(self.enc_layers): |
|
|
self.enc_self_attn.append( |
|
|
SelfAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
), |
|
|
) |
|
|
self.hierarchical_cross.append( |
|
|
HierarchicalCrossAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=False, |
|
|
) |
|
|
) |
|
|
self.enc_ffn.append( |
|
|
FFNLayer( |
|
|
d_model=hidden_dim, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
) |
|
|
) |
|
|
|
|
|
for _ in range(self.num_layers): |
|
|
self.transformer_self_attention_layers.append( |
|
|
SelfAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
) |
|
|
) |
|
|
|
|
|
self.transformer_cross_attention_layers.append( |
|
|
CrossAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
) |
|
|
) |
|
|
|
|
|
self.transformer_ffn_layers.append( |
|
|
FFNLayer( |
|
|
d_model=hidden_dim, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
) |
|
|
) |
|
|
self.cross_test.append( |
|
|
CrossAttentionLayer( |
|
|
d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.0, |
|
|
normalize_before=pre_norm, |
|
|
) |
|
|
) |
|
|
self.fusion_layers.append( |
|
|
BiAttentionBlock( |
|
|
v_dim=hidden_dim, |
|
|
l_dim=hidden_dim, |
|
|
embed_dim=dim_feedforward // 2, |
|
|
num_heads=self.num_heads // 2, |
|
|
dropout=0.1, |
|
|
drop_path=0.1, |
|
|
) |
|
|
) |
|
|
self.vita_mask_features = Conv2d( |
|
|
conv_dim, |
|
|
mask_dim, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
) |
|
|
weight_init.c2_xavier_fill(self.vita_mask_features) |
|
|
|
|
|
self.decoder_norm = nn.LayerNorm(hidden_dim) |
|
|
|
|
|
self.num_queries = num_queries |
|
|
|
|
|
self.query_feat = nn.Embedding(num_queries, hidden_dim) |
|
|
|
|
|
self.query_embed = nn.Embedding(num_queries, hidden_dim) |
|
|
|
|
|
self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim) |
|
|
|
|
|
if in_channels != hidden_dim or enforce_input_project: |
|
|
self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim) |
|
|
else: |
|
|
self.input_proj_dec = nn.Sequential() |
|
|
self.src_embed = nn.Identity() |
|
|
|
|
|
self.class_embed = nn.Linear(hidden_dim, num_classes + 1) |
|
|
|
|
|
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) |
|
|
if self.use_sim: |
|
|
self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim) |
|
|
if self.sim_use_clip: |
|
|
self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.contrastive_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3) |
|
|
self.cross_motion = CrossAttentionLayer(d_model=hidden_dim, |
|
|
nhead=nheads, |
|
|
dropout=0.1, |
|
|
normalize_before=pre_norm) |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, cfg, in_channels): |
|
|
ret = {} |
|
|
ret["in_channels"] = in_channels |
|
|
|
|
|
ret["hidden_dim"] = cfg.MODEL.VITA.HIDDEN_DIM |
|
|
ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES |
|
|
ret["num_queries"] = cfg.MODEL.VITA.NUM_OBJECT_QUERIES |
|
|
|
|
|
ret["nheads"] = cfg.MODEL.VITA.NHEADS |
|
|
ret["dim_feedforward"] = cfg.MODEL.VITA.DIM_FEEDFORWARD |
|
|
|
|
|
assert cfg.MODEL.VITA.DEC_LAYERS >= 1 |
|
|
ret["enc_layers"] = cfg.MODEL.VITA.ENC_LAYERS |
|
|
ret["dec_layers"] = cfg.MODEL.VITA.DEC_LAYERS |
|
|
ret["enc_window_size"] = cfg.MODEL.VITA.ENC_WINDOW_SIZE |
|
|
ret["pre_norm"] = cfg.MODEL.VITA.PRE_NORM |
|
|
ret["enforce_input_project"] = cfg.MODEL.VITA.ENFORCE_INPUT_PROJ |
|
|
|
|
|
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES |
|
|
ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM |
|
|
ret["clip_last_layer_num"] = cfg.MODEL.VITA.LAST_LAYER_NUM |
|
|
|
|
|
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM |
|
|
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM |
|
|
ret["sim_use_clip"] = cfg.MODEL.VITA.SIM_USE_CLIP |
|
|
ret["use_sim"] = cfg.MODEL.VITA.SIM_WEIGHT > 0.0 |
|
|
|
|
|
return ret |
|
|
|
|
|
def forward(self, frame_query, lang_feat, lang_mask, motion_feat=None): |
|
|
""" |
|
|
L: Number of Layers. |
|
|
B: Batch size. |
|
|
T: Temporal window size. Number of frames per video. |
|
|
C: Channel size. |
|
|
fQ: Number of frame-wise queries from IFC. |
|
|
cQ: Number of clip-wise queries to decode Q. |
|
|
""" |
|
|
if not self.training: |
|
|
frame_query = frame_query[[-1]] |
|
|
|
|
|
L, BT, fQ, C = frame_query.shape |
|
|
B = BT // self.num_frames if self.training else 1 |
|
|
T = self.num_frames if self.training else BT // B |
|
|
|
|
|
frame_query = frame_query.reshape(L * B, T, fQ, C) |
|
|
frame_query = frame_query.permute(1, 2, 0, 3).contiguous() |
|
|
frame_query = self.input_proj_dec(frame_query) |
|
|
|
|
|
if self.window_size > 0: |
|
|
pad = int(ceil(T / self.window_size)) * self.window_size - T |
|
|
_T = pad + T |
|
|
frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) |
|
|
enc_mask = frame_query.new_ones(L * B, _T).bool() |
|
|
enc_mask[:, :T] = False |
|
|
else: |
|
|
enc_mask = None |
|
|
|
|
|
frame_query = self.encode_frame_query(frame_query, enc_mask, motion_feat) |
|
|
frame_query = frame_query[:T].flatten(0, 1) |
|
|
|
|
|
if self.use_sim: |
|
|
pred_fq_embed = self.sim_embed_frame(frame_query) |
|
|
pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C) |
|
|
else: |
|
|
pred_fq_embed = None |
|
|
|
|
|
src = self.src_embed(frame_query) |
|
|
dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) |
|
|
|
|
|
|
|
|
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) |
|
|
output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) |
|
|
|
|
|
output_motion = self.cross_motion(output, motion_feat.unsqueeze(1).repeat(1, L, 1)) |
|
|
output = output + 0.1 * output_motion |
|
|
decoder_outputs = [] |
|
|
|
|
|
lang_feat_fusion = lang_feat.repeat(L, 1, 1).transpose(0, 1) |
|
|
lang_mask = lang_mask.repeat(L, 1) |
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
|
|
|
|
|
src, lang_feat_fusion = self.fusion_layers[i]( |
|
|
v=src, |
|
|
l=lang_feat_fusion, |
|
|
attention_mask_v=None, |
|
|
attention_mask_l=~lang_mask.bool(), |
|
|
) |
|
|
|
|
|
output = self.transformer_cross_attention_layers[i]( |
|
|
output, src, |
|
|
memory_mask=None, |
|
|
memory_key_padding_mask=None, |
|
|
pos=dec_pos, query_pos=query_embed |
|
|
) |
|
|
|
|
|
output = self.cross_test[i]( |
|
|
output, |
|
|
lang_feat_fusion, |
|
|
memory_key_padding_mask=~lang_mask.bool() |
|
|
) |
|
|
|
|
|
output = self.transformer_self_attention_layers[i]( |
|
|
output, tgt_mask=None, |
|
|
tgt_key_padding_mask=None, |
|
|
query_pos=query_embed |
|
|
) |
|
|
|
|
|
|
|
|
output = self.transformer_ffn_layers[i]( |
|
|
output |
|
|
) |
|
|
|
|
|
if (self.training and self.aux_loss) or (i == self.num_layers - 1): |
|
|
dec_out = self.decoder_norm(output) |
|
|
dec_out = dec_out.transpose(0, 1) |
|
|
decoder_outputs.append(dec_out.view(L, B, self.num_queries, C)) |
|
|
|
|
|
decoder_outputs = torch.stack(decoder_outputs, dim=0) |
|
|
|
|
|
pred_cls = self.class_embed(decoder_outputs) |
|
|
pred_mask_embed = self.mask_embed(decoder_outputs) |
|
|
if self.use_sim and self.sim_use_clip: |
|
|
pred_cq_embed = self.sim_embed_clip(decoder_outputs) |
|
|
else: |
|
|
pred_cq_embed = [None] * self.num_layers |
|
|
pred_contrastive_embed = self.contrastive_embed(decoder_outputs) |
|
|
|
|
|
lang_feat_mask = lang_feat_fusion.transpose(0, 1) * lang_mask.unsqueeze(-1) |
|
|
lang_feat_mean = torch.sum(lang_feat_mask, dim=1) / lang_mask.sum(dim=1, keepdim=True) |
|
|
out = { |
|
|
'pred_logits': pred_cls[-1], |
|
|
'pred_mask_embed': pred_mask_embed[-1], |
|
|
' ': lang_feat_mean[-1], |
|
|
'pred_fq_embed': pred_fq_embed, |
|
|
'pred_cq_embed': pred_cq_embed[-1], |
|
|
"pred_contrastive_embed": pred_contrastive_embed[-1], |
|
|
'aux_outputs': self._set_aux_loss( |
|
|
pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed |
|
|
) |
|
|
} |
|
|
return out |
|
|
|
|
|
@torch.jit.unused |
|
|
def _set_aux_loss( |
|
|
self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed |
|
|
): |
|
|
return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed} |
|
|
for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])] |
|
|
|
|
|
def encode_frame_query(self, frame_query, attn_mask, lang_motion_feat): |
|
|
""" |
|
|
input shape (frame_query) : T, fQ, LB, C |
|
|
output shape (frame_query) : T, fQ, LB, C |
|
|
""" |
|
|
|
|
|
|
|
|
if self.window_size == 0: |
|
|
return_shape = frame_query.shape |
|
|
frame_query = frame_query.flatten(0, 1) |
|
|
|
|
|
for i in range(self.enc_layers): |
|
|
frame_query = self.enc_self_attn[i](frame_query) |
|
|
frame_query = self.enc_ffn[i](frame_query) |
|
|
|
|
|
frame_query = frame_query.view(return_shape) |
|
|
return frame_query |
|
|
|
|
|
else: |
|
|
T, fQ, LB, C = frame_query.shape |
|
|
W = self.window_size |
|
|
Nw = T // W |
|
|
half_W = int(ceil(W / 2)) |
|
|
|
|
|
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1) |
|
|
|
|
|
_attn_mask = torch.roll(attn_mask, half_W, 1) |
|
|
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) |
|
|
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1) |
|
|
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1) |
|
|
_attn_mask[:, 0, :half_W, half_W:] = True |
|
|
_attn_mask[:, 0, half_W:, :half_W] = True |
|
|
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(LB * Nw * self.num_heads, W * fQ, W * fQ) |
|
|
shift_window_mask = _attn_mask.float() * -1000 |
|
|
|
|
|
for layer_idx in range(self.enc_layers): |
|
|
if self.training or layer_idx % 2 == 0: |
|
|
frame_query = self._window_attn(frame_query, window_mask, layer_idx, lang_motion_feat) |
|
|
else: |
|
|
frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx, lang_motion_feat) |
|
|
return frame_query |
|
|
|
|
|
def _window_attn(self, frame_query, attn_mask, layer_idx, lang_motion_feat): |
|
|
T, fQ, LB, C = frame_query.shape |
|
|
|
|
|
|
|
|
W = self.window_size |
|
|
Nw = T // W |
|
|
|
|
|
frame_query = frame_query.view(Nw, W, fQ, LB, C) |
|
|
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C) |
|
|
|
|
|
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask) |
|
|
frame_query = frame_query.reshape(W, fQ, LB * Nw, C) |
|
|
frame_query = self.hierarchical_cross[layer_idx](frame_query, lang_motion_feat) |
|
|
frame_query = frame_query.reshape(W * fQ, LB * Nw, C) |
|
|
frame_query = self.enc_ffn[layer_idx](frame_query) |
|
|
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C) |
|
|
|
|
|
return frame_query |
|
|
|
|
|
def _shift_window_attn(self, frame_query, attn_mask, layer_idx, lang_motion_feat): |
|
|
T, fQ, LB, C = frame_query.shape |
|
|
|
|
|
|
|
|
W = self.window_size |
|
|
Nw = T // W |
|
|
half_W = int(ceil(W / 2)) |
|
|
|
|
|
frame_query = torch.roll(frame_query, half_W, 0) |
|
|
frame_query = frame_query.view(Nw, W, fQ, LB, C) |
|
|
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C) |
|
|
|
|
|
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask) |
|
|
frame_query = frame_query.reshape(W, fQ, LB * Nw, C) |
|
|
frame_query = self.hierarchical_cross[layer_idx](frame_query, lang_motion_feat) |
|
|
frame_query = frame_query.reshape(W * fQ, LB * Nw, C) |
|
|
frame_query = self.enc_ffn[layer_idx](frame_query) |
|
|
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C) |
|
|
|
|
|
frame_query = torch.roll(frame_query, -half_W, 0) |
|
|
|
|
|
return frame_query |
|
|
|
|
|
|
|
|
class HierarchicalCrossAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, d_model, nhead, dropout=0.0, eps=1e-8, |
|
|
activation="relu", normalize_before=False): |
|
|
super().__init__() |
|
|
self.to_q = nn.Linear(d_model, d_model) |
|
|
self.to_k = nn.Linear(d_model, d_model) |
|
|
self.to_v = nn.Linear(d_model, d_model) |
|
|
self.out1 = nn.Linear(d_model, d_model) |
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.out2 = nn.Linear(d_model, d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.out3 = nn.Linear(d_model, d_model) |
|
|
self.norm3 = nn.LayerNorm(d_model) |
|
|
self.eps = eps |
|
|
self.scale = d_model ** -0.5 |
|
|
self.res_gate = nn.Sequential( |
|
|
nn.Linear(d_model, d_model, bias=False), |
|
|
nn.ReLU(), |
|
|
nn.Linear(d_model, d_model, bias=False), |
|
|
nn.Tanh() |
|
|
) |
|
|
def sort_embedding(self, frame_query): |
|
|
sort_frame_query_j = [] |
|
|
inverse_maps = [] |
|
|
T, fQ, L, C = frame_query.shape |
|
|
for j in range(L): |
|
|
sort_frame_query_t = [] |
|
|
inverse_maps_j = [] |
|
|
for i in range(T): |
|
|
single_frame_embeds = frame_query[i, :, j] |
|
|
if i == 0: |
|
|
last_frame_embeds = single_frame_embeds |
|
|
inverse_map = np.arange(last_frame_embeds.size(0)) |
|
|
else: |
|
|
indices = self.match_embeds(last_frame_embeds, single_frame_embeds) |
|
|
inverse_map = np.argsort(indices) |
|
|
last_frame_embeds = single_frame_embeds[indices] |
|
|
sort_frame_query_t.append(last_frame_embeds) |
|
|
inverse_maps_j.append(inverse_map) |
|
|
sort_frame_query = torch.stack(sort_frame_query_t) |
|
|
sort_frame_query_j.append(sort_frame_query) |
|
|
inverse_maps.append(inverse_maps_j) |
|
|
sort_frame_query_j = torch.stack(sort_frame_query_j, dim=2) |
|
|
return sort_frame_query_j, inverse_maps |
|
|
|
|
|
def unsort_embedding(self, frame_query, inverse_maps): |
|
|
T, fQ, L, C = frame_query.shape |
|
|
for j in range(L): |
|
|
for i in range(1, T): |
|
|
frame_query[i, :, j] = frame_query[i, :, j][inverse_maps[j][i]] |
|
|
return frame_query |
|
|
|
|
|
def match_embeds(self, ref_embds, cur_embds): |
|
|
|
|
|
ref_embds, cur_embds = ref_embds.detach(), cur_embds.detach() |
|
|
ref_embds = ref_embds / (ref_embds.norm(dim=1)[:, None] + 1e-6) |
|
|
cur_embds = cur_embds / (cur_embds.norm(dim=1)[:, None] + 1e-6) |
|
|
cos_sim = torch.mm(ref_embds, cur_embds.transpose(0, 1)) |
|
|
C = 1 - cos_sim |
|
|
|
|
|
C = C.cpu() |
|
|
C = torch.where(torch.isnan(C), torch.full_like(C, 0), C) |
|
|
|
|
|
indices = linear_sum_assignment(C.transpose(0, 1)) |
|
|
indices = indices[1] |
|
|
return indices |
|
|
|
|
|
def pairwise_average(self, tensor): |
|
|
if tensor.shape[1] % 2 != 0: |
|
|
tensor = torch.cat([tensor, tensor.narrow(1, -1, 1)], dim=1) |
|
|
|
|
|
tensor = (tensor[:, ::2] + tensor[:, 1::2]) / 2 |
|
|
return tensor |
|
|
|
|
|
def duplicate_and_trim(self, tensor, final_dim): |
|
|
while tensor.shape[1] < final_dim: |
|
|
tensor = torch.repeat_interleave(tensor, 2, dim=1) |
|
|
return tensor.narrow(1, 0, final_dim) |
|
|
|
|
|
|
|
|
def get_attn(self, q, k, v): |
|
|
dots = torch.matmul(q, k.transpose(1, 2)) * self.scale |
|
|
|
|
|
attn = dots.softmax(dim=-2) |
|
|
weight = attn.sum(dim=-1, keepdim=True) |
|
|
attn = attn / (weight + self.eps) |
|
|
output = torch.matmul(attn, v) |
|
|
return output, weight |
|
|
|
|
|
|
|
|
def forward(self, tgt, memory): |
|
|
n_frame, n_instance, bs, n_channel = tgt.shape |
|
|
tgt, inverse_maps = self.sort_embedding(tgt) |
|
|
tgt = tgt.flatten(1, 2).transpose(0, 1) |
|
|
memory = memory.unsqueeze(0).repeat(n_instance*bs, 1, 1) |
|
|
q = self.to_q(tgt) |
|
|
k = self.to_k(memory) |
|
|
v = self.to_v(memory) |
|
|
|
|
|
output, weight = self.get_attn(q, k, v) |
|
|
output = self.out1(output) |
|
|
output = self.norm1(output + q) |
|
|
|
|
|
output = output |
|
|
|
|
|
output = self.pairwise_average(output) |
|
|
|
|
|
q = output |
|
|
output, weight = self.get_attn(q, k, v) |
|
|
output = self.out2(output) |
|
|
output = self.norm2(output + q) |
|
|
output = output |
|
|
output = self.pairwise_average(output) |
|
|
|
|
|
q = output |
|
|
output, weight = self.get_attn(q, k, v) |
|
|
output = self.out3(output) |
|
|
output = self.norm3(output + q) |
|
|
output = output |
|
|
output = self.pairwise_average(output) |
|
|
|
|
|
tgt = tgt + (self.res_gate(output) * output) |
|
|
tgt = tgt.transpose(0, 1).reshape(n_frame, n_instance, bs, n_channel) |
|
|
tgt = self.unsort_embedding(tgt, inverse_maps) |
|
|
return tgt |
|
|
|