cafierom commited on
Commit
b964fe6
·
verified ·
1 Parent(s): 37e3c22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -0
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+ import datasets
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from transformers import AutoTokenizer
8
+ from transformers import pipeline
9
+ import random
10
+ import deepchem
11
+ from rdkit import Chem
12
+ from rdkit.Chem import Draw
13
+
14
+ model_name = f"cafierom/bert-base-cased-ChemTok-ZN250K-V1"
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name,padding = True, truncation = True)
17
+ mask_filler = pipeline("fill-mask", model_name)
18
+
19
+ def tokenize(batch):
20
+ return tokenizer(batch["text"], padding=True, truncation=True, max_length=250, return_special_tokens_mask=True)
21
+
22
+ def gen_from_multimask(text, print_flag=True, mask_flag="random", percent = 0.10, top_k = 3):
23
+ """
24
+ Takes a SMILES string and tokenizes it. Depending on the mask flag, it then masks the
25
+ requested percentage of tokens in the string either randomly, at the begining (first) or at
26
+ the end (last). The masked string is then sent to the mask filler, and the result is expanded
27
+ into all possible new strings where the top k beams are selected and used if their probability
28
+ is greater than 0.1. Entropy is also calculated for each beam.
29
+
30
+ Args:
31
+ text: The SMILES string of the original molecule.
32
+
33
+ Returns:
34
+ final_smiles: a list of all the generated molecules.
35
+ total_entropy: a list of the entropy of each generated molecule.
36
+ """
37
+ new_tok_list = []
38
+ single_tok = tokenizer(text, padding=True, truncation=True, max_length=250, return_special_tokens_mask=True)
39
+ length_count = 0
40
+ for token in single_tok["input_ids"]:
41
+ if token != 0:
42
+ length_count += 1
43
+
44
+ if mask_flag == "last":
45
+ masked_tokens = [*range(int(length_count*(1.0-percent))-1,length_count-1)]
46
+ elif mask_flag == "first":
47
+ masked_tokens = [*range(0,int(length_count*percent))]
48
+ elif mask_flag == "random":
49
+ masked_tokens = random.sample(range(1, length_count), int(length_count*percent))
50
+
51
+ for j,token in enumerate(single_tok["input_ids"]):
52
+ if token != 0:
53
+ if j in masked_tokens:
54
+ new_tok_list.append(103)
55
+ else:
56
+ new_tok_list.append(token)
57
+ masked_smile = tokenizer.decode(new_tok_list,
58
+ skip_special_tokens=False).replace("[PAD]","").replace("[SEP]","").replace("[CLS]","").replace(" ","")
59
+ result = mask_filler(masked_smile,top_k=top_k)
60
+
61
+ new_smiles = []
62
+ total_batch = []
63
+ total_entropy = []
64
+
65
+ for i in range(len(result)):
66
+
67
+ batch_smiles = []
68
+ batch_entropy = []
69
+
70
+ for j in range(top_k):
71
+
72
+ p = result[i][j]["score"]
73
+
74
+ if result[i][j]["score"] > 0.1:
75
+ if i == 0:
76
+ new_smile = result[i][j]["sequence"].replace(" ","").replace("[SEP]","").replace("[CLS]","")
77
+ batch_smiles.append(new_smile)
78
+ batch_entropy.append(-p*np.log(p))
79
+ else:
80
+ for smile,entropy in zip(total_batch[i-1],total_entropy[i-1]):
81
+ new_smile = smile.replace("[MASK]",result[i][j]["token_str"],1)
82
+ batch_smiles.append(new_smile)
83
+ new_entropy = entropy - p*np.log(p)
84
+ batch_entropy.append(new_entropy)
85
+
86
+ total_entropy.append(batch_entropy)
87
+ total_batch.append(batch_smiles)
88
+
89
+ final_smiles = []
90
+ for smile in total_batch[-1]:
91
+ new_smile = smile.replace("##","")
92
+ final_smiles.append(new_smile)
93
+
94
+ if print_flag:
95
+ print(f"original: {text}")
96
+ final_smiles.insert(0,text)
97
+ for smile in final_smiles:
98
+ print(f"generated: {smile}")
99
+
100
+ return final_smiles,total_entropy[-1]
101
+
102
+ def validate_smiles(in_smiles, in_entropy):
103
+ """
104
+ Takes a list of SMILES strings checks to see if the compile to valid MOL objects.
105
+ Valid molecules are then converted to canonical SMILES strings and duplicates are
106
+ dropped.
107
+
108
+ Args:
109
+ text: The SMILES string of the original molecule.
110
+
111
+ Returns:
112
+ unique_smiles: a list of all the unique, valid generated molecules.
113
+ unique_entropies: a list of the entropy of each generated molecule.
114
+ """
115
+ valid_smiles = []
116
+ valid_entropies = []
117
+ unique_smiles = []
118
+ unique_entropies = []
119
+
120
+ for smile,entropy in zip(in_smiles,in_entropy):
121
+ try:
122
+ mol = Chem.MolFromSmiles(smile)
123
+ if mol is not None:
124
+ valid_smiles.append(smile)
125
+ valid_entropies.append(entropy)
126
+ except:
127
+ print("Could not convert to mol")
128
+
129
+ canon_smiles = [Chem.CanonSmiles(smile) for smile in valid_smiles]
130
+
131
+ for smile,entropy in zip(canon_smiles,valid_entropies):
132
+ if smile not in unique_smiles:
133
+ unique_smiles.append(smile)
134
+ unique_entropies.append(entropy)
135
+
136
+ print(f"Total unique SMILES generated: {len(unique_smiles)}")
137
+ print(f"Average entropy: {sum(unique_entropies)/len(unique_entropies)}")
138
+
139
+ return unique_smiles,unique_entropies
140
+
141
+ def calc_qed(smiles):
142
+ mols = [Chem.MolFromSmiles(smile) for smile in smiles]
143
+ qed = [Chem.QED.default(mol) for mol in mols]
144
+ return qed,mols
145
+
146
+ def gen_mask(smile_in: str) -> str:
147
+ """
148
+ The molecule corresponding to the input smiles is masked in different,
149
+ random ways, creating various masked versions of the molelcule.
150
+ A model, cafierom/bert-base-cased-ChemTok-ZN250K-V1,
151
+ is used to generate SMILES strings for analogue molecules by unmasking the
152
+ masked versions. All possibilities created by the generative mask-filling
153
+ are kept as long as the probability is greater than a cut-off, which is set
154
+ to 0.1 but which may be changed.The QED value, or quantitative estimate of druglikeness, a weighted average of
155
+ various ADME properties is also calculated. A value of 1.0 is perfect
156
+ drug-likeness, and a value of 0.0 is not drug-like.
157
+
158
+ Args:
159
+ smile: The SMILES string of the original molecule.
160
+
161
+ Returns:
162
+ out_text: a string with all of the SMILES for the generated molecules
163
+ and their QED values.
164
+
165
+ pic: An image of the molecules with QED values.
166
+ """
167
+ which_statins = [smile_in]
168
+ percent_to_use = 0.10
169
+ try:
170
+ main_smiles = []
171
+ main_entropy = []
172
+ for statin in which_statins:
173
+ result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "first", percent=percent_to_use)
174
+ for smile,entropy in zip(result,calc_entropy):
175
+ if smile not in main_smiles:
176
+ main_smiles.append(smile)
177
+ main_entropy.append(entropy)
178
+ length = len(main_smiles)
179
+ print(f"First masking generated {length} SMILES")
180
+
181
+ result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "last", percent=percent_to_use)
182
+ for smile,entropy in zip(result,calc_entropy):
183
+ if smile not in main_smiles:
184
+ main_smiles.append(smile)
185
+ main_entropy.append(entropy)
186
+ print(f"Last masking generated {len(main_smiles)-length} SMILES")
187
+ length = len(main_smiles)
188
+
189
+ for _ in range(4):
190
+ result, calc_entropy = gen_from_multimask(statin, print_flag=False, mask_flag = "random", percent=percent_to_use)
191
+ for smile,entropy in zip(result,calc_entropy):
192
+ if smile not in main_smiles:
193
+ main_smiles.append(smile)
194
+ main_entropy.append(entropy)
195
+ print(f"Random masking generated {len(main_smiles)-length} SMILES")
196
+ length = len(main_smiles)
197
+
198
+ print(f"Total SMILES generated: {len(main_smiles)}")
199
+
200
+ final_smiles,final_entropy = validate_smiles(main_smiles,main_entropy)
201
+ qeds,mols = calc_qed(final_smiles)
202
+
203
+ out_text = f"Total SMILES generated for hit: {len(final_smiles)}\n"
204
+ out_text += "===================================================\n"
205
+ i = 1
206
+ for smile, qed in zip(final_smiles,qeds):
207
+ out_text += f"analogue {i}: {smile} with QED: {qed:.3f}\n"
208
+ i += 1
209
+
210
+ legends = [f"QED = {qed:.3f}" for qed in qeds]
211
+
212
+ img = Draw.MolsToGridImage(mols, legends=legends, molsPerRow=3, subImgSize=(200,200),useSVG=False,returnPNG=False)
213
+
214
+ except:
215
+ out_text = "Invalid SMILES string"
216
+ img = None
217
+ return out_text,img
218
+
219
+ with gr.Blocks() as forest:
220
+ gr.Markdown(
221
+ """
222
+ # Generate Analogues of a hit for hit expansion using generative mask-filling.
223
+
224
+ - The hit molecule is input by the user; this molecule is then masked in different,
225
+ random ways. A model, cafierom/bert-base-cased-ChemTok-ZN250K-V1,
226
+ is used to generate SMILES strings for analogue molecules by unmasking the
227
+ hit molecule. All possibilities created by the generative mask-filling
228
+ are kept as long as the probability is greater than a cut-off, which is set
229
+ to 0.1 but which may be changed.
230
+
231
+ - The QED value, or quantitative estimate of druglikeness, a weighted average of
232
+ various ADME properties is also calculated. A value of 1.0 is perfect
233
+ drug-likeness, and a value of 0.0 is not drug-like. A value of about 0.5
234
+ is average for many drugs.
235
+ """)
236
+
237
+ with gr.Row():
238
+ smile = gr.Textbox(label="SMILES for hit expansion")
239
+
240
+ adme_btn = gr.Button("Generate analogues.")
241
+
242
+ with gr.Row():
243
+ results = gr.Textbox(label="New Molecules: ")
244
+ mol_pic = gr.Image(label="Molecule Images:")
245
+
246
+
247
+ @adme_btn.click(inputs=[smile], outputs=[results, mol_pic])
248
+ def do_genmask(smile,struct_type):
249
+ return gen_mask(smile)
250
+
251
+ @smile.submit(inputs=[smile], outputs=[results, mol_pic])
252
+ def do_genmask(smile,struct_type):
253
+ return gen_mask(smile)
254
+
255
+
256
+ forest.launch(mcp_server=True)