|
|
import copy |
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
|
|
|
|
|
|
"""Transformer class. |
|
|
|
|
|
Copy-paste from torch.nn.Transformer with modifications: |
|
|
* positional encodings are passed in MHattention |
|
|
* extra LN at the end of encoder is removed |
|
|
* decoder returns a stack of activations from all decoding layers |
|
|
""" |
|
|
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d): |
|
|
"""A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and |
|
|
more features.""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
"""Extra keyword arguments supported in addition to those in |
|
|
`torch.nn.Conv2d`: |
|
|
|
|
|
Args: |
|
|
norm (nn.Module, optional): a normalization layer |
|
|
activation (callable(Tensor) -> Tensor): a callable |
|
|
activation function |
|
|
|
|
|
It assumes that norm layer is used before activation. |
|
|
""" |
|
|
norm = kwargs.pop('norm', None) |
|
|
activation = kwargs.pop('activation', None) |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
self.norm = norm |
|
|
self.activation = activation |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, |
|
|
self.dilation, self.groups) |
|
|
if self.norm is not None: |
|
|
x = self.norm(x) |
|
|
if self.activation is not None: |
|
|
x = self.activation(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
|
"""This is a more standard version of the position embedding, very similar |
|
|
to the one used by the Attention is all you need paper, generalized to work |
|
|
on images.""" |
|
|
|
|
|
def __init__(self, |
|
|
num_pos_feats=64, |
|
|
temperature=10000, |
|
|
normalize=False, |
|
|
scale=None): |
|
|
super().__init__() |
|
|
self.num_pos_feats = num_pos_feats |
|
|
self.temperature = temperature |
|
|
self.normalize = normalize |
|
|
if scale is not None and normalize is False: |
|
|
raise ValueError('normalize should be True if scale is passed') |
|
|
if scale is None: |
|
|
scale = 2 * math.pi |
|
|
self.scale = scale |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
if mask is None: |
|
|
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), |
|
|
device=x.device, |
|
|
dtype=torch.bool) |
|
|
not_mask = ~mask |
|
|
y_embed = not_mask.cumsum(1, dtype=x.dtype) |
|
|
x_embed = not_mask.cumsum(2, dtype=x.dtype) |
|
|
if self.normalize: |
|
|
eps = 1e-6 |
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
|
|
dim_t = torch.arange( |
|
|
self.num_pos_feats, dtype=x.dtype, device=x.device) |
|
|
dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
|
pos_x = torch.stack( |
|
|
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), |
|
|
dim=4).flatten(3) |
|
|
pos_y = torch.stack( |
|
|
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), |
|
|
dim=4).flatten(3) |
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
|
return pos |
|
|
|
|
|
def __repr__(self, _repr_indent=4): |
|
|
head = 'Positional encoding ' + self.__class__.__name__ |
|
|
body = [ |
|
|
'num_pos_feats: {}'.format(self.num_pos_feats), |
|
|
'temperature: {}'.format(self.temperature), |
|
|
'normalize: {}'.format(self.normalize), |
|
|
'scale: {}'.format(self.scale), |
|
|
] |
|
|
|
|
|
lines = [head] + [' ' * _repr_indent + line for line in body] |
|
|
return '\n'.join(lines) |
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
|
super().__init__() |
|
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
|
self.num_layers = num_layers |
|
|
self.norm = norm |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src, |
|
|
mask: Optional[Tensor] = None, |
|
|
src_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
): |
|
|
output = src |
|
|
|
|
|
for layer in self.layers: |
|
|
output = layer( |
|
|
output, |
|
|
src_mask=mask, |
|
|
src_key_padding_mask=src_key_padding_mask, |
|
|
pos=pos) |
|
|
|
|
|
if self.norm is not None: |
|
|
output = self.norm(output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model, |
|
|
nhead, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.1, |
|
|
activation='relu', |
|
|
normalize_before=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
|
self.norm2 = nn.LayerNorm(d_model) |
|
|
self.dropout1 = nn.Dropout(dropout) |
|
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post( |
|
|
self, |
|
|
src, |
|
|
src_mask: Optional[Tensor] = None, |
|
|
src_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
): |
|
|
q = k = self.with_pos_embed(src, pos) |
|
|
|
|
|
src2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=src, |
|
|
attn_mask=src_mask, |
|
|
key_padding_mask=src_key_padding_mask)[0] |
|
|
src = src + self.dropout1(src2) |
|
|
src = self.norm1(src) |
|
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
|
src = src + self.dropout2(src2) |
|
|
src = self.norm2(src) |
|
|
return src |
|
|
|
|
|
def forward_pre( |
|
|
self, |
|
|
src, |
|
|
src_mask: Optional[Tensor] = None, |
|
|
src_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
): |
|
|
src2 = self.norm1(src) |
|
|
q = k = self.with_pos_embed(src2, pos) |
|
|
src2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=src2, |
|
|
attn_mask=src_mask, |
|
|
key_padding_mask=src_key_padding_mask)[0] |
|
|
src = src + self.dropout1(src2) |
|
|
src2 = self.norm2(src) |
|
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
|
|
src = src + self.dropout2(src2) |
|
|
return src |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src, |
|
|
src_mask: Optional[Tensor] = None, |
|
|
src_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
|
|
return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
|
|
|
|
class SelfAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
d_model, |
|
|
nhead, |
|
|
dropout=0.0, |
|
|
activation='relu', |
|
|
normalize_before=False): |
|
|
super().__init__() |
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, |
|
|
tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
|
tgt2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=tgt, |
|
|
attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, |
|
|
tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
|
tgt2 = self.self_attn( |
|
|
q, |
|
|
k, |
|
|
value=tgt2, |
|
|
attn_mask=tgt_mask, |
|
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt |
|
|
|
|
|
def forward(self, |
|
|
tgt, |
|
|
tgt_mask: Optional[Tensor] = None, |
|
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, |
|
|
query_pos) |
|
|
return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, |
|
|
query_pos) |
|
|
|
|
|
|
|
|
class CrossAttentionLayer(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
d_model, |
|
|
nhead, |
|
|
dropout=0.0, |
|
|
activation='relu', |
|
|
normalize_before=False): |
|
|
super().__init__() |
|
|
self.multihead_attn = nn.MultiheadAttention( |
|
|
d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, |
|
|
tgt, |
|
|
memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2, avg_attn = self.multihead_attn( |
|
|
query=self.with_pos_embed(tgt, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, |
|
|
attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
return tgt, avg_attn |
|
|
|
|
|
def forward_pre(self, |
|
|
tgt, |
|
|
memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2, avg_attn = self.multihead_attn( |
|
|
query=self.with_pos_embed(tgt2, query_pos), |
|
|
key=self.with_pos_embed(memory, pos), |
|
|
value=memory, |
|
|
attn_mask=memory_mask, |
|
|
key_padding_mask=memory_key_padding_mask) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
|
|
|
return tgt, avg_attn |
|
|
|
|
|
def forward(self, |
|
|
tgt, |
|
|
memory, |
|
|
memory_mask: Optional[Tensor] = None, |
|
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
|
pos: Optional[Tensor] = None, |
|
|
query_pos: Optional[Tensor] = None): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
return self.forward_post(tgt, memory, memory_mask, |
|
|
memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
|
|
|
class FFNLayer(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
d_model, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.0, |
|
|
activation='relu', |
|
|
normalize_before=False): |
|
|
super().__init__() |
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.activation = _get_activation_fn(activation) |
|
|
self.normalize_before = normalize_before |
|
|
|
|
|
self._reset_parameters() |
|
|
|
|
|
def _reset_parameters(self): |
|
|
for p in self.parameters(): |
|
|
if p.dim() > 1: |
|
|
nn.init.xavier_uniform_(p) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward_post(self, tgt): |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
tgt = self.norm(tgt) |
|
|
return tgt |
|
|
|
|
|
def forward_pre(self, tgt): |
|
|
tgt2 = self.norm(tgt) |
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
|
tgt = tgt + self.dropout(tgt2) |
|
|
return tgt |
|
|
|
|
|
def forward(self, tgt): |
|
|
if self.normalize_before: |
|
|
return self.forward_pre(tgt) |
|
|
return self.forward_post(tgt) |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
"""Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
h = [hidden_dim] * (num_layers - 1) |
|
|
self.layers = nn.ModuleList( |
|
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
|
|
def forward(self, x): |
|
|
for i, layer in enumerate(self.layers): |
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_norm(norm, out_channels): |
|
|
""" |
|
|
Args: |
|
|
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; |
|
|
or a callable that takes a channel number and returns |
|
|
the normalization layer as a nn.Module. |
|
|
|
|
|
Returns: |
|
|
nn.Module or None: the normalization layer |
|
|
""" |
|
|
if norm is None: |
|
|
return None |
|
|
if isinstance(norm, str): |
|
|
if len(norm) == 0: |
|
|
return None |
|
|
norm = { |
|
|
'BN': nn.BatchNorm2d, |
|
|
'GN': lambda channels: nn.GroupNorm(32, channels), |
|
|
}[norm] |
|
|
return norm(out_channels) |
|
|
|
|
|
|
|
|
def _get_clones(module, N): |
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
|
|
|
def _get_activation_fn(activation): |
|
|
"""Return an activation function given a string.""" |
|
|
if activation == 'relu': |
|
|
return F.relu |
|
|
if activation == 'gelu': |
|
|
return F.gelu |
|
|
if activation == 'glu': |
|
|
return F.glu |
|
|
raise RuntimeError(f'activation should be relu/gelu, not {activation}.') |
|
|
|