Spaces:
Sleeping
Sleeping
| import sys | |
| sys.path.append("..") | |
| import copy | |
| import os | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from copy import deepcopy | |
| from torch_geometric.transforms import Compose | |
| from torch_geometric.nn.pool import knn_graph | |
| from torch_geometric.utils.subgraph import subgraph | |
| from torch_geometric.utils.num_nodes import maybe_num_nodes | |
| from torch_geometric.data import Data, Batch | |
| from torch_scatter import scatter_add | |
| from rdkit import Chem | |
| from rdkit.Chem import Descriptors | |
| from rdkit.Chem import AllChem | |
| from .data import ProteinLigandData | |
| from .protein_ligand import ATOM_FAMILIES | |
| from .chemutils import enumerate_assemble, list_filter, rand_rotate | |
| # allowable node and edge features | |
| allowable_features = { | |
| 'possible_atomic_num_list': list(range(1, 119)), | |
| 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], | |
| 'possible_chirality_list': [ | |
| Chem.rdchem.ChiralType.CHI_UNSPECIFIED, | |
| Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, | |
| Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, | |
| Chem.rdchem.ChiralType.CHI_OTHER | |
| ], | |
| 'possible_hybridization_list': [ | |
| Chem.rdchem.HybridizationType.S, | |
| Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, | |
| Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, | |
| Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED | |
| ], | |
| 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8], | |
| 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], | |
| 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], | |
| 'possible_bonds': [ | |
| Chem.rdchem.BondType.SINGLE, | |
| Chem.rdchem.BondType.DOUBLE, | |
| Chem.rdchem.BondType.TRIPLE, | |
| Chem.rdchem.BondType.AROMATIC | |
| ], | |
| 'possible_bond_dirs': [ # only for double bond stereo information | |
| Chem.rdchem.BondDir.NONE, | |
| Chem.rdchem.BondDir.ENDUPRIGHT, | |
| Chem.rdchem.BondDir.ENDDOWNRIGHT | |
| ] | |
| } | |
| def mol_to_graph_data_obj_simple(mol): | |
| """ | |
| Converts rdkit mol object to graph Data object required by the pytorch | |
| geometric package. NB: Uses simplified atom and bond features, and represent | |
| as indices | |
| :param mol: rdkit mol object | |
| :return: graph data object with the attributes: x, edge_index, edge_attr | |
| """ | |
| # atoms | |
| num_atom_features = 2 # atom type, chirality tag | |
| atom_features_list = [] | |
| for atom in mol.GetAtoms(): | |
| atom_feature = [allowable_features['possible_atomic_num_list'].index( | |
| atom.GetAtomicNum())] + [allowable_features[ | |
| 'possible_chirality_list'].index(atom.GetChiralTag())] | |
| atom_features_list.append(atom_feature) | |
| x = torch.tensor(np.array(atom_features_list), dtype=torch.long) | |
| # bonds | |
| num_bond_features = 2 # bond type, bond direction | |
| if len(mol.GetBonds()) > 0: # mol has bonds | |
| edges_list = [] | |
| edge_features_list = [] | |
| for bond in mol.GetBonds(): | |
| i = bond.GetBeginAtomIdx() | |
| j = bond.GetEndAtomIdx() | |
| edge_feature = [allowable_features['possible_bonds'].index( | |
| bond.GetBondType())] + [allowable_features[ | |
| 'possible_bond_dirs'].index( | |
| bond.GetBondDir())] | |
| edges_list.append((i, j)) | |
| edge_features_list.append(edge_feature) | |
| edges_list.append((j, i)) | |
| edge_features_list.append(edge_feature) | |
| # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] | |
| edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) | |
| # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] | |
| edge_attr = torch.tensor(np.array(edge_features_list), | |
| dtype=torch.long) | |
| else: # mol has no bonds | |
| edge_index = torch.empty((2, 0), dtype=torch.long) | |
| edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) | |
| data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
| return data | |
| class RefineData(object): | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, data): | |
| # delete H atom of pocket | |
| protein_element = data.protein_element | |
| is_H_protein = (protein_element == 1) | |
| if torch.sum(is_H_protein) > 0: | |
| not_H_protein = ~is_H_protein | |
| data.protein_atom_name = list(compress(data.protein_atom_name, not_H_protein)) | |
| data.protein_atom_to_aa_type = data.protein_atom_to_aa_type[not_H_protein] | |
| data.protein_element = data.protein_element[not_H_protein] | |
| data.protein_is_backbone = data.protein_is_backbone[not_H_protein] | |
| data.protein_pos = data.protein_pos[not_H_protein] | |
| # delete H atom of ligand | |
| ligand_element = data.ligand_element | |
| is_H_ligand = (ligand_element == 1) | |
| if torch.sum(is_H_ligand) > 0: | |
| not_H_ligand = ~is_H_ligand | |
| data.ligand_atom_feature = data.ligand_atom_feature[not_H_ligand] | |
| data.ligand_element = data.ligand_element[not_H_ligand] | |
| data.ligand_pos = data.ligand_pos[not_H_ligand] | |
| # nbh | |
| index_atom_H = torch.nonzero(is_H_ligand)[:, 0] | |
| index_changer = -np.ones(len(not_H_ligand), dtype=np.int64) | |
| index_changer[not_H_ligand] = np.arange(torch.sum(not_H_ligand)) | |
| new_nbh_list = [value for ind_this, value in zip(not_H_ligand, data.ligand_nbh_list.values()) if ind_this] | |
| data.ligand_nbh_list = {i: [index_changer[node] for node in neigh if node not in index_atom_H] for i, neigh | |
| in enumerate(new_nbh_list)} | |
| # bond | |
| ind_bond_with_H = np.array([(bond_i in index_atom_H) | (bond_j in index_atom_H) for bond_i, bond_j in | |
| zip(*data.ligand_bond_index)]) | |
| ind_bond_without_H = ~ind_bond_with_H | |
| old_ligand_bond_index = data.ligand_bond_index[:, ind_bond_without_H] | |
| data.ligand_bond_index = torch.tensor(index_changer)[old_ligand_bond_index] | |
| data.ligand_bond_type = data.ligand_bond_type[ind_bond_without_H] | |
| return data | |
| class FocalBuilder(object): | |
| def __init__(self, close_threshold=0.8, max_bond_length=2.4): | |
| self.close_threshold = close_threshold | |
| self.max_bond_length = max_bond_length | |
| super().__init__() | |
| def __call__(self, data: ProteinLigandData): | |
| # ligand_context_pos = data.ligand_context_pos | |
| # ligand_pos = data.ligand_pos | |
| ligand_masked_pos = data.ligand_masked_pos | |
| protein_pos = data.protein_pos | |
| context_idx = data.context_idx | |
| masked_idx = data.masked_idx | |
| old_bond_index = data.ligand_bond_index | |
| # old_bond_types = data.ligand_bond_type # type: 0, 1, 2 | |
| has_unmask_atoms = context_idx.nelement() > 0 | |
| if has_unmask_atoms: | |
| # # get bridge bond index (mask-context bond) | |
| ind_edge_index_candidate = [ | |
| (context_node in context_idx) and (mask_node in masked_idx) | |
| for mask_node, context_node in zip(*old_bond_index) | |
| ] # the mask-context order is right | |
| bridge_bond_index = old_bond_index[:, ind_edge_index_candidate] | |
| # candidate_bond_types = old_bond_types[idx_edge_index_candidate] | |
| idx_generated_in_whole_ligand = bridge_bond_index[0] | |
| idx_focal_in_whole_ligand = bridge_bond_index[1] | |
| index_changer_masked = torch.zeros(masked_idx.max() + 1, dtype=torch.int64) | |
| index_changer_masked[masked_idx] = torch.arange(len(masked_idx)) | |
| idx_generated_in_ligand_masked = index_changer_masked[idx_generated_in_whole_ligand] | |
| pos_generate = ligand_masked_pos[idx_generated_in_ligand_masked] | |
| data.idx_generated_in_ligand_masked = idx_generated_in_ligand_masked | |
| data.pos_generate = pos_generate | |
| index_changer_context = torch.zeros(context_idx.max() + 1, dtype=torch.int64) | |
| index_changer_context[context_idx] = torch.arange(len(context_idx)) | |
| idx_focal_in_ligand_context = index_changer_context[idx_focal_in_whole_ligand] | |
| idx_focal_in_compose = idx_focal_in_ligand_context # if ligand_context was not before protein in the compose, this was not correct | |
| data.idx_focal_in_compose = idx_focal_in_compose | |
| data.idx_protein_all_mask = torch.empty(0, dtype=torch.long) # no use if has context | |
| data.y_protein_frontier = torch.empty(0, dtype=torch.bool) # no use if has context | |
| else: # # the initial atom. surface atoms between ligand and protein | |
| assign_index = radius(x=ligand_masked_pos, y=protein_pos, r=4., num_workers=16) | |
| if assign_index.size(1) == 0: | |
| dist = torch.norm(data.protein_pos.unsqueeze(1) - data.ligand_masked_pos.unsqueeze(0), p=2, dim=-1) | |
| assign_index = torch.nonzero(dist <= torch.min(dist) + 1e-5)[0:1].transpose(0, 1) | |
| idx_focal_in_protein = assign_index[0] | |
| data.idx_focal_in_compose = idx_focal_in_protein # no ligand context, so all composes are protein atoms | |
| data.pos_generate = ligand_masked_pos[assign_index[1]] | |
| data.idx_generated_in_ligand_masked = torch.unique(assign_index[1]) # for real of the contractive transform | |
| data.idx_protein_all_mask = data.idx_protein_in_compose # for input of initial frontier prediction | |
| y_protein_frontier = torch.zeros_like(data.idx_protein_all_mask, | |
| dtype=torch.bool) # for label of initial frontier prediction | |
| y_protein_frontier[torch.unique(idx_focal_in_protein)] = True | |
| data.y_protein_frontier = y_protein_frontier | |
| # generate not positions: around pos_focal ( with `max_bond_length` distance) but not close to true generated within `close_threshold` | |
| # pos_focal = ligand_context_pos[idx_focal_in_ligand_context] | |
| # pos_notgenerate = pos_focal + torch.randn_like(pos_focal) * self.max_bond_length / 2.4 | |
| # dist = torch.norm(pos_generate - pos_notgenerate, p=2, dim=-1) | |
| # ind_close = (dist < self.close_threshold) | |
| # while ind_close.any(): | |
| # new_pos_notgenerate = pos_focal[ind_close] + torch.randn_like(pos_focal[ind_close]) * self.max_bond_length / 2.3 | |
| # dist[ind_close] = torch.norm(pos_generate[ind_close] - new_pos_notgenerate, p=2, dim=-1) | |
| # pos_notgenerate[ind_close] = new_pos_notgenerate | |
| # ind_close = (dist < self.close_threshold) | |
| # data.pos_notgenerate = pos_notgenerate | |
| return data | |
| class AtomComposer(object): | |
| def __init__(self, protein_dim, ligand_dim, knn): | |
| super().__init__() | |
| self.protein_dim = protein_dim | |
| self.ligand_dim = ligand_dim | |
| self.knn = knn # knn of compose atoms | |
| def __call__(self, data: ProteinLigandData): | |
| # fetch ligand context and protein from data | |
| ligand_context_pos = data['ligand_context_pos'] | |
| ligand_context_feature_full = data['ligand_context_feature_full'] | |
| protein_pos = data['protein_pos'] | |
| protein_atom_feature = data['protein_atom_feature'] | |
| len_ligand_ctx = len(ligand_context_pos) | |
| len_protein = len(protein_pos) | |
| # compose ligand context and protein. save idx of them in compose | |
| data['compose_pos'] = torch.cat([ligand_context_pos, protein_pos], dim=0) | |
| len_compose = len_ligand_ctx + len_protein | |
| ligand_context_feature_full_expand = torch.cat([ | |
| ligand_context_feature_full, | |
| torch.zeros([len_ligand_ctx, self.protein_dim - self.ligand_dim], dtype=torch.long) | |
| ], dim=1) | |
| data['compose_feature'] = torch.cat([ligand_context_feature_full_expand, protein_atom_feature], dim=0) | |
| data['idx_ligand_ctx_in_compose'] = torch.arange(len_ligand_ctx, dtype=torch.long) # can be delete | |
| data['idx_protein_in_compose'] = torch.arange(len_protein, dtype=torch.long) + len_ligand_ctx # can be delete | |
| # build knn graph and bond type | |
| data = self.get_knn_graph(data, self.knn, len_ligand_ctx, len_compose, num_workers=16) | |
| return data | |
| def get_knn_graph(data: ProteinLigandData, knn, len_ligand_ctx, len_compose, num_workers=1, ): | |
| data['compose_knn_edge_index'] = knn_graph(data['compose_pos'], knn, flow='target_to_source', num_workers=num_workers) | |
| id_compose_edge = data['compose_knn_edge_index'][0, | |
| :len_ligand_ctx * knn] * len_compose + data['compose_knn_edge_index'][1, :len_ligand_ctx * knn] | |
| id_ligand_ctx_edge = data['ligand_context_bond_index'][0] * len_compose + data['ligand_context_bond_index'][1] | |
| idx_edge = [torch.nonzero(id_compose_edge == id_) for id_ in id_ligand_ctx_edge] | |
| idx_edge = torch.tensor([a.squeeze() if len(a) > 0 else torch.tensor(-1) for a in idx_edge], dtype=torch.long) | |
| data['compose_knn_edge_type'] = torch.zeros(len(data['compose_knn_edge_index'][0]), | |
| dtype=torch.long) # for encoder edge embedding | |
| data['compose_knn_edge_type'][idx_edge[idx_edge >= 0]] = data['ligand_context_bond_type'][idx_edge >= 0] | |
| data['compose_knn_edge_feature'] = torch.cat([ | |
| torch.ones([len(data['compose_knn_edge_index'][0]), 1], dtype=torch.long), | |
| torch.zeros([len(data['compose_knn_edge_index'][0]), 3], dtype=torch.long), | |
| ], dim=-1) | |
| data['compose_knn_edge_feature'][idx_edge[idx_edge >= 0]] = F.one_hot(data['ligand_context_bond_type'][idx_edge >= 0], | |
| num_classes=4) # 0 (1,2,3)-onehot | |
| return data | |
| class FeaturizeProteinAtom(object): | |
| def __init__(self): | |
| super().__init__() | |
| self.atomic_numbers = torch.LongTensor([6, 7, 8, 16, 34]) # H, C, N, O, S, Se | |
| self.atom_types = torch.arange(38) | |
| self.max_num_aa = 21 | |
| def feature_dim(self): | |
| return 38 | |
| def __call__(self, data: ProteinLigandData): | |
| atom_type = data['protein_atom_name'].view(-1, 1) == self.atom_types.view(1, -1) | |
| data['protein_atom_feature'] = atom_type.float() | |
| return data | |
| class FeaturizeLigandAtom(object): | |
| def __init__(self): | |
| super().__init__() | |
| # self.atomic_numbers = torch.LongTensor([1,6,7,8,9,15,16,17]) # H C N O F P S Cl | |
| self.atomic_numbers = torch.LongTensor([6, 7, 8, 9, 15, 16, 17]) # C N O F P S Cl | |
| def num_properties(self): | |
| return len(ATOM_FAMILIES) | |
| def feature_dim(self): | |
| return self.atomic_numbers.size(0) + len(ATOM_FAMILIES) | |
| def __call__(self, data: ProteinLigandData): | |
| element = data['ligand_element'].view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements) | |
| x = torch.cat([element, data['ligand_atom_feature']], dim=-1) | |
| data['ligand_atom_feature'] = x.float() | |
| return data | |
| class FeaturizeLigandBond(object): | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, data: ProteinLigandData): | |
| data['ligand_bond_feature'] = F.one_hot(data['ligand_bond_type'] - 1, num_classes=3) # (1,2,3) to (0,1,2)-onehot | |
| neighbor_dict = {} | |
| # used in rotation angle prediction | |
| mol = data['moltree'].mol | |
| for i, atom in enumerate(mol.GetAtoms()): | |
| neighbor_dict[i] = [n.GetIdx() for n in atom.GetNeighbors()] | |
| data['ligand_neighbors'] = neighbor_dict | |
| return data | |
| class LigandCountNeighbors(object): | |
| def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None): | |
| assert symmetry == True, 'Only support symmetrical edges.' | |
| if num_nodes is None: | |
| num_nodes = maybe_num_nodes(edge_index) | |
| if valence is None: | |
| valence = torch.ones([edge_index.size(1)], device=edge_index.device) | |
| valence = valence.view(edge_index.size(1)) | |
| return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long() | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, data): | |
| data['ligand_num_neighbors'] = self.count_neighbors( | |
| data['ligand_bond_index'], | |
| symmetry=True, | |
| num_nodes=data['ligand_element'].size(0), | |
| ) | |
| data['ligand_atom_valence'] = self.count_neighbors( | |
| data['ligand_bond_index'], | |
| symmetry=True, | |
| valence=data['ligand_bond_type'], | |
| num_nodes=data['ligand_element'].size(0), | |
| ) | |
| return data | |
| def kabsch(A, B): | |
| # Input: | |
| # Nominal A Nx3 matrix of points | |
| # Measured B Nx3 matrix of points | |
| # Returns R,t | |
| # R = 3x3 rotation matrix (B to A) | |
| # t = 3x1 translation vector (B to A) | |
| assert len(A) == len(B) | |
| N = A.shape[0] # total points | |
| centroid_A = np.mean(A, axis=0) | |
| centroid_B = np.mean(B, axis=0) | |
| # center the points | |
| AA = A - np.tile(centroid_A, (N, 1)) | |
| BB = B - np.tile(centroid_B, (N, 1)) | |
| H = np.transpose(BB) * AA | |
| U, S, Vt = np.linalg.svd(H) | |
| R = Vt.T * U.T | |
| # special reflection case | |
| if np.linalg.det(R) < 0: | |
| Vt[2, :] *= -1 | |
| R = Vt.T * U.T | |
| t = -R * centroid_B.T + centroid_A.T | |
| return R, t | |