|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
Transformer class
|
|
|
"""
|
|
|
import math
|
|
|
import copy
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from torch import nn, Tensor
|
|
|
|
|
|
from rfdetr.models.ops.modules import MSDeformAttn
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
""" Very simple multi-layer perceptron (also called FFN)"""
|
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
|
super().__init__()
|
|
|
self.num_layers = num_layers
|
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
|
|
|
|
def forward(self, x):
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def gen_sineembed_for_position(pos_tensor, dim=128):
|
|
|
|
|
|
|
|
|
scale = 2 * math.pi
|
|
|
dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device)
|
|
|
dim_t = 10000 ** (2 * (dim_t // 2) / dim)
|
|
|
x_embed = pos_tensor[:, :, 0] * scale
|
|
|
y_embed = pos_tensor[:, :, 1] * scale
|
|
|
pos_x = x_embed[:, :, None] / dim_t
|
|
|
pos_y = y_embed[:, :, None] / dim_t
|
|
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
|
pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
|
if pos_tensor.size(-1) == 2:
|
|
|
pos = torch.cat((pos_y, pos_x), dim=2)
|
|
|
elif pos_tensor.size(-1) == 4:
|
|
|
w_embed = pos_tensor[:, :, 2] * scale
|
|
|
pos_w = w_embed[:, :, None] / dim_t
|
|
|
pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
|
|
|
|
h_embed = pos_tensor[:, :, 3] * scale
|
|
|
pos_h = h_embed[:, :, None] / dim_t
|
|
|
pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
|
|
|
|
|
|
pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
|
|
|
else:
|
|
|
raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
|
|
|
return pos
|
|
|
|
|
|
|
|
|
def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True):
|
|
|
"""
|
|
|
Input:
|
|
|
- memory: bs, \sum{hw}, d_model
|
|
|
- memory_padding_mask: bs, \sum{hw}
|
|
|
- spatial_shapes: nlevel, 2
|
|
|
Output:
|
|
|
- output_memory: bs, \sum{hw}, d_model
|
|
|
- output_proposals: bs, \sum{hw}, 4
|
|
|
"""
|
|
|
N_, S_, C_ = memory.shape
|
|
|
base_scale = 4.0
|
|
|
proposals = []
|
|
|
_cur = 0
|
|
|
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
|
|
if memory_padding_mask is not None:
|
|
|
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
|
|
|
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
|
|
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
|
|
else:
|
|
|
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
|
|
|
valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device)
|
|
|
|
|
|
grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
|
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
|
|
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
|
|
|
|
|
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
|
|
|
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
|
|
|
|
|
|
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
|
|
|
|
|
|
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
|
|
proposals.append(proposal)
|
|
|
_cur += (H_ * W_)
|
|
|
|
|
|
output_proposals = torch.cat(proposals, 1)
|
|
|
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
|
|
|
|
|
if unsigmoid:
|
|
|
output_proposals = torch.log(output_proposals / (1 - output_proposals))
|
|
|
if memory_padding_mask is not None:
|
|
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
|
|
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
|
|
|
else:
|
|
|
if memory_padding_mask is not None:
|
|
|
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
|
|
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(0))
|
|
|
|
|
|
output_memory = memory
|
|
|
if memory_padding_mask is not None:
|
|
|
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
|
|
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
|
|
|
|
|
return output_memory.to(memory.dtype), output_proposals.to(memory.dtype)
|
|
|
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
|
|
|
|
def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300,
|
|
|
num_decoder_layers=6, dim_feedforward=2048, dropout=0.0,
|
|
|
activation="relu", normalize_before=False,
|
|
|
return_intermediate_dec=False, group_detr=1,
|
|
|
two_stage=False,
|
|
|
num_feature_levels=4, dec_n_points=4,
|
|
|
lite_refpoint_refine=False,
|
|
|
decoder_norm_type='LN',
|
|
|
bbox_reparam=False):
|
|
|
super().__init__()
|
|
|
self.encoder = None
|
|
|
|
|
|
decoder_layer = TransformerDecoderLayer(d_model, sa_nhead, ca_nhead, dim_feedforward,
|
|
|
dropout, activation, normalize_before,
|
|
|
group_detr=group_detr,
|
|
|
num_feature_levels=num_feature_levels,
|
|
|
dec_n_points=dec_n_points,
|
|
|
skip_self_attn=False,)
|
|
|
assert decoder_norm_type in ['LN', 'Identity']
|
|
|
norm = {
|
|
|
"LN": lambda channels: nn.LayerNorm(channels),
|
|
|
"Identity": lambda channels: nn.Identity(),
|
|
|
}
|
|
|
decoder_norm = norm[decoder_norm_type](d_model)
|
|
|
|
|
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
|
|
return_intermediate=return_intermediate_dec,
|
|
|
d_model=d_model,
|
|
|
lite_refpoint_refine=lite_refpoint_refine,
|
|
|
bbox_reparam=bbox_reparam)
|
|
|
|
|
|
|
|
|
self.two_stage = two_stage
|
|
|
if two_stage:
|
|
|
self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)])
|
|
|
self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)])
|
|
|
|
|
|
self._reset_parameters()
|
|
|
|
|
|
self.num_queries = num_queries
|
|
|
self.d_model = d_model
|
|
|
self.dec_layers = num_decoder_layers
|
|
|
self.group_detr = group_detr
|
|
|
self.num_feature_levels = num_feature_levels
|
|
|
self.bbox_reparam = bbox_reparam
|
|
|
|
|
|
self._export = False
|
|
|
|
|
|
def export(self):
|
|
|
self._export = True
|
|
|
|
|
|
def _reset_parameters(self):
|
|
|
for p in self.parameters():
|
|
|
if p.dim() > 1:
|
|
|
nn.init.xavier_uniform_(p)
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, MSDeformAttn):
|
|
|
m._reset_parameters()
|
|
|
|
|
|
def get_valid_ratio(self, mask):
|
|
|
_, H, W = mask.shape
|
|
|
valid_H = torch.sum(~mask[:, :, 0], 1)
|
|
|
valid_W = torch.sum(~mask[:, 0, :], 1)
|
|
|
valid_ratio_h = valid_H.float() / H
|
|
|
valid_ratio_w = valid_W.float() / W
|
|
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
|
|
return valid_ratio
|
|
|
|
|
|
def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
|
|
|
src_flatten = []
|
|
|
mask_flatten = [] if masks is not None else None
|
|
|
lvl_pos_embed_flatten = []
|
|
|
spatial_shapes = []
|
|
|
valid_ratios = [] if masks is not None else None
|
|
|
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
|
|
|
bs, c, h, w = src.shape
|
|
|
spatial_shape = (h, w)
|
|
|
spatial_shapes.append(spatial_shape)
|
|
|
|
|
|
src = src.flatten(2).transpose(1, 2)
|
|
|
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
|
|
lvl_pos_embed_flatten.append(pos_embed)
|
|
|
src_flatten.append(src)
|
|
|
if masks is not None:
|
|
|
mask = masks[lvl].flatten(1)
|
|
|
mask_flatten.append(mask)
|
|
|
memory = torch.cat(src_flatten, 1)
|
|
|
if masks is not None:
|
|
|
mask_flatten = torch.cat(mask_flatten, 1)
|
|
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
|
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
|
|
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device)
|
|
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
|
|
|
|
|
if self.two_stage:
|
|
|
output_memory, output_proposals = gen_encoder_output_proposals(
|
|
|
memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam)
|
|
|
|
|
|
refpoint_embed_ts, memory_ts, boxes_ts = [], [], []
|
|
|
group_detr = self.group_detr if self.training else 1
|
|
|
for g_idx in range(group_detr):
|
|
|
output_memory_gidx = self.enc_output_norm[g_idx](self.enc_output[g_idx](output_memory))
|
|
|
|
|
|
enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](output_memory_gidx)
|
|
|
if self.bbox_reparam:
|
|
|
enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](output_memory_gidx)
|
|
|
enc_outputs_coord_cxcy_gidx = enc_outputs_coord_delta_gidx[...,
|
|
|
:2] * output_proposals[..., 2:] + output_proposals[..., :2]
|
|
|
enc_outputs_coord_wh_gidx = enc_outputs_coord_delta_gidx[..., 2:].exp() * output_proposals[..., 2:]
|
|
|
enc_outputs_coord_unselected_gidx = torch.concat(
|
|
|
[enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1)
|
|
|
else:
|
|
|
enc_outputs_coord_unselected_gidx = self.enc_out_bbox_embed[g_idx](
|
|
|
output_memory_gidx) + output_proposals
|
|
|
|
|
|
topk = min(self.num_queries, enc_outputs_class_unselected_gidx.shape[-2])
|
|
|
topk_proposals_gidx = torch.topk(enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1)[1]
|
|
|
|
|
|
refpoint_embed_gidx_undetach = torch.gather(
|
|
|
enc_outputs_coord_unselected_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4))
|
|
|
|
|
|
refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach()
|
|
|
|
|
|
|
|
|
tgt_undetach_gidx = torch.gather(
|
|
|
output_memory_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model))
|
|
|
|
|
|
refpoint_embed_ts.append(refpoint_embed_gidx)
|
|
|
memory_ts.append(tgt_undetach_gidx)
|
|
|
boxes_ts.append(refpoint_embed_gidx_undetach)
|
|
|
|
|
|
refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1)
|
|
|
|
|
|
memory_ts = torch.cat(memory_ts, dim=1)
|
|
|
boxes_ts = torch.cat(boxes_ts, dim=1)
|
|
|
|
|
|
tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1)
|
|
|
refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1)
|
|
|
if self.two_stage:
|
|
|
ts_len = refpoint_embed_ts.shape[-2]
|
|
|
refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :]
|
|
|
refpoint_embed_subset = refpoint_embed[..., ts_len:, :]
|
|
|
|
|
|
if self.bbox_reparam:
|
|
|
refpoint_embed_cxcy = refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:]
|
|
|
refpoint_embed_cxcy = refpoint_embed_cxcy + refpoint_embed_ts[..., :2]
|
|
|
refpoint_embed_wh = refpoint_embed_ts_subset[..., 2:].exp() * refpoint_embed_ts[..., 2:]
|
|
|
refpoint_embed_ts_subset = torch.concat(
|
|
|
[refpoint_embed_cxcy, refpoint_embed_wh], dim=-1
|
|
|
)
|
|
|
else:
|
|
|
refpoint_embed_ts_subset = refpoint_embed_ts_subset + refpoint_embed_ts
|
|
|
|
|
|
refpoint_embed = torch.concat(
|
|
|
[refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2)
|
|
|
|
|
|
hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten,
|
|
|
pos=lvl_pos_embed_flatten, refpoints_unsigmoid=refpoint_embed,
|
|
|
level_start_index=level_start_index,
|
|
|
spatial_shapes=spatial_shapes,
|
|
|
valid_ratios=valid_ratios.to(memory.dtype) if valid_ratios is not None else valid_ratios)
|
|
|
if self.two_stage:
|
|
|
if self.bbox_reparam:
|
|
|
return hs, references, memory_ts, boxes_ts
|
|
|
else:
|
|
|
return hs, references, memory_ts, boxes_ts.sigmoid()
|
|
|
return hs, references, None, None
|
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
|
|
|
|
def __init__(self,
|
|
|
decoder_layer,
|
|
|
num_layers,
|
|
|
norm=None,
|
|
|
return_intermediate=False,
|
|
|
d_model=256,
|
|
|
lite_refpoint_refine=False,
|
|
|
bbox_reparam=False):
|
|
|
super().__init__()
|
|
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
|
self.num_layers = num_layers
|
|
|
self.d_model = d_model
|
|
|
self.norm = norm
|
|
|
self.return_intermediate = return_intermediate
|
|
|
self.lite_refpoint_refine = lite_refpoint_refine
|
|
|
self.bbox_reparam = bbox_reparam
|
|
|
|
|
|
self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2)
|
|
|
|
|
|
self._export = False
|
|
|
|
|
|
def export(self):
|
|
|
self._export = True
|
|
|
|
|
|
def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta):
|
|
|
if self.bbox_reparam:
|
|
|
new_refpoints_cxcy = new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:] + refpoints_unsigmoid[..., :2]
|
|
|
new_refpoints_wh = new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:]
|
|
|
new_refpoints_unsigmoid = torch.concat(
|
|
|
[new_refpoints_cxcy, new_refpoints_wh], dim=-1
|
|
|
)
|
|
|
else:
|
|
|
new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta
|
|
|
return new_refpoints_unsigmoid
|
|
|
|
|
|
def forward(self, tgt, memory,
|
|
|
tgt_mask: Optional[Tensor] = None,
|
|
|
memory_mask: Optional[Tensor] = None,
|
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
|
pos: Optional[Tensor] = None,
|
|
|
refpoints_unsigmoid: Optional[Tensor] = None,
|
|
|
|
|
|
level_start_index: Optional[Tensor] = None,
|
|
|
spatial_shapes: Optional[Tensor] = None,
|
|
|
valid_ratios: Optional[Tensor] = None):
|
|
|
output = tgt
|
|
|
|
|
|
intermediate = []
|
|
|
hs_refpoints_unsigmoid = [refpoints_unsigmoid]
|
|
|
|
|
|
def get_reference(refpoints):
|
|
|
|
|
|
obj_center = refpoints[..., :4]
|
|
|
|
|
|
if self._export:
|
|
|
query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model / 2)
|
|
|
refpoints_input = obj_center[:, :, None]
|
|
|
else:
|
|
|
refpoints_input = obj_center[:, :, None] \
|
|
|
* torch.cat([valid_ratios, valid_ratios], -1)[:, None]
|
|
|
query_sine_embed = gen_sineembed_for_position(
|
|
|
refpoints_input[:, :, 0, :], self.d_model / 2)
|
|
|
query_pos = self.ref_point_head(query_sine_embed)
|
|
|
return obj_center, refpoints_input, query_pos, query_sine_embed
|
|
|
|
|
|
|
|
|
if self.lite_refpoint_refine:
|
|
|
if self.bbox_reparam:
|
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid)
|
|
|
else:
|
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid())
|
|
|
|
|
|
for layer_id, layer in enumerate(self.layers):
|
|
|
|
|
|
if not self.lite_refpoint_refine:
|
|
|
if self.bbox_reparam:
|
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid)
|
|
|
else:
|
|
|
obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid())
|
|
|
|
|
|
|
|
|
pos_transformation = 1
|
|
|
|
|
|
query_pos = query_pos * pos_transformation
|
|
|
|
|
|
output = layer(output, memory, tgt_mask=tgt_mask,
|
|
|
memory_mask=memory_mask,
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
|
pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
|
|
|
is_first=(layer_id == 0),
|
|
|
reference_points=refpoints_input,
|
|
|
spatial_shapes=spatial_shapes,
|
|
|
level_start_index=level_start_index)
|
|
|
|
|
|
if not self.lite_refpoint_refine:
|
|
|
|
|
|
new_refpoints_delta = self.bbox_embed(output)
|
|
|
new_refpoints_unsigmoid = self.refpoints_refine(refpoints_unsigmoid, new_refpoints_delta)
|
|
|
if layer_id != self.num_layers - 1:
|
|
|
hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid)
|
|
|
refpoints_unsigmoid = new_refpoints_unsigmoid.detach()
|
|
|
|
|
|
if self.return_intermediate:
|
|
|
intermediate.append(self.norm(output))
|
|
|
|
|
|
if self.norm is not None:
|
|
|
output = self.norm(output)
|
|
|
if self.return_intermediate:
|
|
|
intermediate.pop()
|
|
|
intermediate.append(output)
|
|
|
|
|
|
if self.return_intermediate:
|
|
|
if self._export:
|
|
|
|
|
|
hs = intermediate[-1]
|
|
|
if self.bbox_embed is not None:
|
|
|
ref = hs_refpoints_unsigmoid[-1]
|
|
|
else:
|
|
|
ref = refpoints_unsigmoid
|
|
|
return hs, ref
|
|
|
|
|
|
if self.bbox_embed is not None:
|
|
|
return [
|
|
|
torch.stack(intermediate),
|
|
|
torch.stack(hs_refpoints_unsigmoid),
|
|
|
]
|
|
|
else:
|
|
|
return [
|
|
|
torch.stack(intermediate),
|
|
|
refpoints_unsigmoid.unsqueeze(0)
|
|
|
]
|
|
|
|
|
|
return output.unsqueeze(0)
|
|
|
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
|
|
|
|
def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1,
|
|
|
activation="relu", normalize_before=False, group_detr=1,
|
|
|
num_feature_levels=4, dec_n_points=4,
|
|
|
skip_self_attn=False):
|
|
|
super().__init__()
|
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True)
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
|
|
|
|
self.cross_attn = MSDeformAttn(
|
|
|
d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points)
|
|
|
|
|
|
self.nhead = ca_nhead
|
|
|
|
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
|
|
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
self.normalize_before = normalize_before
|
|
|
self.group_detr = group_detr
|
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
|
return tensor if pos is None else tensor + pos
|
|
|
|
|
|
def forward_post(self, tgt, memory,
|
|
|
tgt_mask: Optional[Tensor] = None,
|
|
|
memory_mask: Optional[Tensor] = None,
|
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
|
pos: Optional[Tensor] = None,
|
|
|
query_pos: Optional[Tensor] = None,
|
|
|
query_sine_embed = None,
|
|
|
is_first = False,
|
|
|
reference_points = None,
|
|
|
spatial_shapes=None,
|
|
|
level_start_index=None,
|
|
|
):
|
|
|
bs, num_queries, _ = tgt.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = k = tgt + query_pos
|
|
|
v = tgt
|
|
|
if self.training:
|
|
|
q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0)
|
|
|
k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0)
|
|
|
v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0)
|
|
|
|
|
|
tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
|
|
|
key_padding_mask=tgt_key_padding_mask,
|
|
|
need_weights=False)[0]
|
|
|
|
|
|
if self.training:
|
|
|
tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1)
|
|
|
|
|
|
|
|
|
tgt = tgt + self.dropout1(tgt2)
|
|
|
tgt = self.norm1(tgt)
|
|
|
|
|
|
|
|
|
tgt2 = self.cross_attn(
|
|
|
self.with_pos_embed(tgt, query_pos),
|
|
|
reference_points,
|
|
|
memory,
|
|
|
spatial_shapes,
|
|
|
level_start_index,
|
|
|
memory_key_padding_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
|
tgt = self.norm2(tgt)
|
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
|
tgt = tgt + self.dropout3(tgt2)
|
|
|
tgt = self.norm3(tgt)
|
|
|
return tgt
|
|
|
|
|
|
def forward(self, tgt, memory,
|
|
|
tgt_mask: Optional[Tensor] = None,
|
|
|
memory_mask: Optional[Tensor] = None,
|
|
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
|
memory_key_padding_mask: Optional[Tensor] = None,
|
|
|
pos: Optional[Tensor] = None,
|
|
|
query_pos: Optional[Tensor] = None,
|
|
|
query_sine_embed = None,
|
|
|
is_first = False,
|
|
|
reference_points = None,
|
|
|
spatial_shapes=None,
|
|
|
level_start_index=None):
|
|
|
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
|
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,
|
|
|
query_sine_embed, is_first,
|
|
|
reference_points, spatial_shapes, level_start_index)
|
|
|
|
|
|
|
|
|
def _get_clones(module, N):
|
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
|
|
|
def build_transformer(args):
|
|
|
|
|
|
try:
|
|
|
two_stage = args.two_stage
|
|
|
except:
|
|
|
two_stage = False
|
|
|
|
|
|
return Transformer(
|
|
|
d_model=args.hidden_dim,
|
|
|
sa_nhead=args.sa_nheads,
|
|
|
ca_nhead=args.ca_nheads,
|
|
|
num_queries=args.num_queries,
|
|
|
dropout=args.dropout,
|
|
|
dim_feedforward=args.dim_feedforward,
|
|
|
num_decoder_layers=args.dec_layers,
|
|
|
return_intermediate_dec=True,
|
|
|
group_detr=args.group_detr,
|
|
|
two_stage=two_stage,
|
|
|
num_feature_levels=args.num_feature_levels,
|
|
|
dec_n_points=args.dec_n_points,
|
|
|
lite_refpoint_refine=args.lite_refpoint_refine,
|
|
|
decoder_norm_type=args.decoder_norm,
|
|
|
bbox_reparam=args.bbox_reparam,
|
|
|
)
|
|
|
|
|
|
|
|
|
def _get_activation_fn(activation):
|
|
|
"""Return an activation function given a string"""
|
|
|
if activation == "relu":
|
|
|
return F.relu
|
|
|
if activation == "gelu":
|
|
|
return F.gelu
|
|
|
if activation == "glu":
|
|
|
return F.glu
|
|
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
|
|
|