cafierom commited on
Commit
acf312d
·
verified ·
1 Parent(s): 7c995c8

Create agent_nodes.py

Browse files
Files changed (1) hide show
  1. agent_nodes.py +275 -0
agent_nodes.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Annotated, TypedDict, Literal
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langgraph.graph import StateGraph, START, END
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
+
10
+ from langchain_huggingface.llms import HuggingFacePipeline
11
+ from langchain_huggingface import ChatHuggingFace
12
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
+ from langchain_core.runnables import chain
14
+ from uuid import uuid4
15
+ import re
16
+ import matplotlib.pyplot as plt
17
+
18
+ import gradio as gr
19
+ from PIL import Image
20
+
21
+ def first_node(state: State) -> State:
22
+ '''
23
+ The first node of the agent. This node receives the input and asks the LLM
24
+ to determine which is the best tool to use to answer the QUERY TASK.
25
+
26
+ Input: the initial prompt from the user. should contain only one of more of the following:
27
+
28
+ smiles: the smiles string, task: the query task, path: the path to the file,
29
+ reference: the reference smiles
30
+
31
+ the value should be separated from the name by a ':' and each field should
32
+ be separated from the previous one by a ','.
33
+
34
+ All of these values are saved to the state
35
+
36
+ Output: the tool choice
37
+ '''
38
+ query_smiles = None
39
+ state["query_smiles"] = query_smiles
40
+ query_task = None
41
+ state["query_task"] = query_task
42
+ query_name = None
43
+ state["query_name"] = query_name
44
+ query_reference = None
45
+ state["query_reference"] = query_reference
46
+ state['similars_img'] = None
47
+ props_string = ""
48
+ state["props_string"] = props_string
49
+ state["loop_again"] = None
50
+
51
+ raw_input = state["messages"][-1].content
52
+ #print(raw_input)
53
+ parts = raw_input.split(',')
54
+ for part in parts:
55
+ if 'query_smiles' in part:
56
+ query_smiles = part.split(':')[1]
57
+ if query_smiles.lower() == 'none':
58
+ query_smiles = None
59
+ state["query_smiles"] = query_smiles
60
+ if 'query_task' in part:
61
+ query_task = part.split(':')[1]
62
+ state["query_task"] = query_task
63
+ if 'query_name' in part:
64
+ query_name = part.split(':')[1]
65
+ if query_name.lower() == 'none':
66
+ query_name = None
67
+ state["query_name"] = query_name
68
+ if 'query_reference' in part:
69
+ query_reference = part.split(':')[1]
70
+ state["query_reference"] = query_reference
71
+
72
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
73
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
74
+ are required, reply with both tool names separated by a comma and followed by "#". \
75
+ If the tools cannot complete the task, reply with "None #".\n \
76
+ QUERY_TASK: {query_task}.\n \
77
+ The information provided by the user is:\n \
78
+ QUERY_SMILES: {query_smiles}.\n \
79
+ QUERY_NAME: {query_name}.\n \
80
+ Tools: \n \
81
+ smiles_tool: queries Pubchem for the smiles string of the molecule based on the name.\n \
82
+ name_tool: queries Pubchem for the NAME of the molecule based on the smiles string.\n \
83
+ similars_tool: queries Pubchem for similar molecules based on the smiles string or name and returns 20 results. \
84
+ returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
85
+ '
86
+
87
+ res = chat_model.invoke(prompt)
88
+
89
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
90
+ tool_choices = tool_choices.split(',')
91
+ if len(tool_choices) == 1:
92
+ if tool_choices[0].strip().lower() == 'none':
93
+ tool_choice = (None, None)
94
+ else:
95
+ tool_choice = (tool_choices[0].strip().lower(), None)
96
+ elif len(tool_choices) == 2:
97
+ if tool_choices[0].strip().lower() == 'none':
98
+ tool_choice = (None, tool_choices[1].strip().lower())
99
+ elif tool_choices[1].strip().lower() == 'none':
100
+ tool_choice = (tool_choices[0].strip().lower(), None)
101
+ else:
102
+ tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
103
+ else:
104
+ tool_choice = None
105
+
106
+ state["tool_choice"] = tool_choice
107
+ state["which_tool"] = 0
108
+ print(f"The chosen tools are: {tool_choice}")
109
+
110
+ return state
111
+
112
+ def retry_node(state: State) -> State:
113
+ '''
114
+ If the previous loop of the agent does not get enough informartion from the
115
+ tools to answer the query, this node is called to retry the previous loop.
116
+
117
+ Input: the previous loop of the agent.
118
+
119
+ Output: the tool choice
120
+ '''
121
+ query_task = state["query_task"]
122
+ query_smiles = state["query_smiles"]
123
+ query_name = state["query_name"]
124
+
125
+ prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
126
+ or two of the tools descibed below could complete the task. TYou tool choices did not succeed. \
127
+ Please re-examine the tool choices and determine if one or two of the tools descibed below \
128
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
129
+ are required, reply with both tool names separated by a comma and followed by "#". \
130
+ If the tools cannot complete the task, reply with "None #".\n \
131
+ QUERY_TASK: {query_task}.\n \
132
+ The information provided by the user is:\n \
133
+ QUERY_SMILES: {query_smiles}.\n \
134
+ QUERY_NAME: {query_name}.\n \
135
+ Tools: \n \
136
+ smiles_tool: queries Pubchem for the smiles string of the molecule based on the name as input.\n \
137
+ name_tool: queries Pubchem for the NAME (IUPAC) of the molecule based on the smiles string as input. \
138
+ Also returns a short list of common names for the molecule. \n \
139
+ similars_tool: queries Pubchem for similar molecules based on the smiles string or name as input and returns 20 results. \
140
+ Returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
141
+ '
142
+
143
+ res = chat_model.invoke(prompt)
144
+
145
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
146
+ tool_choices = tool_choices.split(',')
147
+ if len(tool_choices) == 1:
148
+ if tool_choices[0].strip().lower() == 'none':
149
+ tool_choice = (None, None)
150
+ else:
151
+ tool_choice = (tool_choices[0].strip().lower(), None)
152
+ elif len(tool_choices) == 2:
153
+ if tool_choices[0].strip().lower() == 'none':
154
+ tool_choice = (None, tool_choices[1].strip().lower())
155
+ elif tool_choices[1].strip().lower() == 'none':
156
+ tool_choice = (tool_choices[0].strip().lower(), None)
157
+ else:
158
+ tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
159
+ elif 'none' in tool_choices[0].strip().lower():
160
+ tool_choice = None
161
+ else:
162
+ tool_choice = None
163
+
164
+ state["tool_choice"] = tool_choice
165
+ state["which_tool"] = 0
166
+ print(f"The chosen tools are (Retry): {tool_choice}")
167
+
168
+ return state
169
+
170
+ def loop_node(state: State) -> State:
171
+ '''
172
+ This node accepts the tool returns and decides if it needs to call another
173
+ tool or go on to the parser node.
174
+
175
+ Input: the tool returns.
176
+ Output: the next node to call.
177
+ '''
178
+ return state
179
+
180
+ def parser_node(state: State) -> State:
181
+ '''
182
+ This is the third node in the agent. It receives the output from the tool,
183
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
184
+ query.
185
+
186
+ Input: the output from the tool.
187
+ Output: the answer to the original query.
188
+ '''
189
+ props_string = state["props_string"]
190
+ query_task = state["query_task"]
191
+
192
+ check_prompt = f'Determine if there is enough CONTEXT below to answer the original \
193
+ QUERY TASK. If there is, respond with "PROCEED #" . If there is not enough information \
194
+ to answer the QUERY TASK, respond with "LOOP #" \n \
195
+ CONTEXT: {props_string}.\n \
196
+ QUERY_TASK: {query_task}.\n'
197
+
198
+ res = chat_model.invoke(check_prompt)
199
+ # print('*'*50)
200
+ # print(res)
201
+ # print('*'*50)
202
+ if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
203
+ state["loop_again"] = "loop_again"
204
+ return state
205
+ elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
206
+ state["loop_again"] = None
207
+
208
+ prompt = f'Using the CONTEXT below, answer the original query, which \
209
+ was to answer the QUERY_TASK. End your answer with a "#" \
210
+ QUERY_TASK: {query_task}.\n \
211
+ CONTEXT: {props_string}.\n '
212
+
213
+ res = chat_model.invoke(prompt)
214
+ return {"messages": res}
215
+
216
+ def reflect_node(state: State) -> State:
217
+ '''
218
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
219
+ tries to improve it.
220
+
221
+ Input: the LLMs last answer.
222
+ Output: the improved answer.
223
+ '''
224
+ previous_answer = state["messages"][-1].content
225
+ props_string = state["props_string"]
226
+
227
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
228
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
229
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
230
+ your new answer with a "#" \
231
+ PREVIOUS ANSWER: {previous_answer}.\n \
232
+ TOOL RESULTS: {props_string}. '
233
+
234
+ res = chat_model.invoke(prompt)
235
+ return {"messages": res}
236
+
237
+ def get_chemtool(state):
238
+ '''
239
+ '''
240
+ which_tool = state["which_tool"]
241
+ tool_choice = state["tool_choice"]
242
+ #print(tool_choice)
243
+ if tool_choice == None:
244
+ return None
245
+ if which_tool == 0 or which_tool == 1:
246
+ current_tool = tool_choice[which_tool]
247
+ if current_tool == "smiles_tool" and ("query_name" not in state.keys()):
248
+ current_tool = "name_tool"
249
+ print("Switching from smiles tool to name tool")
250
+ elif current_tool == "name_tool" and ("query_smiles" not in state.keys()):
251
+ current_tool = "smiles_tool"
252
+ print("Switching from name tool to smiles tool")
253
+
254
+ elif which_tool > 1:
255
+ current_tool = None
256
+
257
+ return current_tool
258
+
259
+ def loop_or_not(state):
260
+ '''
261
+ '''
262
+ print(f"Loop? {state["loop_again"]}")
263
+ if state["loop_again"] == "loop_again":
264
+ return True
265
+ else:
266
+ return False
267
+
268
+ def pretty_print(answer):
269
+ final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
270
+ for i in range(0,len(final),100):
271
+ print(final[i:i+100])
272
+
273
+ def print_short(answer):
274
+ for i in range(0,len(answer),100):
275
+ print(answer[i:i+100])