dianecy's picture
Upload folder using huggingface_hub
729c925 verified
###########################################################################
# Created by: NTU
# Email: heshuting555@gmail.com
# Copyright (c) 2024
###########################################################################
from math import ceil
import fvcore.nn.weight_init as weight_init
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from detectron2.config import configurable
from detectron2.layers import Conv2d
from .fuse_modules import BiAttentionBlock
from scipy.optimize import linear_sum_assignment
import numpy as np
class SelfAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
return self.forward_post(tgt, tgt_mask,
tgt_key_padding_mask, query_pos)
class CrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt, memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, memory_mask,
memory_key_padding_mask, pos, query_pos)
class FFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm(tgt)
return tgt
def forward_pre(self, tgt):
tgt2 = self.norm(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout(tgt2)
return tgt
def forward(self, tgt):
if self.normalize_before:
return self.forward_pre(tgt)
return self.forward_post(tgt)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class DsHmpHierarchical(nn.Module):
@configurable
def __init__(
self,
in_channels,
aux_loss,
*,
hidden_dim: int,
num_frame_queries: int,
num_queries: int,
nheads: int,
dim_feedforward: int,
enc_layers: int,
dec_layers: int,
enc_window_size: int,
pre_norm: bool,
enforce_input_project: bool,
num_frames: int,
num_classes: int,
clip_last_layer_num: bool,
conv_dim: int,
mask_dim: int,
sim_use_clip: list,
use_sim: bool,
):
"""
NOTE: this interface is experimental.
Args:
in_channels: channels of the input features
hidden_dim: Transformer feature dimension
num_queries: number of queries
nheads: number of heads
dim_feedforward: feature dimension in feedforward network
enc_layers: number of Transformer encoder layers
dec_layers: number of Transformer decoder layers
pre_norm: whether to use pre-LayerNorm or not
enforce_input_project: add input project 1x1 conv even if input
channels and hidden dim is identical
"""
super().__init__()
# define Transformer decoder here
self.num_heads = nheads
self.num_layers = dec_layers
self.transformer_self_attention_layers = nn.ModuleList()
self.transformer_cross_attention_layers = nn.ModuleList()
self.transformer_ffn_layers = nn.ModuleList()
self.fusion_layers = nn.ModuleList()
self.cross_test = nn.ModuleList()
self.num_frames = num_frames
self.num_classes = num_classes
self.clip_last_layer_num = clip_last_layer_num
self.enc_layers = enc_layers
self.window_size = enc_window_size
self.sim_use_clip = sim_use_clip
self.use_sim = use_sim
self.aux_loss = aux_loss
self.enc_layers = enc_layers
if enc_layers > 0:
self.enc_self_attn = nn.ModuleList()
self.hierarchical_cross = nn.ModuleList()
self.enc_ffn = nn.ModuleList()
for _ in range(self.enc_layers):
self.enc_self_attn.append(
SelfAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
),
)
self.hierarchical_cross.append(
HierarchicalCrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=False,
)
)
self.enc_ffn.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
for _ in range(self.num_layers):
self.transformer_self_attention_layers.append(
SelfAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_cross_attention_layers.append(
CrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.transformer_ffn_layers.append(
FFNLayer(
d_model=hidden_dim,
dim_feedforward=dim_feedforward,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.cross_test.append(
CrossAttentionLayer(
d_model=hidden_dim,
nhead=nheads,
dropout=0.0,
normalize_before=pre_norm,
)
)
self.fusion_layers.append(
BiAttentionBlock(
v_dim=hidden_dim,
l_dim=hidden_dim,
embed_dim=dim_feedforward // 2,
num_heads=self.num_heads // 2,
dropout=0.1,
drop_path=0.1,
)
)
self.vita_mask_features = Conv2d(
conv_dim,
mask_dim,
kernel_size=1,
stride=1,
padding=0,
)
weight_init.c2_xavier_fill(self.vita_mask_features)
self.decoder_norm = nn.LayerNorm(hidden_dim)
self.num_queries = num_queries
# learnable query features
self.query_feat = nn.Embedding(num_queries, hidden_dim)
# learnable query p.e.
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.fq_pos = nn.Embedding(num_frame_queries, hidden_dim)
if in_channels != hidden_dim or enforce_input_project:
self.input_proj_dec = nn.Linear(hidden_dim, hidden_dim)
else:
self.input_proj_dec = nn.Sequential()
self.src_embed = nn.Identity()
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
# self.class_embed = nn.Linear(hidden_dim * 2, num_classes + 1)
self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
if self.use_sim:
self.sim_embed_frame = nn.Linear(hidden_dim, hidden_dim)
if self.sim_use_clip:
self.sim_embed_clip = nn.Linear(hidden_dim, hidden_dim)
self.contrastive_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
self.cross_motion = CrossAttentionLayer(d_model=hidden_dim,
nhead=nheads,
dropout=0.1,
normalize_before=pre_norm)
@classmethod
def from_config(cls, cfg, in_channels):
ret = {}
ret["in_channels"] = in_channels
ret["hidden_dim"] = cfg.MODEL.VITA.HIDDEN_DIM
ret["num_frame_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
ret["num_queries"] = cfg.MODEL.VITA.NUM_OBJECT_QUERIES
# Transformer parameters:
ret["nheads"] = cfg.MODEL.VITA.NHEADS
ret["dim_feedforward"] = cfg.MODEL.VITA.DIM_FEEDFORWARD
assert cfg.MODEL.VITA.DEC_LAYERS >= 1
ret["enc_layers"] = cfg.MODEL.VITA.ENC_LAYERS
ret["dec_layers"] = cfg.MODEL.VITA.DEC_LAYERS
ret["enc_window_size"] = cfg.MODEL.VITA.ENC_WINDOW_SIZE
ret["pre_norm"] = cfg.MODEL.VITA.PRE_NORM
ret["enforce_input_project"] = cfg.MODEL.VITA.ENFORCE_INPUT_PROJ
ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
ret["num_frames"] = cfg.INPUT.SAMPLING_FRAME_NUM
ret["clip_last_layer_num"] = cfg.MODEL.VITA.LAST_LAYER_NUM
ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
ret["sim_use_clip"] = cfg.MODEL.VITA.SIM_USE_CLIP
ret["use_sim"] = cfg.MODEL.VITA.SIM_WEIGHT > 0.0
return ret
def forward(self, frame_query, lang_feat, lang_mask, motion_feat=None):
"""
L: Number of Layers.
B: Batch size.
T: Temporal window size. Number of frames per video.
C: Channel size.
fQ: Number of frame-wise queries from IFC.
cQ: Number of clip-wise queries to decode Q.
"""
if not self.training:
frame_query = frame_query[[-1]]
L, BT, fQ, C = frame_query.shape
B = BT // self.num_frames if self.training else 1
T = self.num_frames if self.training else BT // B
frame_query = frame_query.reshape(L * B, T, fQ, C)
frame_query = frame_query.permute(1, 2, 0, 3).contiguous()
frame_query = self.input_proj_dec(frame_query) # T, fQ, LB, C
if self.window_size > 0:
pad = int(ceil(T / self.window_size)) * self.window_size - T
_T = pad + T
frame_query = F.pad(frame_query, (0, 0, 0, 0, 0, 0, 0, pad)) # _T, fQ, LB, C
enc_mask = frame_query.new_ones(L * B, _T).bool() # LB, _T
enc_mask[:, :T] = False
else:
enc_mask = None
frame_query = self.encode_frame_query(frame_query, enc_mask, motion_feat)
frame_query = frame_query[:T].flatten(0, 1) # TfQ, LB, C
if self.use_sim:
pred_fq_embed = self.sim_embed_frame(frame_query) # TfQ, LB, C
pred_fq_embed = pred_fq_embed.transpose(0, 1).reshape(L, B, T, fQ, C)
else:
pred_fq_embed = None
src = self.src_embed(frame_query) # TfQ, LB, C
dec_pos = self.fq_pos.weight[None, :, None, :].repeat(T, 1, L * B, 1).flatten(0, 1) # TfQ, LB, C
# QxNxC
query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
output = self.query_feat.weight.unsqueeze(1).repeat(1, L * B, 1) # cQ, LB, C
output_motion = self.cross_motion(output, motion_feat.unsqueeze(1).repeat(1, L, 1))
output = output + 0.1 * output_motion
decoder_outputs = []
lang_feat_fusion = lang_feat.repeat(L, 1, 1).transpose(0, 1)
lang_mask = lang_mask.repeat(L, 1)
for i in range(self.num_layers):
# attention: cross-attention first
src, lang_feat_fusion = self.fusion_layers[i](
v=src,
l=lang_feat_fusion,
attention_mask_v=None,
attention_mask_l=~lang_mask.bool(),
)
output = self.transformer_cross_attention_layers[i](
output, src,
memory_mask=None,
memory_key_padding_mask=None,
pos=dec_pos, query_pos=query_embed
)
output = self.cross_test[i](
output,
lang_feat_fusion,
memory_key_padding_mask=~lang_mask.bool()
)
output = self.transformer_self_attention_layers[i](
output, tgt_mask=None,
tgt_key_padding_mask=None,
query_pos=query_embed
)
# FFN
output = self.transformer_ffn_layers[i](
output
)
if (self.training and self.aux_loss) or (i == self.num_layers - 1):
dec_out = self.decoder_norm(output) # cQ, LB, C
dec_out = dec_out.transpose(0, 1) # LB, cQ, C
decoder_outputs.append(dec_out.view(L, B, self.num_queries, C))
decoder_outputs = torch.stack(decoder_outputs, dim=0) # D, L, B, cQ, C
pred_cls = self.class_embed(decoder_outputs)
pred_mask_embed = self.mask_embed(decoder_outputs)
if self.use_sim and self.sim_use_clip:
pred_cq_embed = self.sim_embed_clip(decoder_outputs)
else:
pred_cq_embed = [None] * self.num_layers
pred_contrastive_embed = self.contrastive_embed(decoder_outputs)
lang_feat_mask = lang_feat_fusion.transpose(0, 1) * lang_mask.unsqueeze(-1)
lang_feat_mean = torch.sum(lang_feat_mask, dim=1) / lang_mask.sum(dim=1, keepdim=True)
out = {
'pred_logits': pred_cls[-1],
'pred_mask_embed': pred_mask_embed[-1],
' ': lang_feat_mean[-1], # lang_feat_mean
'pred_fq_embed': pred_fq_embed,
'pred_cq_embed': pred_cq_embed[-1],
"pred_contrastive_embed": pred_contrastive_embed[-1],
'aux_outputs': self._set_aux_loss(
pred_cls, pred_mask_embed, pred_cq_embed, pred_fq_embed
)
}
return out
@torch.jit.unused
def _set_aux_loss(
self, outputs_cls, outputs_mask_embed, outputs_cq_embed, outputs_fq_embed
):
return [{"pred_logits": a, "pred_mask_embed": b, "pred_cq_embed": c, "pred_fq_embed": outputs_fq_embed}
for a, b, c in zip(outputs_cls[:-1], outputs_mask_embed[:-1], outputs_cq_embed[:-1])]
def encode_frame_query(self, frame_query, attn_mask, lang_motion_feat):
"""
input shape (frame_query) : T, fQ, LB, C
output shape (frame_query) : T, fQ, LB, C
"""
# Not using window-based attention if self.window_size == 0.
if self.window_size == 0:
return_shape = frame_query.shape # T, fQ, LB, C
frame_query = frame_query.flatten(0, 1) # TfQ, LB, C
for i in range(self.enc_layers):
frame_query = self.enc_self_attn[i](frame_query)
frame_query = self.enc_ffn[i](frame_query)
frame_query = frame_query.view(return_shape)
return frame_query
# Using window-based attention if self.window_size > 0.
else:
T, fQ, LB, C = frame_query.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
window_mask = attn_mask.view(LB * Nw, W)[..., None].repeat(1, 1, fQ).flatten(1)
_attn_mask = torch.roll(attn_mask, half_W, 1)
_attn_mask = _attn_mask.view(LB, Nw, W)[..., None].repeat(1, 1, 1, W) # LB, Nw, W, W
_attn_mask[:, 0] = _attn_mask[:, 0] | _attn_mask[:, 0].transpose(-2, -1)
_attn_mask[:, -1] = _attn_mask[:, -1] | _attn_mask[:, -1].transpose(-2, -1)
_attn_mask[:, 0, :half_W, half_W:] = True
_attn_mask[:, 0, half_W:, :half_W] = True
_attn_mask = _attn_mask.view(LB * Nw, 1, W, 1, W, 1).repeat(1, self.num_heads, 1, fQ, 1, fQ).view(LB * Nw * self.num_heads, W * fQ, W * fQ)
shift_window_mask = _attn_mask.float() * -1000
for layer_idx in range(self.enc_layers):
if self.training or layer_idx % 2 == 0:
frame_query = self._window_attn(frame_query, window_mask, layer_idx, lang_motion_feat)
else:
frame_query = self._shift_window_attn(frame_query, shift_window_mask, layer_idx, lang_motion_feat)
return frame_query
def _window_attn(self, frame_query, attn_mask, layer_idx, lang_motion_feat):
T, fQ, LB, C = frame_query.shape
# LBN, WTfQ = attn_mask.shape
W = self.window_size
Nw = T // W
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_key_padding_mask=attn_mask)
frame_query = frame_query.reshape(W, fQ, LB * Nw, C)
frame_query = self.hierarchical_cross[layer_idx](frame_query, lang_motion_feat)
frame_query = frame_query.reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_ffn[layer_idx](frame_query)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
return frame_query
def _shift_window_attn(self, frame_query, attn_mask, layer_idx, lang_motion_feat):
T, fQ, LB, C = frame_query.shape
# LBNH, WfQ, WfQ = attn_mask.shape
W = self.window_size
Nw = T // W
half_W = int(ceil(W / 2))
frame_query = torch.roll(frame_query, half_W, 0)
frame_query = frame_query.view(Nw, W, fQ, LB, C)
frame_query = frame_query.permute(1, 2, 3, 0, 4).reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_self_attn[layer_idx](frame_query, tgt_mask=attn_mask)
frame_query = frame_query.reshape(W, fQ, LB * Nw, C)
frame_query = self.hierarchical_cross[layer_idx](frame_query, lang_motion_feat)
frame_query = frame_query.reshape(W * fQ, LB * Nw, C)
frame_query = self.enc_ffn[layer_idx](frame_query)
frame_query = frame_query.reshape(W, fQ, LB, Nw, C).permute(3, 0, 1, 2, 4).reshape(T, fQ, LB, C)
frame_query = torch.roll(frame_query, -half_W, 0)
return frame_query
class HierarchicalCrossAttentionLayer(nn.Module):
def __init__(self, d_model, nhead, dropout=0.0, eps=1e-8,
activation="relu", normalize_before=False):
super().__init__()
self.to_q = nn.Linear(d_model, d_model)
self.to_k = nn.Linear(d_model, d_model)
self.to_v = nn.Linear(d_model, d_model)
self.out1 = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.out2 = nn.Linear(d_model, d_model)
self.norm2 = nn.LayerNorm(d_model)
self.out3 = nn.Linear(d_model, d_model)
self.norm3 = nn.LayerNorm(d_model)
self.eps = eps
self.scale = d_model ** -0.5
self.res_gate = nn.Sequential(
nn.Linear(d_model, d_model, bias=False),
nn.ReLU(),
nn.Linear(d_model, d_model, bias=False),
nn.Tanh()
)
def sort_embedding(self, frame_query):
sort_frame_query_j = []
inverse_maps = []
T, fQ, L, C = frame_query.shape
for j in range(L):
sort_frame_query_t = []
inverse_maps_j = []
for i in range(T):
single_frame_embeds = frame_query[i, :, j]
if i == 0:
last_frame_embeds = single_frame_embeds
inverse_map = np.arange(last_frame_embeds.size(0))
else:
indices = self.match_embeds(last_frame_embeds, single_frame_embeds)
inverse_map = np.argsort(indices)
last_frame_embeds = single_frame_embeds[indices]
sort_frame_query_t.append(last_frame_embeds)
inverse_maps_j.append(inverse_map)
sort_frame_query = torch.stack(sort_frame_query_t)
sort_frame_query_j.append(sort_frame_query)
inverse_maps.append(inverse_maps_j)
sort_frame_query_j = torch.stack(sort_frame_query_j, dim=2)
return sort_frame_query_j, inverse_maps
def unsort_embedding(self, frame_query, inverse_maps):
T, fQ, L, C = frame_query.shape
for j in range(L):
for i in range(1, T):
frame_query[i, :, j] = frame_query[i, :, j][inverse_maps[j][i]]
return frame_query
def match_embeds(self, ref_embds, cur_embds):
# embeds (q, b, c)
ref_embds, cur_embds = ref_embds.detach(), cur_embds.detach()
ref_embds = ref_embds / (ref_embds.norm(dim=1)[:, None] + 1e-6)
cur_embds = cur_embds / (cur_embds.norm(dim=1)[:, None] + 1e-6)
cos_sim = torch.mm(ref_embds, cur_embds.transpose(0, 1))
C = 1 - cos_sim
C = C.cpu()
C = torch.where(torch.isnan(C), torch.full_like(C, 0), C)
indices = linear_sum_assignment(C.transpose(0, 1))
indices = indices[1]
return indices
def pairwise_average(self, tensor):
if tensor.shape[1] % 2 != 0:
tensor = torch.cat([tensor, tensor.narrow(1, -1, 1)], dim=1)
tensor = (tensor[:, ::2] + tensor[:, 1::2]) / 2
return tensor
def duplicate_and_trim(self, tensor, final_dim):
while tensor.shape[1] < final_dim:
tensor = torch.repeat_interleave(tensor, 2, dim=1)
return tensor.narrow(1, 0, final_dim)
def get_attn(self, q, k, v):
dots = torch.matmul(q, k.transpose(1, 2)) * self.scale
# attn = dots.softmax(dim=-1)
attn = dots.softmax(dim=-2)
weight = attn.sum(dim=-1, keepdim=True)
attn = attn / (weight + self.eps)
output = torch.matmul(attn, v)
return output, weight
def forward(self, tgt, memory):
n_frame, n_instance, bs, n_channel = tgt.shape
tgt, inverse_maps = self.sort_embedding(tgt)
tgt = tgt.flatten(1, 2).transpose(0, 1)
memory = memory.unsqueeze(0).repeat(n_instance*bs, 1, 1)
q = self.to_q(tgt)
k = self.to_k(memory)
v = self.to_v(memory)
output, weight = self.get_attn(q, k, v)
output = self.out1(output)
output = self.norm1(output + q)
output = output
output = self.pairwise_average(output)
q = output
output, weight = self.get_attn(q, k, v)
output = self.out2(output)
output = self.norm2(output + q)
output = output
output = self.pairwise_average(output)
q = output
output, weight = self.get_attn(q, k, v)
output = self.out3(output)
output = self.norm3(output + q)
output = output
output = self.pairwise_average(output)
# output = self.duplicate_and_trim(output, tgt.shape[1])
tgt = tgt + (self.res_gate(output) * output)
tgt = tgt.transpose(0, 1).reshape(n_frame, n_instance, bs, n_channel)
tgt = self.unsort_embedding(tgt, inverse_maps)
return tgt