PengLiu
push inference code
56ef371
import math
from typing import Dict, List
import torch
import torch.nn as nn
from detect_tools.upn import ARCHITECTURES, build_decoder, build_encoder
from detect_tools.upn.models.utils import (gen_encoder_output_proposals,
inverse_sigmoid)
from detect_tools.upn.ops.modules import MSDeformAttn
@ARCHITECTURES.register_module()
class DeformableTransformer(nn.Module):
"""Implementation of Deformable DETR.
Args:
encoder_cfg (Dict): Config for the TransformerEncoder.
decoder_cfg (Dict): Config for the TransformerDecoder.
num_queries (int): Number of queries. This is for matching part. Default: 900.
d_model (int): Dimension of the model. Default: 256.
num_feature_levels (int): Number of feature levels. Default: 1.
binary_query_selection (bool): Whether to use binary query selection. Default: False.
When using binary query selection, a linear with out channe =1 will be used to select
topk proposals. Otherwise, we will use ContrastiveAssign to select topk proposals.
learnable_tgt_init (bool): Whether to use learnable target init. Default: True. If False,
we will use the topk encoder features as the target init.
random_refpoints_xy (bool): Whether to use random refpoints xy. This is only used when
two_stage_type is not 'no'. Default: False. If True, we will use random refpoints xy.
two_stage_type (str): Type of two stage. Default: 'standard'. Options: 'no', 'standard'
two_stage_learn_wh (bool): Whether to learn the width and height of anchor boxes. Default: False.
two_stage_keep_all_tokens (bool): If False, the returned hs_enc, ref_enc, init_box_proposal
will only be the topk proposals. Otherwise, we will return all the proposals from the
encoder. Default: False.
two_stage_bbox_embed_share (bool): Whether to share the bbox embedding between the two stages.
Default: False.
two_stage_class_embed_share (bool): Whether to share the class embedding between the two stages.
rm_self_attn_layers (List[int]): The indices of the decoder layers to remove self-attention.
Default: None.
rm_detach (bool): Whether to detach the decoder output. Default: None.
embed_init_tgt (bool): If true, the target embedding is learnable. Otherwise, we will use
the topk encoder features as the target init. Default: True.
"""
def __init__(
self,
encoder_cfg: Dict,
decoder_cfg: Dict,
mask_decoder_cfg: Dict = None,
num_queries: int = 900,
d_model: int = 256,
num_feature_levels: int = 4,
binary_query_selection: bool = False,
# init query (target)
learnable_tgt_init=True,
random_refpoints_xy=False,
# for two stage
two_stage_type: str = "standard",
two_stage_learn_wh: bool = False,
two_stage_keep_all_tokens: bool = False,
two_stage_bbox_embed_share: bool = False,
two_stage_class_embed_share: bool = False,
# evo of #anchors
rm_self_attn_layers: List[int] = None,
# for detach
rm_detach: bool = None,
with_encoder_out: bool = True,
) -> None:
super().__init__()
self.binary_query_selection = binary_query_selection
self.num_queries = num_queries
self.num_feature_levels = num_feature_levels
self.rm_self_attn_layers = rm_self_attn_layers
self.d_model = d_model
self.two_stage_bbox_embed_share = two_stage_bbox_embed_share
self.two_stage_class_embed_share = two_stage_class_embed_share
if self.binary_query_selection:
self.binary_query_selection_layer = nn.Linear(d_model, 1)
# build encoder
self.encoder = build_encoder(encoder_cfg)
# build decoder
self.decoder = build_decoder(decoder_cfg)
self.num_decoder_layers = self.decoder.num_layers
# build sole mask decoder
if mask_decoder_cfg is not None:
self.mask_decoder = build_decoder(mask_decoder_cfg)
else:
self.mask_decoder = None
# level embedding
if num_feature_levels > 1:
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
# learnable target embedding
self.learnable_tgt_init = learnable_tgt_init
assert learnable_tgt_init, "why not learnable_tgt_init"
self.tgt_embed = nn.Embedding(num_queries, d_model)
nn.init.normal_(self.tgt_embed.weight.data)
# for two stage
# TODO: this part is really confusing
self.two_stage_type = two_stage_type
self.two_stage_learn_wh = two_stage_learn_wh
self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
assert two_stage_type in [
"no",
"standard",
], "unknown param {} of two_stage_type".format(two_stage_type)
self.with_encoder_out = with_encoder_out
if two_stage_type == "standard":
# anchor selection at the output of encoder
if with_encoder_out:
self.enc_output = nn.Linear(d_model, d_model)
self.enc_output_norm = nn.LayerNorm(d_model)
if two_stage_learn_wh:
# import ipdb; ipdb.set_trace()
self.two_stage_wh_embedding = nn.Embedding(1, 2)
else:
self.two_stage_wh_embedding = None
elif two_stage_type == "no":
self.init_ref_points(
num_queries, random_refpoints_xy
) # init self.refpoint_embed
self.enc_out_class_embed = None # this will be initialized outside of the model
self.enc_out_bbox_embed = None # this will be initialized outside of the model
# remove some self_attn_layers or rm_detach
self._reset_parameters()
self.rm_self_attn_layers = rm_self_attn_layers
if rm_self_attn_layers is not None:
# assert len(rm_self_attn_layers) == num_decoder_layers
print(
"Removing the self-attn in {} decoder layers".format(
rm_self_attn_layers
)
)
for lid, dec_layer in enumerate(self.decoder.layers):
if lid in rm_self_attn_layers:
dec_layer.rm_self_attn_modules()
self.rm_detach = rm_detach
if self.rm_detach:
assert isinstance(rm_detach, list)
assert any([i in ["enc_ref", "enc_tgt", "dec"] for i in rm_detach])
self.decoder.rm_detach = rm_detach
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MSDeformAttn):
m._reset_parameters()
if self.num_feature_levels > 1 and self.level_embed is not None:
nn.init.normal_(self.level_embed)
if self.two_stage_learn_wh:
nn.init.constant_(
self.two_stage_wh_embedding.weight, math.log(0.05 / (1 - 0.05))
)
def init_ref_points(self, num_queries: int, random_refpoints_xy: bool = False):
"""Initialize learnable reference points for each query.
Args:
num_queries (int): number of queries
random_refpoints_xy (bool, optional): whether to init the refpoints randomly.
Defaults to False.
"""
self.refpoint_embed = nn.Embedding(num_queries, 4)
if random_refpoints_xy:
self.refpoint_embed.weight.data[:, :2].uniform_(0, 1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(
self.refpoint_embed.weight.data[:, :2]
)
self.refpoint_embed.weight.data[:, :2].requires_grad = False
def get_valid_ratio(self, mask):
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(
self,
src_flatten: torch.Tensor,
lvl_pos_embed_flatten: torch.Tensor,
level_start_index: List[int],
spatial_shapes: List[torch.Tensor],
valid_ratios: List[torch.Tensor],
mask_flatten: torch.Tensor,
prompt_type: str,
) -> List[torch.Tensor]:
"""Forward function."""
memory = self.encoder(
src_flatten,
pos=lvl_pos_embed_flatten,
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
key_padding_mask=mask_flatten,
)
batch_size = src_flatten.shape[0]
crop_region_features = torch.zeros(batch_size, 1, self.d_model).to(
memory.device
)
if prompt_type == "fine_grained_prompt":
crop_region_features = (
self.fine_grained_prompt.weight[0]
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
elif prompt_type == "coarse_grained_prompt":
crop_region_features = (
self.coarse_grained_prompt.weight[0]
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
pad_mask = torch.ones(batch_size, 1).to(crop_region_features.device).bool()
self_attn_mask = torch.ones(batch_size, 1, 1).to(crop_region_features.device)
ref_dict = dict(
encoded_ref_feature=crop_region_features,
pad_mask=pad_mask,
self_attn_mask=self_attn_mask,
prompt_type="universal_prompt",
)
(
refpoint_embed,
tgt,
init_box_proposal,
) = self.get_two_stage_proposal(memory, mask_flatten, spatial_shapes, ref_dict)
hs, references = self.decoder(
tgt=tgt.transpose(0, 1),
tgt_key_padding_mask=None,
memory=memory.transpose(0, 1),
memory_key_padding_mask=mask_flatten,
pos=lvl_pos_embed_flatten.transpose(0, 1),
refpoints_unsigmoid=refpoint_embed.transpose(0, 1),
level_start_index=level_start_index,
spatial_shapes=spatial_shapes,
valid_ratios=valid_ratios,
tgt_mask=None,
# we ~ the mask . False means use the token; True means pad the token
)
hs_enc = ref_enc = None
return (
hs,
references,
ref_dict,
)
def get_two_stage_proposal(
self,
memory: torch.Tensor,
mask_flatten: torch.Tensor,
spatial_shapes: List[torch.Tensor],
ref_dict: Dict,
) -> List[torch.Tensor]:
"""Two stage proposal generation for decoder
Args:
memory (torch.Tensor): Image encoded feature. [bs, n, 256]
mask_flatten (torch.Tensor): Flattened mask. [bs, n]
spatial_shapes (List[torch.Tensor]): Spatial shapes of each feature map. [bs, num_levels, 2]
refpoint_embed_dn (torch.Tensor): Denosing refpoint embedding. [bs, num_dn_queries, 256]
tgt_dn (torch.Tensor): Denosing target embedding. [bs, num_dn_queries, 256]
ref_dict (Dict): A dict containing all kinds of reference image related features.
"""
bs = memory.shape[0]
input_hw = None
output_memory, output_proposals = gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes, input_hw
)
output_memory = self.enc_output_norm(self.enc_output(output_memory))
if self.binary_query_selection: # Unused
topk_logits = self.binary_query_selection_layer(output_memory).squeeze(-1)
else:
if ref_dict is not None:
enc_outputs_class_unselected = self.enc_out_class_embed(
output_memory, ref_dict
) # this is not a linear layer for prediction. But contrastive similaryity, shape [B, len_image, len_text]
else:
enc_outputs_class_unselected = self.enc_out_class_embed(output_memory)
topk_logits = enc_outputs_class_unselected.max(-1)[
0
] # shape [B, len_image]
enc_outputs_coord_unselected = (
self.enc_out_bbox_embed(output_memory) + output_proposals
) # (bs, \sum{hw}, 4) unsigmoid
topk = self.num_queries
try:
topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq
except:
raise ValueError(f"dadad {topk_logits.shape}")
# gather boxes
refpoint_embed_undetach = torch.gather(
enc_outputs_coord_unselected,
1,
topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
) # unsigmoid
refpoint_embed_ = refpoint_embed_undetach.detach()
init_box_proposal = torch.gather(
output_proposals, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
).sigmoid() # sigmoid
# gather tgt
tgt_undetach = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, self.d_model)
)
tgt_ = (
self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1)
) # nq, bs, d_model
refpoint_embed, tgt = refpoint_embed_, tgt_
return (
refpoint_embed,
tgt,
init_box_proposal,
)