Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------ | |
| # RF-DETR | |
| # Copyright (c) 2025 Roboflow. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| # Modified from LW-DETR (https://github.com/Atten4Vis/LW-DETR) | |
| # Copyright (c) 2024 Baidu. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| # Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) | |
| # Copyright (c) 2021 Microsoft. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| # ------------------------------------------------------------------------ | |
| """ | |
| Transformer class | |
| """ | |
| import math | |
| import copy | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, Tensor | |
| from rfdetr.models.ops.modules import MSDeformAttn | |
| class MLP(nn.Module): | |
| """ Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def gen_sineembed_for_position(pos_tensor, dim=128): | |
| # n_query, bs, _ = pos_tensor.size() | |
| # sineembed_tensor = torch.zeros(n_query, bs, 256) | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(dim, dtype=pos_tensor.dtype, device=pos_tensor.device) | |
| dim_t = 10000 ** (2 * (dim_t // 2) / dim) | |
| x_embed = pos_tensor[:, :, 0] * scale | |
| y_embed = pos_tensor[:, :, 1] * scale | |
| pos_x = x_embed[:, :, None] / dim_t | |
| pos_y = y_embed[:, :, None] / dim_t | |
| pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) | |
| pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) | |
| if pos_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=2) | |
| elif pos_tensor.size(-1) == 4: | |
| w_embed = pos_tensor[:, :, 2] * scale | |
| pos_w = w_embed[:, :, None] / dim_t | |
| pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) | |
| h_embed = pos_tensor[:, :, 3] * scale | |
| pos_h = h_embed[:, :, None] / dim_t | |
| pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
| else: | |
| raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) | |
| return pos | |
| def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, unsigmoid=True): | |
| """ | |
| Input: | |
| - memory: bs, \sum{hw}, d_model | |
| - memory_padding_mask: bs, \sum{hw} | |
| - spatial_shapes: nlevel, 2 | |
| Output: | |
| - output_memory: bs, \sum{hw}, d_model | |
| - output_proposals: bs, \sum{hw}, 4 | |
| """ | |
| N_, S_, C_ = memory.shape | |
| base_scale = 4.0 | |
| proposals = [] | |
| _cur = 0 | |
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |
| if memory_padding_mask is not None: | |
| mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) | |
| valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) | |
| valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) | |
| else: | |
| valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device) | |
| valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device) | |
| grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), | |
| torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) | |
| grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 | |
| scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) | |
| grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale | |
| wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) | |
| proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) | |
| proposals.append(proposal) | |
| _cur += (H_ * W_) | |
| output_proposals = torch.cat(proposals, 1) | |
| output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) | |
| if unsigmoid: | |
| output_proposals = torch.log(output_proposals / (1 - output_proposals)) # unsigmoid | |
| if memory_padding_mask is not None: | |
| output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) | |
| output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) | |
| else: | |
| if memory_padding_mask is not None: | |
| output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) | |
| output_proposals = output_proposals.masked_fill(~output_proposals_valid, float(0)) | |
| output_memory = memory | |
| if memory_padding_mask is not None: | |
| output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) | |
| output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) | |
| return output_memory.to(memory.dtype), output_proposals.to(memory.dtype) | |
| class Transformer(nn.Module): | |
| def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, | |
| num_decoder_layers=6, dim_feedforward=2048, dropout=0.0, | |
| activation="relu", normalize_before=False, | |
| return_intermediate_dec=False, group_detr=1, | |
| two_stage=False, | |
| num_feature_levels=4, dec_n_points=4, | |
| lite_refpoint_refine=False, | |
| decoder_norm_type='LN', | |
| bbox_reparam=False): | |
| super().__init__() | |
| self.encoder = None | |
| decoder_layer = TransformerDecoderLayer(d_model, sa_nhead, ca_nhead, dim_feedforward, | |
| dropout, activation, normalize_before, | |
| group_detr=group_detr, | |
| num_feature_levels=num_feature_levels, | |
| dec_n_points=dec_n_points, | |
| skip_self_attn=False,) | |
| assert decoder_norm_type in ['LN', 'Identity'] | |
| norm = { | |
| "LN": lambda channels: nn.LayerNorm(channels), | |
| "Identity": lambda channels: nn.Identity(), | |
| } | |
| decoder_norm = norm[decoder_norm_type](d_model) | |
| self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, | |
| return_intermediate=return_intermediate_dec, | |
| d_model=d_model, | |
| lite_refpoint_refine=lite_refpoint_refine, | |
| bbox_reparam=bbox_reparam) | |
| self.two_stage = two_stage | |
| if two_stage: | |
| self.enc_output = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(group_detr)]) | |
| self.enc_output_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(group_detr)]) | |
| self._reset_parameters() | |
| self.num_queries = num_queries | |
| self.d_model = d_model | |
| self.dec_layers = num_decoder_layers | |
| self.group_detr = group_detr | |
| self.num_feature_levels = num_feature_levels | |
| self.bbox_reparam = bbox_reparam | |
| self._export = False | |
| def export(self): | |
| self._export = True | |
| def _reset_parameters(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| for m in self.modules(): | |
| if isinstance(m, MSDeformAttn): | |
| m._reset_parameters() | |
| def get_valid_ratio(self, mask): | |
| _, H, W = mask.shape | |
| valid_H = torch.sum(~mask[:, :, 0], 1) | |
| valid_W = torch.sum(~mask[:, 0, :], 1) | |
| valid_ratio_h = valid_H.float() / H | |
| valid_ratio_w = valid_W.float() / W | |
| valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
| return valid_ratio | |
| def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): | |
| src_flatten = [] | |
| mask_flatten = [] if masks is not None else None | |
| lvl_pos_embed_flatten = [] | |
| spatial_shapes = [] | |
| valid_ratios = [] if masks is not None else None | |
| for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)): | |
| bs, c, h, w = src.shape | |
| spatial_shape = (h, w) | |
| spatial_shapes.append(spatial_shape) | |
| src = src.flatten(2).transpose(1, 2) # bs, hw, c | |
| pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c | |
| lvl_pos_embed_flatten.append(pos_embed) | |
| src_flatten.append(src) | |
| if masks is not None: | |
| mask = masks[lvl].flatten(1) # bs, hw | |
| mask_flatten.append(mask) | |
| memory = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c | |
| if masks is not None: | |
| mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} | |
| valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) | |
| lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c | |
| spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device) | |
| level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |
| if self.two_stage: | |
| output_memory, output_proposals = gen_encoder_output_proposals( | |
| memory, mask_flatten, spatial_shapes, unsigmoid=not self.bbox_reparam) | |
| # group detr for first stage | |
| refpoint_embed_ts, memory_ts, boxes_ts = [], [], [] | |
| group_detr = self.group_detr if self.training else 1 | |
| for g_idx in range(group_detr): | |
| output_memory_gidx = self.enc_output_norm[g_idx](self.enc_output[g_idx](output_memory)) | |
| enc_outputs_class_unselected_gidx = self.enc_out_class_embed[g_idx](output_memory_gidx) | |
| if self.bbox_reparam: | |
| enc_outputs_coord_delta_gidx = self.enc_out_bbox_embed[g_idx](output_memory_gidx) | |
| enc_outputs_coord_cxcy_gidx = enc_outputs_coord_delta_gidx[..., | |
| :2] * output_proposals[..., 2:] + output_proposals[..., :2] | |
| enc_outputs_coord_wh_gidx = enc_outputs_coord_delta_gidx[..., 2:].exp() * output_proposals[..., 2:] | |
| enc_outputs_coord_unselected_gidx = torch.concat( | |
| [enc_outputs_coord_cxcy_gidx, enc_outputs_coord_wh_gidx], dim=-1) | |
| else: | |
| enc_outputs_coord_unselected_gidx = self.enc_out_bbox_embed[g_idx]( | |
| output_memory_gidx) + output_proposals # (bs, \sum{hw}, 4) unsigmoid | |
| topk = min(self.num_queries, enc_outputs_class_unselected_gidx.shape[-2]) | |
| topk_proposals_gidx = torch.topk(enc_outputs_class_unselected_gidx.max(-1)[0], topk, dim=1)[1] # bs, nq | |
| refpoint_embed_gidx_undetach = torch.gather( | |
| enc_outputs_coord_unselected_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, 4)) # unsigmoid | |
| # for decoder layer, detached as initial ones, (bs, nq, 4) | |
| refpoint_embed_gidx = refpoint_embed_gidx_undetach.detach() | |
| # get memory tgt | |
| tgt_undetach_gidx = torch.gather( | |
| output_memory_gidx, 1, topk_proposals_gidx.unsqueeze(-1).repeat(1, 1, self.d_model)) | |
| refpoint_embed_ts.append(refpoint_embed_gidx) | |
| memory_ts.append(tgt_undetach_gidx) | |
| boxes_ts.append(refpoint_embed_gidx_undetach) | |
| # concat on dim=1, the nq dimension, (bs, nq, d) --> (bs, nq, d) | |
| refpoint_embed_ts = torch.cat(refpoint_embed_ts, dim=1) | |
| # (bs, nq, d) | |
| memory_ts = torch.cat(memory_ts, dim=1)#.transpose(0, 1) | |
| boxes_ts = torch.cat(boxes_ts, dim=1)#.transpose(0, 1) | |
| if self.dec_layers > 0: | |
| tgt = query_feat.unsqueeze(0).repeat(bs, 1, 1) | |
| refpoint_embed = refpoint_embed.unsqueeze(0).repeat(bs, 1, 1) | |
| if self.two_stage: | |
| ts_len = refpoint_embed_ts.shape[-2] | |
| refpoint_embed_ts_subset = refpoint_embed[..., :ts_len, :] | |
| refpoint_embed_subset = refpoint_embed[..., ts_len:, :] | |
| if self.bbox_reparam: | |
| refpoint_embed_cxcy = refpoint_embed_ts_subset[..., :2] * refpoint_embed_ts[..., 2:] | |
| refpoint_embed_cxcy = refpoint_embed_cxcy + refpoint_embed_ts[..., :2] | |
| refpoint_embed_wh = refpoint_embed_ts_subset[..., 2:].exp() * refpoint_embed_ts[..., 2:] | |
| refpoint_embed_ts_subset = torch.concat( | |
| [refpoint_embed_cxcy, refpoint_embed_wh], dim=-1 | |
| ) | |
| else: | |
| refpoint_embed_ts_subset = refpoint_embed_ts_subset + refpoint_embed_ts | |
| refpoint_embed = torch.concat( | |
| [refpoint_embed_ts_subset, refpoint_embed_subset], dim=-2) | |
| hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask_flatten, | |
| pos=lvl_pos_embed_flatten, refpoints_unsigmoid=refpoint_embed, | |
| level_start_index=level_start_index, | |
| spatial_shapes=spatial_shapes, | |
| valid_ratios=valid_ratios.to(memory.dtype) if valid_ratios is not None else valid_ratios) | |
| else: | |
| assert self.two_stage, "if not using decoder, two_stage must be True" | |
| hs = None | |
| references = None | |
| if self.two_stage: | |
| if self.bbox_reparam: | |
| return hs, references, memory_ts, boxes_ts | |
| else: | |
| return hs, references, memory_ts, boxes_ts.sigmoid() | |
| return hs, references, None, None | |
| class TransformerDecoder(nn.Module): | |
| def __init__(self, | |
| decoder_layer, | |
| num_layers, | |
| norm=None, | |
| return_intermediate=False, | |
| d_model=256, | |
| lite_refpoint_refine=False, | |
| bbox_reparam=False): | |
| super().__init__() | |
| self.layers = _get_clones(decoder_layer, num_layers) | |
| self.num_layers = num_layers | |
| self.d_model = d_model | |
| self.norm = norm | |
| self.return_intermediate = return_intermediate | |
| self.lite_refpoint_refine = lite_refpoint_refine | |
| self.bbox_reparam = bbox_reparam | |
| self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) | |
| self._export = False | |
| def export(self): | |
| self._export = True | |
| def refpoints_refine(self, refpoints_unsigmoid, new_refpoints_delta): | |
| if self.bbox_reparam: | |
| new_refpoints_cxcy = new_refpoints_delta[..., :2] * refpoints_unsigmoid[..., 2:] + refpoints_unsigmoid[..., :2] | |
| new_refpoints_wh = new_refpoints_delta[..., 2:].exp() * refpoints_unsigmoid[..., 2:] | |
| new_refpoints_unsigmoid = torch.concat( | |
| [new_refpoints_cxcy, new_refpoints_wh], dim=-1 | |
| ) | |
| else: | |
| new_refpoints_unsigmoid = refpoints_unsigmoid + new_refpoints_delta | |
| return new_refpoints_unsigmoid | |
| def forward(self, tgt, memory, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| refpoints_unsigmoid: Optional[Tensor] = None, | |
| # for memory | |
| level_start_index: Optional[Tensor] = None, # num_levels | |
| spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2 | |
| valid_ratios: Optional[Tensor] = None): | |
| output = tgt | |
| intermediate = [] | |
| hs_refpoints_unsigmoid = [refpoints_unsigmoid] | |
| def get_reference(refpoints): | |
| # [num_queries, batch_size, 4] | |
| obj_center = refpoints[..., :4] | |
| if self._export: | |
| query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model / 2) # bs, nq, 256*2 | |
| refpoints_input = obj_center[:, :, None] # bs, nq, 1, 4 | |
| else: | |
| refpoints_input = obj_center[:, :, None] \ | |
| * torch.cat([valid_ratios, valid_ratios], -1)[:, None] # bs, nq, nlevel, 4 | |
| query_sine_embed = gen_sineembed_for_position( | |
| refpoints_input[:, :, 0, :], self.d_model / 2) # bs, nq, 256*2 | |
| query_pos = self.ref_point_head(query_sine_embed) | |
| return obj_center, refpoints_input, query_pos, query_sine_embed | |
| # always use init refpoints | |
| if self.lite_refpoint_refine: | |
| if self.bbox_reparam: | |
| obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) | |
| else: | |
| obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) | |
| for layer_id, layer in enumerate(self.layers): | |
| # iter refine each layer | |
| if not self.lite_refpoint_refine: | |
| if self.bbox_reparam: | |
| obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid) | |
| else: | |
| obj_center, refpoints_input, query_pos, query_sine_embed = get_reference(refpoints_unsigmoid.sigmoid()) | |
| # For the first decoder layer, we do not apply transformation over p_s | |
| pos_transformation = 1 | |
| query_pos = query_pos * pos_transformation | |
| output = layer(output, memory, tgt_mask=tgt_mask, | |
| memory_mask=memory_mask, | |
| tgt_key_padding_mask=tgt_key_padding_mask, | |
| memory_key_padding_mask=memory_key_padding_mask, | |
| pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, | |
| is_first=(layer_id == 0), | |
| reference_points=refpoints_input, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index) | |
| if not self.lite_refpoint_refine: | |
| # box iterative update | |
| new_refpoints_delta = self.bbox_embed(output) | |
| new_refpoints_unsigmoid = self.refpoints_refine(refpoints_unsigmoid, new_refpoints_delta) | |
| if layer_id != self.num_layers - 1: | |
| hs_refpoints_unsigmoid.append(new_refpoints_unsigmoid) | |
| refpoints_unsigmoid = new_refpoints_unsigmoid.detach() | |
| if self.return_intermediate: | |
| intermediate.append(self.norm(output)) | |
| if self.norm is not None: | |
| output = self.norm(output) | |
| if self.return_intermediate: | |
| intermediate.pop() | |
| intermediate.append(output) | |
| if self.return_intermediate: | |
| if self._export: | |
| # to shape: B, N, C | |
| hs = intermediate[-1] | |
| if self.bbox_embed is not None: | |
| ref = hs_refpoints_unsigmoid[-1] | |
| else: | |
| ref = refpoints_unsigmoid | |
| return hs, ref | |
| # box iterative update | |
| if self.bbox_embed is not None: | |
| return [ | |
| torch.stack(intermediate), | |
| torch.stack(hs_refpoints_unsigmoid), | |
| ] | |
| else: | |
| return [ | |
| torch.stack(intermediate), | |
| refpoints_unsigmoid.unsqueeze(0) | |
| ] | |
| return output.unsqueeze(0) | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1, | |
| activation="relu", normalize_before=False, group_detr=1, | |
| num_feature_levels=4, dec_n_points=4, | |
| skip_self_attn=False): | |
| super().__init__() | |
| # Decoder Self-Attention | |
| self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # Decoder Cross-Attention | |
| self.cross_attn = MSDeformAttn( | |
| d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points) | |
| self.nhead = ca_nhead | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| self.normalize_before = normalize_before | |
| self.group_detr = group_detr | |
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): | |
| return tensor if pos is None else tensor + pos | |
| def forward_post(self, tgt, memory, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None, | |
| query_sine_embed = None, | |
| is_first = False, | |
| reference_points = None, | |
| spatial_shapes=None, | |
| level_start_index=None, | |
| ): | |
| bs, num_queries, _ = tgt.shape | |
| # ========== Begin of Self-Attention ============= | |
| # Apply projections here | |
| # shape: batch_size x num_queries x 256 | |
| q = k = tgt + query_pos | |
| v = tgt | |
| if self.training: | |
| q = torch.cat(q.split(num_queries // self.group_detr, dim=1), dim=0) | |
| k = torch.cat(k.split(num_queries // self.group_detr, dim=1), dim=0) | |
| v = torch.cat(v.split(num_queries // self.group_detr, dim=1), dim=0) | |
| tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask, | |
| key_padding_mask=tgt_key_padding_mask, | |
| need_weights=False)[0] | |
| if self.training: | |
| tgt2 = torch.cat(tgt2.split(bs, dim=0), dim=1) | |
| # ========== End of Self-Attention ============= | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt = self.norm1(tgt) | |
| # ========== Begin of Cross-Attention ============= | |
| tgt2 = self.cross_attn( | |
| self.with_pos_embed(tgt, query_pos), | |
| reference_points, | |
| memory, | |
| spatial_shapes, | |
| level_start_index, | |
| memory_key_padding_mask | |
| ) | |
| # ========== End of Cross-Attention ============= | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt = self.norm2(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| tgt = self.norm3(tgt) | |
| return tgt | |
| def forward(self, tgt, memory, | |
| tgt_mask: Optional[Tensor] = None, | |
| memory_mask: Optional[Tensor] = None, | |
| tgt_key_padding_mask: Optional[Tensor] = None, | |
| memory_key_padding_mask: Optional[Tensor] = None, | |
| pos: Optional[Tensor] = None, | |
| query_pos: Optional[Tensor] = None, | |
| query_sine_embed = None, | |
| is_first = False, | |
| reference_points = None, | |
| spatial_shapes=None, | |
| level_start_index=None): | |
| return self.forward_post(tgt, memory, tgt_mask, memory_mask, | |
| tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, | |
| query_sine_embed, is_first, | |
| reference_points, spatial_shapes, level_start_index) | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| def build_transformer(args): | |
| try: | |
| two_stage = args.two_stage | |
| except: | |
| two_stage = False | |
| return Transformer( | |
| d_model=args.hidden_dim, | |
| sa_nhead=args.sa_nheads, | |
| ca_nhead=args.ca_nheads, | |
| num_queries=args.num_queries, | |
| dropout=args.dropout, | |
| dim_feedforward=args.dim_feedforward, | |
| num_decoder_layers=args.dec_layers, | |
| return_intermediate_dec=True, | |
| group_detr=args.group_detr, | |
| two_stage=two_stage, | |
| num_feature_levels=args.num_feature_levels, | |
| dec_n_points=args.dec_n_points, | |
| lite_refpoint_refine=args.lite_refpoint_refine, | |
| decoder_norm_type=args.decoder_norm, | |
| bbox_reparam=args.bbox_reparam, | |
| ) | |
| def _get_activation_fn(activation): | |
| """Return an activation function given a string""" | |
| if activation == "relu": | |
| return F.relu | |
| if activation == "gelu": | |
| return F.gelu | |
| if activation == "glu": | |
| return F.glu | |
| raise RuntimeError(F"activation should be relu/gelu, not {activation}.") | |