ATCTrack-VLM / lib /models /atctrack /decoder.py
SunXiang2025's picture
Upload ATCTrack-VLM code and selected checkpoints
25986db verified
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
"""
import copy
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
import torch.nn as nn
class DecoderEmbeddings(nn.Module):
def __init__(self, vocab_size, instruct_vocab_size, hidden_dim, max_position_embeddings, dropout):
super().__init__()
self.vocab_size = vocab_size
self.instruct_vocab_size = instruct_vocab_size
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(
vocab_size, hidden_dim)
self.prompt_embeddings = nn.Embedding(
instruct_vocab_size, hidden_dim)
self.position_embeddings = nn.Embedding(
max_position_embeddings, hidden_dim
)
self.LayerNorm = torch.nn.LayerNorm(
hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input_embeds = self.word_embeddings(x)
# embeddings = input_embeds
use_word_embeddings = (x < self.vocab_size)
use_prompt_embeddings = ~use_word_embeddings
embeddings = torch.zeros([x.size(0),x.size(1),self.hidden_dim], dtype=torch.float32).to(x.device)
embeddings[use_word_embeddings] = self.word_embeddings(x[use_word_embeddings])
embeddings[use_prompt_embeddings] = self.prompt_embeddings(x[use_prompt_embeddings]-self.vocab_size)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class SeqTrackDecoder(nn.Module):
def __init__(self, d_model=512, nhead=8,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False,
return_intermediate_dec=False, bins=1000, num_frames=9,
instruct=True):
super().__init__()
self.bins = bins
self.instruct = instruct
self.instruct_tokens = {
'end': bins,
'lasot': bins+1,
'trackingnet': bins+1,
'got10k': bins+1,
'coco': bins+1,
'depthtrack': bins+2,
'lasher': bins+3,
'visevent': bins+4,
'otb99_lang': bins+5,
'refcocog': bins+5,
'tnl2k': bins+5,
'lasot_lang': bins+5
}
instruct_vocab_size = 4 # should be consistent with new tokens in self.instruct_tokens
self.num_frames = num_frames
self.num_coordinates = 4 # [x,y,w,h]
max_position_embeddings = (self.num_coordinates+1) * num_frames
self.embedding = DecoderEmbeddings(bins+2, instruct_vocab_size, d_model, max_position_embeddings, dropout)
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
decoder_norm = nn.LayerNorm(d_model)
self.body = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, src, pos_embed, seq):
# flatten NxCxHxW to HWxNxC
n, bs, c = src.shape
if not self.instruct:
seq[:, 0] = self.bins+1
tgt = self.embedding(seq).permute(1, 0, 2)
query_embed = self.embedding.position_embeddings.weight.unsqueeze(1)
query_embed = query_embed.repeat(1, bs, 1)
memory = src
tgt_mask = generate_square_subsequent_mask(len(tgt)).to(tgt.device) #generate the causal mask
hs = self.body(tgt, memory, pos=pos_embed, query_pos=query_embed[:len(tgt)],
tgt_mask=tgt_mask, memory_mask=None)
return hs.transpose(1, 2)
def inference(self, src, pos_embed, seq, vocab_embed,
window, seq_format):
if not self.instruct:
seq[:, 0] = self.bins+1
# flatten NxCxHxW to HWxNxC
n, bs, c = src.shape
memory = src
confidence_list = []
box_pos = [0, 1, 2, 3] # the position of bounding box
center_pos = [0, 1] # the position of x_center and y_center
if seq_format == 'whxy':
center_pos = [2, 3]
for i in range(self.num_coordinates): # only cycle 4 times, because we do not need to predict the end token during inference
tgt = self.embedding(seq).permute(1, 0, 2)
query_embed = self.embedding.position_embeddings.weight.unsqueeze(1)
query_embed = query_embed.repeat(1, bs, 1)
tgt_mask = generate_square_subsequent_mask(len(tgt)).to(tgt.device)
hs = self.body(tgt, memory, pos=pos_embed[:len(memory)], query_pos=query_embed[:len(tgt)],
tgt_mask=tgt_mask, memory_mask=None)
# embedding --> likelihood
out = vocab_embed(hs.transpose(1, 2)[-1, :, -1, :])
out = out.softmax(-1)
if i in box_pos:
out = out[:, :self.bins] # only include the coordinate values' confidence
if ((i in center_pos) and (window!=None)):
out = out * window # window penalty
confidence, token_generated = out.topk(dim=-1, k=1)
seq = torch.cat([seq, token_generated], dim=-1)
confidence_list.append(confidence)
out_dict = {}
out_dict['pred_boxes'] = seq[:, -self.num_coordinates:] # Discard the START token, only get the bounding box
out_dict['confidence'] = torch.cat(confidence_list, dim=-1)[:, :]
return out_dict
def generate_square_subsequent_mask(sz):
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
#each token only can see tokens before them
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float(
'-inf')).masked_fill(mask == 1, float(0.0))
return mask
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt, query_pos),
self.with_pos_embed(memory, pos),
memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(self.with_pos_embed(tgt2, query_pos),
self.with_pos_embed(memory, pos),
memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_decoder(cfg):
return SeqTrackDecoder(
d_model=cfg.MODEL.HIDDEN_DIM,
dropout=cfg.MODEL.DECODER.DROPOUT,
nhead=cfg.MODEL.DECODER.NHEADS,
dim_feedforward=cfg.MODEL.DECODER.DIM_FEEDFORWARD,
num_decoder_layers=cfg.MODEL.DECODER.DEC_LAYERS,
normalize_before=cfg.MODEL.DECODER.PRE_NORM,
return_intermediate_dec=False,
bins=cfg.MODEL.BINS,
num_frames=cfg.DATA.SEARCH.NUMBER,
instruct=cfg.MODEL.DECODER.INSTRUCT
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")