cafierom commited on
Commit
9837b4e
·
verified ·
1 Parent(s): 803f8c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. PA_requirements.txt +13 -0
  2. PropAgent_HFS.py +506 -0
PA_requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ langchain-huggingface
3
+ langchain_core
4
+ langchain_community
5
+ langgraph
6
+ rdkit
7
+ matplotlib
8
+ pillow
9
+ gradio
10
+ transformers
11
+ huggingface-hub
12
+ accelerate
13
+
PropAgent_HFS.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Annotated, TypedDict, Literal
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langgraph.graph import StateGraph, START, END
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
+
10
+ from langchain_huggingface.llms import HuggingFacePipeline
11
+ from langchain_huggingface import ChatHuggingFace
12
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
+ from langchain_core.runnables import chain
14
+ from uuid import uuid4
15
+ import re
16
+ import matplotlib.pyplot as plt
17
+ import PIL.Image as Image
18
+ import gradio as gr
19
+ import spaces
20
+
21
+ from rdkit import Chem
22
+ from rdkit.Chem import AllChem, QED
23
+ from rdkit.Chem import Draw
24
+ from rdkit import rdBase
25
+ from rdkit.Chem import rdMolAlign
26
+ import os
27
+ from rdkit import RDConfig
28
+ from rdkit.Chem.Features.ShowFeats import _featColors as featColors
29
+ from rdkit.Chem.FeatMaps import FeatMaps
30
+
31
+ fdef = AllChem.BuildFeatureFactory(os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef'))
32
+
33
+ fmParams = {}
34
+ for k in fdef.GetFeatureFamilies():
35
+ fparams = FeatMaps.FeatMapParams()
36
+ fmParams[k] = fparams
37
+
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ hf = HuggingFacePipeline.from_model_id(
41
+ #model_id= "swiss-ai/Apertus-8B-Instruct-2509",
42
+ model_id= "microsoft/Phi-4-mini-instruct",
43
+ task="text-generation",
44
+ pipeline_kwargs = {"max_new_tokens": 500, "temperature": 0.4})
45
+
46
+ chat_model = ChatHuggingFace(llm=hf)
47
+
48
+ class State(TypedDict):
49
+ '''
50
+ The state of the agent.
51
+ '''
52
+ messages: Annotated[list, add_messages]
53
+ query_smiles: str
54
+ query_task: str
55
+ query_path: str
56
+ query_reference: str
57
+ tool_choice: tuple
58
+ which_tool: int
59
+ props_string: str
60
+ #(Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"],
61
+ # Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"])
62
+
63
+
64
+ def substitution_node(state: State) -> State:
65
+ '''
66
+ A simple substitution routine that looks for a substituent on a phenyl ring and
67
+ substitutes different fragments in that location. Returns a list of novel molecules and their
68
+ QED score (1 is most drug-like, 0 is least drug-like).
69
+
70
+ Args:
71
+ smiles: the input smiles string
72
+ Returns:
73
+ new_smiles_string: a string of novel molecules and their QED scores.
74
+ '''
75
+ print("substitution tool")
76
+ print('===================================================')
77
+
78
+ smiles = state["query_smiles"]
79
+ current_props_string = state["props_string"]
80
+
81
+ new_fragments = ["c(Cl)c", "c(F)c", "c(O)c", "c(C)c", "c(OC)c", "c([NH3+])c",
82
+ "c(Br)c", "c(C(F)(F)(F))c"]
83
+
84
+ new_smiles = []
85
+ for fragment in new_fragments:
86
+ m = re.findall(r"c(\D\D*)c", smiles)
87
+ if len(m) != 0:
88
+ for group in m:
89
+ #print(group)
90
+ if fragment not in group:
91
+ new_smile = smiles.replace(group[1:], fragment)
92
+ new_smiles.append(new_smile)
93
+
94
+ qeds = []
95
+ for new_smile in new_smiles:
96
+ qeds.append(get_qed(new_smile))
97
+ original_qed = get_qed(smiles)
98
+
99
+ new_smiles_string = "Substitution or Analogue creation tool results: \n"
100
+ new_smiles_string += f"The original molecule SMILES was {smiles} with QED {original_qed}.\n"
101
+ new_smiles_string += "Novel Molecules or Analogues and QED values: \n"
102
+ for i in range(len(new_smiles)):
103
+ new_smiles_string += f"SMILES: {new_smiles[i]}, QED: {qeds[i]:.3f}\n"
104
+
105
+ if len(new_smiles) > 0:
106
+ img = Draw.MolsToGridImage(new_smiles, molsPerRow=3, subImgSize=(200,200), legends=[f"QED: {qeds[i]:.3f}" for i in range(len(new_smiles))])
107
+ img.save('Substitution_image.png')
108
+ else:
109
+ new_smiles_string += "No valid substitutions were found.\n"
110
+
111
+ print(new_smiles_string)
112
+ current_props_string += new_smiles_string
113
+ state["props_string"] = current_props_string
114
+ state["which_tool"] += 1
115
+ return state
116
+
117
+ def get_qed(smiles):
118
+ '''
119
+ Helper function to compute QED for a given molecule.
120
+ Args:
121
+ smiles: the input smiles string
122
+ Returns:
123
+ qed: the QED score of the molecule.
124
+ '''
125
+ mol = Chem.MolFromSmiles(smiles)
126
+ qed = Chem.QED.default(mol)
127
+
128
+ return qed
129
+
130
+ def lipinski_node(state: State) -> State:
131
+ '''
132
+ A tool to calculate QED and other lipinski properties of a molecule.
133
+ Args:
134
+ smiles: the input smiles string
135
+ Returns:
136
+ props_string: a string of the QED and other lipinski properties of the molecule,
137
+ including Molecular Weight, LogP, HBA, HBD, Polar Surface Area,
138
+ Rotatable Bonds, Aromatic Rings and Undesireable Moieties.
139
+ '''
140
+ print("lipinski tool")
141
+ print('===================================================')
142
+
143
+ smiles = state["query_smiles"]
144
+ current_props_string = state["props_string"]
145
+
146
+ mol = Chem.MolFromSmiles(smiles)
147
+ qed = Chem.QED.default(mol)
148
+
149
+ p = Chem.QED.properties(mol)
150
+ mw = p[0]
151
+ logP = p[1]
152
+ hba = p[2]
153
+ hbd = p[3]
154
+ psa = p[4]
155
+ rb = p[5]
156
+ ar = p[6]
157
+ um = p[7]
158
+
159
+ props_string = "Lipinski tool results: \n"
160
+ props_string += f'''QED and other lipinski properties of the molecule:
161
+ SMILES: {smiles},
162
+ QED: {qed:.3f},
163
+ Molecular Weight: {mw:.3f},
164
+ LogP: {logP:.3f},
165
+ Hydrogen bond acceptors: {hba},
166
+ Hydrogen bond donors: {hbd},
167
+ Polar Surface Area: {psa:.3f},
168
+ Rotatable Bonds: {rb},
169
+ Aromatic Rings: {ar},
170
+ Undesireable moieties: {um}
171
+ '''
172
+
173
+ current_props_string += props_string
174
+ state["props_string"] = current_props_string
175
+ state["which_tool"] += 1
176
+ return state
177
+
178
+ def pharmfeature_node(state: State) -> State:
179
+ '''
180
+ A tool to compare the pharmacophore features of a query molecule against
181
+ a those of a reference molecule and report the pharmacophore features of both and the feature
182
+ score of the query molecule.
183
+
184
+ Args:
185
+ known_smiles: the reference smiles string
186
+ test_smiles: the query smiles string
187
+ Returns:
188
+ props_string: a string of the pharmacophore features of both molecules and the feature
189
+ score of the query molecule.
190
+ '''
191
+ print("pharmfeature tool")
192
+ print('===================================================')
193
+
194
+ test_smiles = state["query_smiles"]
195
+ known_smiles = state["query_reference"]
196
+ current_props_string = state["props_string"]
197
+
198
+ smiles = [known_smiles, test_smiles]
199
+ mols = [Chem.MolFromSmiles(x) for x in smiles]
200
+
201
+ mols = [Chem.AddHs(m) for m in mols]
202
+ ps = AllChem.ETKDGv3()
203
+
204
+ for m in mols:
205
+ AllChem.EmbedMolecule(m,ps)
206
+
207
+ o3d = rdMolAlign.GetO3A(mols[1],mols[0])
208
+ o3d.Align()
209
+
210
+ keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable', 'ZnBinder', 'Aromatic', 'LumpedHydrophobe')
211
+ feat_hash = {'Donor': 'Hydrogen bond donors', 'Acceptor': 'Hydrogen bond acceptors',
212
+ 'NegIonizable': 'Negatively ionizable groups', 'PosIonizable': 'Positively ionizable groups',
213
+ 'ZnBinder': 'Zinc Binders', 'Aromatic': 'Aromatic rings', 'LumpedHydrophobe': 'Hydrophobic/non-polar groups' }
214
+
215
+ feat_vectors = []
216
+ for m in mols:
217
+ rawFeats = fdef.GetFeaturesForMol(m)
218
+ feat_vectors.append([f for f in rawFeats if f.GetFamily() in keep])
219
+
220
+ feat_maps = [FeatMaps.FeatMap(feats = x,weights=[1]*len(x),params=fmParams) for x in feat_vectors]
221
+ test_score = feat_maps[0].ScoreFeats(feat_maps[1].GetFeatures())/(feat_maps[0].GetNumFeatures())
222
+
223
+ feats_known = {}
224
+ feats_test = {}
225
+ for feat in feat_vectors[0]:
226
+ if feat.GetFamily() not in feats_known.keys():
227
+ feats_known[feat.GetFamily()] = 1
228
+ else:
229
+ feats_known[feat.GetFamily()] += 1
230
+
231
+ for feat in feat_vectors[1]:
232
+ if feat.GetFamily() not in feats_test.keys():
233
+ feats_test[feat.GetFamily()] = 1
234
+ else:
235
+ feats_test[feat.GetFamily()] += 1
236
+
237
+ props_string = "PharmFeature tool results: \n"
238
+ props_string += f"The Pharmacophore Feature Overlap Score of the test molecule \
239
+ versus the reference molecule is {test_score:.3f}. \n\n"
240
+
241
+ for feat in feats_known.keys():
242
+ props_string += f"There are {feats_known[feat]} {feat_hash[feat]} in the reference molecule. \n"
243
+
244
+ for feat in feats_test.keys():
245
+ props_string += f"There are {feats_test[feat]} {feat_hash[feat]} in the test molecule. \n"
246
+
247
+ current_props_string += props_string
248
+ state["props_string"] = current_props_string
249
+ state["which_tool"] += 1
250
+ return state
251
+
252
+ def first_node(state: State) -> State:
253
+ '''
254
+ The first node of the agent. This node receives the input and asks the LLM
255
+ to determine which is the best tool to use to answer the QUERY TASK.
256
+
257
+ Input: the initial prompt from the user. should contain only one of more of the following:
258
+
259
+ smiles: the smiles string, task: the query task, path: the path to the file,
260
+ reference: the reference smiles
261
+
262
+ the value should be separated from the name by a ':' and each field should
263
+ be separated from the previous one by a ','.
264
+
265
+ All of these values are saved to the state
266
+
267
+ Output: the tool choice
268
+ '''
269
+ query_smiles = None
270
+ state["query_smiles"] = query_smiles
271
+ query_task = None
272
+ state["query_task"] = query_task
273
+ query_path = None
274
+ state["query_path"] = query_path
275
+ query_reference = None
276
+ state["query_reference"] = query_reference
277
+ props_string = ""
278
+ state["props_string"] = props_string
279
+
280
+ raw_input = state["messages"][-1].content
281
+ parts = raw_input.split(',')
282
+ for part in parts:
283
+ if 'smiles' in part:
284
+ query_smiles = part.split(':')[1]
285
+ if query_smiles.lower() == 'none':
286
+ query_smiles = None
287
+ state["query_smiles"] = query_smiles
288
+ if 'task' in part:
289
+ query_task = part.split(':')[1]
290
+ state["query_task"] = query_task
291
+ if 'path' in part:
292
+ query_path = part.split(':')[1]
293
+ if query_path.lower() == 'none':
294
+ query_path = None
295
+ state["query_path"] = query_path
296
+ if 'reference' in part:
297
+ query_reference = part.split(':')[1]
298
+ if query_reference.lower() == 'none':
299
+ query_reference = None
300
+ state["query_reference"] = query_reference
301
+
302
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
303
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
304
+ are required, reply with both tool names separated by a comma and followed by "#". \
305
+ If the tools cannot complete the task, reply with "None #".\n \
306
+ QUERY_TASK: {query_task}.\n \
307
+ Tools: \n \
308
+ lipinski_tool: this tool can calculate the following properties: Quantitative \
309
+ Estimate of Drug-likeness (QED), Molecular weight, LogP (measures lipophilicity, higher is more lipophilic), \
310
+ HBA, HBD, Polar Surface Area, Rotatable Bonds, Aromatic Rings and Undesireable Moieties. \n \
311
+ substitution_tool: this tool can generate analogues of the molecule by substituting \
312
+ different chemical groups on the original molecule. Returns a list of novel molecules and their \
313
+ QED score (1 is most drug-like, 0 is least drug-like). \n \
314
+ pharm_feature_tool: this tool can compare the pharmacophore features of a query molecule against \
315
+ a those of a reference molecule and report the pharmacophore features of both and the feature \
316
+ score of the query molecule. This score tells how the common features score against each other, but \
317
+ does not inform about features unique to each molecule.'
318
+
319
+ res = chat_model.invoke(prompt)
320
+
321
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
322
+ tool_choices = tool_choices.split(',')
323
+ if len(tool_choices) == 1:
324
+ if tool_choices[0].strip().lower() == 'none':
325
+ tool_choice = (None, None)
326
+ else:
327
+ tool_choice = (tool_choices[0].strip().lower(), None)
328
+ elif len(tool_choices) == 2:
329
+ if tool_choices[0].strip().lower() == 'none':
330
+ tool_choice = (None, tool_choices[1].strip().lower())
331
+ elif tool_choices[1].strip().lower() == 'none':
332
+ tool_choice = (tool_choices[0].strip().lower(), None)
333
+ else:
334
+ tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
335
+ else:
336
+ tool_choice = (None, None)
337
+ state["tool_choice"] = tool_choice
338
+ state["which_tool"] = 0
339
+ print(f"The chosen tools are: {tool_choice}")
340
+
341
+ return state
342
+
343
+ def loop_node(state: State) -> State:
344
+ '''
345
+ This node accepts the tool returns and decides if it needs to call another
346
+ tool or go on to the parser node.
347
+
348
+ Input: the tool returns.
349
+ Output: the next node to call.
350
+ '''
351
+ return state
352
+
353
+ def parser_node(state: State) -> State:
354
+ '''
355
+ This is the third node in the agent. It receives the output from the tool,
356
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
357
+ query.
358
+
359
+ Input: the output from the tool.
360
+ Output: the answer to the original query.
361
+ '''
362
+ props_string = state["props_string"]
363
+ query_task = state["query_task"]
364
+
365
+ prompt = f'Using the CONTEXT below, answer the original query, which \
366
+ was to answer the QUERY_TASK. End your answer with a "#" \
367
+ QUERY_TASK: {query_task}.\n \
368
+ CONTEXT: {props_string}.\n '
369
+
370
+ res = chat_model.invoke(prompt)
371
+ return {"messages": res}
372
+
373
+ def reflect_node(state: State) -> State:
374
+ '''
375
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
376
+ tries to improve it.
377
+
378
+ Input: the LLMs last answer.
379
+ Output: the improved answer.
380
+ '''
381
+ previous_answer = state["messages"][-1].content
382
+ props_string = state["props_string"]
383
+
384
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
385
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
386
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
387
+ your new answer with a "#" \
388
+ PREVIOUS ANSWER: {previous_answer}.\n \
389
+ TOOL RESULTS: {props_string}. '
390
+
391
+ res = chat_model.invoke(prompt)
392
+ return {"messages": res}
393
+
394
+ def get_chemtool(state):
395
+ '''
396
+ '''
397
+ which_tool = state["which_tool"]
398
+ tool_choice = state["tool_choice"]
399
+
400
+ if tool_choice is None or tool_choice == (None, None):
401
+ return None
402
+
403
+ if which_tool == 0 or which_tool == 1:
404
+ current_tool = tool_choice[which_tool]
405
+ if current_tool is None:
406
+ return None
407
+ elif which_tool > 1:
408
+ current_tool = None
409
+
410
+ return current_tool
411
+
412
+ def pretty_print(answer):
413
+ final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
414
+ for i in range(0,len(final),100):
415
+ print(final[i:i+100])
416
+
417
+ def print_short(answer):
418
+ for i in range(0,len(answer),100):
419
+ print(answer[i:i+100])
420
+
421
+ builder = StateGraph(State)
422
+ builder.add_node("first_node", first_node)
423
+ builder.add_node("substitution_node", substitution_node)
424
+ builder.add_node("lipinski_node", lipinski_node)
425
+ builder.add_node("pharmfeature_node", pharmfeature_node)
426
+ builder.add_node("loop_node", loop_node)
427
+ builder.add_node("parser_node", parser_node)
428
+ builder.add_node("reflect_node", reflect_node)
429
+
430
+ builder.add_edge(START, "first_node")
431
+ builder.add_conditional_edges("first_node", get_chemtool, {
432
+ "substitution_tool": "substitution_node",
433
+ "lipinski_tool": "lipinski_node",
434
+ "pharm_feature_tool": "pharmfeature_node",
435
+ None: "parser_node"})
436
+
437
+ builder.add_edge("lipinski_node", "loop_node")
438
+ builder.add_edge("substitution_node", "loop_node")
439
+ builder.add_edge("pharmfeature_node", "loop_node")
440
+
441
+ builder.add_conditional_edges("loop_node", get_chemtool, {
442
+ "substitution_tool": "substitution_node",
443
+ "lipinski_tool": "lipinski_node",
444
+ "pharm_feature_tool": "pharmfeature_node",
445
+ None: "parser_node"})
446
+
447
+ builder.add_edge("parser_node", "reflect_node")
448
+ builder.add_edge("reflect_node", END)
449
+
450
+ graph = builder.compile()
451
+
452
+ @spaces.GPU
453
+ def PropAgent(smiles, reference, task):
454
+
455
+ #if Substitution_image.png exists, remove it
456
+ if os.path.exists('Substitution_image.png'):
457
+ os.remove('Substitution_image.png')
458
+
459
+ input = {
460
+ "messages": [
461
+ HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_reference: {reference}')
462
+ ]
463
+ }
464
+ #print(input)
465
+
466
+ replies = []
467
+ for c in graph.stream(input): #, stream_mode='updates'):
468
+ m = re.findall(r'[a-z]+\_node', str(c))
469
+ if len(m) != 0:
470
+ reply = c[str(m[0])]['messages']
471
+ if 'assistant' in str(reply):
472
+ reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
473
+ replies.append(reply)
474
+ #check if image exists
475
+ if os.path.exists('Substitution_image.png'):
476
+ img_loc = 'Substitution_image.png'
477
+ img = Image.open(img_loc)
478
+ #else create a dummy blank image
479
+ else:
480
+ img = Image.new('RGB', (250, 250), color = (255, 255, 255))
481
+
482
+ return replies[-1], img
483
+
484
+ with gr.Blocks(fill_height=True) as forest:
485
+ gr.Markdown('''
486
+ # Properties Agent
487
+ - uses RDKit to calculate lipinski properties
488
+ - finds pharmacophore similarity between two molecules
489
+ - generated analogues of a molecule
490
+ ''')
491
+
492
+ name, smiles = None, None
493
+ with gr.Row():
494
+ with gr.Column():
495
+ smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
496
+ ref = gr.Textbox(label="Reference molecule SMILES of interest (optional): ", placeholder='none')
497
+ task = gr.Textbox(label="Task for Agent: ")
498
+ calc_btn = gr.Button(value = "Submit to Agent")
499
+ with gr.Column():
500
+ props = gr.Textbox(label="Agent results: ", lines=20 )
501
+ pic = gr.Image(label="Molecule")
502
+
503
+
504
+ calc_btn.click(PropAgent, inputs = [smiles, ref, task], outputs = [props, pic])
505
+
506
+ forest.launch(debug=False, mcp_server=True)