cafierom commited on
Commit
5f27d36
·
verified ·
1 Parent(s): e68f539

Upload PropAgent_HFS.py

Browse files
Files changed (1) hide show
  1. PropAgent_HFS.py +650 -0
PropAgent_HFS.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Annotated, TypedDict, Literal
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langgraph.graph import StateGraph, START, END
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
+
10
+ from langchain_huggingface.llms import HuggingFacePipeline
11
+ from langchain_huggingface import ChatHuggingFace
12
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
+ from langchain_core.runnables import chain
14
+ from uuid import uuid4
15
+ import re
16
+ import matplotlib.pyplot as plt
17
+ import PIL.Image as Image
18
+ import gradio as gr
19
+ import spaces
20
+
21
+ from rdkit import Chem
22
+ from rdkit.Chem import AllChem, QED
23
+ from rdkit.Chem import Draw
24
+ from rdkit import rdBase
25
+ from rdkit.Chem import rdMolAlign
26
+ import os
27
+ from rdkit import RDConfig
28
+ from rdkit.Chem.Features.ShowFeats import _featColors as featColors
29
+ from rdkit.Chem.FeatMaps import FeatMaps
30
+
31
+ fdef = AllChem.BuildFeatureFactory(os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef'))
32
+
33
+ fmParams = {}
34
+ for k in fdef.GetFeatureFamilies():
35
+ fparams = FeatMaps.FeatMapParams()
36
+ fmParams[k] = fparams
37
+
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ hf = HuggingFacePipeline.from_model_id(
41
+ #model_id= "swiss-ai/Apertus-8B-Instruct-2509",
42
+ model_id= "microsoft/Phi-4-mini-instruct",
43
+ task="text-generation",
44
+ pipeline_kwargs = {"max_new_tokens": 500, "temperature": 0.4})
45
+
46
+ chat_model = ChatHuggingFace(llm=hf)
47
+
48
+ class State(TypedDict):
49
+ '''
50
+ The state of the agent.
51
+ '''
52
+ messages: Annotated[list, add_messages]
53
+ query_smiles: str
54
+ query_task: str
55
+ query_path: str
56
+ query_reference: str
57
+ tool_choice: tuple
58
+ which_tool: int
59
+ props_string: str
60
+ loop_again: str
61
+ #(Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"],
62
+ # Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"])
63
+
64
+
65
+ def substitution_node(state: State) -> State:
66
+ '''
67
+ A simple substitution routine that looks for a substituent on a phenyl ring and
68
+ substitutes different fragments in that location. Returns a list of novel molecules and their
69
+ QED score (1 is most drug-like, 0 is least drug-like).
70
+
71
+ Args:
72
+ smiles: the input smiles string
73
+ Returns:
74
+ new_smiles_string: a string of novel molecules and their QED scores.
75
+ '''
76
+ print("substitution tool")
77
+ print('===================================================')
78
+
79
+ smiles = state["query_smiles"]
80
+ current_props_string = state["props_string"]
81
+
82
+ new_fragments = ["c(Cl)c", "c(F)c", "c(O)c", "c(C)c", "c(OC)c", "c([NH3+])c",
83
+ "c(Br)c", "c(C(F)(F)(F))c"]
84
+
85
+ new_smiles = []
86
+ for fragment in new_fragments:
87
+ m = re.findall(r"c(\D\D*)c", smiles)
88
+ if len(m) != 0:
89
+ for group in m:
90
+ #print(group)
91
+ if fragment not in group:
92
+ new_smile = smiles.replace(group[1:], fragment)
93
+ new_smiles.append(new_smile)
94
+
95
+ qeds = []
96
+ for new_smile in new_smiles:
97
+ qeds.append(get_qed(new_smile))
98
+ original_qed = get_qed(smiles)
99
+
100
+ new_smiles_string = "Substitution or Analogue creation tool results: \n"
101
+ new_smiles_string += f"The original molecule SMILES was {smiles} with QED {original_qed}.\n"
102
+ new_smiles_string += "Novel Molecules or Analogues and QED values: \n"
103
+ for i in range(len(new_smiles)):
104
+ new_smiles_string += f"SMILES: {new_smiles[i]}, QED: {qeds[i]:.3f}\n"
105
+ new_mols = [Chem.MolFromSmiles(x) for x in new_smiles]
106
+ if len(new_smiles) > 0:
107
+ img = Draw.MolsToGridImage(new_mols, molsPerRow=3, subImgSize=(200,200), legends=[f"QED: {qeds[i]:.3f}" for i in range(len(new_smiles))])
108
+ img.save('Substitution_image.png')
109
+ else:
110
+ new_smiles_string += "No valid substitutions were found.\n"
111
+
112
+ print(new_smiles_string)
113
+ current_props_string += new_smiles_string
114
+ state["props_string"] = current_props_string
115
+ state["which_tool"] += 1
116
+ return state
117
+
118
+ def get_qed(smiles):
119
+ '''
120
+ Helper function to compute QED for a given molecule.
121
+ Args:
122
+ smiles: the input smiles string
123
+ Returns:
124
+ qed: the QED score of the molecule.
125
+ '''
126
+ mol = Chem.MolFromSmiles(smiles)
127
+ qed = Chem.QED.default(mol)
128
+
129
+ return qed
130
+
131
+ def lipinski_node(state: State) -> State:
132
+ '''
133
+ A tool to calculate QED and other lipinski properties of a molecule.
134
+ Args:
135
+ smiles: the input smiles string
136
+ Returns:
137
+ props_string: a string of the QED and other lipinski properties of the molecule,
138
+ including Molecular Weight, LogP, HBA, HBD, Polar Surface Area,
139
+ Rotatable Bonds, Aromatic Rings and Undesireable Moieties.
140
+ '''
141
+ print("lipinski tool")
142
+ print('===================================================')
143
+
144
+ smiles = state["query_smiles"]
145
+ current_props_string = state["props_string"]
146
+
147
+ mol = Chem.MolFromSmiles(smiles)
148
+ qed = Chem.QED.default(mol)
149
+
150
+ p = Chem.QED.properties(mol)
151
+ mw = p[0]
152
+ logP = p[1]
153
+ hba = p[2]
154
+ hbd = p[3]
155
+ psa = p[4]
156
+ rb = p[5]
157
+ ar = p[6]
158
+ um = p[7]
159
+
160
+ props_string = "Lipinski tool results: \n"
161
+ props_string += f'''QED and other lipinski properties of the molecule:
162
+ SMILES: {smiles},
163
+ QED: {qed:.3f},
164
+ Molecular Weight: {mw:.3f},
165
+ LogP: {logP:.3f},
166
+ Hydrogen bond acceptors: {hba},
167
+ Hydrogen bond donors: {hbd},
168
+ Polar Surface Area: {psa:.3f},
169
+ Rotatable Bonds: {rb},
170
+ Aromatic Rings: {ar},
171
+ Undesireable moieties: {um}
172
+ '''
173
+
174
+ current_props_string += props_string
175
+ state["props_string"] = current_props_string
176
+ state["which_tool"] += 1
177
+ return state
178
+
179
+ def pharmfeature_node(state: State) -> State:
180
+ '''
181
+ A tool to compare the pharmacophore features of a query molecule against
182
+ a those of a reference molecule and report the pharmacophore features of both and the feature
183
+ score of the query molecule.
184
+
185
+ Args:
186
+ known_smiles: the reference smiles string
187
+ test_smiles: the query smiles string
188
+ Returns:
189
+ props_string: a string of the pharmacophore features of both molecules and the feature
190
+ score of the query molecule.
191
+ '''
192
+ print("pharmfeature tool")
193
+ print('===================================================')
194
+
195
+ test_smiles = state["query_smiles"]
196
+ known_smiles = state["query_reference"]
197
+ current_props_string = state["props_string"]
198
+
199
+ smiles = [known_smiles, test_smiles]
200
+ mols = [Chem.MolFromSmiles(x) for x in smiles]
201
+
202
+ mols = [Chem.AddHs(m) for m in mols]
203
+ ps = AllChem.ETKDGv3()
204
+
205
+ for m in mols:
206
+ AllChem.EmbedMolecule(m,ps)
207
+
208
+ o3d = rdMolAlign.GetO3A(mols[1],mols[0])
209
+ o3d.Align()
210
+
211
+ keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable', 'ZnBinder', 'Aromatic', 'LumpedHydrophobe')
212
+ feat_hash = {'Donor': 'Hydrogen bond donors', 'Acceptor': 'Hydrogen bond acceptors',
213
+ 'NegIonizable': 'Negatively ionizable groups', 'PosIonizable': 'Positively ionizable groups',
214
+ 'ZnBinder': 'Zinc Binders', 'Aromatic': 'Aromatic rings', 'LumpedHydrophobe': 'Hydrophobic/non-polar groups' }
215
+
216
+ feat_vectors = []
217
+ for m in mols:
218
+ rawFeats = fdef.GetFeaturesForMol(m)
219
+ feat_vectors.append([f for f in rawFeats if f.GetFamily() in keep])
220
+
221
+ feat_maps = [FeatMaps.FeatMap(feats = x,weights=[1]*len(x),params=fmParams) for x in feat_vectors]
222
+ test_score = feat_maps[0].ScoreFeats(feat_maps[1].GetFeatures())/(feat_maps[0].GetNumFeatures())
223
+
224
+ feats_known = {}
225
+ feats_test = {}
226
+ for feat in feat_vectors[0]:
227
+ if feat.GetFamily() not in feats_known.keys():
228
+ feats_known[feat.GetFamily()] = 1
229
+ else:
230
+ feats_known[feat.GetFamily()] += 1
231
+
232
+ for feat in feat_vectors[1]:
233
+ if feat.GetFamily() not in feats_test.keys():
234
+ feats_test[feat.GetFamily()] = 1
235
+ else:
236
+ feats_test[feat.GetFamily()] += 1
237
+
238
+ props_string = "PharmFeature tool results: \n"
239
+ props_string += f"The Pharmacophore Feature Overlap Score of the test molecule \
240
+ versus the reference molecule is {test_score:.3f}. \n\n"
241
+
242
+ for feat in feats_known.keys():
243
+ props_string += f"There are {feats_known[feat]} {feat_hash[feat]} in the reference molecule. \n"
244
+
245
+ for feat in feats_test.keys():
246
+ props_string += f"There are {feats_test[feat]} {feat_hash[feat]} in the test molecule. \n"
247
+
248
+ current_props_string += props_string
249
+ state["props_string"] = current_props_string
250
+ state["which_tool"] += 1
251
+ return state
252
+
253
+ def first_node(state: State) -> State:
254
+ '''
255
+ The first node of the agent. This node receives the input and asks the LLM
256
+ to determine which is the best tool to use to answer the QUERY TASK.
257
+
258
+ Input: the initial prompt from the user. should contain only one of more of the following:
259
+
260
+ smiles: the smiles string, task: the query task, path: the path to the file,
261
+ reference: the reference smiles
262
+
263
+ the value should be separated from the name by a ':' and each field should
264
+ be separated from the previous one by a ','.
265
+
266
+ All of these values are saved to the state
267
+
268
+ Output: the tool choice
269
+ '''
270
+ query_smiles = None
271
+ state["query_smiles"] = query_smiles
272
+ query_task = None
273
+ state["query_task"] = query_task
274
+ query_path = None
275
+ state["query_path"] = query_path
276
+ query_reference = None
277
+ state["query_reference"] = query_reference
278
+ props_string = ""
279
+ state["props_string"] = props_string
280
+ state["loop_again"] = None
281
+
282
+ raw_input = state["messages"][-1].content
283
+ parts = raw_input.split(',')
284
+ for part in parts:
285
+ if 'smiles' in part:
286
+ query_smiles = part.split(':')[1]
287
+ if query_smiles.lower() == 'none':
288
+ query_smiles = None
289
+ state["query_smiles"] = query_smiles
290
+ if 'task' in part:
291
+ query_task = part.split(':')[1]
292
+ state["query_task"] = query_task
293
+ if 'path' in part:
294
+ query_path = part.split(':')[1]
295
+ if query_path.lower() == 'none':
296
+ query_path = None
297
+ state["query_path"] = query_path
298
+ if 'reference' in part:
299
+ query_reference = part.split(':')[1]
300
+ if query_reference.lower() == 'none':
301
+ query_reference = None
302
+ state["query_reference"] = query_reference
303
+
304
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
305
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
306
+ are required, reply with both tool names separated by a comma and followed by "#". \
307
+ If the tools cannot complete the task, reply with "None #".\n \
308
+ QUERY_TASK: {query_task}.\n \
309
+ Tools: \n \
310
+ lipinski_tool: this tool can calculate the following properties: Quantitative \
311
+ Estimate of Drug-likeness (QED), Molecular weight, LogP (measures lipophilicity, higher is more lipophilic), \
312
+ HBA, HBD, Polar Surface Area, Rotatable Bonds, Aromatic Rings and Undesireable Moieties. \n \
313
+ substitution_tool: this tool can generate analogues of the molecule by substituting \
314
+ different chemical groups on the original molecule. Returns a list of novel molecules and their \
315
+ QED score (1 is most drug-like, 0 is least drug-like). \n \
316
+ pharm_feature_tool: this tool can compare the pharmacophore features of a query molecule against \
317
+ a those of a reference molecule and report the pharmacophore features of both and the feature \
318
+ score of the query molecule. This score tells how the common features score against each other, but \
319
+ does not inform about features unique to each molecule.'
320
+
321
+ res = chat_model.invoke(prompt)
322
+
323
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
324
+ tool_choices = tool_choices.split(',')
325
+
326
+ if len(tool_choices) == 1:
327
+ tool1 = tool_choices[0].strip()
328
+ if tool1.lower() == 'none':
329
+ tool_choice = (None, None)
330
+ else:
331
+ tool_choice = (tool1, None)
332
+ elif len(tool_choices) == 2:
333
+ tool1 = tool_choices[0].lower().strip()
334
+ tool2 = tool_choices[1].lower().strip()
335
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
336
+ tool_choice = (None, None)
337
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
338
+ tool_choice = (None, tool2)
339
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
340
+ tool_choice = (tool1, None)
341
+ else:
342
+ tool_choice = (tool1, tool2)
343
+ else:
344
+ tool_choice = (None, None)
345
+
346
+ state["tool_choice"] = tool_choice
347
+ state["which_tool"] = 0
348
+ print(f"The chosen tools are: {tool_choice}")
349
+
350
+ return state
351
+
352
+ def retry_node(state: State) -> State:
353
+ '''
354
+ If the previous loop of the agent does not get enough information from the
355
+ tools to answer the query, this node is called to retry the previous loop.
356
+ Input: the previous loop of the agent.
357
+ Output: the tool choice
358
+ '''
359
+ query_task = state["query_task"]
360
+ query_smiles = state["query_smiles"]
361
+ query_reference = state["query_reference"]
362
+
363
+ prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
364
+ or two of the tools described below could complete the task. The tool choices did not succeed. \
365
+ Please re-examine the tool choices and determine if one or two of the tools described below \
366
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
367
+ are required, reply with both tool names separated by a comma and followed by "#". \
368
+ If the tools cannot complete the task, reply with "None #".\n \
369
+ The information provided by the user is:\n \
370
+ QUERY_SMILES: {query_smiles}.\n \
371
+ QUERY_REFERENCE: {query_reference}.\n \
372
+ The task is: \
373
+ QUERY_TASK: {query_task}.\n \
374
+ Tool options: \n \
375
+ lipinski_tool: this tool can calculate the following properties: Quantitative \
376
+ Estimate of Drug-likeness (QED), Molecular weight, LogP (measures lipophilicity, higher is more lipophilic), \
377
+ HBA, HBD, Polar Surface Area, Rotatable Bonds, Aromatic Rings and Undesireable Moieties. \n \
378
+ substitution_tool: this tool can generate analogues of the molecule by substituting \
379
+ different chemical groups on the original molecule. Returns a list of novel molecules and their \
380
+ QED score (1 is most drug-like, 0 is least drug-like). \n \
381
+ pharm_feature_tool: this tool can compare the pharmacophore features of a query molecule against \
382
+ a those of a reference molecule and report the pharmacophore features of both and the feature \
383
+ score of the query molecule. This score tells how the common features score against each other, but \
384
+ does not inform about features unique to each molecule. \n'
385
+
386
+ res = chat_model.invoke(prompt)
387
+
388
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
389
+ tool_choices = tool_choices.split(',')
390
+
391
+ if len(tool_choices) == 1:
392
+ tool1 = tool_choices[0].strip()
393
+ if tool1.lower() == 'none':
394
+ tool_choice = (None, None)
395
+ else:
396
+ tool_choice = (tool1.lower().strip(), None)
397
+ elif len(tool_choices) > 1:
398
+ tool1 = tool_choices[0].lower().strip()
399
+ tool2 = tool_choices[1].lower().strip()
400
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
401
+ tool_choice = (None, None)
402
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
403
+ tool_choice = (None, tool2)
404
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
405
+ tool_choice = (tool1, None)
406
+ else:
407
+ tool_choice = (tool1, tool2)
408
+ else:
409
+ tool_choice = (None, None)
410
+
411
+ state["tool_choice"] = tool_choice
412
+ state["which_tool"] = 0
413
+ print(f"The chosen tools are (Retry): {tool_choice}")
414
+
415
+ return state
416
+
417
+ def loop_node(state: State) -> State:
418
+ '''
419
+ This node accepts the tool returns and decides if it needs to call another
420
+ tool or go on to the parser node.
421
+
422
+ Input: the tool returns.
423
+ Output: the next node to call.
424
+ '''
425
+ return state
426
+
427
+ def parser_node(state: State) -> State:
428
+ '''
429
+ This is the third node in the agent. It receives the output from the tool,
430
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
431
+ query.
432
+
433
+ Input: the output from the tool.
434
+ Output: the answer to the original query.
435
+ '''
436
+ props_string = state["props_string"]
437
+ query_task = state["query_task"]
438
+ tool_choice = state["tool_choice"]
439
+
440
+ if type(tool_choice) != tuple and tool_choice == None:
441
+ state["loop_again"] = "finish_gracefully"
442
+ return state
443
+ elif type(tool_choice) == tuple and (tool_choice[0] == None) and (tool_choice[1] == None):
444
+ state["loop_again"] = "finish_gracefully"
445
+ return state
446
+
447
+ prompt = f'Using the CONTEXT below, answer the original query, which \
448
+ was to answer the QUERY_TASK. End your answer with a "#" \
449
+ QUERY_TASK: {query_task}.\n \
450
+ CONTEXT: {props_string}.\n '
451
+
452
+ res = chat_model.invoke(prompt)
453
+ trial_answer = str(res).split('<|assistant|>')[1]
454
+ print('parser 1 ', trial_answer)
455
+ state["messages"] = res
456
+
457
+ check_prompt = f'Determine if the TRIAL ANSWER below answers the original \
458
+ QUERY TASK. If it does, respond with "PROCEED #" . If the TRIAL ANSWER did not \
459
+ answer the QUERY TASK, respond with "LOOP #" \n \
460
+ Only loop again if the TRIAL ANSWER did not answer the QUERY TASK. \
461
+ TRIAL ANSWER: {trial_answer}.\n \
462
+ QUERY_TASK: {query_task}.\n'
463
+
464
+ res = chat_model.invoke(check_prompt)
465
+ print('parser, loop again? ', res)
466
+
467
+ if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
468
+ state["loop_again"] = "loop_again"
469
+ return state
470
+ elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
471
+ state["loop_again"] = None
472
+ print('trying to break loop')
473
+ elif "proceed" in str(res).split('<|assistant|>')[1].lower():
474
+ state["loop_again"] = None
475
+ print('trying to break loop')
476
+
477
+ return state
478
+
479
+ def reflect_node(state: State) -> State:
480
+ '''
481
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
482
+ tries to improve it.
483
+
484
+ Input: the LLMs last answer.
485
+ Output: the improved answer.
486
+ '''
487
+ previous_answer = state["messages"][-1].content
488
+ props_string = state["props_string"]
489
+
490
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
491
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
492
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
493
+ your new answer with a "#" \
494
+ PREVIOUS ANSWER: {previous_answer}.\n \
495
+ TOOL RESULTS: {props_string}. '
496
+
497
+ res = chat_model.invoke(prompt)
498
+ return {"messages": res}
499
+
500
+ def graceful_exit_node(state: State) -> State:
501
+ '''
502
+ Called when the Agent cannot assign any tools for the task
503
+ '''
504
+ props_string = state["props_string"]
505
+ prompt = f'Summarize the information in the CONTEXT, including any useful chemical information. Start your answer with: \
506
+ Here is what I found: \n \
507
+ CONTEXT: {props_string}'
508
+
509
+ res = chat_model.invoke(prompt)
510
+
511
+ return {"messages": res}
512
+
513
+ def get_chemtool(state):
514
+ '''
515
+ '''
516
+ which_tool = state["which_tool"]
517
+ tool_choice = state["tool_choice"]
518
+
519
+ if tool_choice is None or tool_choice == (None, None):
520
+ return None
521
+
522
+ if which_tool == 0 or which_tool == 1:
523
+ current_tool = tool_choice[which_tool]
524
+ if current_tool is None:
525
+ return None
526
+ elif which_tool > 1:
527
+ current_tool = None
528
+
529
+ return current_tool
530
+
531
+ def loop_or_not(state):
532
+ '''
533
+ '''
534
+ print(f"(line 482) Loop? {state['loop_again']}")
535
+ if state["loop_again"] == "loop_again":
536
+ return True
537
+ elif state["loop_again"] == "finish_gracefully":
538
+ return 'lets_get_outta_here'
539
+ else:
540
+ return False
541
+
542
+ def pretty_print(answer):
543
+ final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
544
+ for i in range(0,len(final),100):
545
+ print(final[i:i+100])
546
+
547
+ def print_short(answer):
548
+ for i in range(0,len(answer),100):
549
+ print(answer[i:i+100])
550
+
551
+ builder = StateGraph(State)
552
+ builder.add_node("first_node", first_node)
553
+ builder.add_node("retry_node", retry_node)
554
+ builder.add_node("substitution_node", substitution_node)
555
+ builder.add_node("lipinski_node", lipinski_node)
556
+ builder.add_node("pharmfeature_node", pharmfeature_node)
557
+ builder.add_node("loop_node", loop_node)
558
+ builder.add_node("parser_node", parser_node)
559
+ builder.add_node("reflect_node", reflect_node)
560
+ builder.add_node("graceful_exit_node", graceful_exit_node)
561
+
562
+ builder.add_edge(START, "first_node")
563
+ builder.add_conditional_edges("first_node", get_chemtool, {
564
+ "substitution_tool": "substitution_node",
565
+ "lipinski_tool": "lipinski_node",
566
+ "pharm_feature_tool": "pharmfeature_node",
567
+ None: "parser_node"})
568
+
569
+ builder.add_conditional_edges("retry_node", get_chemtool, {
570
+ "substitution_tool": "substitution_node",
571
+ "lipinski_tool": "lipinski_node",
572
+ "pharm_feature_tool": "pharmfeature_node",
573
+ None: "parser_node"})
574
+
575
+ builder.add_edge("lipinski_node", "loop_node")
576
+ builder.add_edge("substitution_node", "loop_node")
577
+ builder.add_edge("pharmfeature_node", "loop_node")
578
+
579
+ builder.add_conditional_edges("loop_node", get_chemtool, {
580
+ "substitution_tool": "substitution_node",
581
+ "lipinski_tool": "lipinski_node",
582
+ "pharm_feature_tool": "pharmfeature_node",
583
+ None: "parser_node"})
584
+
585
+ builder.add_conditional_edges("parser_node", loop_or_not, {
586
+ True: "retry_node",
587
+ 'lets_get_outta_here': "graceful_exit_node",
588
+ False: "reflect_node"})
589
+
590
+ builder.add_edge("reflect_node", END)
591
+ builder.add_edge("graceful_exit_node", END)
592
+
593
+ graph = builder.compile()
594
+
595
+ @spaces.GPU
596
+ def PropAgent(task, smiles, reference):
597
+
598
+ #if Substitution_image.png exists, remove it
599
+ if os.path.exists('Substitution_image.png'):
600
+ os.remove('Substitution_image.png')
601
+
602
+ input = {
603
+ "messages": [
604
+ HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_reference: {reference}')
605
+ ]
606
+ }
607
+ #print(input)
608
+
609
+ replies = []
610
+ for c in graph.stream(input): #, stream_mode='updates'):
611
+ m = re.findall(r'[a-z]+\_node', str(c))
612
+ if len(m) != 0:
613
+ reply = c[str(m[0])]['messages']
614
+ if 'assistant' in str(reply):
615
+ reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
616
+ replies.append(reply)
617
+ #check if image exists
618
+ if os.path.exists('Substitution_image.png'):
619
+ img_loc = 'Substitution_image.png'
620
+ img = Image.open(img_loc)
621
+ #else create a dummy blank image
622
+ else:
623
+ img = Image.new('RGB', (250, 250), color = (255, 255, 255))
624
+
625
+ return replies[-1], img
626
+
627
+ with gr.Blocks(fill_height=True) as forest:
628
+ gr.Markdown('''
629
+ # Properties Agent
630
+ - uses RDKit to calculate lipinski properties
631
+ - finds pharmacophore similarity between two molecules
632
+ - generated analogues of a molecule
633
+ ''')
634
+
635
+ name, smiles = None, None
636
+ with gr.Row():
637
+ with gr.Column():
638
+ smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
639
+ ref = gr.Textbox(label="Reference molecule SMILES of interest (optional): ", placeholder='none')
640
+ task = gr.Textbox(label="Task for Agent: ")
641
+ calc_btn = gr.Button(value = "Submit to Agent")
642
+ with gr.Column():
643
+ props = gr.Textbox(label="Agent results: ", lines=20 )
644
+ pic = gr.Image(label="Molecule")
645
+
646
+
647
+ calc_btn.click(PropAgent, inputs = [task, smiles, ref], outputs = [props, pic])
648
+ task.submit(PropAgent, inputs = [task, smiles, ref], outputs = [props, pic])
649
+
650
+ forest.launch(debug=False, mcp_server=True)