|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| DETR Transformer class.
|
|
|
| Copy-paste from torch.nn.Transformer with modifications:
|
| * positional encodings are passed in MHattention
|
| * extra LN at the end of encoder is removed
|
| * decoder returns a stack of activations from all decoding layers
|
| """
|
| from typing import Optional
|
|
|
| from torch import Tensor, nn
|
|
|
| from ..util.misc import _get_activation_fn, _get_clones
|
|
|
|
|
| class TransformerEncoder(nn.Module):
|
| def __init__(self, encoder_layer, num_layers, norm=None):
|
| super().__init__()
|
| self.layers = _get_clones(encoder_layer, num_layers)
|
| self.num_layers = num_layers
|
| self.norm = norm
|
|
|
| def forward(
|
| self,
|
| src,
|
| mask: Optional[Tensor] = None,
|
| src_key_padding_mask: Optional[Tensor] = None,
|
| pos: Optional[Tensor] = None,
|
| ):
|
| output = src
|
|
|
| for layer in self.layers:
|
| output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
|
|
|
| if self.norm is not None:
|
| output = self.norm(output)
|
|
|
| return output
|
|
|
|
|
| class TransformerEncoderLayer(nn.Module):
|
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False):
|
| super().__init__()
|
|
|
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0, batch_first=True)
|
|
|
| self.linear1 = nn.Linear(d_model, dim_feedforward)
|
| self.dropout = nn.Dropout(dropout)
|
| self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
| self.norm1 = nn.LayerNorm(d_model)
|
| self.norm2 = nn.LayerNorm(d_model)
|
| self.dropout1 = nn.Dropout(dropout)
|
| self.dropout2 = nn.Dropout(dropout)
|
|
|
| self.activation = _get_activation_fn(activation)
|
| self.normalize_before = normalize_before
|
|
|
| def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| return tensor if pos is None else tensor + pos
|
|
|
| def forward_post(
|
| self,
|
| src,
|
| src_mask: Optional[Tensor] = None,
|
| src_key_padding_mask: Optional[Tensor] = None,
|
| pos: Optional[Tensor] = None,
|
| ):
|
| q = k = self.with_pos_embed(src, pos)
|
| src2 = self.self_attn(
|
| q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=False
|
| )[0]
|
| src = src + self.dropout1(src2)
|
| src = self.norm1(src)
|
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
| src = src + self.dropout2(src2)
|
| src = self.norm2(src)
|
| return src
|
|
|
| def forward_pre(
|
| self,
|
| src,
|
| src_mask: Optional[Tensor] = None,
|
| src_key_padding_mask: Optional[Tensor] = None,
|
| pos: Optional[Tensor] = None,
|
| ):
|
| src2 = self.norm1(src)
|
| q = k = self.with_pos_embed(src2, pos)
|
| src2 = self.self_attn(
|
| q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=False
|
| )[0]
|
| src = src + self.dropout1(src2)
|
| src2 = self.norm2(src)
|
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
| src = src + self.dropout2(src2)
|
| return src
|
|
|
| def forward(
|
| self,
|
| src,
|
| src_mask: Optional[Tensor] = None,
|
| src_key_padding_mask: Optional[Tensor] = None,
|
| pos: Optional[Tensor] = None,
|
| ):
|
| if self.normalize_before:
|
| return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
| return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
|
|