SAT-HMR / models /decoder.py
ChiSu001's picture
Upload model files
ff07ed4 verified
# Modified from DAB-DETR (https://github.com/IDEA-Research/DAB-DETR)
import math
import copy
import os
from typing import Optional, List
from utils.misc import inverse_sigmoid
import torch
import torch.nn.functional as F
from torch.nn.functional import scaled_dot_product_attention
from torch import nn, Tensor
from torch.nn.init import constant_
from .position_encoding import position_encoding_xy
from xformers.ops import memory_efficient_attention, fmha
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class TransformerDecoder(nn.Module):
def __init__(self, d_model=512, nhead=8, num_queries=300,
num_decoder_layers=6, dim_feedforward=2048, dropout=0.0,
activation="relu",
return_intermediate_dec=False, query_dim=4,
keep_query_pos=False, query_scale_type='cond_elewise',
modulate_hw_attn=True,
bbox_embed_diff_each_layer=True,
):
super().__init__()
decoder_layer = XformerDecoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, keep_query_pos=keep_query_pos)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = XformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
return_intermediate=return_intermediate_dec,
d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type,
modulate_hw_attn=modulate_hw_attn,
bbox_embed_diff_each_layer=bbox_embed_diff_each_layer)
self._reset_parameters()
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
self.d_model = d_model
self.nhead = nhead
self.dec_layers = num_decoder_layers
self.num_queries = num_queries
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def mask2bias(self, mask, batch_size):
if mask is None:
return None
assert mask.dtype == torch.bool
assert mask.ndim == 2
L, S = mask.shape[0], mask.shape[1]
pad_size = (S + 7) // 8 * 8
bias = torch.zeros((batch_size, self.nhead, L, pad_size), device = mask.device)[:,:,:,:S]
bias.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
return bias
def forward(self, memory, memory_lens, tgt, tgt_lens, refpoint_embed, pos_embed, self_attn_mask):
self_attn_bias = self.mask2bias(self_attn_mask, batch_size=len(memory_lens))
hs, references = self.decoder(memory=memory, memory_lens=memory_lens,
tgt=tgt, tgt_lens=tgt_lens,
pos=pos_embed, refpoints_unsigmoid=refpoint_embed,
self_attn_bias = self_attn_bias)
return hs, references
class XformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=True,
d_model=512, query_dim=4, keep_query_pos=False, query_scale_type='cond_elewise',
modulate_hw_attn=False,
bbox_embed_diff_each_layer=False,
):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
assert return_intermediate
self.query_dim = query_dim
assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise']
self.query_scale_type = query_scale_type
if query_scale_type == 'cond_elewise':
self.query_scale = MLP(d_model, d_model, d_model, 2)
elif query_scale_type == 'cond_scalar':
self.query_scale = MLP(d_model, d_model, 1, 2)
elif query_scale_type == 'fix_elewise':
self.query_scale = nn.Embedding(num_layers, d_model)
else:
raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type))
self.ref_point_head = MLP(query_dim // 2 * d_model, d_model, d_model, 2)
self.bbox_embed = None
self.d_model = d_model
self.modulate_hw_attn = modulate_hw_attn
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
if modulate_hw_attn:
self.ref_anchor_head = MLP(d_model, d_model, 2, 2)
if not keep_query_pos:
for layer_id in range(num_layers - 1):
self.layers[layer_id + 1].ca_qpos_proj = None
def forward(self, memory, memory_lens, tgt, tgt_lens,
pos: Optional[Tensor] = None,
refpoints_unsigmoid: Optional[Tensor] = None, # L_tgt, 4
self_attn_bias = None):
B, num_queries = len(tgt_lens), tgt_lens[0]
output = tgt
intermediate = []
reference_points = refpoints_unsigmoid.sigmoid()
ref_points = [reference_points.view(B, num_queries, self.query_dim)]
# import ipdb; ipdb.set_trace()
for layer_id, layer in enumerate(self.layers):
obj_center = reference_points[:, :self.query_dim] # [L_tgt, 4]
# get sine embedding for the query vector
xy_embed = position_encoding_xy(obj_center[:,0], obj_center[:,1], self.d_model)
wh_embed = position_encoding_xy(obj_center[:,2], obj_center[:,3], self.d_model)
query_sine_embed = torch.cat([xy_embed,wh_embed],dim=1) #[L_tgt, 2*d_model]
query_pos = self.ref_point_head(query_sine_embed)
# For the first decoder layer, we do not apply transformation over p_s
if self.query_scale_type != 'fix_elewise':
if layer_id == 0:
pos_transformation = 1
else:
pos_transformation = self.query_scale(output)
else:
pos_transformation = self.query_scale.weight[layer_id]
# apply transformation
query_sine_embed = query_sine_embed[:,:self.d_model] * pos_transformation
# modulated HW attentions
if self.modulate_hw_attn:
refHW_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 2
query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)
output = layer(memory=memory, memory_lens=memory_lens,
tgt=output, tgt_lens=tgt_lens,
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
is_first=(layer_id == 0),
self_attn_bias = self_attn_bias)
# iter update
if self.bbox_embed is not None:
if self.bbox_embed_diff_each_layer:
tmp = self.bbox_embed[layer_id](self.norm(output))
else:
tmp = self.bbox_embed(self.norm(output))
# import ipdb; ipdb.set_trace()
tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)
new_reference_points = tmp[..., :self.query_dim].sigmoid()
if layer_id != self.num_layers - 1:
ref_points.append(new_reference_points.view(B, num_queries, self.query_dim))
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(self.norm(output).view(B, num_queries, self.d_model))
# if self.norm is not None:
# output = self.norm(output)
# if self.return_intermediate:
# intermediate.pop()
# intermediate.append(output.view(B, num_queries, self.d_model))
if self.return_intermediate:
if self.bbox_embed is not None:
return [
torch.stack(intermediate),
torch.stack(ref_points),
]
else:
return [
torch.stack(intermediate),
reference_points.unsqueeze(0)
]
return output.unsqueeze(0)
class XformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.0,
activation="relu", keep_query_pos=False):
super().__init__()
# Decoder Self-Attention
self.sa_qcontent_proj = nn.Linear(d_model, d_model)
self.sa_qpos_proj = nn.Linear(d_model, d_model)
self.sa_kcontent_proj = nn.Linear(d_model, d_model)
self.sa_kpos_proj = nn.Linear(d_model, d_model)
self.sa_v_proj = nn.Linear(d_model, d_model)
self.sa_out_proj = nn.Linear(d_model, d_model)
constant_(self.sa_out_proj.bias, 0.)
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Decoder Cross-Attention
self.ca_qcontent_proj = nn.Linear(d_model, d_model)
self.ca_qpos_proj = nn.Linear(d_model, d_model)
self.ca_kcontent_proj = nn.Linear(d_model, d_model)
self.ca_kpos_proj = nn.Linear(d_model, d_model)
self.ca_v_proj = nn.Linear(d_model, d_model)
self.ca_qpos_sine_proj = nn.Linear(d_model, d_model)
self.ca_out_proj = nn.Linear(d_model, d_model)
constant_(self.ca_out_proj.bias, 0.)
self.d_model = d_model
self.nhead = nhead
assert self.d_model%self.nhead == 0
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.keep_query_pos = keep_query_pos
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, memory, memory_lens, pos,
tgt, tgt_lens, query_pos, query_sine_embed,
is_first=False,
self_attn_bias=None):
# self_attn_bias is only used for dn_training
# 'True' indicates that the element should take part in attention
B, num_queries = len(tgt_lens), tgt_lens[0]
L_mem, C_mem = memory.shape
L_tgt, C_tgt = tgt.shape
assert C_mem == C_tgt
# ========== Begin of Self-Attention =============
tgt_b4n = tgt
tgt = self.norm1(tgt)
q_content = self.sa_qcontent_proj(tgt)
q_pos = self.sa_qpos_proj(query_pos)
k_content = self.sa_kcontent_proj(tgt)
k_pos = self.sa_kpos_proj(query_pos)
v = self.sa_v_proj(tgt)
q = q_content + q_pos
k = k_content + k_pos
q = q.view(B, num_queries, self.nhead, self.d_model // self.nhead)
k = k.view(B, num_queries, self.nhead, self.d_model // self.nhead)
v = v.view(B, num_queries, self.nhead, self.d_model // self.nhead)
tgt2 = memory_efficient_attention(q, k, v, attn_bias=self_attn_bias)
tgt2 = self.sa_out_proj(tgt2.view(L_tgt, self.d_model))
tgt = tgt_b4n + self.dropout1(tgt2)
# ========== End of Self-Attention =============
# ========== Begin of Cross-Attention =============
tgt_b4n = tgt
tgt = self.norm2(tgt)
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)
k_pos = self.ca_kpos_proj(pos)
# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first or self.keep_query_pos:
q_pos = self.ca_qpos_proj(query_pos)
q = q_content + q_pos
k = k_content + k_pos
else:
q = q_content
k = k_content
q = q.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(1, L_tgt, self.nhead, self.d_model//self.nhead)
q = torch.cat([q, query_sine_embed], dim=3)
k = k.view(1, L_mem, self.nhead, self.d_model//self.nhead)
k_pos = k_pos.view(1, L_mem, self.nhead, self.d_model//self.nhead)
k = torch.cat([k, k_pos], dim=3)
v = v.view(1, L_mem, self.nhead, self.d_model//self.nhead)
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen = tgt_lens, kv_seqlen = memory_lens)
tgt2 = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
tgt2 = self.ca_out_proj(tgt2.view(L_tgt, self.d_model))
tgt = tgt_b4n + self.dropout2(tgt2)
# ========== End of Cross-Attention =============
# FFN
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm3(tgt)))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
if activation == "prelu":
return nn.PReLU()
if activation == "selu":
return F.selu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_decoder(args):
return TransformerDecoder(
d_model=args.hidden_dim,
dropout=args.dropout,
nhead=args.nheads,
num_queries=args.num_queries,
dim_feedforward=args.dim_feedforward,
num_decoder_layers=args.dec_layers,
return_intermediate_dec=True,
query_dim=4,
activation=args.transformer_activation
)
def torch_attention(query, key, value, attn_bias = None):
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
# attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2)