Spaces:
Runtime error
Runtime error
| from typing import Dict | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint as checkpoint | |
| from detect_tools.upn import ENCODERS, build_encoder | |
| from detect_tools.upn.models.utils import get_activation_fn, get_clones | |
| from detect_tools.upn.ops.modules import MSDeformAttn | |
| class DeformableTransformerEncoderLayer(nn.Module): | |
| """Deformable Transformer Encoder Layer. | |
| Args: | |
| d_model (int): The dimension of keys/values/queries in | |
| :class:`MultiheadAttention`. | |
| d_ffn (int): The dimension of the feedforward network model. | |
| dropout (float): Probability of an element to be zeroed. | |
| activation (str): Activation function in the feedforward network. | |
| 'relu' and 'gelu' are supported. | |
| n_levels (int): The number of levels in Multi-Scale Deformable Attention. | |
| n_heads (int): Parallel attention heads. | |
| n_points (int): Number of sampling points in Multi-Scale Deformable Attention. | |
| add_channel_attention (bool): If True, add channel attention. | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int = 256, | |
| d_ffn: int = 1024, | |
| dropout: float = 0.1, | |
| activation: str = "relu", | |
| n_levels: int = 4, | |
| n_heads: int = 8, | |
| n_points: int = 4, | |
| add_channel_attention: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| # self attention | |
| self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| # ffn | |
| self.linear1 = nn.Linear(d_model, d_ffn) | |
| self.activation = get_activation_fn(activation, d_model=d_ffn) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(d_ffn, d_model) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| # channel attention | |
| self.add_channel_attention = add_channel_attention | |
| if add_channel_attention: | |
| self.activ_channel = get_activation_fn("dyrelu", d_model=d_model) | |
| self.norm_channel = nn.LayerNorm(d_model) | |
| def with_pos_embed(tensor, pos): | |
| return tensor if pos is None else tensor + pos | |
| def forward_ffn(self, src: torch.Tensor) -> torch.Tensor: | |
| src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) | |
| src = src + self.dropout3(src2) | |
| src = self.norm2(src) | |
| return src | |
| def forward( | |
| self, | |
| src: torch.Tensor, | |
| pos: torch.Tensor, | |
| reference_points: torch.Tensor, | |
| spatial_shapes: torch.Tensor, | |
| level_start_index: torch.Tensor, | |
| key_padding_mask: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| """Forward function for `DeformableTransformerEncoderLayer`. | |
| Args: | |
| src (torch.Tensor): The input sequence of shape (S, N, E). | |
| pos (torch.Tensor): The position embedding of shape (S, N, E). | |
| reference_points (torch.Tensor): The reference points of shape (N, L, 2). | |
| spatial_shapes (torch.Tensor): The spatial shapes of feature levels. | |
| level_start_index (torch.Tensor): The start index of each level. | |
| key_padding_mask (torch.Tensor): The mask for keys with shape (N, S). | |
| """ | |
| # self attention | |
| # import ipdb; ipdb.set_trace() | |
| src2 = self.self_attn( | |
| self.with_pos_embed(src, pos), | |
| reference_points, | |
| src, | |
| spatial_shapes, | |
| level_start_index, | |
| key_padding_mask, | |
| ) | |
| src = src + self.dropout1(src2) | |
| src = self.norm1(src) | |
| # ffn | |
| src = self.forward_ffn(src) | |
| # channel attn | |
| if self.add_channel_attention: | |
| src = self.norm_channel(src + self.activ_channel(src)) | |
| return src | |
| class UPNEncoder(nn.Module): | |
| """Implementation of UPN Encoder. | |
| Args: | |
| num_layers (int): The number of layers in the TransformerEncoder. | |
| d_model (int, optional): The dimension of the input feature. Defaults to 256. | |
| encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer. | |
| use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for | |
| memory saving. Defaults to False. | |
| use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder. | |
| enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer. | |
| Defaults to False. This is used for all the sub-layers in the basic block. | |
| """ | |
| def __init__( | |
| self, | |
| num_layers: int, | |
| d_model: int = 256, | |
| encoder_layer_cfg: Dict = None, | |
| use_checkpoint: bool = True, | |
| use_transformer_ckpt: bool = True, | |
| enc_layer_share: bool = False, | |
| multi_level_encoder_fusion: str = None, | |
| ): | |
| super().__init__() | |
| # prepare layers | |
| self.layers = [] | |
| self.refImg_layers = [] | |
| self.fusion_layers = [] | |
| encoder_layer = build_encoder(encoder_layer_cfg) | |
| self.multi_level_encoder_fusion = multi_level_encoder_fusion | |
| self._initilize_memory_fusion_layers( | |
| multi_level_encoder_fusion, num_layers, d_model | |
| ) | |
| if num_layers > 0: | |
| self.layers = get_clones( | |
| encoder_layer, num_layers, layer_share=enc_layer_share | |
| ) | |
| else: | |
| self.layers = [] | |
| del encoder_layer | |
| self.query_scale = None | |
| self.num_layers = num_layers | |
| self.d_model = d_model | |
| self.use_checkpoint = use_checkpoint | |
| self.use_transformer_ckpt = use_transformer_ckpt | |
| def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model): | |
| if fusion_type is None: | |
| self.memory_fusion_layer = None | |
| return | |
| assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"] | |
| if fusion_type == "stable_dense_fusion": | |
| self.memory_fusion_layer = nn.Sequential( | |
| nn.Linear(d_model * (num_layers + 1), d_model), | |
| nn.LayerNorm(d_model), | |
| ) | |
| nn.init.constant_(self.memory_fusion_layer[0].bias, 0) | |
| elif fusion_type == "dense_net_fusion": | |
| self.memory_fusion_layer = nn.ModuleList() | |
| for i in range(num_layers): | |
| self.memory_fusion_layer.append( | |
| nn.Sequential( | |
| nn.Linear( | |
| d_model * (i + 2), d_model | |
| ), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256 | |
| nn.LayerNorm(d_model), | |
| ) | |
| ) | |
| for layer in self.memory_fusion_layer: | |
| nn.init.constant_(layer[0].bias, 0) | |
| else: | |
| raise NotImplementedError | |
| def get_reference_points(spatial_shapes, valid_ratios, device): | |
| reference_points_list = [] | |
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |
| ref_y, ref_x = torch.meshgrid( | |
| torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), | |
| torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), | |
| ) | |
| ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) | |
| ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) | |
| ref = torch.stack((ref_x, ref_y), -1) | |
| reference_points_list.append(ref) | |
| reference_points = torch.cat(reference_points_list, 1) | |
| reference_points = reference_points[:, :, None] * valid_ratios[:, None] | |
| return reference_points | |
| def forward( | |
| self, | |
| src: torch.Tensor, | |
| pos: torch.Tensor, | |
| spatial_shapes: torch.Tensor, | |
| level_start_index: torch.Tensor, | |
| valid_ratios: torch.Tensor, | |
| key_padding_mask: torch.Tensor = None, | |
| ): | |
| """Forward function | |
| Args: | |
| src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256] | |
| pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256] | |
| spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2] | |
| level_start_index (torch.Tensor): Start index of each level in shape [num_level] | |
| valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2] | |
| key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)] | |
| memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults | |
| to None. | |
| refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature | |
| in shape [bs, n_text]. Defaults to None. | |
| pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape | |
| [bs, n_ref, 256]. Defaults to None. | |
| refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference | |
| image feature in shape [bs, n_ref, n_ref]. Defaults to None. | |
| Outpus: | |
| torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256] | |
| torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256] | |
| """ | |
| output = src | |
| # preparation and reshape | |
| if self.num_layers > 0: | |
| reference_points = self.get_reference_points( | |
| spatial_shapes, valid_ratios, device=src.device | |
| ) | |
| # multi-level dense fusion | |
| output_list = [output] | |
| # main process | |
| for layer_id, layer in enumerate(self.layers): | |
| # main process | |
| if self.use_transformer_ckpt: | |
| output = checkpoint.checkpoint( | |
| layer, | |
| output, | |
| pos, | |
| reference_points, | |
| spatial_shapes, | |
| level_start_index, | |
| key_padding_mask, | |
| ) | |
| else: | |
| output = layer( | |
| src=output, | |
| pos=pos, | |
| reference_points=reference_points, | |
| spatial_shapes=spatial_shapes, | |
| level_start_index=level_start_index, | |
| key_padding_mask=key_padding_mask, | |
| ) | |
| output_list.append(output) | |
| if ( | |
| self.multi_level_encoder_fusion is not None | |
| and self.multi_level_encoder_fusion == "dense_net_fusion" | |
| ): | |
| output = self.memory_fusion_layer[layer_id]( | |
| torch.cat(output_list, dim=-1) | |
| ) | |
| if ( | |
| self.multi_level_encoder_fusion is not None | |
| and self.multi_level_encoder_fusion == "stable_dense_fusion" | |
| ): | |
| output = self.memory_fusion_layer(torch.cat(output_list, dim=-1)) | |
| return output | |