File size: 10,128 Bytes
acf312d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cfdc46
04f1733
acf312d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a7e38
acf312d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
from typing import Annotated, TypedDict, Literal
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall

from langchain_huggingface.llms import HuggingFacePipeline
from langchain_huggingface import ChatHuggingFace
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.runnables import chain
from uuid import uuid4
import re
import matplotlib.pyplot as plt
from chem_nodes import *
from app import chat_model

import gradio as gr
from PIL import Image

def first_node(state: State) -> State:
  '''
    The first node of the agent. This node receives the input and asks the LLM
    to determine which is the best tool to use to answer the QUERY TASK.

      Input: the initial prompt from the user. should contain only one of more of the following:

             smiles: the smiles string, task: the query task, path: the path to the file,
             reference: the reference smiles

             the value should be separated from the name by a ':' and each field should
             be separated from the previous one by a ','.

             All of these values are saved to the state

      Output: the tool choice
  '''
  query_smiles = None
  state["query_smiles"] = query_smiles
  query_task = None
  state["query_task"] = query_task
  query_name = None
  state["query_name"] = query_name
  query_reference = None
  state["query_reference"] = query_reference
  state['similars_img'] = None
  props_string = ""
  state["props_string"] = props_string
  state["loop_again"] = None

  raw_input = state["messages"][-1].content
  #print(raw_input)
  parts = raw_input.split(',')
  for part in parts:
    if 'query_smiles' in part:
      query_smiles = part.split(':')[1]
      if query_smiles.lower() == 'none':
        query_smiles = None
      state["query_smiles"] = query_smiles
    if 'query_task' in part:
      query_task = part.split(':')[1]
      state["query_task"] = query_task
    if 'query_name' in part:
      query_name = part.split(':')[1]
      if query_name.lower() == 'none':
        query_name = None
      state["query_name"] = query_name
    if 'query_reference' in part:
      query_reference = part.split(':')[1]
      state["query_reference"] = query_reference

  prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
can complete the task. If so, reply with only the tool names followed by "#". If two tools \
are required, reply with both tool names separated by a comma and followed by "#". \
If the tools cannot complete the task, reply with "None #".\n \
QUERY_TASK: {query_task}.\n \
The information provided by the user is:\n \
QUERY_SMILES: {query_smiles}.\n \
QUERY_NAME: {query_name}.\n \
Tools: \n \
smiles_tool: queries Pubchem for the smiles string of the molecule based on the name.\n \
name_tool: queries Pubchem for the NAME of the molecule based on the smiles string.\n \
similars_tool: queries Pubchem for similar molecules based on the smiles string or name and returns 20 results. \
returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
'

  res = chat_model.invoke(prompt)

  tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
  tool_choices = tool_choices.split(',')
  if len(tool_choices) == 1:
    if tool_choices[0].strip().lower() == 'none':
      tool_choice = (None, None)
    else:
      tool_choice = (tool_choices[0].strip().lower(), None)
  elif len(tool_choices) == 2:
    if tool_choices[0].strip().lower() == 'none':
      tool_choice = (None, tool_choices[1].strip().lower())
    elif tool_choices[1].strip().lower() == 'none':
      tool_choice = (tool_choices[0].strip().lower(), None)
    else:
      tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
  else:
    tool_choice = None

  state["tool_choice"] = tool_choice
  state["which_tool"] = 0
  print(f"The chosen tools are: {tool_choice}")

  return state

def retry_node(state: State) -> State:
  '''
    If the previous loop of the agent does not get enough informartion from the 
    tools to answer the query, this node is called to retry the previous loop.

      Input: the previous loop of the agent.

      Output: the tool choice
  '''
  query_task = state["query_task"]
  query_smiles = state["query_smiles"]
  query_name = state["query_name"]

  prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
