cafierom commited on
Commit
ced091f
·
verified ·
1 Parent(s): e1b462f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -503
app.py CHANGED
@@ -1,504 +1,504 @@
1
- import torch
2
- from typing import Annotated, TypedDict, Literal
3
- from langchain_community.tools import DuckDuckGoSearchRun
4
- from langchain_core.tools import tool
5
- from langgraph.prebuilt import ToolNode, tools_condition
6
- from langgraph.graph import StateGraph, START, END
7
- from langgraph.graph.message import add_messages
8
- from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
-
10
- from langchain_huggingface.llms import HuggingFacePipeline
11
- from langchain_huggingface import ChatHuggingFace
12
- from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
- from langchain_core.runnables import chain
14
- from uuid import uuid4
15
- import re
16
- import matplotlib.pyplot as plt
17
- import spaces
18
-
19
- from dockstring import load_target
20
- from rdkit import Chem
21
- from rdkit.Chem import AllChem, QED
22
- from rdkit.Chem import Draw
23
- from rdkit.Chem.Draw import MolsToGridImage
24
- import os, re
25
- import gradio as gr
26
- from PIL import Image
27
-
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
-
30
- hf = HuggingFacePipeline.from_model_id(
31
- model_id= "microsoft/Phi-4-mini-instruct",
32
- task="text-generation",
33
- pipeline_kwargs = {"max_new_tokens": 500, "temperature": 0.4})
34
-
35
- chat_model = ChatHuggingFace(llm=hf)
36
-
37
- cpuCount = os.cpu_count()
38
- print(f"Number of CPUs: {cpuCount}")
39
-
40
- class State(TypedDict):
41
- '''
42
- The state of the agent.
43
- '''
44
- messages: Annotated[list, add_messages]
45
- #for the agent
46
- tool_choice: tuple
47
- which_tool: int
48
- props_string: str
49
- similars_img: str
50
- loop_again: str
51
-
52
- #for the user input
53
- query_smiles: str
54
- query_task: str
55
- query_protein: str
56
-
57
- def docking_node(state: State) -> State:
58
- '''
59
- Docking tool: uses dockstring to dock the molecule into the protein
60
- '''
61
- print("docking tool")
62
- print('===================================================')
63
- current_props_string = state["props_string"]
64
- query_protein = state["query_protein"]
65
- query_smiles = state["query_smiles"]
66
-
67
- print(f'query_protein: {query_protein}')
68
- print(f'query_smiles: {query_smiles}')
69
-
70
- try:
71
- target = load_target(query_protein)
72
- print("===============================================")
73
- print(f"Docking molecule with {cpuCount} cpu cores.")
74
- score, aux = target.dock(query_smiles, num_cpus = cpuCount)
75
- mol = aux['ligand']
76
- print(f"Docking score: {score}")
77
- print("===============================================")
78
- atoms_list = ""
79
- template = mol
80
- molH = Chem.AddHs(mol)
81
- AllChem.ConstrainedEmbed(molH,template, useTethers=True)
82
- xyz_string = f"{molH.GetNumAtoms()}\n\n"
83
- for atom in molH.GetAtoms():
84
- atoms_list += atom.GetSymbol()
85
- pos = molH.GetConformer().GetAtomPosition(atom.GetIdx())
86
- xyz_string += f"{atom.GetSymbol()} {pos[0]} {pos[1]} {pos[2]}\n"
87
- prop_string = f"Docking score: {score} kcal/mol \n\n"
88
- prop_string += f"pose structure: {xyz_string}\n"
89
-
90
- except:
91
- print(f"Molecule could not be docked!")
92
- prop_string = ''
93
-
94
-
95
- current_props_string += prop_string
96
- state["props_string"] = current_props_string
97
- state["which_tool"] += 1
98
- return state
99
-
100
- def first_node(state: State) -> State:
101
- '''
102
- The first node of the agent. This node receives the input and asks the LLM
103
- to determine which is the best tool to use to answer the QUERY TASK.
104
- Input: the initial prompt from the user. should contain only one of more of the following:
105
- smiles: the smiles string, task: the query task, path: the path to the file,
106
- reference: the reference smiles
107
- the value should be separated from the name by a ':' and each field should
108
- be separated from the previous one by a ','.
109
- All of these values are saved to the state
110
- Output: the tool choice
111
- '''
112
- #for the user input
113
- query_smiles = None
114
- state["query_smiles"] = query_smiles
115
- query_task = None
116
- state["query_task"] = query_task
117
- query_protein = None
118
- state["query_protein"] = query_protein
119
- #for the agent
120
- state['similars_img'] = None
121
- props_string = ""
122
- state["props_string"] = props_string
123
- state["loop_again"] = None
124
-
125
- raw_input = state["messages"][-1].content
126
- #print(raw_input)
127
- parts = raw_input.split(',')
128
- for part in parts:
129
- if 'query_smiles' in part:
130
- query_smiles = part.split(':')[1]
131
- if query_smiles.lower() == 'none':
132
- query_smiles = None
133
- state["query_smiles"] = query_smiles
134
- if 'query_task' in part:
135
- query_task = part.split(':')[1]
136
- state["query_task"] = query_task
137
- if 'query_protein' in part:
138
- query_protein = part.split(':')[1]
139
- state["query_protein"] = query_protein
140
-
141
- prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
142
- can complete the task. If so, reply with only the tool names followed by "#". If two tools \
143
- are required, reply with both tool names separated by a comma and followed by "#". \
144
- If the tools cannot complete the task, reply with "None #".\n \
145
- QUERY_TASK: {query_task}.\n \
146
- The information provided by the user is:\n \
147
- QUERY_SMILES: {query_smiles}.\n \
148
- QUERY_PROTEIN: {query_protein}.\n \
149
- Tools: \n \
150
- docking_tool: uses dockstring to dock the molecule into the protein, producing a pose structure and a docking score.\n \
151
- '
152
-
153
- res = chat_model.invoke(prompt)
154
- print(res)
155
-
156
- tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
157
- tool_choices = tool_choices.split(',')
158
- print(tool_choices)
159
-
160
- if len(tool_choices) == 1:
161
- tool1 = tool_choices[0].strip()
162
- if tool1.lower() == 'none':
163
- tool_choice = (None, None)
164
- else:
165
- tool_choice = (tool1, None)
166
- elif len(tool_choices) == 2:
167
- tool1 = tool_choices[0].strip()
168
- tool2 = tool_choices[1].strip()
169
- if tool1.lower() == 'none' and tool2.lower() == 'none':
170
- tool_choice = (None, None)
171
- elif tool1.lower() == 'none' and tool2.lower() != 'none':
172
- tool_choice = (None, tool2)
173
- elif tool2.lower() == 'none' and tool1.lower() != 'none':
174
- tool_choice = (tool1, None)
175
- else:
176
- tool_choice = (tool1, tool2)
177
- else:
178
- tool_choice = (None, None)
179
-
180
- state["tool_choice"] = tool_choice
181
- state["which_tool"] = 0
182
- print(f"First Node. The chosen tools are: {tool_choice}")
183
-
184
- return state
185
-
186
- def retry_node(state: State) -> State:
187
- '''
188
- If the previous loop of the agent does not get enough informartion from the
189
- tools to answer the query, this node is called to retry the previous loop.
190
- Input: the previous loop of the agent.
191
- Output: the tool choice
192
- '''
193
- query_task = state["query_task"]
194
- query_smiles = state["query_smiles"]
195
- query_protein = state["query_protein"]
196
-
197
- prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
198
- or two of the tools descibed below could complete the task. The tool choices did not succeed. \
199
- Please re-examine the tool choices and determine if one or two of the tools descibed below \
200
- can complete the task. If so, reply with only the tool names followed by "#". If two tools \
201
- are required, reply with both tool names separated by a comma and followed by "#". \
202
- If the tools cannot complete the task, reply with "None #".\n \
203
- The information provided by the user is:\n \
204
- QUERY_SMILES: {query_smiles}.\n \
205
- QUERY_PROTEIN: {query_protein}.\n \
206
- The task is: \
207
- QUERY_TASK: {query_task}.\n \
208
- Tool options: \n \
209
- docking_tool: uses dockstring to dock the molecule into the protein, producing a pose structure and a docking score.\n \
210
- '
211
-
212
- res = chat_model.invoke(prompt)
213
-
214
- tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
215
- tool_choices = tool_choices.split(',')
216
- if len(tool_choices) == 1:
217
- tool1 = tool_choices[0].strip()
218
- if tool1.lower() == 'none':
219
- tool_choice = (None, None)
220
- else:
221
- tool_choice = (tool1, None)
222
- elif len(tool_choices) == 2:
223
- tool1 = tool_choices[0].strip()
224
- tool2 = tool_choices[1].strip()
225
- if tool1.lower() == 'none' and tool2.lower() == 'none':
226
- tool_choice = (None, None)
227
- elif tool1.lower() == 'none' and tool2.lower() != 'none':
228
- tool_choice = (None, tool2)
229
- elif tool2.lower() == 'none' and tool1.lower() != 'none':
230
- tool_choice = (tool1, None)
231
- else:
232
- tool_choice = (tool1, tool2)
233
- else:
234
- tool_choice = (None, None)
235
-
236
- state["tool_choice"] = tool_choice
237
- state["which_tool"] = 0
238
- print(f"The chosen tools are (Retry): {tool_choice}")
239
-
240
- return state
241
-
242
- def loop_node(state: State) -> State:
243
- '''
244
- This node accepts the tool returns and decides if it needs to call another
245
- tool or go on to the parser node.
246
- Input: the tool returns.
247
- Output: the next node to call.
248
- '''
249
- return state
250
-
251
- def parser_node(state: State) -> State:
252
- '''
253
- This is the third node in the agent. It receives the output from the tool,
254
- puts it into a prompt as CONTEXT, and asks the LLM to answer the original
255
- query.
256
- Input: the output from the tool.
257
- Output: the answer to the original query.
258
- '''
259
- props_string = state["props_string"]
260
- query_task = state["query_task"]
261
- tool_choice = state["tool_choice"]
262
-
263
- if type(tool_choice) != tuple and tool_choice == None:
264
- state["loop_again"] = "finish_gracefully"
265
- return state
266
- elif type(tool_choice) == tuple and (tool_choice[0] == None) and (tool_choice[1] == None):
267
- state["loop_again"] = "finish_gracefully"
268
- return state
269
-
270
- prompt = f'Using the CONTEXT below, answer the original query, which \
271
- was to answer the QUERY_TASK. End your answer with a "#" \
272
- CONTEXT: {props_string}.\n \
273
- QUERY_TASK: {query_task}.\n '
274
-
275
- res = chat_model.invoke(prompt)
276
- trial_answer = str(res).split('<|assistant|>')[1]
277
- print('parser 1 ', trial_answer)
278
- state["messages"] = res
279
-
280
- check_prompt = f'Determine if the TRIAL ANSWER below answers the original \
281
- QUERY TASK. If it does, respond with "PROCEED #" . If the TRIAL ANSWER did not \
282
- answer the QUERY TASK, respond with "LOOP #" \n \
283
- Only loop again if the TRIAL ANSWER did not answer the QUERY TASK. \
284
- TRIAL ANSWER: {trial_answer}.\n \
285
- QUERY_TASK: {query_task}.\n'
286
-
287
- res = chat_model.invoke(check_prompt)
288
- print('parser, loop again? ', res)
289
-
290
- if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
291
- state["loop_again"] = "loop_again"
292
- return state
293
- elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
294
- state["loop_again"] = None
295
- print('trying to break loop')
296
- elif "proceed" in str(res).split('<|assistant|>')[1].lower():
297
- state["loop_again"] = None
298
- print('trying to break loop')
299
-
300
- return state
301
-
302
- def reflect_node(state: State) -> State:
303
- '''
304
- This is the fourth node of the agent. It recieves the LLMs previous answer and
305
- tries to improve it.
306
- Input: the LLMs last answer.
307
- Output: the improved answer.
308
- '''
309
- previous_answer = state["messages"][-1].content
310
- props_string = state["props_string"]
311
-
312
- prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
313
- TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
314
- TOOL RESULTS by adding additional clarifying and enriching information. End \
315
- your new answer with a "#" \
316
- PREVIOUS ANSWER: {previous_answer}.\n \
317
- TOOL RESULTS: {props_string}. '
318
-
319
- res = chat_model.invoke(prompt)
320
-
321
- return {"messages": res}
322
-
323
- def graceful_exit_node(state: State) -> State:
324
- '''
325
- Called when the Agent cannot assign any tools for the task
326
- '''
327
- props_string = state["props_string"]
328
- prompt = f'Summarize the information in the CONTEXT, including any useful chemical information. Start your answer with: \
329
- Here is what I found: \n \
330
- CONTEXT: {props_string}'
331
-
332
- res = chat_model.invoke(prompt)
333
-
334
- return {"messages": res}
335
-
336
-
337
- def get_chemtool(state):
338
- '''
339
- '''
340
- which_tool = state["which_tool"]
341
- tool_choice = state["tool_choice"]
342
- print('in get_chemtool ',tool_choice)
343
- if tool_choice == None:
344
- return None
345
- if which_tool == 0 or which_tool == 1:
346
- current_tool = tool_choice[which_tool]
347
- elif which_tool > 1:
348
- current_tool = None
349
-
350
- return current_tool
351
-
352
- def loop_or_not(state):
353
- '''
354
- '''
355
- print(f"(line 417) Loop? {state['loop_again']}")
356
- if state["loop_again"] == "loop_again":
357
- return True
358
- elif state["loop_again"] == "finish_gracefully":
359
- return 'lets_get_outta_here'
360
- else:
361
- return False
362
-
363
- builder = StateGraph(State)
364
- #for the agent
365
- builder.add_node("first_node", first_node)
366
- builder.add_node("retry_node", retry_node)
367
- builder.add_node("loop_node", loop_node)
368
- builder.add_node("parser_node", parser_node)
369
- builder.add_node("reflect_node", reflect_node)
370
- builder.add_node("graceful_exit_node", graceful_exit_node)
371
- #for the tools
372
- builder.add_node("docking_node", docking_node)
373
-
374
- builder.add_edge(START, "first_node")
375
- builder.add_conditional_edges("first_node", get_chemtool, {
376
- "docking_tool": "docking_node",
377
- None: "parser_node"})
378
-
379
- builder.add_conditional_edges("retry_node", get_chemtool, {
380
- "docking_tool": "docking_node",
381
- None: "parser_node"})
382
-
383
- builder.add_edge("docking_node", "loop_node")
384
-
385
- builder.add_conditional_edges("loop_node", get_chemtool, {
386
- "docking_tool" : "docking_node",
387
- "loop_again": "first_node",
388
- None: "parser_node"})
389
-
390
- builder.add_conditional_edges("parser_node", loop_or_not, {
391
- True: "retry_node",
392
- 'lets_get_outta_here': "graceful_exit_node",
393
- False: "reflect_node"})
394
-
395
- builder.add_edge("reflect_node", END)
396
- builder.add_edge("graceful_exit_node", END)
397
-
398
- graph = builder.compile()
399
-
400
- @spaces.GPU
401
- def DockAgent(task, smiles, protein): # add variables as needed
402
-
403
- #if Similars_image.png exists, remove it
404
- if os.path.exists('Similars_image.png'):
405
- os.remove('Similars_image.png')
406
-
407
- input = {
408
- "messages": [
409
- HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_protein: {protein}') # add variables as needed
410
- ]
411
- }
412
- #print(input)
413
-
414
- replies = []
415
- for c in graph.stream(input): #, stream_mode='updates'):
416
- m = re.findall(r'[a-z]+\_node', str(c))
417
- if len(m) != 0:
418
- try:
419
- reply = c[str(m[0])]['messages']
420
- if 'assistant' in str(reply):
421
- reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
422
- replies.append(reply)
423
- except:
424
- reply = str(c).split("<|assistant|>")[-1].split('#')[0].strip()
425
- replies.append(reply)
426
- #check if image exists
427
- if os.path.exists('Similars_image.png'):
428
- img_loc = 'Similars_image.png'
429
- img = Image.open(img_loc)
430
- #else create a dummy blank image
431
- else:
432
- img = Image.new('RGB', (250, 250), color = (255, 255, 255))
433
-
434
- return replies[-1], img
435
-
436
- dudes = ['IGF1R', 'JAK2', 'KIT', 'LCK', 'MAPK14', 'MAPKAPK2', 'MET', 'PTK2', 'PTPN1', 'SRC', 'ABL1', 'AKT1', 'AKT2', 'CDK2', 'CSF1R', 'EGFR', 'KDR', 'MAPK1', 'FGFR1', 'ROCK1', 'MAP2K1', 'PLK1',
437
- 'HSD11B1', 'PARP1', 'PDE5A', 'PTGS2', 'ACHE', 'MAOB', 'CA2', 'GBA', 'HMGCR', 'NOS1', 'REN', 'DHFR', 'ESR1', 'ESR2', 'NR3C1', 'PGR', 'PPARA', 'PPARD', 'PPARG',
438
- 'AR','THRB','ADAM17', 'F10', 'F2', 'BACE1', 'CASP3', 'MMP13', 'DPP4', 'ADRB1', 'ADRB2', 'DRD2', 'DRD3','ADORA2A','CYP2C9', 'CYP3A4', 'HSP90AA1']
439
-
440
- with gr.Blocks(fill_height=True) as forest:
441
- gr.Markdown('''
442
- # Docking Agent
443
- - uses dockstring to dock a molecule into a protein using only a SMILES string and a protein name
444
- - produces a pose structure and a docking score
445
- ''')
446
-
447
- with gr.Accordion("ProteinOptions", open=False):
448
- gr.Markdown('''
449
- # Protein Options
450
- ## Kinase
451
- ### Highest quality
452
- - IGF1R, JAK2, KIT, LCK, MAPK14, MAPKAPK2, MET, PTK2, PTPN1, SRC
453
- ### Medium quality
454
- - ABL1, AKT1, AKT2, CDK2, CSF1R, EGFR, KDR, MAPK1, FGFR1, ROCK1
455
- ### Lower quality
456
- - MAP2K1, PLK1
457
- ## Enzyme
458
- ### Highest quality
459
- - HSD11B1, PARP1, PDE5A, PTGS2
460
- ### Medium quality
461
- - ACHE, MAOB
462
- ### Lower quality
463
- - CA2, GBA, HMGCR, NOS1, REN, DHFR
464
- ## Nuclear Receptor
465
- ### Highest quality
466
- - ESR1, ESR2, NR3C1, PGR, PPARA, PPARD, PPARG
467
- ### Medium quality
468
- - AR
469
- ### Lower quality
470
- - THRB
471
- ## Protease
472
- ### Higher quality
473
- - ADAM17, F10, F2
474
- ### Medium quality
475
- - BACE1, CASP3, MMP13
476
- ### Lower quality
477
- - DPP4
478
- ## GPCR
479
- ### Medium quality
480
- - ADRB1, ADRB2, DRD2, DRD3
481
- ### Lower quality
482
- - ADORA2A
483
- ## Cytochrome
484
- ### Medium quality
485
- - CYP2C9, CYP3A4
486
- ## Chaperone
487
- ### Lower quality
488
- - HSP90AA1
489
- ''')
490
- with gr.Row():
491
- with gr.Column():
492
- smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
493
- protein = gr.Dropdown(dudes, label="Protein name (see options): ")
494
- task = gr.Textbox(label="Task for Agent: ")
495
- # add variables as needed
496
- calc_btn = gr.Button(value = "Submit to Agent")
497
- with gr.Column():
498
- props = gr.Textbox(label="Agent results: ", lines=20 )
499
- pic = gr.Image(label="Molecule")
500
-
501
- calc_btn.click(DockAgent, inputs = [task, smiles, protein], outputs = [props, pic])
502
- task.submit(DockAgent, inputs = [task, smiles, protein], outputs = [props, pic])
503
-
504
  forest.launch(debug=False, mcp_server=True)
 
