File size: 7,186 Bytes
10efe81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import sys
sys.path.append("..")
import rdkit
import rdkit.Chem as Chem
import copy
import pickle
from tqdm.auto import tqdm
import numpy as np
import torch
import random
from .chemutils import get_clique_mol, tree_decomp, get_mol, get_smiles, set_atommap, get_clique_mol_simple
from collections import defaultdict
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles, sanitize=False)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x: i for i, x in enumerate(self.vocab)}
#self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
if smiles in self.vmap.keys():
return self.vmap[smiles]
else:
return 0
def get_smiles(self, idx):
return self.vocab[idx]
def get_slots(self, idx):
return copy.deepcopy(self.slots[idx])
def size(self):
return len(self.vocab)
class MolTreeNode(object):
def __init__(self, mol, cmol, clique):
self.smiles = Chem.MolToSmiles(cmol, canonical=True)
self.mol = cmol
self.clique = [x for x in clique] # copy
self.neighbors = []
self.rotatable = False
if len(self.clique) == 2:
if mol.GetAtomWithIdx(self.clique[0]).GetDegree() >= 2 and mol.GetAtomWithIdx(self.clique[1]).GetDegree() >= 2:
self.rotatable = True
# should restrict to single bond, but double bond is ok
def add_neighbor(self, nei_node):
self.neighbors.append(nei_node)
def recover(self, original_mol):
clique = []
clique.extend(self.clique)
if not self.is_leaf:
for cidx in self.clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(self.nid)
for nei_node in self.neighbors:
clique.extend(nei_node.clique)
if nei_node.is_leaf: # Leaf node, no need to mark
continue
for cidx in nei_node.clique:
# allow singleton node override the atom mapping
if cidx not in self.clique or len(nei_node.clique) == 1:
atom = original_mol.GetAtomWithIdx(cidx)
atom.SetAtomMapNum(nei_node.nid)
clique = list(set(clique))
label_mol = get_clique_mol_simple(original_mol, clique)
self.label = Chem.MolToSmiles(Chem.MolFromSmiles(get_smiles(label_mol)))
self.label_mol = get_mol(self.label)
for cidx in clique:
original_mol.GetAtomWithIdx(cidx).SetAtomMapNum(0)
return self.label
def assemble(self):
# neighbors = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() > 1]
neighbors = sorted(self.neighbors, key=lambda x: x.mol.GetNumAtoms(), reverse=True)
# singletons = [nei for nei in self.neighbors if nei.mol.GetNumAtoms() == 1]
# neighbors = singletons + neighbors
cands = enum_assemble(self, neighbors)
if len(cands) > 0:
self.cands, self.cand_mols, _ = zip(*cands)
self.cands = list(self.cands)
self.cand_mols = list(self.cand_mols)
else:
self.cands = []
self.cand_mols = []
class MolTree(object):
def __init__(self, mol):
self.smiles = Chem.MolToSmiles(mol)
self.mol = mol
self.num_rotatable_bond = 0
'''
# use reference_vocab and threshold to control the size of vocab
reference_vocab = np.load('./utils/reference.npy', allow_pickle=True).item()
reference = defaultdict(int)
for k, v in reference_vocab.items():
reference[k] = v'''
# use vanilla tree decomposition for simplicity
cliques, edges = tree_decomp(self.mol, reference_vocab=None)
self.nodes = []
root = 0
for i, c in enumerate(cliques):
cmol = get_clique_mol_simple(self.mol, c)
node = MolTreeNode(self.mol, cmol, c)
self.nodes.append(node)
if min(c) == 0:
root = i
for node in self.nodes:
if node.rotatable:
self.num_rotatable_bond += 1
for x, y in edges:
self.nodes[x].add_neighbor(self.nodes[y])
self.nodes[y].add_neighbor(self.nodes[x])
if root > 0:
self.nodes[0], self.nodes[root] = self.nodes[root], self.nodes[0]
for i, node in enumerate(self.nodes):
node.nid = i + 1
'''
if len(node.neighbors) > 1: # Leaf node mol is not marked
set_atommap(node.mol, node.nid)
node.is_leaf = (len(node.neighbors) == 1)'''
def size(self):
return len(self.nodes)
def recover(self):
for node in self.nodes:
node.recover(self.mol)
def assemble(self):
for node in self.nodes:
node.assemble()
if __name__ == "__main__":
seed = 2023
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
vocab = {}
cnt = 0
rot = 0
'''
index_path = './data/crossdocked_pocket10/index.pkl'
with open(index_path, 'rb') as f:
index = pickle.load(f)
for i, (pocket_fn, ligand_fn, _, rmsd_str) in enumerate(tqdm(index)):
if pocket_fn is None: continue
try:
path = './data/crossdocked_pocket10/' + ligand_fn
mol = Chem.MolFromMolFile(path, sanitize=False)
moltree = MolTree(mol)
cnt += 1
if moltree.num_rotatable_bond > 0:
rot += 1
except:
continue
for c in moltree.nodes:
smile_cluster = c.smiles
if smile_cluster not in vocab:
vocab[smile_cluster] = 1
else:
vocab[smile_cluster] += 1
'''
index = torch.load('/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/index.pt')
for i, pdbid in enumerate(tqdm(index)):
if pdbid is None: continue
try:
path = '/n/holyscratch01/mzitnik_lab/zaixizhang/pdbbind_pocket10/'
ligand_path = os.path.join(path, os.path.join(item, item+'_ligand.sdf'))
mol = Chem.MolFromMolFile(ligand_path, sanitize=False)
moltree = MolTree(mol)
cnt += 1
if moltree.num_rotatable_bond > 0:
rot += 1
except:
continue
for c in moltree.nodes:
smile_cluster = c.smiles
if smile_cluster not in vocab:
vocab[smile_cluster] = 1
else:
vocab[smile_cluster] += 1
vocab = dict(sorted(vocab.items(), key=lambda kv: (kv[1], kv[0]), reverse=True))
filename = open('./vocab.txt', 'w')
for k, v in vocab.items():
filename.write(k + ':' + str(v))
filename.write('\n')
filename.close()
# number of molecules and vocab
print('Size of the motif vocab:', len(vocab))
print('Total number of molecules', cnt)
print('Percent of molecules with rotatable bonds:', rot / cnt)
|