File size: 11,182 Bytes
56ef371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from typing import Dict

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

from detect_tools.upn import ENCODERS, build_encoder
from detect_tools.upn.models.utils import get_activation_fn, get_clones
from detect_tools.upn.ops.modules import MSDeformAttn


@ENCODERS.register_module()
class DeformableTransformerEncoderLayer(nn.Module):
    """Deformable Transformer Encoder Layer.

    Args:
        d_model (int): The dimension of keys/values/queries in
            :class:`MultiheadAttention`.
        d_ffn (int): The dimension of the feedforward network model.
        dropout (float): Probability of an element to be zeroed.
        activation (str): Activation function in the feedforward network.
            'relu' and 'gelu' are supported.
        n_levels (int): The number of levels in Multi-Scale Deformable Attention.
        n_heads (int): Parallel attention heads.
        n_points (int): Number of sampling points in Multi-Scale Deformable Attention.
        add_channel_attention (bool): If True, add channel attention.
    """

    def __init__(
        self,
        d_model: int = 256,
        d_ffn: int = 1024,
        dropout: float = 0.1,
        activation: str = "relu",
        n_levels: int = 4,
        n_heads: int = 8,
        n_points: int = 4,
        add_channel_attention: bool = False,
    ) -> None:
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = get_activation_fn(activation, d_model=d_ffn)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # channel attention
        self.add_channel_attention = add_channel_attention
        if add_channel_attention:
            self.activ_channel = get_activation_fn("dyrelu", d_model=d_model)
            self.norm_channel = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src: torch.Tensor) -> torch.Tensor:
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(
        self,
        src: torch.Tensor,
        pos: torch.Tensor,
        reference_points: torch.Tensor,
        spatial_shapes: torch.Tensor,
        level_start_index: torch.Tensor,
        key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        """Forward function for `DeformableTransformerEncoderLayer`.

        Args:
            src (torch.Tensor): The input sequence of shape (S, N, E).
            pos (torch.Tensor): The position embedding of shape (S, N, E).
            reference_points (torch.Tensor): The reference points of shape (N, L, 2).
            spatial_shapes (torch.Tensor): The spatial shapes of feature levels.
            level_start_index (torch.Tensor): The start index of each level.
            key_padding_mask (torch.Tensor): The mask for keys with shape (N, S).
        """
        # self attention
        # import ipdb; ipdb.set_trace()
        src2 = self.self_attn(
            self.with_pos_embed(src, pos),
            reference_points,
            src,
            spatial_shapes,
            level_start_index,
            key_padding_mask,
        )
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)

        # channel attn
        if self.add_channel_attention:
            src = self.norm_channel(src + self.activ_channel(src))

        return src


