Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,519 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import torch
import torch.nn as nn
from torch.distributions import Categorical
from torch_scatter import scatter_mean
from src.modules.gvp_module import GVP, GVPConvLayer, LayerNorm, tuple_index
class GVP_Model(nn.Module):
'''
GVP-GNN for structure-conditioned autoregressive
protein design as described in manuscript.
Takes in protein structure graphs of type `torch_geometric.data.Data`
or `torch_geometric.data.Batch` and returns a categorical distribution
over 20 amino acids at each position in a `torch.Tensor` of
shape [n_nodes, 20].
Should be used with `gvp.data.ProteinGraphDataset`, or with generators
of `torch_geometric.data.Batch` objects with the same attributes.
The standard forward pass requires sequence information as input
and should be used for training or evaluating likelihood.
For sampling or design, use `self.sample`.
:param node_in_dim: node dimensions in input graph, should be
(6, 3) if using original features
:param node_h_dim: node dimensions to use in GVP-GNN layers
:param node_in_dim: edge dimensions in input graph, should be
(32, 1) if using original features
:param edge_h_dim: edge dimensions to embed to before use
in GVP-GNN layers
:param num_layers: number of GVP-GNN layers in each of the encoder
and decoder modules
:param drop_rate: rate to use in all dropout layers
'''
def __init__(self, args, drop_rate=0.1):
super(GVP_Model, self).__init__()
self.args = args
self.node_in_dim = (6, 3)
self.node_h_dim = (100, 16)
self.edge_in_dim = (32, 1)
self.edge_h_dim = (32, 1)
self.num_layers = 3
self.W_v = nn.Sequential(
GVP(self.node_in_dim, self.node_h_dim, activations=(None, None)),
LayerNorm(self.node_h_dim)
)
self.W_e = nn.Sequential(
GVP(self.edge_in_dim, self.edge_h_dim, activations=(None, None)),
LayerNorm(self.edge_h_dim)
)
self.encoder_layers = nn.ModuleList(
GVPConvLayer(self.node_h_dim, self.edge_h_dim, drop_rate=drop_rate)
for _ in range(self.num_layers))
self.W_s = nn.Embedding(33, 20)
self.edge_h_dim = (self.edge_h_dim[0] + 20, self.edge_h_dim[1])
self.decoder_layers = nn.ModuleList(
GVPConvLayer(self.node_h_dim, self.edge_h_dim,
drop_rate=drop_rate, autoregressive=True)
for _ in range(self.num_layers))
self.W_out = GVP(self.node_h_dim, (33, 0), activations=(None, None))
self.encode_t = 0
self.decode_t = 0
def _get_features(self, batch):
return batch
def forward(self, batch):
'''
Forward pass to be used at train-time, or evaluating likelihood.
:param h_V: tuple (s, V) of node embeddings
:param edge_index: `torch.Tensor` of shape [2, num_edges]
:param h_E: tuple (s, V) of edge embeddings
:param seq: int `torch.Tensor` of shape [num_nodes]
'''
h_V = (batch.node_s, batch.node_v)
h_E = (batch.edge_s, batch.edge_v)
edge_index = batch.edge_index
seq = batch.seq
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.encoder_layers:
h_V = layer(h_V, edge_index, h_E)
encoder_embeddings = h_V
h_S = self.W_s(seq)
h_S = h_S[edge_index[0]]
h_S[edge_index[0] >= edge_index[1]] = 0
h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
for layer in self.decoder_layers:
h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings)
logits = self.W_out(h_V)
log_probs = nn.functional.log_softmax(logits, dim=-1)
return {'log_probs': log_probs, 'logits': logits}
def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1):
'''
Samples sequences auto-regressively from the distribution
learned by the model.
:param h_V: tuple (s, V) of node embeddings
:param edge_index: `torch.Tensor` of shape [2, num_edges]
:param h_E: tuple (s, V) of edge embeddings
:param n_samples: number of samples
:param temperature: temperature to use in softmax
over the categorical distribution
:return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the
residue-to-int mapping of the original training data
'''
with torch.no_grad():
device = edge_index.device
L = h_V[0].shape[0]
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.encoder_layers:
h_V = layer(h_V, edge_index, h_E)
h_V = (h_V[0].repeat(n_samples, 1),
h_V[1].repeat(n_samples, 1, 1))
h_E = (h_E[0].repeat(n_samples, 1),
h_E[1].repeat(n_samples, 1, 1))
edge_index = edge_index.expand(n_samples, -1, -1)
offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1)
edge_index = torch.cat(tuple(edge_index + offset), dim=-1)
seq = torch.zeros(n_samples * L, device=device, dtype=torch.int)
h_S = torch.zeros(n_samples * L, 20, device=device)
h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers]
all_probs = []
for i in range(L):
h_S_ = h_S[edge_index[0]]
h_S_[edge_index[0] >= edge_index[1]] = 0
h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1])
edge_mask = edge_index[1] % L == i
edge_index_ = edge_index[:, edge_mask]
h_E_ = tuple_index(h_E_, edge_mask)
node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool)
node_mask[i::L] = True
for j, layer in enumerate(self.decoder_layers):
out = layer(h_V_cache[j], edge_index_, h_E_,
autoregressive_x=h_V_cache[0], node_mask=node_mask)
out = tuple_index(out, node_mask)
if j < len(self.decoder_layers)-1:
h_V_cache[j+1][0][i::L] = out[0]
h_V_cache[j+1][1][i::L] = out[1]
logits = self.W_out(out)
seq[i::L] = Categorical(logits=logits / temperature).sample()
h_S[i::L] = self.W_s(seq[i::L])
all_probs.append(torch.softmax(logits, dim=-1))
self.probs = torch.cat(all_probs, dim=0)
return seq.view(n_samples, L)
def test_recovery(self, protein):
h_V = (protein.node_s, protein.node_v)
h_E = (protein.edge_s, protein.edge_v)
sample = self.sample(h_V, protein.edge_index, h_E, n_samples=1)
return sample.squeeze(0) |