Spaces:
Sleeping
Sleeping
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| class BACPI(nn.Module): | |
| def __init__( | |
| self, | |
| n_atom, | |
| n_amino, | |
| comp_dim, | |
| prot_dim, | |
| gat_dim, | |
| num_head, | |
| dropout, | |
| alpha, | |
| window, | |
| layer_cnn, | |
| latent_dim, | |
| ): | |
| super().__init__() | |
| self.embedding_layer_atom = nn.Embedding(n_atom + 1, comp_dim) | |
| self.embedding_layer_amino = nn.Embedding(n_amino + 1, prot_dim) | |
| self.dropout = dropout | |
| self.alpha = alpha | |
| self.layer_cnn = layer_cnn | |
| self.gat_layers = [GATLayer(comp_dim, gat_dim, dropout=dropout, alpha=alpha, concat=True) | |
| for _ in range(num_head)] | |
| for i, layer in enumerate(self.gat_layers): | |
| self.add_module('gat_layer_{}'.format(i), layer) | |
| self.gat_out = GATLayer(gat_dim * num_head, comp_dim, dropout=dropout, alpha=alpha, concat=False) | |
| self.W_comp = nn.Linear(comp_dim, latent_dim) | |
| self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2 * window + 1, | |
| stride=1, padding=window) for _ in range(layer_cnn)]) | |
| self.W_prot = nn.Linear(prot_dim, latent_dim) | |
| self.fp0 = nn.Parameter(torch.empty(size=(1024, latent_dim))) | |
| nn.init.xavier_uniform_(self.fp0, gain=1.414) | |
| self.fp1 = nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) | |
| nn.init.xavier_uniform_(self.fp1, gain=1.414) | |
| self.bidat_num = 4 | |
| self.U = nn.ParameterList([ | |
| nn.Parameter(torch.empty(size=(latent_dim, latent_dim))) for _ in range(self.bidat_num) | |
| ]) | |
| for i in range(self.bidat_num): | |
| nn.init.xavier_uniform_(self.U[i], gain=1.414) | |
| self.transform_c2p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
| self.transform_p2c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
| self.bihidden_c = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
| self.bihidden_p = nn.ModuleList([nn.Linear(latent_dim, latent_dim) for _ in range(self.bidat_num)]) | |
| self.biatt_c = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) | |
| self.biatt_p = nn.ModuleList([nn.Linear(latent_dim * 2, 1) for _ in range(self.bidat_num)]) | |
| self.comb_c = nn.Linear(latent_dim * self.bidat_num, latent_dim) | |
| self.comb_p = nn.Linear(latent_dim * self.bidat_num, latent_dim) | |
| def comp_gat(self, atoms, atoms_mask, adj): | |
| atoms_vector = self.embedding_layer_atom(atoms) | |
| atoms_multi_head = torch.cat([gat(atoms_vector, adj) for gat in self.gat_layers], dim=2) | |
| atoms_vector = F.elu(self.gat_out(atoms_multi_head, adj)) | |
| atoms_vector = F.leaky_relu(self.W_comp(atoms_vector), self.alpha) | |
| return atoms_vector | |
| def prot_cnn(self, amino, amino_mask): | |
| amino_vector = self.embedding_layer_amino(amino) | |
| amino_vector = torch.unsqueeze(amino_vector, 1) | |
| for i in range(self.layer_cnn): | |
| amino_vector = F.leaky_relu(self.conv_layers[i](amino_vector), self.alpha) | |
| amino_vector = torch.squeeze(amino_vector, 1) | |
| amino_vector = F.leaky_relu(self.W_prot(amino_vector), self.alpha) | |
| return amino_vector | |
| def mask_softmax(self, a, mask, dim=-1): | |
| a_max = torch.max(a, dim, keepdim=True)[0] | |
| a_exp = torch.exp(a - a_max) | |
| a_exp = a_exp * mask | |
| a_softmax = a_exp / (torch.sum(a_exp, dim, keepdim=True) + 1e-6) | |
| return a_softmax | |
| def bidirectional_attention_prediction(self, atoms_vector, atoms_mask, fps, amino_vector, amino_mask): | |
| b = atoms_vector.shape[0] | |
| for i in range(self.bidat_num): | |
| A = torch.tanh(torch.matmul(torch.matmul(atoms_vector, self.U[i]), amino_vector.transpose(1, 2))) | |
| A = A * torch.matmul(atoms_mask.view(b, -1, 1).float(), amino_mask.view(b, 1, -1).float()) | |
| atoms_trans = torch.matmul(A, torch.tanh(self.transform_p2c[i](amino_vector))) | |
| amino_trans = torch.matmul(A.transpose(1, 2), torch.tanh(self.transform_c2p[i](atoms_vector))) | |
| atoms_tmp = torch.cat([torch.tanh(self.bihidden_c[i](atoms_vector)), atoms_trans], dim=2) | |
| amino_tmp = torch.cat([torch.tanh(self.bihidden_p[i](amino_vector)), amino_trans], dim=2) | |
| atoms_att = self.mask_softmax(self.biatt_c[i](atoms_tmp).view(b, -1), atoms_mask.view(b, -1).float()) | |
| amino_att = self.mask_softmax(self.biatt_p[i](amino_tmp).view(b, -1), amino_mask.view(b, -1).float()) | |
| cf = torch.sum(atoms_vector * atoms_att.view(b, -1, 1), dim=1) | |
| pf = torch.sum(amino_vector * amino_att.view(b, -1, 1), dim=1) | |
| if i == 0: | |
| cat_cf = cf | |
| cat_pf = pf | |
| else: | |
| cat_cf = torch.cat([cat_cf.view(b, -1), cf.view(b, -1)], dim=1) | |
| cat_pf = torch.cat([cat_pf.view(b, -1), pf.view(b, -1)], dim=1) | |
| cf_final = torch.cat([self.comb_c(cat_cf).view(b, -1), fps.view(b, -1)], dim=1) | |
| pf_final = self.comb_p(cat_pf) | |
| cf_pf = F.leaky_relu( | |
| torch.matmul( | |
| cf_final.view(b, -1, 1), pf_final.view(b, 1, -1) | |
| ).view(b, -1), 0.1 | |
| ) | |
| return cf_pf | |
| def forward(self, compound, protein): | |
| atom, adj, fp = compound | |
| atom, atom_lengths = atom | |
| adj, _ = adj | |
| fp, _ = fp | |
| amino, amino_lengths = protein | |
| atom_mask = torch.arange(atom.size(1), device=atom.device) >= atom_lengths.unsqueeze(1) | |
| amino_mask = torch.arange(amino.size(1), device=amino.device) >= amino_lengths.unsqueeze(1) | |
| atoms_vector = self.comp_gat(atom, atom_mask, adj) | |
| amino_vector = self.prot_cnn(amino, amino_mask) | |
| super_feature = F.leaky_relu(torch.matmul(fp.float(), self.fp0), 0.1) | |
| super_feature = F.leaky_relu(torch.matmul(super_feature, self.fp1), 0.1) | |
| prediction = self.bidirectional_attention_prediction( | |
| atoms_vector, atom_mask, super_feature, amino_vector, amino_mask) | |
| return prediction | |
| class GATLayer(nn.Module): | |
| def __init__(self, in_features, out_features, dropout=0.5, alpha=0.2, concat=True): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.dropout = dropout | |
| self.alpha = alpha | |
| self.concat = concat | |
| self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) | |
| nn.init.xavier_uniform_(self.W.data, gain=1.414) | |
| self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) | |
| nn.init.xavier_uniform_(self.a.data, gain=1.414) | |
| def forward(self, h, adj): | |
| Wh = torch.matmul(h, self.W) | |
| a_input = self._prepare_attentional_mechanism_input(Wh) | |
| e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(3), self.alpha) | |
| zero_vec = -9e15 * torch.ones_like(e) | |
| attention = torch.where(adj > 0, e, zero_vec) | |
| attention = F.softmax(attention, dim=2) | |
| # attention = F.dropout(attention, self.dropout, training=self.training) | |
| h_prime = torch.bmm(attention, Wh) | |
| return F.elu(h_prime) if self.concat else h_prime | |
| def _prepare_attentional_mechanism_input(self, Wh): | |
| b = Wh.size()[0] | |
| N = Wh.size()[1] | |
| Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=1) | |
| Wh_repeated_alternating = Wh.repeat_interleave(N, dim=0).view(b, N * N, self.out_features) | |
| all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=2) | |
| return all_combinations_matrix.view(b, N, N, 2 * self.out_features) | |
| atom_dict = defaultdict(lambda: len(atom_dict)) | |
| bond_dict = defaultdict(lambda: len(bond_dict)) | |
| fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) | |
| edge_dict = defaultdict(lambda: len(edge_dict)) | |
| word_dict = defaultdict(lambda: len(word_dict)) | |
| def create_atoms(mol): | |
| atoms = [a.GetSymbol() for a in mol.GetAtoms()] | |
| for a in mol.GetAromaticAtoms(): | |
| i = a.GetIdx() | |
| atoms[i] = (atoms[i], 'aromatic') | |
| atoms = [atom_dict[a] for a in atoms] | |
| return np.array(atoms) | |
| def create_ijbonddict(mol): | |
| i_jbond_dict = defaultdict(lambda: []) | |
| for b in mol.GetBonds(): | |
| i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() | |
| bond = bond_dict[str(b.GetBondType())] | |
| i_jbond_dict[i].append((j, bond)) | |
| i_jbond_dict[j].append((i, bond)) | |
| atoms_set = set(range(mol.GetNumAtoms())) | |
| isolate_atoms = atoms_set - set(i_jbond_dict.keys()) | |
| bond = bond_dict['nan'] | |
| for a in isolate_atoms: | |
| i_jbond_dict[a].append((a, bond)) | |
| return i_jbond_dict | |
| def atom_features(atoms, i_jbond_dict, radius): | |
| if (len(atoms) == 1) or (radius == 0): | |
| fingerprints = [fingerprint_dict[a] for a in atoms] | |
| else: | |
| nodes = atoms | |
| i_jedge_dict = i_jbond_dict | |
| for _ in range(radius): | |
| fingerprints = [] | |
| for i, j_edge in i_jedge_dict.items(): | |
| neighbors = [(nodes[j], edge) for j, edge in j_edge] | |
| fingerprint = (nodes[i], tuple(sorted(neighbors))) | |
| fingerprints.append(fingerprint_dict[fingerprint]) | |
| nodes = fingerprints | |
| _i_jedge_dict = defaultdict(lambda: []) | |
| for i, j_edge in i_jedge_dict.items(): | |
| for j, edge in j_edge: | |
| both_side = tuple(sorted((nodes[i], nodes[j]))) | |
| edge = edge_dict[(both_side, edge)] | |
| _i_jedge_dict[i].append((j, edge)) | |
| i_jedge_dict = _i_jedge_dict | |
| return np.array(fingerprints) | |
| def create_adjacency(mol): | |
| adjacency = Chem.GetAdjacencyMatrix(mol) | |
| adjacency = np.array(adjacency) | |
| adjacency += np.eye(adjacency.shape[0], dtype=int) | |
| return adjacency | |
| def get_fingerprints(mol): | |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024, useChirality=True) | |
| return np.array(fp) | |
| def split_sequence(sequence, ngram=3): | |
| sequence = '-' + sequence + '=' | |
| words = [word_dict[sequence[i:i + ngram]] | |
| for i in range(len(sequence) - ngram + 1)] | |
| return np.array(words) | |
| def drug_featurizer(smiles, radius=2): | |
| from deepscreen.utils import get_logger | |
| log = get_logger(__name__) | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if not mol: | |
| return None | |
| mol = Chem.AddHs(mol) | |
| atoms = create_atoms(mol) | |
| i_jbond_dict = create_ijbonddict(mol) | |
| compound = atom_features(atoms, i_jbond_dict, radius) | |
| adjacency = create_adjacency(mol) | |
| fp = get_fingerprints(mol) | |
| return compound, adjacency, fp | |
| except Exception as e: | |
| log.warning(f"Failed to featurize SMILES ({smiles}) to graph due to {str(e)}") | |
| return None | |