@ENCODERS.register_module()
class UPNEncoder(nn.Module):
    """Implementation of UPN Encoder.

    Args:
        num_layers (int): The number of layers in the TransformerEncoder.
        d_model (int, optional): The dimension of the input feature. Defaults to 256.
        encoder_layer_cfg (Dict): Config for the DeformableEncoderLayer.
        use_checkpoint (bool, optional): Whether to use checkpoint in the fusion layer for
            memory saving. Defaults to False.
        use_transformer_ckpt (bool, optional): Whether to use checkpoint for the deformableencoder.
        enc_layer_share (bool, optional): Whether to share the same memory for the encoder_layer.
            Defaults to False. This is used for all the sub-layers in the basic block.
    """

    def __init__(
        self,
        num_layers: int,
        d_model: int = 256,
        encoder_layer_cfg: Dict = None,
        use_checkpoint: bool = True,
        use_transformer_ckpt: bool = True,
        enc_layer_share: bool = False,
        multi_level_encoder_fusion: str = None,
    ):
        super().__init__()
        # prepare layers
        self.layers = []
        self.refImg_layers = []
        self.fusion_layers = []
        encoder_layer = build_encoder(encoder_layer_cfg)

        self.multi_level_encoder_fusion = multi_level_encoder_fusion
        self._initilize_memory_fusion_layers(
            multi_level_encoder_fusion, num_layers, d_model
        )

        if num_layers > 0:
            self.layers = get_clones(
                encoder_layer, num_layers, layer_share=enc_layer_share
            )
        else:
            self.layers = []
            del encoder_layer

        self.query_scale = None
        self.num_layers = num_layers
        self.d_model = d_model

        self.use_checkpoint = use_checkpoint
        self.use_transformer_ckpt = use_transformer_ckpt

    def _initilize_memory_fusion_layers(self, fusion_type, num_layers, d_model):
        if fusion_type is None:
            self.memory_fusion_layer = None
            return

        assert fusion_type in ["dense_net_fusion", "stable_dense_fusion"]
        if fusion_type == "stable_dense_fusion":
            self.memory_fusion_layer = nn.Sequential(
                nn.Linear(d_model * (num_layers + 1), d_model),
                nn.LayerNorm(d_model),
            )
            nn.init.constant_(self.memory_fusion_layer[0].bias, 0)
        elif fusion_type == "dense_net_fusion":
            self.memory_fusion_layer = nn.ModuleList()
            for i in range(num_layers):
                self.memory_fusion_layer.append(
                    nn.Sequential(
                        nn.Linear(
                            d_model * (i + 2), d_model
                        ),  # from second encoder layer, 512 -> 256 / 3rd: 768 -> 256
                        nn.LayerNorm(d_model),
                    )
                )
            for layer in self.memory_fusion_layer:
                nn.init.constant_(layer[0].bias, 0)
        else:
            raise NotImplementedError

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
            )
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

    def forward(
        self,
        src: torch.Tensor,
        pos: torch.Tensor,
        spatial_shapes: torch.Tensor,
        level_start_index: torch.Tensor,
        valid_ratios: torch.Tensor,
        key_padding_mask: torch.Tensor = None,
    ):
        """Forward function

        Args:
            src (torch.Tensor): Flattened Image features in shape [bs, sum(hi*wi), 256]
            pos (torch.Tensor): Position embedding for image feature in shape [bs, sum(hi*wi), 256]
            spatial_shapes (torch.Tensor): Spatial shape of each level in shape [num_level, 2]
            level_start_index (torch.Tensor): Start index of each level in shape [num_level]
            valid_ratios (torch.Tensor): Valid ratio of each level in shape [bs, num_level, 2]
            key_padding_mask (torch.Tensor): Padding mask for image feature in shape [bs, sum(hi*wi)]
            memory_refImg (torch.Tensor, optional): Text feature in shape [bs, n_ref, 256]. Defaults
                to None.
            refImg_padding_mask (torch.Tensor, optional): Padding mask for reference image feature
                in shape [bs, n_text]. Defaults to None.
            pos_refImg (torch.Tensor, optional): Position embedding for reference image in shape
                [bs, n_ref, 256]. Defaults to None.
            refImg_self_attention_masks (torch.Tensor, optional): Self attention mask for reference
                image feature in shape [bs, n_ref, n_ref]. Defaults to None.
        Outpus:
            torch.Tensor: Encoded image feature in shape [bs, sum(hi*wi), 256]
            torch.Tensor: Encoded reference image feature in shape [bs, n_ref, 256]
        """

        output = src
        # preparation and reshape
        if self.num_layers > 0:
            reference_points = self.get_reference_points(
                spatial_shapes, valid_ratios, device=src.device
            )

        # multi-level dense fusion
        output_list = [output]
        # main process
        for layer_id, layer in enumerate(self.layers):
            # main process
            if self.use_transformer_ckpt:
                output = checkpoint.checkpoint(
                    layer,
                    output,
                    pos,
                    reference_points,
                    spatial_shapes,
                    level_start_index,
                    key_padding_mask,
                )
            else:
                output = layer(
                    src=output,
                    pos=pos,
                    reference_points=reference_points,
                    spatial_shapes=spatial_shapes,
                    level_start_index=level_start_index,
                    key_padding_mask=key_padding_mask,
                )

                output_list.append(output)
                if (
                    self.multi_level_encoder_fusion is not None
                    and self.multi_level_encoder_fusion == "dense_net_fusion"
                ):
                    output = self.memory_fusion_layer[layer_id](
                        torch.cat(output_list, dim=-1)
                    )

        if (
            self.multi_level_encoder_fusion is not None
            and self.multi_level_encoder_fusion == "stable_dense_fusion"
        ):
            output = self.memory_fusion_layer(torch.cat(output_list, dim=-1))

        return output