or two of the tools descibed below could complete the task. TYou tool choices did not succeed. \
Please re-examine the tool choices and determine if one or two of the tools descibed below \
can complete the task. If so, reply with only the tool names followed by "#". If two tools \
are required, reply with both tool names separated by a comma and followed by "#". \
If the tools cannot complete the task, reply with "None #".\n \
QUERY_TASK: {query_task}.\n \
The information provided by the user is:\n \
QUERY_SMILES: {query_smiles}.\n \
QUERY_NAME: {query_name}.\n \
Tools: \n \
smiles_tool: queries Pubchem for the smiles string of the molecule based on the name as input.\n \
name_tool: queries Pubchem for the NAME (IUPAC) of the molecule based on the smiles string as input. \
Also returns a short list of common names for the molecule. \n \
similars_tool: queries Pubchem for similar molecules based on the smiles string or name as input and returns 20 results. \
Returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
'

  res = chat_model.invoke(prompt)

  tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
  tool_choices = tool_choices.split(',')
  if len(tool_choices) == 1:
    if tool_choices[0].strip().lower() == 'none':
      tool_choice = (None, None)
    else:
      tool_choice = (tool_choices[0].strip().lower(), None)
  elif len(tool_choices) == 2:
    if tool_choices[0].strip().lower() == 'none':
      tool_choice = (None, tool_choices[1].strip().lower())
    elif tool_choices[1].strip().lower() == 'none':
      tool_choice = (tool_choices[0].strip().lower(), None)
    else:
      tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
  elif 'none' in tool_choices[0].strip().lower():
    tool_choice = None
  else:
    tool_choice = None

  state["tool_choice"] = tool_choice
  state["which_tool"] = 0
  print(f"The chosen tools are (Retry): {tool_choice}")

  return state

def loop_node(state: State) -> State:
  '''
    This node accepts the tool returns and decides if it needs to call another
    tool or go on to the parser node.

      Input: the tool returns.
      Output: the next node to call.
  '''
  return state

def parser_node(state: State) -> State:
  '''
    This is the third node in the agent. It receives the output from the tool,
    puts it into a prompt as CONTEXT, and asks the LLM to answer the original
    query.

      Input: the output from the tool.
      Output: the answer to the original query.
  '''
  props_string = state["props_string"]
  query_task = state["query_task"]

  check_prompt = f'Determine if there is enough CONTEXT below to answer the original \
QUERY TASK. If there is, respond with "PROCEED #" . If there is not enough information \
to answer the QUERY TASK, respond with "LOOP #" \n \
CONTEXT: {props_string}.\n \
QUERY_TASK: {query_task}.\n'

  res = chat_model.invoke(check_prompt)
  # print('*'*50)
  # print(res)
  # print('*'*50)
  if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
    state["loop_again"] = "loop_again"
    return state
  elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
    state["loop_again"] = None

    prompt = f'Using the CONTEXT below, answer the original query, which \
was to answer the QUERY_TASK. End your answer with a "#" \
QUERY_TASK: {query_task}.\n \
CONTEXT: {props_string}.\n '

    res = chat_model.invoke(prompt)
    return {"messages": res}

def reflect_node(state: State) -> State:
  '''
    This is the fourth node of the agent. It recieves the LLMs previous answer and
    tries to improve it.

      Input: the LLMs last answer.
      Output: the improved answer.
  '''
  previous_answer = state["messages"][-1].content
  props_string = state["props_string"]

  prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
TOOL RESULTS by adding additional clarifying and enriching information. End \
your new answer with a "#" \
PREVIOUS ANSWER: {previous_answer}.\n \
TOOL RESULTS: {props_string}. '

  res = chat_model.invoke(prompt)
  return {"messages": res}

def get_chemtool(state):
  '''
  '''
  which_tool = state["which_tool"]
  tool_choice = state["tool_choice"]
  #print(tool_choice)
  if tool_choice == None:
    return None
  if which_tool == 0 or which_tool == 1:
    current_tool = tool_choice[which_tool]
    if current_tool == "smiles_tool" and ("query_name" not in state.keys()):
      current_tool = "name_tool"
      print("Switching from smiles tool to name tool")
    elif current_tool == "name_tool" and ("query_smiles" not in state.keys()):
      current_tool = "smiles_tool"
      print("Switching from name tool to smiles tool")

  elif which_tool > 1:
    current_tool = None

  return current_tool

def loop_or_not(state):
  '''
  '''
  print(f"Loop? {state['loop_again']}")
  if state["loop_again"] == "loop_again":
    return True
  else:
    return False

def pretty_print(answer):
  final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
  for i in range(0,len(final),100):
    print(final[i:i+100])

def print_short(answer):
  for i in range(0,len(answer),100):
    print(answer[i:i+100])