import torch import torch.nn as nn from torch.distributions import Categorical from torch_scatter import scatter_mean from src.modules.gvp_module import GVP, GVPConvLayer, LayerNorm, tuple_index class GVP_Model(nn.Module): ''' GVP-GNN for structure-conditioned autoregressive protein design as described in manuscript. Takes in protein structure graphs of type `torch_geometric.data.Data` or `torch_geometric.data.Batch` and returns a categorical distribution over 20 amino acids at each position in a `torch.Tensor` of shape [n_nodes, 20]. Should be used with `gvp.data.ProteinGraphDataset`, or with generators of `torch_geometric.data.Batch` objects with the same attributes. The standard forward pass requires sequence information as input and should be used for training or evaluating likelihood. For sampling or design, use `self.sample`. :param node_in_dim: node dimensions in input graph, should be (6, 3) if using original features :param node_h_dim: node dimensions to use in GVP-GNN layers :param node_in_dim: edge dimensions in input graph, should be (32, 1) if using original features :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers :param num_layers: number of GVP-GNN layers in each of the encoder and decoder modules :param drop_rate: rate to use in all dropout layers ''' def __init__(self, args, drop_rate=0.1): super(GVP_Model, self).__init__() self.args = args self.node_in_dim = (6, 3) self.node_h_dim = (100, 16) self.edge_in_dim = (32, 1) self.edge_h_dim = (32, 1) self.num_layers = 3 self.W_v = nn.Sequential( GVP(self.node_in_dim, self.node_h_dim, activations=(None, None)), LayerNorm(self.node_h_dim) ) self.W_e = nn.Sequential( GVP(self.edge_in_dim, self.edge_h_dim, activations=(None, None)), LayerNorm(self.edge_h_dim) ) self.encoder_layers = nn.ModuleList( GVPConvLayer(self.node_h_dim, self.edge_h_dim, drop_rate=drop_rate) for _ in range(self.num_layers)) self.W_s = nn.Embedding(33, 20) self.edge_h_dim = (self.edge_h_dim[0] + 20, self.edge_h_dim[1]) self.decoder_layers = nn.ModuleList( GVPConvLayer(self.node_h_dim, self.edge_h_dim, drop_rate=drop_rate, autoregressive=True) for _ in range(self.num_layers)) self.W_out = GVP(self.node_h_dim, (33, 0), activations=(None, None)) self.encode_t = 0 self.decode_t = 0 def _get_features(self, batch): return batch def forward(self, batch): ''' Forward pass to be used at train-time, or evaluating likelihood. :param h_V: tuple (s, V) of node embeddings :param edge_index: `torch.Tensor` of shape [2, num_edges] :param h_E: tuple (s, V) of edge embeddings :param seq: int `torch.Tensor` of shape [num_nodes] ''' h_V = (batch.node_s, batch.node_v) h_E = (batch.edge_s, batch.edge_v) edge_index = batch.edge_index seq = batch.seq h_V = self.W_v(h_V) h_E = self.W_e(h_E) for layer in self.encoder_layers: h_V = layer(h_V, edge_index, h_E) encoder_embeddings = h_V h_S = self.W_s(seq) h_S = h_S[edge_index[0]] h_S[edge_index[0] >= edge_index[1]] = 0 h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1]) for layer in self.decoder_layers: h_V = layer(h_V, edge_index, h_E, autoregressive_x = encoder_embeddings) logits = self.W_out(h_V) log_probs = nn.functional.log_softmax(logits, dim=-1) return {'log_probs': log_probs, 'logits': logits} def sample(self, h_V, edge_index, h_E, n_samples, temperature=0.1): ''' Samples sequences auto-regressively from the distribution learned by the model. :param h_V: tuple (s, V) of node embeddings :param edge_index: `torch.Tensor` of shape [2, num_edges] :param h_E: tuple (s, V) of edge embeddings :param n_samples: number of samples :param temperature: temperature to use in softmax over the categorical distribution :return: int `torch.Tensor` of shape [n_samples, n_nodes] based on the residue-to-int mapping of the original training data ''' with torch.no_grad(): device = edge_index.device L = h_V[0].shape[0] h_V = self.W_v(h_V) h_E = self.W_e(h_E) for layer in self.encoder_layers: h_V = layer(h_V, edge_index, h_E) h_V = (h_V[0].repeat(n_samples, 1), h_V[1].repeat(n_samples, 1, 1)) h_E = (h_E[0].repeat(n_samples, 1), h_E[1].repeat(n_samples, 1, 1)) edge_index = edge_index.expand(n_samples, -1, -1) offset = L * torch.arange(n_samples, device=device).view(-1, 1, 1) edge_index = torch.cat(tuple(edge_index + offset), dim=-1) seq = torch.zeros(n_samples * L, device=device, dtype=torch.int) h_S = torch.zeros(n_samples * L, 20, device=device) h_V_cache = [(h_V[0].clone(), h_V[1].clone()) for _ in self.decoder_layers] all_probs = [] for i in range(L): h_S_ = h_S[edge_index[0]] h_S_[edge_index[0] >= edge_index[1]] = 0 h_E_ = (torch.cat([h_E[0], h_S_], dim=-1), h_E[1]) edge_mask = edge_index[1] % L == i edge_index_ = edge_index[:, edge_mask] h_E_ = tuple_index(h_E_, edge_mask) node_mask = torch.zeros(n_samples * L, device=device, dtype=torch.bool) node_mask[i::L] = True for j, layer in enumerate(self.decoder_layers): out = layer(h_V_cache[j], edge_index_, h_E_, autoregressive_x=h_V_cache[0], node_mask=node_mask) out = tuple_index(out, node_mask) if j < len(self.decoder_layers)-1: h_V_cache[j+1][0][i::L] = out[0] h_V_cache[j+1][1][i::L] = out[1] logits = self.W_out(out) seq[i::L] = Categorical(logits=logits / temperature).sample() h_S[i::L] = self.W_s(seq[i::L]) all_probs.append(torch.softmax(logits, dim=-1)) self.probs = torch.cat(all_probs, dim=0) return seq.view(n_samples, L) def test_recovery(self, protein): h_V = (protein.node_s, protein.node_v) h_E = (protein.edge_s, protein.edge_v) sample = self.sample(h_V, protein.edge_index, h_E, n_samples=1) return sample.squeeze(0)