| |
| """ |
| DETR Transformer class. |
| |
| Copy-paste from torch.nn.Transformer with modifications: |
| * positional encodings are passed in MHattention |
| * extra LN at the end of encoder is removed |
| * decoder returns a stack of activations from all decoding layers |
| """ |
| import copy |
| from typing import Optional |
| import torch |
| import torch.nn.functional as F |
| from torch import nn, Tensor |
| import math |
| import numpy as np |
| from .attention import MultiheadAttention |
| from .crossattention import MultiheadAttention as cateattention |
|
|
| class MLP(nn.Module): |
| """ Very simple multi-layer perceptron (also called FFN)""" |
|
|
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| super().__init__() |
| self.num_layers = num_layers |
| h = [hidden_dim] * (num_layers - 1) |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
| def forward(self, x): |
| for i, layer in enumerate(self.layers): |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| return x |
|
|
| def inverse_sigmoid(x, eps=1e-3): |
| x = x.clamp(min=0, max=1) |
| x1 = x.clamp(min=eps) |
| x2 = (1 - x).clamp(min=eps) |
| return torch.log(x1/x2) |
|
|
| def gen_sineembed_for_position(pos_tensor, d_model): |
| |
| |
| scale = 2 * math.pi |
| dim_t = torch.arange(d_model//2, dtype=torch.float32, device=pos_tensor.device) |
| dim_t = 10000 ** (2 * (dim_t // 2) / (d_model//2)) |
| center_embed = pos_tensor[:, :, 0] * scale |
| pos_x = center_embed[:, :, None] / dim_t |
| pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
| span_embed = pos_tensor[:, :, 1] * scale |
| pos_w = span_embed[:, :, None] / dim_t |
| pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
| pos = torch.cat((pos_x, pos_w), dim=2) |
| return pos |
|
|
| class Transformer(nn.Module): |
|
|
| def __init__(self, d_model=512, nhead=8, num_queries=2, num_encoder_layers=6, |
| num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False, |
| return_intermediate_dec=False, query_dim=2, |
| keep_query_pos=False, query_scale_type='cond_elewise', |
| num_patterns=0, |
| modulate_t_attn=True, |
| bbox_embed_diff_each_layer=False, args=None |
| ): |
| super().__init__() |
| self.args = args |
| mcls_encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| dropout, activation, normalize_before) |
| mcls_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| self.mcls_encoder = TransformerEncoder(mcls_encoder_layer, args.moment_layers, mcls_encoder_norm) |
|
|
| t2v_encoder_layer = T2V_TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| dropout, activation, normalize_before, self.args.num_dummies) |
| encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| self.t2v_encoder = TransformerCATEEncoder(t2v_encoder_layer, args.t2v_layers, encoder_norm) |
|
|
| encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| dropout, activation, normalize_before) |
| encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
| decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, |
| dropout, activation, normalize_before, keep_query_pos=keep_query_pos) |
| decoder_norm = nn.LayerNorm(d_model) |
| self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
| return_intermediate=return_intermediate_dec, |
| d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type, |
| modulate_t_attn=modulate_t_attn, |
| bbox_embed_diff_each_layer=bbox_embed_diff_each_layer) |
|
|
| self._reset_parameters() |
|
|
| self.d_model = d_model |
| self.nhead = nhead |
| self.dec_layers = num_decoder_layers |
| self.num_queries = num_queries |
| self.num_patterns = num_patterns |
|
|
| def _reset_parameters(self): |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward(self, src, mask, query_embed, pos_embed, video_length=None, moment_idx=None, msrc=None, mpos=None, mmask=None, |
| nmsrc=None, nmpos=None, nmmask=None, |
| ctxtoken=None, gtoken=None, gpos=None, vlen=None): |
| """ |
| Args: |
| src: (batch_size, L, d) |
| mask: (batch_size, L) |
| query_embed: (#queries, d) |
| pos_embed: (batch_size, L, d) the same as src |
| video length: feature shape |
| vlen: actual video length |
| Returns: |
| """ |
| |
| device = ctxtoken.device |
| if msrc is not None: |
| msrc = msrc.permute(1, 0, 2) |
| mpos = mpos.permute(1, 0, 2) |
| mmemory = self.mcls_encoder(msrc, src_key_padding_mask=mmask, pos=mpos) |
| mmemory_moment, mmemory_frames = mmemory[0], mmemory[1:] |
| else: |
| mmemory_moment = None |
| mmemory_frames = None |
| if nmsrc is not None: |
| nmsrc = nmsrc.permute(1, 0, 2) |
| nmpos = nmpos.permute(1, 0, 2) |
| nmmemory = self.mcls_encoder(nmsrc, src_key_padding_mask=nmmask, pos=nmpos) |
| nmmemory_moment, nmmemory_frames = nmmemory[0], nmmemory[1:] |
| else: |
| nmmemory_moment = None |
| nmmemory_frames = None |
|
|
| |
| bs, l, d = src.shape |
| src = src.permute(1, 0, 2) |
| pos_embed = pos_embed.permute(1, 0, 2) |
| refpoint_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) |
|
|
| |
| |
| t2v_src, attn_weights = self.t2v_encoder(src, src_key_padding_mask=mask, pos=pos_embed, video_length=video_length) |
|
|
| |
| |
| ctx_src_ = ctxtoken.permute(1, 0, 2) |
|
|
| |
| |
| |
| fr_token_sim = torch.softmax(torch.matmul(F.normalize((src[:video_length] - ctx_src_).permute(1, 0, 2), dim=2), F.normalize(gtoken, dim=1).T), dim=-1) |
| |
| frame_importance = attn_weights[:, :, self.args.num_dummies:].sum(2).clone().detach() |
| |
| for i in range(len(frame_importance)): |
| frame_importance[i][vlen[i]:] *= 0. |
| |
| frame_importance = (frame_importance / frame_importance.sum(1).unsqueeze(1)) * frame_importance.size(1) |
| |
| fr_token_sim = fr_token_sim * frame_importance.unsqueeze(2).repeat(1, 1, fr_token_sim.size(2)) |
| fr_token_sim = fr_token_sim.mean(1) |
| topk_val, topkidx = torch.topk(fr_token_sim, k=self.args.num_prompts, dim=1) |
| src_ = torch.zeros((len(fr_token_sim), self.d_model), dtype=torch.bfloat16).to(device) |
| for i in range(len(fr_token_sim)): |
| src_[i] = (topk_val[i].unsqueeze(1) * gtoken[topkidx[i]]).sum(0) |
| src_ = src_.reshape(1, src.size(1), -1) |
|
|
| |
| src_ = src_ + ctx_src_ |
| pos_ = gpos.reshape([1, 1, self.d_model]).repeat(1, pos_embed.shape[1], 1) |
| mask_ = torch.tensor([[False]]).to(mask.device).repeat(mask.shape[0], 1) |
|
|
| |
| src_, _ = self.t2v_encoder(src_, src_key_padding_mask=mask_, pos=pos_, |
| video_length=video_length, dummy=False) |
|
|
| src = torch.cat([src_, t2v_src], dim=0) |
| mask = torch.cat([mask_, mask], dim=1) |
| pos_embed = torch.cat([pos_, pos_embed], dim=0) |
|
|
| src = src[:video_length + 1] |
| mask = mask[:, :video_length + 1] |
| pos_embed = pos_embed[:video_length + 1] |
|
|
| memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) |
| memory_global, memory_local = memory[0], memory[1:] |
| memory_local += memory_global.unsqueeze(0).repeat(memory_local.size(0), 1, 1) |
| mask_local = mask[:, 1:] |
| pos_embed_local = pos_embed[1:] |
|
|
| tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(device) |
| tgt = tgt.type(torch.bfloat16) |
|
|
| |
| hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local, pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) |
| memory_local = memory_local.transpose(0, 1) |
|
|
| return hs, references, memory_local, memory_global, attn_weights, mmemory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames |
|
|
|
|
| class TransformerCATEEncoder(nn.Module): |
| def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): |
| super().__init__() |
| self.layers = _get_clones(encoder_layer, num_layers) |
| self.num_layers = num_layers |
| self.norm = norm |
| self.return_intermediate = return_intermediate |
|
|
| def forward(self, src, |
| mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| dummy=True, |
| **kwargs): |
| output = src |
|
|
| intermediate = [] |
| attn_weights = None |
| for i, layer in enumerate(self.layers): |
| output, attn_weight = layer(output, src_mask=mask, |
| src_key_padding_mask=src_key_padding_mask, pos=pos, dummy=dummy, **kwargs) |
| if attn_weights is None: |
| attn_weights = attn_weight |
| else: |
| attn_weights = attn_weights + attn_weight |
| if self.return_intermediate: |
| intermediate.append(output) |
| attn_weights /= self.num_layers |
|
|
| if self.norm is not None: |
| output = self.norm(output) |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate) |
|
|
| return output, attn_weights |
|
|
| class TransformerEncoder(nn.Module): |
|
|
| def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): |
| super().__init__() |
| self.layers = _get_clones(encoder_layer, num_layers) |
| self.num_layers = num_layers |
| self.norm = norm |
| self.return_intermediate = return_intermediate |
|
|
| def forward(self, src, |
| mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| **kwargs): |
| output = src |
|
|
| intermediate = [] |
|
|
| for layer in self.layers: |
| output = layer(output, src_mask=mask, |
| src_key_padding_mask=src_key_padding_mask, pos=pos, **kwargs) |
| if self.return_intermediate: |
| intermediate.append(output) |
|
|
| if self.norm is not None: |
| output = self.norm(output) |
|
|
| if self.return_intermediate: |
| return torch.stack(intermediate) |
|
|
| return output |
|
|
|
|
| class TransformerDecoder(nn.Module): |
|
|
| def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, |
| d_model=256, query_dim=2, keep_query_pos=False, query_scale_type='cond_elewise', |
| modulate_t_attn=False, |
| bbox_embed_diff_each_layer=False, |
| ): |
| super().__init__() |
| self.layers = _get_clones(decoder_layer, num_layers) |
| self.num_layers = num_layers |
| self.norm = norm |
| self.return_intermediate = return_intermediate |
| assert return_intermediate |
| self.query_dim = query_dim |
|
|
| assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise'] |
| self.query_scale_type = query_scale_type |
| if query_scale_type == 'cond_elewise': |
| self.query_scale = MLP(d_model, d_model, d_model, 2) |
| elif query_scale_type == 'cond_scalar': |
| self.query_scale = MLP(d_model, d_model, 1, 2) |
| elif query_scale_type == 'fix_elewise': |
| self.query_scale = nn.Embedding(num_layers, d_model) |
| else: |
| raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type)) |
|
|
| self.ref_point_head = MLP(d_model, d_model, d_model, 2) |
|
|
| |
| |
| if bbox_embed_diff_each_layer: |
| self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for i in range(num_layers)]) |
| else: |
| self.bbox_embed = MLP(d_model, d_model, 2, 3) |
| |
| if bbox_embed_diff_each_layer: |
| for bbox_embed in self.bbox_embed: |
| nn.init.constant_(bbox_embed.layers[-1].weight.data, 0) |
| nn.init.constant_(bbox_embed.layers[-1].bias.data, 0) |
| else: |
| nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) |
| nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) |
| self.d_model = d_model |
| self.modulate_t_attn = modulate_t_attn |
| self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer |
|
|
| if modulate_t_attn: |
| self.ref_anchor_head = MLP(d_model, d_model, 1, 2) |
|
|
| if not keep_query_pos: |
| for layer_id in range(num_layers - 1): |
| self.layers[layer_id + 1].ca_qpos_proj = None |
|
|
| def forward(self, tgt, memory, |
| tgt_mask: Optional[Tensor] = None, |
| memory_mask: Optional[Tensor] = None, |
| tgt_key_padding_mask: Optional[Tensor] = None, |
| memory_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| refpoints_unsigmoid: Optional[Tensor] = None, |
| ): |
| output = tgt |
|
|
| intermediate = [] |
| reference_points = refpoints_unsigmoid.sigmoid() |
| ref_points = [reference_points] |
|
|
| |
|
|
| for layer_id, layer in enumerate(self.layers): |
| obj_center = reference_points[..., :self.query_dim] |
| |
| query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model) |
| query_sine_embed = query_sine_embed.type(torch.bfloat16) |
|
|
| query_pos = self.ref_point_head(query_sine_embed) |
| |
| if self.query_scale_type != 'fix_elewise': |
| if layer_id == 0: |
| pos_transformation = 1 |
| else: |
| pos_transformation = self.query_scale(output) |
| else: |
| pos_transformation = self.query_scale.weight[layer_id] |
|
|
| |
| query_sine_embed = query_sine_embed * pos_transformation |
|
|
| |
| if self.modulate_t_attn: |
| reft_cond = self.ref_anchor_head(output).sigmoid() |
|
|
| query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1) |
|
|
|
|
| output = layer(output, memory, tgt_mask=tgt_mask, |
| memory_mask=memory_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask, |
| memory_key_padding_mask=memory_key_padding_mask, |
| pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, |
| is_first=(layer_id == 0)) |
|
|
| |
| if self.bbox_embed is not None: |
| if self.bbox_embed_diff_each_layer: |
| tmp = self.bbox_embed[layer_id](output) |
| else: |
| tmp = self.bbox_embed(output) |
| |
| tmp[..., :self.query_dim] += inverse_sigmoid(reference_points) |
| new_reference_points = tmp[..., :self.query_dim].sigmoid() |
| if layer_id != self.num_layers - 1: |
| ref_points.append(new_reference_points) |
| reference_points = new_reference_points.detach() |
|
|
| if self.return_intermediate: |
| intermediate.append(self.norm(output)) |
|
|
| if self.norm is not None: |
| output = self.norm(output) |
| if self.return_intermediate: |
| intermediate.pop() |
| intermediate.append(output) |
|
|
| if self.return_intermediate: |
| if self.bbox_embed is not None: |
| return [ |
| torch.stack(intermediate).transpose(1, 2), |
| torch.stack(ref_points).transpose(1, 2), |
| ] |
| else: |
| return [ |
| torch.stack(intermediate).transpose(1, 2), |
| reference_points.unsqueeze(0).transpose(1, 2) |
| ] |
|
|
| return output.unsqueeze(0) |
|
|
|
|
| class TransformerEncoderLayerThin(nn.Module): |
|
|
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| |
| |
| |
| |
| self.linear = nn.Linear(d_model, d_model) |
| self.norm = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.normalize_before = normalize_before |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward_post(self, |
| src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| q = k = self.with_pos_embed(src, pos) |
| src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask)[0] |
| src2 = self.linear(src2) |
| src = src + self.dropout(src2) |
| src = self.norm(src) |
| |
| |
| |
| |
| |
| return src |
|
|
| def forward_pre(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| """not used""" |
| src2 = self.norm1(src) |
| q = k = self.with_pos_embed(src2, pos) |
| src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask)[0] |
| src = src + self.dropout1(src2) |
| src2 = self.norm2(src) |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
| src = src + self.dropout2(src2) |
| return src |
|
|
| def forward(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| if self.normalize_before: |
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
| return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
| class T2V_TransformerEncoderLayer(nn.Module): |
|
|
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False, num_dummies=3): |
| super().__init__() |
| self.self_attn = cateattention(d_model, nhead, dropout=dropout, num_dummies=num_dummies) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = DropPath(dropout) |
| self.dropout2 = DropPath(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
| self.normalize_before = normalize_before |
| self.nhead = nhead |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward_post(self, |
| src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| video_length=None, dummy=True): |
| assert video_length is not None |
| pos_src = self.with_pos_embed(src, pos) |
| q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:] |
|
|
| qmask, kmask = src_key_padding_mask[:, :video_length].unsqueeze(2), src_key_padding_mask[:, video_length:].unsqueeze(1) |
| attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| src2, attn_weights = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask[:, video_length:], dummy=dummy) |
|
|
| src2 = src[:video_length] + self.dropout1(src2) |
| src3 = self.norm1(src2) |
| src3 = self.linear2(self.dropout(self.activation(self.linear1(src3)))) |
| src2 = src2 + self.dropout2(src3) |
| src2 = self.norm2(src2) |
|
|
| src = torch.cat([src2, src[video_length:]]) |
| return src, attn_weights |
|
|
| def forward_pre(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, dummy=True): |
| pass |
|
|
|
|
| def forward(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, dummy=True, |
| **kwargs): |
| if self.normalize_before: |
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos, dummy=dummy) |
| return self.forward_post(src, src_mask, src_key_padding_mask, pos, dummy=dummy, **kwargs) |
|
|
| class TransformerEncoderLayer(nn.Module): |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = DropPath(dropout) |
| self.dropout2 = DropPath(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
| self.normalize_before = normalize_before |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward_post(self, |
| src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| q = k = self.with_pos_embed(src, pos) |
| src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask)[0] |
| src = src + self.dropout1(src2) |
| src = self.norm1(src) |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| src = src + self.dropout2(src2) |
| src = self.norm2(src) |
| return src |
|
|
| def forward_pre(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| pass |
|
|
| def forward(self, src, |
| src_mask: Optional[Tensor] = None, |
| src_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None): |
| if self.normalize_before: |
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
| return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
| class TransformerDecoderLayer(nn.Module): |
|
|
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False, keep_query_pos=False, |
| rm_self_attn_decoder=False): |
| super().__init__() |
| |
| if not rm_self_attn_decoder: |
| self.sa_qcontent_proj = nn.Linear(d_model, d_model) |
| self.sa_qpos_proj = nn.Linear(d_model, d_model) |
| self.sa_kcontent_proj = nn.Linear(d_model, d_model) |
| self.sa_kpos_proj = nn.Linear(d_model, d_model) |
| self.sa_v_proj = nn.Linear(d_model, d_model) |
| self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.dropout1 = DropPath(dropout) |
|
|
| |
| self.ca_qcontent_proj = nn.Linear(d_model, d_model) |
| self.ca_qpos_proj = nn.Linear(d_model, d_model) |
| self.ca_kcontent_proj = nn.Linear(d_model, d_model) |
| self.ca_kpos_proj = nn.Linear(d_model, d_model) |
| self.ca_v_proj = nn.Linear(d_model, d_model) |
| self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) |
| self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model) |
|
|
| self.nhead = nhead |
| self.rm_self_attn_decoder = rm_self_attn_decoder |
|
|
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
| self.dropout2 = DropPath(dropout) |
| self.dropout3 = DropPath(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
| self.normalize_before = normalize_before |
| self.keep_query_pos = keep_query_pos |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward(self, tgt, memory, |
| tgt_mask: Optional[Tensor] = None, |
| memory_mask: Optional[Tensor] = None, |
| tgt_key_padding_mask: Optional[Tensor] = None, |
| memory_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| query_pos: Optional[Tensor] = None, |
| query_sine_embed=None, |
| is_first=False): |
|
|
| |
| if not self.rm_self_attn_decoder: |
| |
| |
| q_content = self.sa_qcontent_proj(tgt) |
| q_pos = self.sa_qpos_proj(query_pos) |
| k_content = self.sa_kcontent_proj(tgt) |
| k_pos = self.sa_kpos_proj(query_pos) |
| v = self.sa_v_proj(tgt) |
|
|
| num_queries, bs, n_model = q_content.shape |
| hw, _, _ = k_content.shape |
|
|
| q = q_content + q_pos |
| k = k_content + k_pos |
|
|
| tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask, |
| key_padding_mask=tgt_key_padding_mask)[0] |
| |
|
|
| tgt = tgt + self.dropout1(tgt2) |
| tgt = self.norm1(tgt) |
|
|
| |
| |
| |
| q_content = self.ca_qcontent_proj(tgt) |
| k_content = self.ca_kcontent_proj(memory) |
| v = self.ca_v_proj(memory) |
|
|
| num_queries, bs, n_model = q_content.shape |
| hw, _, _ = k_content.shape |
|
|
| k_pos = self.ca_kpos_proj(pos) |
|
|
| |
| |
| if is_first or self.keep_query_pos: |
| q_pos = self.ca_qpos_proj(query_pos) |
| q = q_content + q_pos |
| k = k_content + k_pos |
| else: |
| q = q_content |
| k = k_content |
|
|
| q = q.view(num_queries, bs, self.nhead, n_model // self.nhead) |
| query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) |
| query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model // self.nhead) |
| q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2) |
| k = k.view(hw, bs, self.nhead, n_model // self.nhead) |
| k_pos = k_pos.view(hw, bs, self.nhead, n_model // self.nhead) |
| k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2) |
|
|
| tgt2 = self.cross_attn(query=q, |
| key=k, |
| value=v, attn_mask=memory_mask, |
| key_padding_mask=memory_key_padding_mask)[0] |
| |
|
|
| tgt = tgt + self.dropout2(tgt2) |
| tgt = self.norm2(tgt) |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
| tgt = tgt + self.dropout3(tgt2) |
| tgt = self.norm3(tgt) |
| return tgt |
|
|
|
|
| class TransformerDecoderLayerThin(nn.Module): |
| """removed intermediate layer""" |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| activation="relu", normalize_before=False): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| |
| self.linear1 = nn.Linear(d_model, d_model) |
|
|
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| |
| self.dropout1 = DropPath(dropout) |
| self.dropout2 = DropPath(dropout) |
|
|
|
|
| |
| self.normalize_before = normalize_before |
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| return tensor if pos is None else tensor + pos |
|
|
| def forward_post(self, tgt, memory, |
| tgt_mask: Optional[Tensor] = None, |
| memory_mask: Optional[Tensor] = None, |
| tgt_key_padding_mask: Optional[Tensor] = None, |
| memory_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| query_pos: Optional[Tensor] = None): |
| q = k = self.with_pos_embed(tgt, query_pos) |
| tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
| key_padding_mask=tgt_key_padding_mask)[0] |
| tgt = tgt + self.dropout1(tgt2) |
| tgt = self.norm1(tgt) |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
| key=self.with_pos_embed(memory, pos), |
| value=memory, attn_mask=memory_mask, |
| key_padding_mask=memory_key_padding_mask)[0] |
| tgt2 = self.linear1(tgt2) |
| tgt = tgt + self.dropout2(tgt2) |
| tgt = self.norm2(tgt) |
| return tgt |
|
|
| def forward_pre(self, tgt, memory, |
| tgt_mask: Optional[Tensor] = None, |
| memory_mask: Optional[Tensor] = None, |
| tgt_key_padding_mask: Optional[Tensor] = None, |
| memory_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| query_pos: Optional[Tensor] = None): |
| tgt2 = self.norm1(tgt) |
| q = k = self.with_pos_embed(tgt2, query_pos) |
| tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
| key_padding_mask=tgt_key_padding_mask)[0] |
| tgt = tgt + self.dropout1(tgt2) |
| tgt2 = self.norm2(tgt) |
| tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
| key=self.with_pos_embed(memory, pos), |
| value=memory, attn_mask=memory_mask, |
| key_padding_mask=memory_key_padding_mask)[0] |
| tgt = tgt + self.dropout2(tgt2) |
| tgt2 = self.norm3(tgt) |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
| tgt = tgt + self.dropout3(tgt2) |
| return tgt |
|
|
| def forward(self, tgt, memory, |
| tgt_mask: Optional[Tensor] = None, |
| memory_mask: Optional[Tensor] = None, |
| tgt_key_padding_mask: Optional[Tensor] = None, |
| memory_key_padding_mask: Optional[Tensor] = None, |
| pos: Optional[Tensor] = None, |
| query_pos: Optional[Tensor] = None): |
| if self.normalize_before: |
| return self.forward_pre(tgt, memory, tgt_mask, memory_mask, |
| tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
| return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
| tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
|
| def _get_clones(module, N): |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
| def build_transformer(args): |
| return Transformer( |
| d_model=args.hidden_dim, |
| dropout=args.dropout, |
| nhead=args.nheads, |
| dim_feedforward=args.dim_feedforward, |
| num_encoder_layers=args.enc_layers, |
| num_decoder_layers=args.dec_layers, |
| normalize_before=args.pre_norm, |
| return_intermediate_dec=True, |
| activation='prelu', |
| args=args |
| ) |
|
|
| def drop_path(x, drop_prob=0.0, training=False): |
| """ |
| Stochastic Depth per sample. |
| """ |
| if drop_prob == 0.0 or not training: |
| return x |
|
|
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| mask.floor_() |
| x = x.div(keep_prob) * mask |
|
|
| return x |
|
|
| class DropPath(nn.Module): |
| """ |
| Drop paths per sample (when applied in main path of residual blocks). |
| """ |
|
|
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
|
|
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| x = x.permute(1, 0, 2) |
| res = drop_path(x, self.drop_prob, self.training) |
| return res.permute(1, 0, 2) |
|
|
| def _get_activation_fn(activation): |
| """Return an activation function given a string""" |
| if activation == "relu": |
| return F.relu |
| if activation == "gelu": |
| return F.gelu |
| if activation == "glu": |
| return F.glu |
| if activation == "prelu": |
| return nn.PReLU() |
| if activation == "selu": |
| return F.selu |
| raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|