import copy from typing import Dict, List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from detect_tools.upn import ARCHITECTURES, build_architecture, build_backbone from detect_tools.upn.models.module import (MLP, ContrastiveAssign, NestedTensor, nested_tensor_from_tensor_list) from detect_tools.upn.models.utils import inverse_sigmoid class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x @ARCHITECTURES.register_module() class UPN(nn.Module): """Implementation of UPN""" def __init__( self, vision_backbone_cfg: Dict, transformer_cfg: Dict, num_queries: int, dec_pred_class_embed_share=True, dec_pred_bbox_embed_share=True, decoder_sa_type="sa", ): super().__init__() # build vision backbone self.backbone = build_backbone(vision_backbone_cfg) # build transformer self.transformer = build_architecture(transformer_cfg) self.hidden_dim = self.transformer.d_model # for dn training self.num_queries = num_queries self.num_feature_levels = self.transformer.num_feature_levels # prepare projection layer for vision feature self.input_proj = self.prepare_vision_feature_projection_layer( self.backbone, self.transformer.num_feature_levels, self.hidden_dim, self.transformer.two_stage_type, ) # prepare prediction head self.prepare_prediction_head( dec_pred_class_embed_share, dec_pred_bbox_embed_share, self.hidden_dim, self.transformer.num_decoder_layers, ) self.decoder_sa_type = decoder_sa_type assert decoder_sa_type in ["sa", "ca_label", "ca_content"] # self.replace_sa_with_double_ca = replace_sa_with_double_ca for layer in self.transformer.decoder.layers: layer.label_embedding = None self.label_embedding = None # build a unversal token self.transformer.fine_grained_prompt = nn.Embedding(1, self.hidden_dim) self.transformer.coarse_grained_prompt = nn.Embedding(1, self.hidden_dim) self._reset_parameters() def forward(self, samples: NestedTensor, prompt_type: str = None) -> Dict: """Foward function""" self.device = samples.device ( src_flatten, lvl_pos_embed_flatten, level_start_index, spatial_shapes, valid_ratios, mask_flatten, ) = self.forward_backbone_encoder(samples) ( hs, reference, ref_dict, ) = self.transformer( src_flatten, lvl_pos_embed_flatten, level_start_index, spatial_shapes, valid_ratios, mask_flatten, prompt_type, ) # deformable-detr-line anchor update outputs_coord_list = [] outputs_class = [] for layer_idx, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate( zip(reference[:-1], self.bbox_embed, hs) ): layer_delta_unsig = layer_bbox_embed(layer_hs) layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) layer_outputs_unsig = layer_outputs_unsig.sigmoid() outputs_coord_list.append(layer_outputs_unsig) outputs_coord_list = torch.stack(outputs_coord_list) if ref_dict is None: # build a mock outputs_class for mask_dn training outputs_class = torch.zeros( outputs_coord_list.shape[0], outputs_coord_list.shape[1], outputs_coord_list.shape[2], self.hidden_dim, ) else: outputs_class = torch.stack( [ layer_cls_embed(layer_hs, ref_dict) for layer_cls_embed, layer_hs in zip(self.class_embed, hs) ] ) out = { "pred_logits": outputs_class[-1], "pred_boxes": outputs_coord_list[-1], } out["ref_dict"] = ref_dict return out def forward_backbone_encoder(self, samples: NestedTensor) -> Tuple: # pass through backbone if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) features, poss = self.backbone(samples) # project features srcs = [] masks = [] for l, feat in enumerate(features): src, mask = feat.decompose() srcs.append(self.input_proj[l](src)) # downsample the feature map to 256 masks.append(mask) assert mask is not None if self.num_feature_levels > len( srcs ): # add more feature levels by downsampling the last feature map _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = samples.mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to( torch.bool )[0] pos_l = self.backbone.forward_pos_embed_only( NestedTensor(src, mask) ).to(src.dtype) srcs.append(src) masks.append(mask) poss.append(pos_l) # prepare input for encoder with the following steps: # 1. flatten the feature maps and masks # 2. Add positional embedding and level embedding # 3. Calculate the valid ratio of each feature map based on the mask src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, poss)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) # bs, hw, c mask = mask.flatten(1) # bs, hw pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c if self.num_feature_levels > 1 and self.transformer.level_embed is not None: lvl_pos_embed = pos_embed + self.transformer.level_embed[lvl].view( 1, 1, -1 ) else: lvl_pos_embed = pos_embed lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=src_flatten.device ) level_start_index = torch.cat( (spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) ) valid_ratios = torch.stack( [self.transformer.get_valid_ratio(m) for m in masks], 1 ) return ( src_flatten, lvl_pos_embed_flatten, level_start_index, spatial_shapes, valid_ratios, mask_flatten, ) def prepare_vision_feature_projection_layer( self, backbone: nn.Module, num_feature_levels: int, hidden_dim: int, two_stage_type: str, ) -> nn.ModuleList: """Prepare projection layer to map backbone feature to hidden dim. Args: backbone (nn.Module): Backbone. num_feature_levels (int): Number of feature levels. hidden_dim (int): Hidden dim. two_stage_type (str): Type of two stage. Returns: nn.ModuleList: Projection layer. """ if num_feature_levels > 1: num_backbone_outs = len(backbone.num_channels) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone.num_channels[_] input_proj_list.append( nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append( nn.Sequential( nn.Conv2d( in_channels, hidden_dim, kernel_size=3, stride=2, padding=1 ), nn.GroupNorm(32, hidden_dim), ) ) in_channels = hidden_dim input_proj = nn.ModuleList(input_proj_list) else: assert ( two_stage_type == "no" ), "two_stage_type should be no if num_feature_levels=1 !!!" input_proj = nn.ModuleList( [ nn.Sequential( nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), ) ] ) return input_proj def prepare_prediction_head( self, dec_pred_class_embed_share: bool, dec_pred_bbox_embed_share: bool, hidden_dim: int, num_decoder_layers: int, ) -> Union[nn.ModuleList, nn.ModuleList]: """Prepare prediction head. Including class embed and bbox embed. Args: dec_pred_class_embed_share (bool): Whether to share class embed for all decoder layers. dec_pred_bbox_embed_share (bool): Whether to share bbox embed for all decoder layers. im (int): Hidden dim. num_decoder_layers (int): Number of decoder layers. """ _class_embed = ContrastiveAssign() _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) if dec_pred_bbox_embed_share: box_embed_layerlist = [_bbox_embed for _ in range(num_decoder_layers)] else: box_embed_layerlist = [ copy.deepcopy(_bbox_embed) for i in range(num_decoder_layers) ] if dec_pred_class_embed_share: class_embed_layerlist = [_class_embed for i in range(num_decoder_layers)] else: class_embed_layerlist = [ copy.deepcopy(_class_embed) for i in range(num_decoder_layers) ] bbox_embed = nn.ModuleList(box_embed_layerlist) class_embed = nn.ModuleList(class_embed_layerlist) self.bbox_embed = bbox_embed self.class_embed = class_embed # iniitalize bbox embed and class embed in transformer self.transformer.decoder.bbox_embed = bbox_embed self.transformer.decoder.class_embed = class_embed if self.transformer.two_stage_type != "no": if self.transformer.two_stage_bbox_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_bbox_embed = _bbox_embed else: self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) if self.transformer.two_stage_class_embed_share: assert dec_pred_class_embed_share and dec_pred_bbox_embed_share self.transformer.enc_out_class_embed = _class_embed else: self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) self.refpoint_embed = None def _reset_parameters(self): # init input_proj for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) nn.init.constant_(proj[0].bias, 0)