Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| # some predefined parameters | |
| elem_list = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', | |
| 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', | |
| 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'W', 'Ru', 'Nb', 'Re', 'Te', 'Rh', 'Tc', 'Ba', 'Bi', 'Hf', 'Mo', 'U', | |
| 'Sm', 'Os', 'Ir', 'Ce', 'Gd', 'Ga', 'Cs', 'unknown'] | |
| atom_fdim = len(elem_list) + 6 + 6 + 6 + 1 | |
| bond_fdim = 6 | |
| max_nb = 6 | |
| class MONN(nn.Module): | |
| # init_A, init_B, init_W = loading_emb(measure) | |
| # net = Net(init_A, init_B, init_W, params) | |
| def __init__(self, init_atom_features, init_bond_features, init_word_features, params): | |
| super().__init__() | |
| self.init_atom_features = init_atom_features | |
| self.init_bond_features = init_bond_features | |
| self.init_word_features = init_word_features | |
| """hyper part""" | |
| GNN_depth, inner_CNN_depth, DMA_depth, k_head, kernel_size, hidden_size1, hidden_size2 = params | |
| self.GNN_depth = GNN_depth | |
| self.inner_CNN_depth = inner_CNN_depth | |
| self.DMA_depth = DMA_depth | |
| self.k_head = k_head | |
| self.kernel_size = kernel_size | |
| self.hidden_size1 = hidden_size1 | |
| self.hidden_size2 = hidden_size2 | |
| """GraphConv Module""" | |
| self.vertex_embedding = nn.Linear(atom_fdim, | |
| self.hidden_size1) # first transform vertex features into hidden representations | |
| # GWM parameters | |
| self.W_a_main = nn.ModuleList( | |
| [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
| range(self.GNN_depth)]) | |
| self.W_a_super = nn.ModuleList( | |
| [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
| range(self.GNN_depth)]) | |
| self.W_main = nn.ModuleList( | |
| [nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.k_head)]) for i in | |
| range(self.GNN_depth)]) | |
| self.W_bmm = nn.ModuleList( | |
| [nn.ModuleList([nn.Linear(self.hidden_size1, 1) for i in range(self.k_head)]) for i in | |
| range(self.GNN_depth)]) | |
| self.W_super = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_main_to_super = nn.ModuleList( | |
| [nn.Linear(self.hidden_size1 * self.k_head, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_super_to_main = nn.ModuleList( | |
| [nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_zm1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_zm2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_zs1 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.W_zs2 = nn.ModuleList([nn.Linear(self.hidden_size1, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| self.GRU_main = nn.GRUCell(self.hidden_size1, self.hidden_size1) | |
| self.GRU_super = nn.GRUCell(self.hidden_size1, self.hidden_size1) | |
| # WLN parameters | |
| self.label_U2 = nn.ModuleList([nn.Linear(self.hidden_size1 + bond_fdim, self.hidden_size1) for i in | |
| range(self.GNN_depth)]) # assume no edge feature transformation | |
| self.label_U1 = nn.ModuleList( | |
| [nn.Linear(self.hidden_size1 * 2, self.hidden_size1) for i in range(self.GNN_depth)]) | |
| """CNN-RNN Module""" | |
| # CNN parameters | |
| self.embed_seq = nn.Embedding(len(self.init_word_features), 20, padding_idx=0) | |
| self.embed_seq.weight = nn.Parameter(self.init_word_features) | |
| self.embed_seq.weight.requires_grad = False | |
| self.conv_first = nn.Conv1d(20, self.hidden_size1, kernel_size=self.kernel_size, | |
| padding=(self.kernel_size - 1) / 2) | |
| self.conv_last = nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, | |
| padding=(self.kernel_size - 1) / 2) | |
| self.plain_CNN = nn.ModuleList([]) | |
| for i in range(self.inner_CNN_depth): | |
| self.plain_CNN.append(nn.Conv1d(self.hidden_size1, self.hidden_size1, kernel_size=self.kernel_size, | |
| padding=(self.kernel_size - 1) / 2)) | |
| """Affinity Prediction Module""" | |
| self.super_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
| self.c_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
| self.p_final = nn.Linear(self.hidden_size1, self.hidden_size2) | |
| # DMA parameters | |
| self.mc0 = nn.Linear(hidden_size2, hidden_size2) | |
| self.mp0 = nn.Linear(hidden_size2, hidden_size2) | |
| self.mc1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.mp1 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.hc0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.hp0 = nn.ModuleList([nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.hc1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) | |
| self.hp1 = nn.ModuleList([nn.Linear(self.hidden_size2, 1) for i in range(self.DMA_depth)]) | |
| self.c_to_p_transform = nn.ModuleList( | |
| [nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.p_to_c_transform = nn.ModuleList( | |
| [nn.Linear(self.hidden_size2, self.hidden_size2) for i in range(self.DMA_depth)]) | |
| self.GRU_dma = nn.GRUCell(self.hidden_size2, self.hidden_size2) | |
| # Output layer | |
| self.W_out = nn.Linear(self.hidden_size2 * self.hidden_size2 * 2, 1) | |
| """Pairwise Interaction Prediction Module""" | |
| self.pairwise_compound = nn.Linear(self.hidden_size1, self.hidden_size1) | |
| self.pairwise_protein = nn.Linear(self.hidden_size1, self.hidden_size1) | |
| def mask_softmax(self, a, mask, dim=-1): | |
| a_max = torch.max(a, dim, keepdim=True)[0] | |
| a_exp = torch.exp(a - a_max) | |
| a_exp = a_exp * mask | |
| a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) | |
| return a_softmax | |
| def graph_conv_module(self, batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask): | |
| n_vertex = vertex_mask.size(1) | |
| # initial features | |
| vertex_initial = torch.index_select(self.init_atom_features, 0, vertex.view(-1)) | |
| vertex_initial = vertex_initial.view(batch_size, -1, atom_fdim) | |
| edge_initial = torch.index_select(self.init_bond_features, 0, edge.view(-1)) | |
| edge_initial = edge_initial.view(batch_size, -1, bond_fdim) | |
| vertex_feature = F.leaky_relu(self.vertex_embedding(vertex_initial), 0.1) | |
| super_feature = torch.sum(vertex_feature * vertex_mask.view(batch_size, -1, 1), dim=1, keepdim=True) | |
| for GWM_iter in range(self.GNN_depth): | |
| # prepare main node features | |
| for k in range(self.k_head): | |
| a_main = torch.tanh(self.W_a_main[GWM_iter][k](vertex_feature)) | |
| a_super = torch.tanh(self.W_a_super[GWM_iter][k](super_feature)) | |
| a = self.W_bmm[GWM_iter][k](a_main * super_feature) | |
| attn = self.mask_softmax(a.view(batch_size, -1), vertex_mask).view(batch_size, -1, 1) | |
| k_main_to_super = torch.bmm(attn.transpose(1, 2), self.W_main[GWM_iter][k](vertex_feature)) | |
| if k == 0: | |
| m_main_to_super = k_main_to_super | |
| else: | |
| m_main_to_super = torch.cat([m_main_to_super, k_main_to_super], dim=-1) # concat k-head | |
| main_to_super = torch.tanh(self.W_main_to_super[GWM_iter](m_main_to_super)) | |
| main_self = self.wln_unit(batch_size, vertex_mask, vertex_feature, edge_initial, atom_adj, bond_adj, | |
| nbs_mask, GWM_iter) | |
| super_to_main = torch.tanh(self.W_super_to_main[GWM_iter](super_feature)) | |
| super_self = torch.tanh(self.W_super[GWM_iter](super_feature)) | |
| # warp gate and GRU for update main node features, use main_self and super_to_main | |
| z_main = torch.sigmoid(self.W_zm1[GWM_iter](main_self) + self.W_zm2[GWM_iter](super_to_main)) | |
| hidden_main = (1 - z_main) * main_self + z_main * super_to_main | |
| vertex_feature = self.GRU_main(hidden_main.view(-1, self.hidden_size1), | |
| vertex_feature.view(-1, self.hidden_size1)) | |
| vertex_feature = vertex_feature.view(batch_size, n_vertex, self.hidden_size1) | |
| # warp gate and GRU for update super node features | |
| z_supper = torch.sigmoid(self.W_zs1[GWM_iter](super_self) + self.W_zs2[GWM_iter](main_to_super)) | |
| hidden_super = (1 - z_supper) * super_self + z_supper * main_to_super | |
| super_feature = self.GRU_super(hidden_super.view(batch_size, self.hidden_size1), | |
| super_feature.view(batch_size, self.hidden_size1)) | |
| super_feature = super_feature.view(batch_size, 1, self.hidden_size1) | |
| return vertex_feature, super_feature | |
| def wln_unit(self, batch_size, vertex_mask, vertex_features, edge_initial, atom_adj, bond_adj, nbs_mask, GNN_iter): | |
| n_vertex = vertex_mask.size(1) | |
| n_nbs = nbs_mask.size(2) | |
| vertex_mask = vertex_mask.view(batch_size, n_vertex, 1) | |
| nbs_mask = nbs_mask.view(batch_size, n_vertex, n_nbs, 1) | |
| vertex_nei = torch.index_select(vertex_features.view(-1, self.hidden_size1), 0, atom_adj).view(batch_size, | |
| n_vertex, n_nbs, | |
| self.hidden_size1) | |
| edge_nei = torch.index_select(edge_initial.view(-1, bond_fdim), 0, bond_adj).view(batch_size, n_vertex, n_nbs, | |
| bond_fdim) | |
| # Weisfeiler Lehman relabelling | |
| l_nei = torch.cat((vertex_nei, edge_nei), -1) | |
| nei_label = F.leaky_relu(self.label_U2[GNN_iter](l_nei), 0.1) | |
| nei_label = torch.sum(nei_label * nbs_mask, dim=-2) | |
| new_label = torch.cat((vertex_features, nei_label), 2) | |
| new_label = self.label_U1[GNN_iter](new_label) | |
| vertex_features = F.leaky_relu(new_label, 0.1) | |
| return vertex_features | |
| def cnn_module(self, sequence): | |
| ebd = self.embed_seq(sequence) | |
| ebd = ebd.transpose(1, 2) | |
| x = F.leaky_relu(self.conv_first(ebd), 0.1) | |
| for i in range(self.inner_CNN_depth): | |
| x = self.plain_CNN[i](x) | |
| x = F.leaky_relu(x, 0.1) | |
| x = F.leaky_relu(self.conv_last(x), 0.1) | |
| H = x.transpose(1, 2) | |
| # H, hidden = self.rnn(H) | |
| return H | |
| def pairwise_pred_module(self, batch_size, comp_feature, prot_feature, vertex_mask, seq_mask): | |
| pairwise_c_feature = F.leaky_relu(self.pairwise_compound(comp_feature), 0.1) | |
| pairwise_p_feature = F.leaky_relu(self.pairwise_protein(prot_feature), 0.1) | |
| pairwise_pred = torch.matmul(pairwise_c_feature, pairwise_p_feature.transpose(1, 2)) | |
| # TODO: difference between the pairwise_mask here and in the data? | |
| pairwise_mask = torch.matmul(vertex_mask.view(batch_size, -1, 1), seq_mask.view(batch_size, 1, -1)) | |
| pairwise_pred = pairwise_pred * pairwise_mask | |
| return pairwise_pred | |
| def affinity_pred_module(self, batch_size, comp_feature, prot_feature, super_feature, vertex_mask, seq_mask, | |
| pairwise_pred): | |
| comp_feature = F.leaky_relu(self.c_final(comp_feature), 0.1) | |
| prot_feature = F.leaky_relu(self.p_final(prot_feature), 0.1) | |
| super_feature = F.leaky_relu(self.super_final(super_feature.view(batch_size, -1)), 0.1) | |
| cf, pf = self.dma_gru(batch_size, comp_feature, vertex_mask, prot_feature, seq_mask, pairwise_pred) | |
| cf = torch.cat([cf.view(batch_size, -1), super_feature.view(batch_size, -1)], dim=1) | |
| kroneck = F.leaky_relu( | |
| torch.matmul(cf.view(batch_size, -1, 1), pf.view(batch_size, 1, -1)).view(batch_size, -1), 0.1) | |
| affinity_pred = self.W_out(kroneck) | |
| return affinity_pred | |
| def dma_gru(self, batch_size, comp_feats, vertex_mask, prot_feats, seq_mask, pairwise_pred): | |
| vertex_mask = vertex_mask.view(batch_size, -1, 1) | |
| seq_mask = seq_mask.view(batch_size, -1, 1) | |
| cf = torch.Tensor() | |
| pf = torch.Tensor() | |
| c0 = torch.sum(comp_feats * vertex_mask, dim=1) / torch.sum(vertex_mask, dim=1) | |
| p0 = torch.sum(prot_feats * seq_mask, dim=1) / torch.sum(seq_mask, dim=1) | |
| m = c0 * p0 | |
| for DMA_iter in range(self.DMA_depth): | |
| c_to_p = torch.matmul(pairwise_pred.transpose(1, 2), | |
| F.tanh(self.c_to_p_transform[DMA_iter](comp_feats))) # batch * n_residue * hidden | |
| p_to_c = torch.matmul(pairwise_pred, | |
| F.tanh(self.p_to_c_transform[DMA_iter](prot_feats))) # batch * n_vertex * hidden | |
| c_tmp = F.tanh(self.hc0[DMA_iter](comp_feats)) * F.tanh(self.mc1[DMA_iter](m)).view(batch_size, 1, | |
| -1) * p_to_c | |
| p_tmp = F.tanh(self.hp0[DMA_iter](prot_feats)) * F.tanh(self.mp1[DMA_iter](m)).view(batch_size, 1, | |
| -1) * c_to_p | |
| c_att = self.mask_softmax(self.hc1[DMA_iter](c_tmp).view(batch_size, -1), vertex_mask.view(batch_size, -1)) | |
| p_att = self.mask_softmax(self.hp1[DMA_iter](p_tmp).view(batch_size, -1), seq_mask.view(batch_size, -1)) | |
| cf = torch.sum(comp_feats * c_att.view(batch_size, -1, 1), dim=1) | |
| pf = torch.sum(prot_feats * p_att.view(batch_size, -1, 1), dim=1) | |
| m = self.GRU_dma(m, cf * pf) | |
| return cf, pf | |
| def forward(self, enc_drug, enc_protein): | |
| vertex_mask, vertex, edge, atom_adj, bond_adj, nbs_mask = enc_drug | |
| vertex, vertex_mask = vertex | |
| edge, _ = edge | |
| atom_adj, _ = atom_adj | |
| bond_adj, _ = bond_adj | |
| nbs_mask, _ = enc_drug | |
| seq_mask, sequence = enc_protein | |
| batch_size = vertex.size(0) | |
| atom_feature, super_feature = self.graph_conv_module(batch_size, vertex_mask, vertex, edge, atom_adj, bond_adj, | |
| nbs_mask) | |
| prot_feature = self.cnn_module(sequence) | |
| pairwise_pred = self.pairwise_pred_module(batch_size, atom_feature, prot_feature, vertex_mask, seq_mask) | |
| affinity_pred = self.affinity_pred_module(batch_size, atom_feature, prot_feature, super_feature, vertex_mask, | |
| seq_mask, pairwise_pred) | |
| return affinity_pred # , pairwise_pred | |