import math import functools import torch from torch import nn import torch.nn.functional as F from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_softmax # ## debug # import sys # from pathlib import Path # # basedir = Path(__file__).resolve().parent.parent.parent # sys.path.append(str(basedir)) # ### from src.model.gvp import GVP, _norm_no_nan, tuple_sum, Dropout, LayerNorm, \ tuple_cat, tuple_index, _rbf, _normalize def tuple_mul(tup, val): if isinstance(val, torch.Tensor): return (tup[0] * val, tup[1] * val.unsqueeze(-1)) return (tup[0] * val, tup[1] * val) class GVPBlock(nn.Module): def __init__(self, in_dims, out_dims, n_layers=1, activations=(F.relu, torch.sigmoid), vector_gate=False, dropout=0.0, skip=False, layernorm=False): super(GVPBlock, self).__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims assert not skip or (self.si == self.so and self.vi == self.vo) self.skip = skip GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate) module_list = [] if n_layers == 1: module_list.append(GVP_(in_dims, out_dims, activations=(None, None))) else: module_list.append(GVP_(in_dims, out_dims)) for i in range(n_layers - 2): module_list.append(GVP_(out_dims, out_dims)) module_list.append(GVP_(out_dims, out_dims, activations=(None, None))) self.layers = nn.Sequential(*module_list) self.norm = LayerNorm(out_dims, learnable_vector_weight=True) if layernorm else None self.dropout = Dropout(dropout) if dropout > 0 else None def forward(self, x): """ :param x: tuple (s, V) of `torch.Tensor` :return: tuple (s, V) of `torch.Tensor` """ dx = self.layers(x) if self.dropout is not None: dx = self.dropout(dx) if self.skip: x = tuple_sum(x, dx) else: x = dx if self.norm is not None: x = self.norm(x) return x class GeometricPNA(nn.Module): def __init__(self, d_in, d_out): """ Map features to global features """ super().__init__() si, vi = d_in so, vo = d_out self.gvp = GVPBlock((4 * si + 3 * vi, vi), d_out) def forward(self, x, batch_mask, batch_size=None): """ x: tuple (s, V) """ s, v = x sm = scatter_mean(s, batch_mask, dim=0, dim_size=batch_size) smi = scatter_min(s, batch_mask, dim=0, dim_size=batch_size)[0] sma = scatter_max(s, batch_mask, dim=0, dim_size=batch_size)[0] sstd = scatter_std(s, batch_mask, dim=0, dim_size=batch_size) vnorm = _norm_no_nan(v) vm = scatter_mean(v, batch_mask, dim=0, dim_size=batch_size) vmi = scatter_min(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] vma = scatter_max(vnorm, batch_mask, dim=0, dim_size=batch_size)[0] vstd = scatter_std(vnorm, batch_mask, dim=0, dim_size=batch_size) z = torch.hstack((sm, smi, sma, sstd, vmi, vma, vstd)) out = self.gvp((z, vm)) return out class TupleLinear(nn.Module): def __init__(self, in_dims, out_dims, bias=True): super().__init__() self.si, self.vi = in_dims self.so, self.vo = out_dims assert self.si and self.so self.ws = nn.Linear(self.si, self.so, bias=bias) self.wv = nn.Linear(self.vi, self.vo, bias=bias) if self.vi and self.vo else None def forward(self, x): if self.vi: s, v = x s = self.ws(s) if self.vo: v = v.transpose(-1, -2) v = self.wv(v) v = v.transpose(-1, -2) else: s = self.ws(x) if self.vo: v = torch.zeros(s.size(0), self.vo, 3, device=s.device) return (s, v) if self.vo else s class GVPTransformerLayer(nn.Module): """ Full graph transformer layer with Geometric Vector Perceptrons. Inspired by - GVP: Jing, Bowen, et al. "Learning from protein structure with geometric vector perceptrons." arXiv preprint arXiv:2009.01411 (2020). - Transformer architecture: Vignac, Clement, et al. "Digress: Discrete denoising diffusion for graph generation." arXiv preprint arXiv:2209.14734 (2022). - Invariant point attention: Jumper, John, et al. "Highly accurate protein structure prediction with AlphaFold." Nature 596.7873 (2021): 583-589. :param node_dims: node embedding dimensions (n_scalar, n_vector) :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) :param global_dims: global feature dimension (n_scalar, n_vector) :param dk: key dimension, (n_scalar, n_vector) :param dv: node value dimension, (n_scalar, n_vector) :param de: edge value dimension, (n_scalar, n_vector) :param db: dimension of edge contribution to attention, int :param attn_heads: number of attention heads, int :param n_feedforward: number of GVPs to use in feedforward function :param drop_rate: drop probability in all dropout layers :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs :param vector_gate: whether to use vector gating. (vector_act will be used as sigma^+ in vector gating if `True`) :param attention: can be used to turn off the attention mechanism """ def __init__(self, node_dims, edge_dims, global_dims, dk, dv, de, db, attn_heads, n_feedforward=1, drop_rate=0.0, activations=(F.relu, torch.sigmoid), vector_gate=False, attention=True): super(GVPTransformerLayer, self).__init__() self.attention = attention dq = dk self.dq = dq self.dk = dk self.dv = dv self.de = de self.db = db self.h = attn_heads self.q = TupleLinear(node_dims, tuple_mul(dq, self.h), bias=False) if self.attention else None self.k = TupleLinear(node_dims, tuple_mul(dk, self.h), bias=False) if self.attention else None self.vx = TupleLinear(node_dims, tuple_mul(dv, self.h), bias=False) self.ve = TupleLinear(edge_dims, tuple_mul(de, self.h), bias=False) self.b = TupleLinear(edge_dims, (db * self.h, 0), bias=False) if self.attention else None m_dim = tuple_sum(tuple_mul(dv, self.h), tuple_mul(de, self.h)) self.msg = GVPBlock(m_dim, m_dim, n_feedforward, activations=activations, vector_gate=vector_gate) m_dim = tuple_sum(m_dim, global_dims) self.x_out = GVPBlock(m_dim, node_dims, n_feedforward, activations=activations, vector_gate=vector_gate) self.x_norm = LayerNorm(node_dims, learnable_vector_weight=True) self.x_dropout = Dropout(drop_rate) e_dim = tuple_sum(tuple_mul(node_dims, 2), edge_dims, global_dims) if self.attention: e_dim = (e_dim[0] + 3 * attn_heads, e_dim[1]) self.e_out = GVPBlock(e_dim, edge_dims, n_feedforward, activations=activations, vector_gate=vector_gate) self.e_norm = LayerNorm(edge_dims, learnable_vector_weight=True) self.e_dropout = Dropout(drop_rate) self.pna_x = GeometricPNA(node_dims, node_dims) self.pna_e = GeometricPNA(edge_dims, edge_dims) self.y = GVP(global_dims, global_dims, activations=(None, None), vector_gate=vector_gate) _dim = tuple_sum(node_dims, edge_dims, global_dims) self.y_out = GVPBlock(_dim, global_dims, n_feedforward, activations=activations, vector_gate=vector_gate) self.y_norm = LayerNorm(global_dims, learnable_vector_weight=True) self.y_dropout = Dropout(drop_rate) def forward(self, x, edge_index, batch_mask, edge_attr, global_attr=None, node_mask=None): """ :param x: tuple (s, V) of `torch.Tensor` :param edge_index: array of shape [2, n_edges] :param batch_mask: array indicating different graphs :param edge_attr: tuple (s, V) of `torch.Tensor` :param global_attr: tuple (s, V) of `torch.Tensor` :param node_mask: array of type `bool` to index into the first dim of node embeddings (s, V). If not `None`, only these nodes will be updated. """ row, col = edge_index n = len(x[0]) batch_size = len(torch.unique(batch_mask)) # Compute attention if self.attention: Q = self.q(x) K = self.k(x) b = self.b(edge_attr) qs, qv = Q # (n, dq * h), (n, dq * h, 3) ks, kv = K # (n, dq * h), (n, dq * h, 3) attn_s = (qs[row] * ks[col]).reshape(len(row), self.h, self.dq[0]).sum(dim=-1) # (m, h) # NOTE: attn_v is the Frobenius inner product between vector-valued queries and keys of size [dq, 3] # (generalizes the dot-product between queries and keys similar to Pocket2Mol) # TODO: double-check if this is correctly implemented! attn_v = (qv[row] * kv[col]).reshape(len(row), self.h, self.dq[1], 3).sum(dim=(-2, -1)) # (m, h) attn_e = b.reshape(b.size(0), self.h, self.db).sum(dim=-1) # (m, h) attn = attn_s / math.sqrt(3 * self.dk[0]) + \ attn_v / math.sqrt(9 * self.dk[1]) + \ attn_e / math.sqrt(3 * self.db) attn = scatter_softmax(attn, row, dim=0) # (m, h) attn = attn.unsqueeze(-1) # (m, h, 1) # Compute new features Vx = self.vx(x) Ve = self.ve(edge_attr) mx = (Vx[0].reshape(Vx[0].size(0), self.h, self.dv[0]), # (n, h, dv) Vx[1].reshape(Vx[1].size(0), self.h, self.dv[1], 3)) # (n, h, dv, 3) me = (Ve[0].reshape(Ve[0].size(0), self.h, self.de[0]), Ve[1].reshape(Ve[1].size(0), self.h, self.de[1], 3)) mx = tuple_index(mx, col) if self.attention: mx = tuple_mul(mx, attn) me = tuple_mul(me, attn) _m = tuple_cat(mx, me) _m = (_m[0].flatten(1), _m[1].flatten(1, 2)) m = self.msg(_m) # (m, h * dv), (m, h * dv, 3) m = (scatter_mean(m[0], row, dim=0, dim_size=n), # (n, h * dv) scatter_mean(m[1], row, dim=0, dim_size=n)) # (n, h * dv, 3) if global_attr is not None: m = tuple_cat(m, tuple_index(global_attr, batch_mask)) X_out = self.x_norm(tuple_sum(x, self.x_dropout(self.x_out(m)))) _e = tuple_cat(tuple_index(x, row), tuple_index(x, col), edge_attr) if self.attention: _e = (torch.cat([_e[0], attn_s, attn_v, attn_e], dim=-1), _e[1]) if global_attr is not None: _e = tuple_cat(_e, tuple_index(global_attr, batch_mask[row])) E_out = self.e_norm(tuple_sum(edge_attr, self.e_dropout(self.e_out(_e)))) _y = tuple_cat(self.pna_x(x, batch_mask, batch_size), self.pna_e(edge_attr, batch_mask[row], batch_size)) if global_attr is not None: _y = tuple_cat(_y, self.y(global_attr)) y_out = self.y_norm(tuple_sum(global_attr, self.y_dropout(self.y_out(_y)))) else: y_out = self.y_norm(self.y_dropout(self.y_out(_y))) if node_mask is not None: X_out[0][~node_mask], X_out[1][~node_mask] = tuple_index(x, ~node_mask) return X_out, E_out, y_out class GVPTransformerModel(torch.nn.Module): """ GVP-Transformer model :param node_in_dim: node dimension in input graph, scalars or tuple (scalars, vectors) :param node_h_dim: node dimensions to use in GVP-GNN layers, tuple (s, V) :param node_out_nf: node dimensions in output graph, tuple (s, V) :param edge_in_nf: edge dimension in input graph (scalars) :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers, tuple (s, V) :param edge_out_nf: edge dimensions in output graph, tuple (s, V) :param num_layers: number of GVP-GNN layers :param drop_rate: rate to use in all dropout layers :param reflection_equiv: bool, use reflection-sensitive feature based on the cross product if False :param d_max: :param num_rbf: :param vector_gate: use vector gates in all GVPs :param attention: can be used to turn off the attention mechanism """ def __init__(self, node_in_dim, node_h_dim, node_out_nf, edge_in_nf, edge_h_dim, edge_out_nf, num_layers, dk, dv, de, db, dy, attn_heads, n_feedforward, drop_rate, reflection_equiv=True, d_max=20.0, num_rbf=16, vector_gate=False, attention=True): super(GVPTransformerModel, self).__init__() self.reflection_equiv = reflection_equiv self.d_max = d_max self.num_rbf = num_rbf # node_in_dim = (node_in_dim, 1) if not isinstance(node_in_dim, tuple): node_in_dim = (node_in_dim, 0) edge_in_dim = (edge_in_nf + 2 * node_in_dim[0] + self.num_rbf, 1) if not self.reflection_equiv: edge_in_dim = (edge_in_dim[0], edge_in_dim[1] + 1) self.W_v = GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate) self.W_e = GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate) # self.W_v = nn.Sequential( # LayerNorm(node_in_dim, learnable_vector_weight=True), # GVP(node_in_dim, node_h_dim, activations=(None, None)), # ) # self.W_e = nn.Sequential( # LayerNorm(edge_in_dim, learnable_vector_weight=True), # GVP(edge_in_dim, edge_h_dim, activations=(None, None)), # ) self.dy = dy self.layers = nn.ModuleList( GVPTransformerLayer(node_h_dim, edge_h_dim, dy, dk, dv, de, db, attn_heads, n_feedforward=n_feedforward, drop_rate=drop_rate, vector_gate=vector_gate, activations=(F.relu, None), attention=attention) for _ in range(num_layers)) self.W_v_out = GVP(node_h_dim, (node_out_nf, 1), activations=(None, None), vector_gate=vector_gate) self.W_e_out = GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None), vector_gate=vector_gate) # self.W_v_out = nn.Sequential( # LayerNorm(node_h_dim, learnable_vector_weight=True), # GVP(node_h_dim, (node_out_nf, 1), activations=(None, None)), # ) # self.W_e_out = nn.Sequential( # LayerNorm(edge_h_dim, learnable_vector_weight=True), # GVP(edge_h_dim, (edge_out_nf, 0), activations=(None, None)) # ) def edge_features(self, h, x, edge_index, batch_mask=None, edge_attr=None): """ :param h: :param x: :param edge_index: :param batch_mask: :param edge_attr: :return: scalar and vector-valued edge features """ row, col = edge_index coord_diff = x[row] - x[col] dist = coord_diff.norm(dim=-1) rbf = _rbf(dist, D_max=self.d_max, D_count=self.num_rbf, device=x.device) edge_s = torch.cat([h[row], h[col], rbf], dim=1) edge_v = _normalize(coord_diff).unsqueeze(-2) if edge_attr is not None: edge_s = torch.cat([edge_s, edge_attr], dim=1) if not self.reflection_equiv: mean = scatter_mean(x, batch_mask, dim=0, dim_size=batch_mask.max() + 1) row, col = edge_index cross = torch.cross(x[row] - mean[batch_mask[row]], x[col] - mean[batch_mask[col]], dim=1) cross = _normalize(cross).unsqueeze(-2) edge_v = torch.cat([edge_v, cross], dim=-2) return torch.nan_to_num(edge_s), torch.nan_to_num(edge_v) def forward(self, h, x, edge_index, v=None, batch_mask=None, edge_attr=None): bs = len(batch_mask.unique()) # h_v = (h, x.unsqueeze(-2)) h_v = h if v is None else (h, v) h_e = self.edge_features(h, x, edge_index, batch_mask, edge_attr) h_v = self.W_v(h_v) h_e = self.W_e(h_e) h_y = (torch.zeros(bs, self.dy[0], device=h.device), torch.zeros(bs, self.dy[1], 3, device=h.device)) for layer in self.layers: h_v, h_e, h_y = layer(h_v, edge_index, batch_mask, h_e, h_y) # h, x = self.W_v_out(h_v) # x = x.squeeze(-2) h, vel = self.W_v_out(h_v) # x = x + vel.squeeze(-2) edge_attr = self.W_e_out(h_e) # return h, x, edge_attr return h, vel.squeeze(-2), edge_attr if __name__ == "__main__": from src.model.gvp import randn from scipy.spatial.transform import Rotation def test_equivariance(model, nodes, edges, glob_feat): random = torch.as_tensor(Rotation.random().as_matrix(), dtype=torch.float32, device=device) with torch.no_grad(): X_out, E_out, y_out = model(nodes, edges, glob_feat) n_v_rot, e_v_rot, y_v_rot = nodes[1] @ random, edges[1] @ random, glob_feat[1] @ random X_out_v_rot = X_out[1] @ random E_out_v_rot = E_out[1] @ random y_out_v_rot = y_out[1] @ random X_out_prime, E_out_prime, y_out_prime = model((nodes[0], n_v_rot), (edges[0], e_v_rot), (glob_feat[0], y_v_rot)) assert torch.allclose(X_out[0], X_out_prime[0], atol=1e-5, rtol=1e-4) assert torch.allclose(X_out_v_rot, X_out_prime[1], atol=1e-5, rtol=1e-4) assert torch.allclose(E_out[0], E_out_prime[0], atol=1e-5, rtol=1e-4) assert torch.allclose(E_out_v_rot, E_out_prime[1], atol=1e-5, rtol=1e-4) assert torch.allclose(y_out[0], y_out_prime[0], atol=1e-5, rtol=1e-4) assert torch.allclose(y_out_v_rot, y_out_prime[1], atol=1e-5, rtol=1e-4) print("SUCCESS") n_nodes = 300 n_edges = 10000 batch_size = 6 node_dim = (16, 8) edge_dim = (8, 4) global_dim = (4, 2) dk = (6, 3) dv = (7, 4) de = (5, 2) db = 10 attn_heads = 9 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') nodes = randn(n_nodes, node_dim, device=device) edges = randn(n_edges, edge_dim, device=device) glob_feat = randn(batch_size, global_dim, device=device) edge_index = torch.randint(0, n_nodes, (2, n_edges), device=device) batch_idx = torch.randint(0, batch_size, (n_nodes,), device=device) model = GVPTransformerLayer(node_dim, edge_dim, global_dim, dk, dv, de, db, attn_heads, n_feedforward = 2, drop_rate = 0.1).to(device).eval() model_fn = lambda h_V, h_E, h_y: model(h_V, edge_index, batch_idx, h_E, h_y) test_equivariance(model_fn, nodes, edges, glob_feat)