1
+ import torch
2
+ from typing import Annotated, TypedDict, Literal
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langgraph.graph import StateGraph, START, END
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
+
10
+ from langchain_huggingface.llms import HuggingFacePipeline
11
+ from langchain_huggingface import ChatHuggingFace
12
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
+ from langchain_core.runnables import chain
14
+ from uuid import uuid4
15
+ import re
16
+ import matplotlib.pyplot as plt
17
+ import spaces
18
+
19
+ from dockstring import load_target
20
+ from rdkit import Chem
21
+ from rdkit.Chem import AllChem, QED
22
+ from rdkit.Chem import Draw
23
+ from rdkit.Chem.Draw import MolsToGridImage
24
+ import os, re
25
+ import gradio as gr
26
+ from PIL import Image
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ hf = HuggingFacePipeline.from_model_id(
31
+ model_id= "microsoft/Phi-4-mini-instruct",
32
+ task="text-generation",
33
+ pipeline_kwargs = {"max_new_tokens": 500, "temperature": 0.4})
34
+
35
+ chat_model = ChatHuggingFace(llm=hf)
36
+
37
+ cpuCount = os.cpu_count()
38
+ print(f"Number of CPUs: {cpuCount}")
39
+
40
+ class State(TypedDict):
41
+ '''
42
+ The state of the agent.
43
+ '''
44
+ messages: Annotated[list, add_messages]
45
+ #for the agent
46
+ tool_choice: tuple
47
+ which_tool: int
48
+ props_string: str
49
+ similars_img: str
50
+ loop_again: str
51
+
52
+ #for the user input
53
+ query_smiles: str
54
+ query_task: str
55
+ query_protein: str
56
+
57
+ def docking_node(state: State) -> State:
58
+ '''
59
+ Docking tool: uses dockstring to dock the molecule into the protein
60
+ '''
61
+ print("docking tool")
62
+ print('===================================================')
63
+ current_props_string = state["props_string"]
64
+ query_protein = state["query_protein"].strip()
65
+ query_smiles = state["query_smiles"].strip()
66
+
67
+ print(f'query_protein: {query_protein}')
68
+ print(f'query_smiles: {query_smiles}')
69
+
70
+ try:
71
+ target = load_target(query_protein)
72
+ print("===============================================")
73
+ print(f"Docking molecule with {cpuCount} cpu cores.")
74
+ score, aux = target.dock(query_smiles, num_cpus = cpuCount)
75
+ mol = aux['ligand']
76
+ print(f"Docking score: {score}")
77
+ print("===============================================")
78
+ atoms_list = ""
79
+ template = mol
80
+ molH = Chem.AddHs(mol)
81
+ AllChem.ConstrainedEmbed(molH,template, useTethers=True)
82
+ xyz_string = f"{molH.GetNumAtoms()}\n\n"
83
+ for atom in molH.GetAtoms():
84
+ atoms_list += atom.GetSymbol()
85
+ pos = molH.GetConformer().GetAtomPosition(atom.GetIdx())
86
+ xyz_string += f"{atom.GetSymbol()} {pos[0]} {pos[1]} {pos[2]}\n"
87
+ prop_string = f"Docking score: {score} kcal/mol \n\n"
88
+ prop_string += f"pose structure: {xyz_string}\n"
89
+
90
+ except:
91
+ print(f"Molecule could not be docked!")
92
+ prop_string = ''
93
+
94
+
95
+ current_props_string += prop_string
96
+ state["props_string"] = current_props_string
97
+ state["which_tool"] += 1
98
+ return state
99
+
100
+ def first_node(state: State) -> State:
101
+ '''
102
+ The first node of the agent. This node receives the input and asks the LLM
103
+ to determine which is the best tool to use to answer the QUERY TASK.
104
+ Input: the initial prompt from the user. should contain only one of more of the following:
105
+ smiles: the smiles string, task: the query task, path: the path to the file,
106
+ reference: the reference smiles
107
+ the value should be separated from the name by a ':' and each field should
108
+ be separated from the previous one by a ','.
109
+ All of these values are saved to the state
110
+ Output: the tool choice
111
+ '''
112
+ #for the user input
113
+ query_smiles = None
114
+ state["query_smiles"] = query_smiles
115
+ query_task = None
116
+ state["query_task"] = query_task
117
+ query_protein = None
118
+ state["query_protein"] = query_protein
119
+ #for the agent
120
+ state['similars_img'] = None
121
+ props_string = ""
122
+ state["props_string"] = props_string
123
+ state["loop_again"] = None
124
+
125
+ raw_input = state["messages"][-1].content
126
+ #print(raw_input)
127
+ parts = raw_input.split(',')
128
+ for part in parts:
129
+ if 'query_smiles' in part:
130
+ query_smiles = part.split(':')[1]
131
+ if query_smiles.lower() == 'none':
132
+ query_smiles = None
133
+ state["query_smiles"] = query_smiles
134
+ if 'query_task' in part:
135
+ query_task = part.split(':')[1]
136
+ state["query_task"] = query_task
137
+ if 'query_protein' in part:
138
+ query_protein = part.split(':')[1]
139
+ state["query_protein"] = query_protein
140
+
141
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
142
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
143
+ are required, reply with both tool names separated by a comma and followed by "#". \
144
+ If the tools cannot complete the task, reply with "None #".\n \
145
+ QUERY_TASK: {query_task}.\n \
146
+ The information provided by the user is:\n \
147
+ QUERY_SMILES: {query_smiles}.\n \
148
+ QUERY_PROTEIN: {query_protein}.\n \
149
+ Tools: \n \
150
+ docking_tool: uses dockstring to dock the molecule into the protein, producing a pose structure and a docking score.\n \
151
+ '
152
+
153
+ res = chat_model.invoke(prompt)
154
+ print(res)
155
+
156
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
157
+ tool_choices = tool_choices.split(',')
158
+ print(tool_choices)
159
+
160
+ if len(tool_choices) == 1:
161
+ tool1 = tool_choices[0].strip()
162
+ if tool1.lower() == 'none':
163
+ tool_choice = (None, None)
164
+ else:
165
+ tool_choice = (tool1, None)
166
+ elif len(tool_choices) == 2:
167
+ tool1 = tool_choices[0].strip()
168
+ tool2 = tool_choices[1].strip()
169
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
170
+ tool_choice = (None, None)
171
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
172
+ tool_choice = (None, tool2)
173
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
174
+ tool_choice = (tool1, None)
175
+ else:
176
+ tool_choice = (tool1, tool2)
177
+ else:
178
+ tool_choice = (None, None)
179
+
180
+ state["tool_choice"] = tool_choice
181
+ state["which_tool"] = 0
182
+ print(f"First Node. The chosen tools are: {tool_choice}")
183
+
184
+ return state
185
+
186
+ def retry_node(state: State) -> State:
187
+ '''
188
+ If the previous loop of the agent does not get enough informartion from the
189
+ tools to answer the query, this node is called to retry the previous loop.
190
+ Input: the previous loop of the agent.
191
+ Output: the tool choice
192
+ '''
193
+ query_task = state["query_task"]
194
+ query_smiles = state["query_smiles"]
195
+ query_protein = state["query_protein"]
196
+
197
+ prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
198
+ or two of the tools descibed below could complete the task. The tool choices did not succeed. \
199
+ Please re-examine the tool choices and determine if one or two of the tools descibed below \
200
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
201
+ are required, reply with both tool names separated by a comma and followed by "#". \
202
+ If the tools cannot complete the task, reply with "None #".\n \
203
+ The information provided by the user is:\n \
204
+ QUERY_SMILES: {query_smiles}.\n \
205
+ QUERY_PROTEIN: {query_protein}.\n \
206
+ The task is: \
207
+ QUERY_TASK: {query_task}.\n \
208
+ Tool options: \n \
209
+ docking_tool: uses dockstring to dock the molecule into the protein, producing a pose structure and a docking score.\n \
210
+ '
211
+
212
+ res = chat_model.invoke(prompt)
213
+
214
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
215
+ tool_choices = tool_choices.split(',')
216
+ if len(tool_choices) == 1:
217
+ tool1 = tool_choices[0].strip()
218
+ if tool1.lower() == 'none':
219
+ tool_choice = (None, None)
220
+ else:
221
+ tool_choice = (tool1, None)
222
+ elif len(tool_choices) == 2:
223
+ tool1 = tool_choices[0].strip()
224
+ tool2 = tool_choices[1].strip()
225
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
226
+ tool_choice = (None, None)
227
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
228
+ tool_choice = (None, tool2)
229
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
230
+ tool_choice = (tool1, None)
231
+ else:
232
+ tool_choice = (tool1, tool2)
233
+ else:
234
+ tool_choice = (None, None)
235
+
236
+ state["tool_choice"] = tool_choice
237
+ state["which_tool"] = 0
238
+ print(f"The chosen tools are (Retry): {tool_choice}")
239
+
240
+ return state
241
+
242
+ def loop_node(state: State) -> State:
243
+ '''
244
+ This node accepts the tool returns and decides if it needs to call another
245
+ tool or go on to the parser node.
246
+ Input: the tool returns.
247
+ Output: the next node to call.
248
+ '''
249
+ return state
250
+
251
+ def parser_node(state: State) -> State:
252
+ '''
253
+ This is the third node in the agent. It receives the output from the tool,
254
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
255
+ query.
256
+ Input: the output from the tool.
257
+ Output: the answer to the original query.
258
+ '''
259
+ props_string = state["props_string"]
260
+ query_task = state["query_task"]
261
+ tool_choice = state["tool_choice"]
262
+
263
+ if type(tool_choice) != tuple and tool_choice == None:
264
+ state["loop_again"] = "finish_gracefully"
265
+ return state
266
+ elif type(tool_choice) == tuple and (tool_choice[0] == None) and (tool_choice[1] == None):
267
+ state["loop_again"] = "finish_gracefully"
268
+ return state
269
+
270
+ prompt = f'Using the CONTEXT below, answer the original query, which \
271
+ was to answer the QUERY_TASK. End your answer with a "#" \
272
+ CONTEXT: {props_string}.\n \
273
+ QUERY_TASK: {query_task}.\n '
274
+
275
+ res = chat_model.invoke(prompt)
276
+ trial_answer = str(res).split('<|assistant|>')[1]
277
+ print('parser 1 ', trial_answer)
278
+ state["messages"] = res
279
+
280
+ check_prompt = f'Determine if the TRIAL ANSWER below answers the original \
281
+ QUERY TASK. If it does, respond with "PROCEED #" . If the TRIAL ANSWER did not \
282
+ answer the QUERY TASK, respond with "LOOP #" \n \
283
+ Only loop again if the TRIAL ANSWER did not answer the QUERY TASK. \
284
+ TRIAL ANSWER: {trial_answer}.\n \
285
+ QUERY_TASK: {query_task}.\n'
286
+
287
+ res = chat_model.invoke(check_prompt)
288
+ print('parser, loop again? ', res)
289
+
290
+ if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
291
+ state["loop_again"] = "loop_again"
292
+ return state
293
+ elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
294
+ state["loop_again"] = None
295
+ print('trying to break loop')
296
+ elif "proceed" in str(res).split('<|assistant|>')[1].lower():
297
+ state["loop_again"] = None
298
+ print('trying to break loop')
299
+
300
+ return state
301
+
302
+ def reflect_node(state: State) -> State:
303
+ '''
304
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
305
+ tries to improve it.
306
+ Input: the LLMs last answer.
307
+ Output: the improved answer.
308
+ '''
309
+ previous_answer = state["messages"][-1].content
310
+ props_string = state["props_string"]
311
+
312
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
313
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
314
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
315
+ your new answer with a "#" \
316
+ PREVIOUS ANSWER: {previous_answer}.\n \
317
+ TOOL RESULTS: {props_string}. '
318
+
319
+ res = chat_model.invoke(prompt)
320
+
321
+ return {"messages": res}
322
+
323
+ def graceful_exit_node(state: State) -> State:
324
+ '''
325
+ Called when the Agent cannot assign any tools for the task
326
+ '''
327
+ props_string = state["props_string"]
328
+ prompt = f'Summarize the information in the CONTEXT, including any useful chemical information. Start your answer with: \
329
+ Here is what I found: \n \
330
+ CONTEXT: {props_string}'
331
+
332
+ res = chat_model.invoke(prompt)
333
+
334
+ return {"messages": res}
335
+
336
+
337
+ def get_chemtool(state):
338
+ '''
339
+ '''
340
+ which_tool = state["which_tool"]
341
+ tool_choice = state["tool_choice"]
342
+ print('in get_chemtool ',tool_choice)
343
+ if tool_choice == None:
344
+ return None
345
+ if which_tool == 0 or which_tool == 1:
346
+ current_tool = tool_choice[which_tool]
347
+ elif which_tool > 1:
348
+ current_tool = None
349
+
350
+ return current_tool
351
+
352
+ def loop_or_not(state):
353
+ '''
354
+ '''
355
+ print(f"(line 417) Loop? {state['loop_again']}")
356
+ if state["loop_again"] == "loop_again":
357
+ return True
358
+ elif state["loop_again"] == "finish_gracefully":
359
+ return 'lets_get_outta_here'
360
+ else:
361
+ return False
362
+
363
+ builder = StateGraph(State)
364
+ #for the agent
365
+ builder.add_node("first_node", first_node)
366
+ builder.add_node("retry_node", retry_node)
367
+ builder.add_node("loop_node", loop_node)
368
+ builder.add_node("parser_node", parser_node)
369
+ builder.add_node("reflect_node", reflect_node)
370
+ builder.add_node("graceful_exit_node", graceful_exit_node)
371
+ #for the tools
372
+ builder.add_node("docking_node", docking_node)
373
+
374
+ builder.add_edge(START, "first_node")
375
+ builder.add_conditional_edges("first_node", get_chemtool, {
376
+ "docking_tool": "docking_node",
377
+ None: "parser_node"})
378
+
379
+ builder.add_conditional_edges("retry_node", get_chemtool, {
380
+ "docking_tool": "docking_node",
381
+ None: "parser_node"})
382
+
383
+ builder.add_edge("docking_node", "loop_node")
384
+
385
+ builder.add_conditional_edges("loop_node", get_chemtool, {
386
+ "docking_tool" : "docking_node",
387
+ "loop_again": "first_node",
388
+ None: "parser_node"})
389
+
390
+ builder.add_conditional_edges("parser_node", loop_or_not, {
391
+ True: "retry_node",
392
+ 'lets_get_outta_here': "graceful_exit_node",
393
+ False: "reflect_node"})
394
+
395
+ builder.add_edge("reflect_node", END)
396
+ builder.add_edge("graceful_exit_node", END)
397
+
398
+ graph = builder.compile()
399
+
400
+ @spaces.GPU
401
+ def DockAgent(task, smiles, protein): # add variables as needed
402
+
403
+ #if Similars_image.png exists, remove it
404
+ if os.path.exists('Similars_image.png'):
405
+ os.remove('Similars_image.png')
406
+
407
+ input = {
408
+ "messages": [
409
+ HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_protein: {protein}') # add variables as needed
410
+ ]
411
+ }
412
+ #print(input)
413
+
414
+ replies = []
415
+ for c in graph.stream(input): #, stream_mode='updates'):
416
+ m = re.findall(r'[a-z]+\_node', str(c))
417
+ if len(m) != 0:
418
+ try:
419
+ reply = c[str(m[0])]['messages']
420
+ if 'assistant' in str(reply):
421
+ reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
422
+ replies.append(reply)
423
+ except:
424
+ reply = str(c).split("<|assistant|>")[-1].split('#')[0].strip()
425
+ replies.append(reply)
426
+ #check if image exists
427
+ if os.path.exists('Similars_image.png'):
428
+ img_loc = 'Similars_image.png'
429
+ img = Image.open(img_loc)
430
+ #else create a dummy blank image
431
+ else:
432
+ img = Image.new('RGB', (250, 250), color = (255, 255, 255))
433
+
434
+ return replies[-1], img
435
+
436
+ dudes = ['IGF1R', 'JAK2', 'KIT', 'LCK', 'MAPK14', 'MAPKAPK2', 'MET', 'PTK2', 'PTPN1', 'SRC', 'ABL1', 'AKT1', 'AKT2', 'CDK2', 'CSF1R', 'EGFR', 'KDR', 'MAPK1', 'FGFR1', 'ROCK1', 'MAP2K1', 'PLK1',
437
+ 'HSD11B1', 'PARP1', 'PDE5A', 'PTGS2', 'ACHE', 'MAOB', 'CA2', 'GBA', 'HMGCR', 'NOS1', 'REN', 'DHFR', 'ESR1', 'ESR2', 'NR3C1', 'PGR', 'PPARA', 'PPARD', 'PPARG',
438
+ 'AR','THRB','ADAM17', 'F10', 'F2', 'BACE1', 'CASP3', 'MMP13', 'DPP4', 'ADRB1', 'ADRB2', 'DRD2', 'DRD3','ADORA2A','CYP2C9', 'CYP3A4', 'HSP90AA1']
439
+
440
+ with gr.Blocks(fill_height=True) as forest:
441
+ gr.Markdown('''
442
+ # Docking Agent
443
+ - uses dockstring to dock a molecule into a protein using only a SMILES string and a protein name
444
+ - produces a pose structure and a docking score
445
+ ''')
446
+
447
+ with gr.Accordion("ProteinOptions", open=False):
448
+ gr.Markdown('''
449
+ # Protein Options
450
+ ## Kinase
451
+ ### Highest quality
452
+ - IGF1R, JAK2, KIT, LCK, MAPK14, MAPKAPK2, MET, PTK2, PTPN1, SRC
453
+ ### Medium quality
454
+ - ABL1, AKT1, AKT2, CDK2, CSF1R, EGFR, KDR, MAPK1, FGFR1, ROCK1
455
+ ### Lower quality
456
+ - MAP2K1, PLK1
457
+ ## Enzyme
458
+ ### Highest quality
459
+ - HSD11B1, PARP1, PDE5A, PTGS2
460
+ ### Medium quality
461
+ - ACHE, MAOB
462
+ ### Lower quality
463
+ - CA2, GBA, HMGCR, NOS1, REN, DHFR
464
+ ## Nuclear Receptor
465
+ ### Highest quality
466
+ - ESR1, ESR2, NR3C1, PGR, PPARA, PPARD, PPARG
467
+ ### Medium quality
468
+ - AR
469
+ ### Lower quality
470
+ - THRB
471
+ ## Protease
472
+ ### Higher quality
473
+ - ADAM17, F10, F2
474
+ ### Medium quality
475
+ - BACE1, CASP3, MMP13
476
+ ### Lower quality
477
+ - DPP4
478
+ ## GPCR
479
+ ### Medium quality
480
+ - ADRB1, ADRB2, DRD2, DRD3
481
+ ### Lower quality
482
+ - ADORA2A
483
+ ## Cytochrome
484
+ ### Medium quality
485
+ - CYP2C9, CYP3A4
486
+ ## Chaperone
487
+ ### Lower quality
488
+ - HSP90AA1
489
+ ''')
490
+ with gr.Row():
491
+ with gr.Column():
492
+ smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
493
+ protein = gr.Dropdown(dudes, label="Protein name (see options): ")
494
+ task = gr.Textbox(label="Task for Agent: ")
495
+ # add variables as needed
496
+ calc_btn = gr.Button(value = "Submit to Agent")
497
+ with gr.Column():
498
+ props = gr.Textbox(label="Agent results: ", lines=20 )
499
+ pic = gr.Image(label="Molecule")
500
+
501
+ calc_btn.click(DockAgent, inputs = [task, smiles, protein], outputs = [props, pic])
502
+ task.submit(DockAgent, inputs = [task, smiles, protein], outputs = [props, pic])
503
+
504
  forest.launch(debug=False, mcp_server=True)