cafierom commited on
Commit
2b31e2e
·
verified ·
1 Parent(s): 04f1733

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -1
app.py CHANGED
@@ -29,7 +29,6 @@ import gradio as gr
29
  from PIL import Image
30
 
31
  from chem_nodes import *
32
- from agent_nodes import *
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
@@ -40,6 +39,252 @@ hf = HuggingFacePipeline.from_model_id(
40
 
41
  chat_model = ChatHuggingFace(llm=hf)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  builder = StateGraph(State)
44
  builder.add_node("first_node", first_node)
45
  builder.add_node("retry_node", retry_node)
 
29
  from PIL import Image
30
 
31
  from chem_nodes import *
 
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
 
39
 
40
  chat_model = ChatHuggingFace(llm=hf)
41
 
42
+ def first_node(state: State) -> State:
43
+ '''
44
+ The first node of the agent. This node receives the input and asks the LLM
45
+ to determine which is the best tool to use to answer the QUERY TASK.
46
+ Input: the initial prompt from the user. should contain only one of more of the following:
47
+ smiles: the smiles string, task: the query task, path: the path to the file,
48
+ reference: the reference smiles
49
+ the value should be separated from the name by a ':' and each field should
50
+ be separated from the previous one by a ','.
51
+ All of these values are saved to the state
52
+ Output: the tool choice
53
+ '''
54
+ query_smiles = None
55
+ state["query_smiles"] = query_smiles
56
+ query_task = None
57
+ state["query_task"] = query_task
58
+ query_name = None
59
+ state["query_name"] = query_name
60
+ query_reference = None
61
+ state["query_reference"] = query_reference
62
+ state['similars_img'] = None
63
+ props_string = ""
64
+ state["props_string"] = props_string
65
+ state["loop_again"] = None
66
+
67
+ raw_input = state["messages"][-1].content
68
+ #print(raw_input)
69
+ parts = raw_input.split(',')
70
+ for part in parts:
71
+ if 'query_smiles' in part:
72
+ query_smiles = part.split(':')[1]
73
+ if query_smiles.lower() == 'none':
74
+ query_smiles = None
75
+ state["query_smiles"] = query_smiles
76
+ if 'query_task' in part:
77
+ query_task = part.split(':')[1]
78
+ state["query_task"] = query_task
79
+ if 'query_name' in part:
80
+ query_name = part.split(':')[1]
81
+ if query_name.lower() == 'none':
82
+ query_name = None
83
+ state["query_name"] = query_name
84
+ if 'query_reference' in part:
85
+ query_reference = part.split(':')[1]
86
+ state["query_reference"] = query_reference
87
+
88
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
89
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
90
+ are required, reply with both tool names separated by a comma and followed by "#". \
91
+ If the tools cannot complete the task, reply with "None #".\n \
92
+ QUERY_TASK: {query_task}.\n \
93
+ The information provided by the user is:\n \
94
+ QUERY_SMILES: {query_smiles}.\n \
95
+ QUERY_NAME: {query_name}.\n \
96
+ Tools: \n \
97
+ smiles_tool: queries Pubchem for the smiles string of the molecule based on the name.\n \
98
+ name_tool: queries Pubchem for the NAME of the molecule based on the smiles string.\n \
99
+ similars_tool: queries Pubchem for similar molecules based on the smiles string or name and returns 20 results. \
100
+ returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
101
+ '
102
+
103
+ res = chat_model.invoke(prompt)
104
+
105
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
106
+ tool_choices = tool_choices.split(',')
107
+ if len(tool_choices) == 1:
108
+ if tool_choices[0].strip().lower() == 'none':
109
+ tool_choice = (None, None)
110
+ else:
111
+ tool_choice = (tool_choices[0].strip().lower(), None)
112
+ elif len(tool_choices) == 2:
113
+ if tool_choices[0].strip().lower() == 'none':
114
+ tool_choice = (None, tool_choices[1].strip().lower())
115
+ elif tool_choices[1].strip().lower() == 'none':
116
+ tool_choice = (tool_choices[0].strip().lower(), None)
117
+ else:
118
+ tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
119
+ else:
120
+ tool_choice = None
121
+
122
+ state["tool_choice"] = tool_choice
123
+ state["which_tool"] = 0
124
+ print(f"The chosen tools are: {tool_choice}")
125
+
126
+ return state
127
+
128
+ def retry_node(state: State) -> State:
129
+ '''
130
+ If the previous loop of the agent does not get enough informartion from the
131
+ tools to answer the query, this node is called to retry the previous loop.
132
+ Input: the previous loop of the agent.
133
+ Output: the tool choice
134
+ '''
135
+ query_task = state["query_task"]
136
+ query_smiles = state["query_smiles"]
137
+ query_name = state["query_name"]
138
+
139
+ prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
140
+ or two of the tools descibed below could complete the task. TYou tool choices did not succeed. \
141
+ Please re-examine the tool choices and determine if one or two of the tools descibed below \
142
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
143
+ are required, reply with both tool names separated by a comma and followed by "#". \
144
+ If the tools cannot complete the task, reply with "None #".\n \
145
+ QUERY_TASK: {query_task}.\n \
146
+ The information provided by the user is:\n \
147
+ QUERY_SMILES: {query_smiles}.\n \
148
+ QUERY_NAME: {query_name}.\n \
149
+ Tools: \n \
150
+ smiles_tool: queries Pubchem for the smiles string of the molecule based on the name as input.\n \
151
+ name_tool: queries Pubchem for the NAME (IUPAC) of the molecule based on the smiles string as input. \
152
+ Also returns a short list of common names for the molecule. \n \
153
+ similars_tool: queries Pubchem for similar molecules based on the smiles string or name as input and returns 20 results. \
154
+ Returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
155
+ '
156
+
157
+ res = chat_model.invoke(prompt)
158
+
159
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
160
+ tool_choices = tool_choices.split(',')
161
+ if len(tool_choices) == 1:
162
+ if tool_choices[0].strip().lower() == 'none':
163
+ tool_choice = (None, None)
164
+ else:
165
+ tool_choice = (tool_choices[0].strip().lower(), None)
166
+ elif len(tool_choices) == 2:
167
+ if tool_choices[0].strip().lower() == 'none':
168
+ tool_choice = (None, tool_choices[1].strip().lower())
169
+ elif tool_choices[1].strip().lower() == 'none':
170
+ tool_choice = (tool_choices[0].strip().lower(), None)
171
+ else:
172
+ tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
173
+ elif 'none' in tool_choices[0].strip().lower():
174
+ tool_choice = None
175
+ else:
176
+ tool_choice = None
177
+
178
+ state["tool_choice"] = tool_choice
179
+ state["which_tool"] = 0
180
+ print(f"The chosen tools are (Retry): {tool_choice}")
181
+
182
+ return state
183
+
184
+ def loop_node(state: State) -> State:
185
+ '''
186
+ This node accepts the tool returns and decides if it needs to call another
187
+ tool or go on to the parser node.
188
+ Input: the tool returns.
189
+ Output: the next node to call.
190
+ '''
191
+ return state
192
+
193
+ def parser_node(state: State) -> State:
194
+ '''
195
+ This is the third node in the agent. It receives the output from the tool,
196
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
197
+ query.
198
+ Input: the output from the tool.
199
+ Output: the answer to the original query.
200
+ '''
201
+ props_string = state["props_string"]
202
+ query_task = state["query_task"]
203
+
204
+ check_prompt = f'Determine if there is enough CONTEXT below to answer the original \
205
+ QUERY TASK. If there is, respond with "PROCEED #" . If there is not enough information \
206
+ to answer the QUERY TASK, respond with "LOOP #" \n \
207
+ CONTEXT: {props_string}.\n \
208
+ QUERY_TASK: {query_task}.\n'
209
+
210
+ res = chat_model.invoke(check_prompt)
211
+ # print('*'*50)
212
+ # print(res)
213
+ # print('*'*50)
214
+ if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
215
+ state["loop_again"] = "loop_again"
216
+ return state
217
+ elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
218
+ state["loop_again"] = None
219
+
220
+ prompt = f'Using the CONTEXT below, answer the original query, which \
221
+ was to answer the QUERY_TASK. End your answer with a "#" \
222
+ QUERY_TASK: {query_task}.\n \
223
+ CONTEXT: {props_string}.\n '
224
+
225
+ res = chat_model.invoke(prompt)
226
+ return {"messages": res}
227
+
228
+ def reflect_node(state: State) -> State:
229
+ '''
230
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
231
+ tries to improve it.
232
+ Input: the LLMs last answer.
233
+ Output: the improved answer.
234
+ '''
235
+ previous_answer = state["messages"][-1].content
236
+ props_string = state["props_string"]
237
+
238
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
239
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
240
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
241
+ your new answer with a "#" \
242
+ PREVIOUS ANSWER: {previous_answer}.\n \
243
+ TOOL RESULTS: {props_string}. '
244
+
245
+ res = chat_model.invoke(prompt)
246
+ return {"messages": res}
247
+
248
+ def get_chemtool(state):
249
+ '''
250
+ '''
251
+ which_tool = state["which_tool"]
252
+ tool_choice = state["tool_choice"]
253
+ #print(tool_choice)
254
+ if tool_choice == None:
255
+ return None
256
+ if which_tool == 0 or which_tool == 1:
257
+ current_tool = tool_choice[which_tool]
258
+ if current_tool == "smiles_tool" and ("query_name" not in state.keys()):
259
+ current_tool = "name_tool"
260
+ print("Switching from smiles tool to name tool")
261
+ elif current_tool == "name_tool" and ("query_smiles" not in state.keys()):
262
+ current_tool = "smiles_tool"
263
+ print("Switching from name tool to smiles tool")
264
+
265
+ elif which_tool > 1:
266
+ current_tool = None
267
+
268
+ return current_tool
269
+
270
+ def loop_or_not(state):
271
+ '''
272
+ '''
273
+ print(f"Loop? {state['loop_again']}")
274
+ if state["loop_again"] == "loop_again":
275
+ return True
276
+ else:
277
+ return False
278
+
279
+ def pretty_print(answer):
280
+ final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
281
+ for i in range(0,len(final),100):
282
+ print(final[i:i+100])
283
+
284
+ def print_short(answer):
285
+ for i in range(0,len(answer),100):
286
+ print(answer[i:i+100])
287
+
288
  builder = StateGraph(State)
289
  builder.add_node("first_node", first_node)
290
  builder.add_node("retry_node", retry_node)