|
|
from torch import nn |
|
|
from src.nn import SelfAttentionBlock, FFN, DropPath, LayerNorm, \ |
|
|
INDEX_BASED_NORMS |
|
|
|
|
|
|
|
|
__all__ = ['TransformerBlock'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Base block of the Transformer architecture: |
|
|
|
|
|
x ---------------- + ---------------- + --> |
|
|
\ | \ | |
|
|
-- N -- SA -- -- N -- FFN -- |
|
|
|
|
|
Where: |
|
|
- N: Normalization |
|
|
- SA: Self-Attention |
|
|
- FFN: Feed-Forward Network |
|
|
|
|
|
Inspired by: https://github.com/microsoft/Swin-Transformer |
|
|
|
|
|
:param dim: int |
|
|
Dimension of the features space on which the transformer |
|
|
block operates |
|
|
:param num_heads: int |
|
|
Number of attention heads |
|
|
:param qkv_bias: bool |
|
|
Whether the linear layers producing queries, keys, and |
|
|
values should have a bias |
|
|
:param qk_dim: int |
|
|
Dimension of the queries and keys |
|
|
:param qk_scale: str |
|
|
Scaling applied to the query*key product before the softmax. |
|
|
More specifically, one may want to normalize the query-key |
|
|
compatibilities based on the number of dimensions (referred |
|
|
to as 'd' here) as in a vanilla Transformer implementation, |
|
|
or based on the number of neighbors each node has in the |
|
|
attention graph (referred to as 'g' here). If nothing is |
|
|
specified the scaling will be `1 / (sqrt(d) * sqrt(g))`, |
|
|
which is equivalent to passing `'d.g'`. Passing `'d+g'` will |
|
|
yield `1 / (sqrt(d) + sqrt(g))`. Meanwhile, passing 'd' will |
|
|
yield `1 / sqrt(d)`, and passing `'g'` will yield |
|
|
`1 / sqrt(g)` |
|
|
:param in_rpe_dim: int |
|
|
Dimension of the features passed as input for relative |
|
|
positional encoding computation (i.e. edge features) |
|
|
:param ffn_ratio: int |
|
|
Multiplicative factor for computing the dimension of the |
|
|
`FFN` inverted bottleneck: `ffn_ratio * dim` |
|
|
:param attn_drop: float |
|
|
Dropout on the attention weights of the `SelfAttentionBlock` |
|
|
:param residual_drop: float |
|
|
Dropout on the output features of the `SelfAttentionBlock` |
|
|
:param drop_path: float |
|
|
Dropout on the `SelfAttentionBlock` and `FFN` paths. Contrary |
|
|
to other dropout parameters, here we either keep all or none |
|
|
features. This allows training with stochastic depth |
|
|
:param activation: nn.Module |
|
|
Activation function for the `FFN` module |
|
|
:param norm: nn.Module |
|
|
Normalization function for the `FFN` module |
|
|
:param pre_norm: bool |
|
|
Whether the normalization should be applied before or after |
|
|
the `SelfAttentionBlock` and `FFN` in the residual branches |
|
|
:param no_sa: bool |
|
|
Whether a self-attention residual branch should be used at |
|
|
all |
|
|
:param no_ffn: bool |
|
|
Whether a feed-forward residual branch should be used at |
|
|
all |
|
|
:param k_rpe: bool |
|
|
Whether keys should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param q_rpe: bool |
|
|
Whether queries should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param v_rpe: bool |
|
|
Whether values should receive relative positional encodings |
|
|
computed from edge features. See `SelfAttentionBlock` |
|
|
:param k_delta_rpe: bool |
|
|
Whether keys should receive relative positional encodings |
|
|
computed from the difference between source and target node |
|
|
features. See `SelfAttentionBlock` |
|
|
:param q_delta_rpe: bool |
|
|
Whether queries should receive relative positional encodings |
|
|
computed from the difference between source and target node |
|
|
features. See `SelfAttentionBlock` |
|
|
:param qk_share_rpe: bool |
|
|
Whether queries and keys should use the same parameters for |
|
|
building relative positional encodings. See |
|
|
`SelfAttentionBlock` |
|
|
:param q_on_minus_rpe: bool |
|
|
Whether relative positional encodings for queries should be |
|
|
computed on the opposite of features used for keys. This allows, |
|
|
for instance, to break the symmetry when `qk_share_rpe` but we |
|
|
want relative positional encodings to capture different meanings |
|
|
for keys and queries. See `SelfAttentionBlock` |
|
|
:param heads_share_rpe: bool |
|
|
whether attention heads should share the same parameters for |
|
|
building relative positional encodings. See |
|
|
`SelfAttentionBlock` |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads=1, |
|
|
qkv_bias=True, |
|
|
qk_dim=8, |
|
|
qk_scale=None, |
|
|
in_rpe_dim=18, |
|
|
ffn_ratio=4, |
|
|
attn_drop=None, |
|
|
residual_drop=None, |
|
|
drop_path=None, |
|
|
activation=nn.LeakyReLU(), |
|
|
norm=LayerNorm, |
|
|
pre_norm=True, |
|
|
no_sa=False, |
|
|
no_ffn=False, |
|
|
k_rpe=False, |
|
|
q_rpe=False, |
|
|
v_rpe=False, |
|
|
k_delta_rpe=False, |
|
|
q_delta_rpe=False, |
|
|
qk_share_rpe=False, |
|
|
q_on_minus_rpe=False, |
|
|
heads_share_rpe=False): |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.pre_norm = pre_norm |
|
|
|
|
|
|
|
|
self.no_sa = no_sa |
|
|
if not no_sa: |
|
|
self.sa_norm = norm(dim) |
|
|
self.sa = SelfAttentionBlock( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
in_dim=None, |
|
|
out_dim=dim, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_dim=qk_dim, |
|
|
qk_scale=qk_scale, |
|
|
in_rpe_dim=in_rpe_dim, |
|
|
attn_drop=attn_drop, |
|
|
drop=residual_drop, |
|
|
k_rpe=k_rpe, |
|
|
q_rpe=q_rpe, |
|
|
v_rpe=v_rpe, |
|
|
k_delta_rpe=k_delta_rpe, |
|
|
q_delta_rpe=q_delta_rpe, |
|
|
qk_share_rpe=qk_share_rpe, |
|
|
q_on_minus_rpe=q_on_minus_rpe, |
|
|
heads_share_rpe=heads_share_rpe) |
|
|
|
|
|
|
|
|
self.no_ffn = no_ffn |
|
|
if not no_ffn: |
|
|
self.ffn_norm = norm(dim) |
|
|
self.ffn_ratio = ffn_ratio |
|
|
self.ffn = FFN( |
|
|
dim, |
|
|
hidden_dim=int(dim * ffn_ratio), |
|
|
activation=activation, |
|
|
drop=residual_drop) |
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) \ |
|
|
if drop_path is not None and drop_path > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x, norm_index, edge_index=None, edge_attr=None): |
|
|
""" |
|
|
:param x: FloatTensor or shape (N, C) |
|
|
Node features |
|
|
:param norm_index: LongTensor or shape (N) |
|
|
Node indices for the LayerNorm |
|
|
:param edge_index: LongTensor of shape (2, E) |
|
|
Edges in torch_geometric [[sources], [targets]] format for |
|
|
the self-attention module |
|
|
:param edge_attr: FloatTensor or shape (E, F) |
|
|
Edge attributes in torch_geometric format for relative pose |
|
|
encoding in the self-attention module |
|
|
:return: |
|
|
""" |
|
|
assert x.dim() == 2, 'x should be a 2D Tensor' |
|
|
assert x.is_floating_point(), 'x should be a 2D FloatTensor' |
|
|
assert norm_index.dim() == 1 and norm_index.shape[0] == x.shape[0], \ |
|
|
'norm_index should be a 1D LongTensor' |
|
|
assert edge_index is None or \ |
|
|
(edge_index.dim() == 2 and not edge_index.is_floating_point()), \ |
|
|
'edge_index should be a 2D LongTensor' |
|
|
assert edge_attr is None or \ |
|
|
(edge_attr.dim() == 2 and edge_attr.shape[0] == edge_index.shape[1]),\ |
|
|
'edge_attr be a 2D LongTensor' |
|
|
|
|
|
|
|
|
shortcut = x |
|
|
|
|
|
|
|
|
|
|
|
if self.no_sa or edge_index is None or edge_index.shape[1] == 0: |
|
|
pass |
|
|
elif self.pre_norm: |
|
|
x = self._forward_norm(self.sa_norm, x, norm_index) |
|
|
x = self.sa(x, edge_index, edge_attr=edge_attr) |
|
|
x = shortcut + self.drop_path(x) |
|
|
else: |
|
|
x = self.sa(x, edge_index, edge_attr=edge_attr) |
|
|
x = self.drop_path(x) |
|
|
x = self._forward_norm(self.sa_norm, shortcut + x, norm_index) |
|
|
|
|
|
|
|
|
shortcut = x |
|
|
|
|
|
|
|
|
if not self.no_ffn and self.pre_norm: |
|
|
x = self._forward_norm(self.ffn_norm, x, norm_index) |
|
|
x = self.ffn(x) |
|
|
x = shortcut + self.drop_path(x) |
|
|
if not self.no_ffn and not self.pre_norm: |
|
|
x = self.ffn(x) |
|
|
x = self.drop_path(x) |
|
|
x = self._forward_norm(self.ffn_norm, shortcut + x, norm_index) |
|
|
|
|
|
return x, norm_index, edge_index |
|
|
|
|
|
@staticmethod |
|
|
def _forward_norm(norm, x, norm_index): |
|
|
"""Simple helper for the forward pass on norm modules. Some |
|
|
modules require an index, while others don't. |
|
|
""" |
|
|
if isinstance(norm, INDEX_BASED_NORMS): |
|
|
return norm(x, batch=norm_index) |
|
|
return norm(x) |
|
|
|