PengLiu
push inference code
56ef371
raw
history blame
11.2 kB
from typing import Dict
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from detect_tools.upn import ENCODERS, build_encoder
from detect_tools.upn.models.utils import get_activation_fn, get_clones
from detect_tools.upn.ops.modules import MSDeformAttn
@ENCODERS.register_module()
class DeformableTransformerEncoderLayer(nn.Module):
"""Deformable Transformer Encoder Layer.
Args:
d_model (int): The dimension of keys/values/queries in
:class:`MultiheadAttention`.
d_ffn (int): The dimension of the feedforward network model.
dropout (float): Probability of an element to be zeroed.
activation (str): Activation function in the feedforward network.
'relu' and 'gelu' are supported.
n_levels (int): The number of levels in Multi-Scale Deformable Attention.
n_heads (int): Parallel attention heads.
n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
add_channel_attention (bool): If True, add channel attention.
"""
def __init__(
self,
d_model: int = 256,
d_ffn: int = 1024,
dropout: float = 0.1,
activation: str = "relu",
n_levels: int = 4,
n_heads: int = 8,
n_points: int = 4,
add_channel_attention: bool = False,
) -> None:
super().__init__()
# self attention
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(d_model)
# ffn
self.linear1 = nn.Linear(d_model, d_ffn)
self.activation = get_activation_fn(activation, d_model=d_ffn)
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ffn, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
# channel attention
self.add_channel_attention = add_channel_attention
if add_channel_attention:
self.activ_channel = get_activation_fn("dyrelu", d_model=d_model)
self.norm_channel = nn.LayerNorm(d_model)
@staticmethod
def with_pos_embed(tensor, pos):
return tensor if pos is None else tensor + pos
def forward_ffn(self, src: torch.Tensor) -> torch.Tensor:
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
src = src + self.dropout3(src2)
src = self.norm2(src)
return src
def forward(
self,
src: torch.Tensor,
pos: torch.Tensor,
reference_points: torch.Tensor,
spatial_shapes: torch.Tensor,
level_start_index: torch.Tensor,
key_padding_mask: torch.Tensor = None,
) -> torch.Tensor:
"""Forward function for `DeformableTransformerEncoderLayer`.
Args:
src (torch.Tensor): The input sequence of shape (S, N, E).
pos (torch.Tensor): The position embedding of shape (S, N, E).
reference_points (torch.Tensor): The reference points of shape (N, L, 2).
spatial_shapes (torch.Tensor): The spatial shapes of feature levels.
level_start_index (torch.Tensor): The start index of each level.
key_padding_mask (torch.Tensor): The mask for keys with shape (N, S).
"""
# self attention
# import ipdb; ipdb.set_trace()
src2 = self.self_attn(
self.with_pos_embed(src, pos),
reference_points,
src,
spatial_shapes,
level_start_index,
key_padding_mask,
)
src = src + self.dropout1(src2)
src = self.norm1(src)
# ffn
src = self.forward_ffn(src)
# channel attn
if self.add_channel_attention:
src = self.norm_channel(src + self.activ_channel(src))
return src
@ENCODERS.register_module()
class UPNEncoder(nn.Module):
"""Implementation of UPN Encoder.
Args:
num_layers (int): The number of layers in the TransformerEncoder.
d_model (int, optional): The dimension of the input feature. Defaults to 256.
encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer.
use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for
memory saving. Defaults to False.
use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder.
enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer.
Defaults to False. This is used for all the sub-layers in the basic block.
"""
def __init__(
self,
num_layers: int,
d_model: int = 256,
encoder_layer_cfg: Dict = None,
use_checkpoint: bool = True,
use_transformer_ckpt: bool = True,
enc_layer_share: bool = False,
multi_level_encoder_fusion: str = None,
):
super().__init__()
# prepare layers
self.layers = []
self.refImg_layers = []
self.fusion_layers = []
encoder_layer = build_encoder(encoder_layer_cfg)
self.multi_level_encoder_fusion = multi_level_encoder_fusion
self._initilize_memory_fusion_layers(
multi_level_encoder_fusion, num_layers, d_model
)
if num_layers > 0:
self.layers = get_clones(
encoder_layer, num_layers, layer_share=enc_layer_share
)
else:
self.layers = []
del encoder_layer
self.query_scale = None
self.num_layers = num_layers
self.d_model = d_model
self.use_checkpoint = use_checkpoint
self.use_transformer_ckpt = use_transformer_ckpt
def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model):
if fusion_type is None:
self.memory_fusion_layer = None
return
assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"]
if fusion_type == "stable_dense_fusion":
self.memory_fusion_layer = nn.Sequential(
nn.Linear(d_model * (num_layers + 1), d_model),
nn.LayerNorm(d_model),
)
nn.init.constant_(self.memory_fusion_layer[0].bias, 0)
elif fusion_type == "dense_net_fusion":
self.memory_fusion_layer = nn.ModuleList()
for i in range(num_layers):
self.memory_fusion_layer.append(
nn.Sequential(
nn.Linear(
d_model * (i + 2), d_model
), # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256
nn.LayerNorm(d_model),
)
)
for layer in self.memory_fusion_layer:
nn.init.constant_(layer[0].bias, 0)
else:
raise NotImplementedError
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
src: torch.Tensor,
pos: torch.Tensor,
spatial_shapes: torch.Tensor,
level_start_index: torch.Tensor,
valid_ratios: torch.Tensor,
key_padding_mask: torch.Tensor = None,
):
"""Forward function
Args:
src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256]
pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256]
spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2]
level_start_index (torch.Tensor): Start index of each level in shape [num_level]
valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2]
key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)]
memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults
to None.
refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature
in shape [bs, n_text]. Defaults to None.
pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape
[bs, n_ref, 256]. Defaults to None.
refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference
image feature in shape [bs, n_ref, n_ref]. Defaults to None.
Outpus:
torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256]
torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256]
"""
output = src
# preparation and reshape
if self.num_layers > 0:
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=src.device
)
# multi-level dense fusion
output_list = [output]
# main process
for layer_id, layer in enumerate(self.layers):
# main process
if self.use_transformer_ckpt:
output = checkpoint.checkpoint(
layer,
output,
pos,
reference_points,
spatial_shapes,
level_start_index,
key_padding_mask,
)
else:
output = layer(
src=output,
pos=pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask,
)
output_list.append(output)
if (
self.multi_level_encoder_fusion is not None
and self.multi_level_encoder_fusion == "dense_net_fusion"
):
output = self.memory_fusion_layer[layer_id](
torch.cat(output_list, dim=-1)
)
if (
self.multi_level_encoder_fusion is not None
and self.multi_level_encoder_fusion == "stable_dense_fusion"
):
output = self.memory_fusion_layer(torch.cat(output_list, dim=-1))
return output