MoleculeAgent / agent_nodes.py
cafierom's picture
Update agent_nodes.py
04f1733 verified
import torch
from typing import Annotated, TypedDict, Literal
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
from langchain_huggingface.llms import HuggingFacePipeline
from langchain_huggingface import ChatHuggingFace
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.runnables import chain
from uuid import uuid4
import re
import matplotlib.pyplot as plt
from chem_nodes import *
from app import chat_model
import gradio as gr
from PIL import Image
def first_node(state: State) -> State:
'''
The first node of the agent. This node receives the input and asks the LLM
to determine which is the best tool to use to answer the QUERY TASK.
Input: the initial prompt from the user. should contain only one of more of the following:
smiles: the smiles string, task: the query task, path: the path to the file,
reference: the reference smiles
the value should be separated from the name by a ':' and each field should
be separated from the previous one by a ','.
All of these values are saved to the state
Output: the tool choice
'''
query_smiles = None
state["query_smiles"] = query_smiles
query_task = None
state["query_task"] = query_task
query_name = None
state["query_name"] = query_name
query_reference = None
state["query_reference"] = query_reference
state['similars_img'] = None
props_string = ""
state["props_string"] = props_string
state["loop_again"] = None
raw_input = state["messages"][-1].content
#print(raw_input)
parts = raw_input.split(',')
for part in parts:
if 'query_smiles' in part:
query_smiles = part.split(':')[1]
if query_smiles.lower() == 'none':
query_smiles = None
state["query_smiles"] = query_smiles
if 'query_task' in part:
query_task = part.split(':')[1]
state["query_task"] = query_task
if 'query_name' in part:
query_name = part.split(':')[1]
if query_name.lower() == 'none':
query_name = None
state["query_name"] = query_name
if 'query_reference' in part:
query_reference = part.split(':')[1]
state["query_reference"] = query_reference
prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
can complete the task. If so, reply with only the tool names followed by "#". If two tools \
are required, reply with both tool names separated by a comma and followed by "#". \
If the tools cannot complete the task, reply with "None #".\n \
QUERY_TASK: {query_task}.\n \
The information provided by the user is:\n \
QUERY_SMILES: {query_smiles}.\n \
QUERY_NAME: {query_name}.\n \
Tools: \n \
smiles_tool: queries Pubchem for the smiles string of the molecule based on the name.\n \
name_tool: queries Pubchem for the NAME of the molecule based on the smiles string.\n \
similars_tool: queries Pubchem for similar molecules based on the smiles string or name and returns 20 results. \
returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
'
res = chat_model.invoke(prompt)
tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
tool_choices = tool_choices.split(',')
if len(tool_choices) == 1:
if tool_choices[0].strip().lower() == 'none':
tool_choice = (None, None)
else:
tool_choice = (tool_choices[0].strip().lower(), None)
elif len(tool_choices) == 2:
if tool_choices[0].strip().lower() == 'none':
tool_choice = (None, tool_choices[1].strip().lower())
elif tool_choices[1].strip().lower() == 'none':
tool_choice = (tool_choices[0].strip().lower(), None)
else:
tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
else:
tool_choice = None
state["tool_choice"] = tool_choice
state["which_tool"] = 0
print(f"The chosen tools are: {tool_choice}")
return state
def retry_node(state: State) -> State:
'''
If the previous loop of the agent does not get enough informartion from the
tools to answer the query, this node is called to retry the previous loop.
Input: the previous loop of the agent.
Output: the tool choice
'''
query_task = state["query_task"]
query_smiles = state["query_smiles"]
query_name = state["query_name"]
prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
or two of the tools descibed below could complete the task. TYou tool choices did not succeed. \
Please re-examine the tool choices and determine if one or two of the tools descibed below \
can complete the task. If so, reply with only the tool names followed by "#". If two tools \
are required, reply with both tool names separated by a comma and followed by "#". \
If the tools cannot complete the task, reply with "None #".\n \
QUERY_TASK: {query_task}.\n \
The information provided by the user is:\n \
QUERY_SMILES: {query_smiles}.\n \
QUERY_NAME: {query_name}.\n \
Tools: \n \
smiles_tool: queries Pubchem for the smiles string of the molecule based on the name as input.\n \
name_tool: queries Pubchem for the NAME (IUPAC) of the molecule based on the smiles string as input. \
Also returns a short list of common names for the molecule. \n \
similars_tool: queries Pubchem for similar molecules based on the smiles string or name as input and returns 20 results. \
Returns the names, SMILES strings, molecular weights and logP values for the similar molecules. \n \
'
res = chat_model.invoke(prompt)
tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
tool_choices = tool_choices.split(',')
if len(tool_choices) == 1:
if tool_choices[0].strip().lower() == 'none':
tool_choice = (None, None)
else:
tool_choice = (tool_choices[0].strip().lower(), None)
elif len(tool_choices) == 2:
if tool_choices[0].strip().lower() == 'none':
tool_choice = (None, tool_choices[1].strip().lower())
elif tool_choices[1].strip().lower() == 'none':
tool_choice = (tool_choices[0].strip().lower(), None)
else:
tool_choice = (tool_choices[0].strip().lower(), tool_choices[1].strip().lower())
elif 'none' in tool_choices[0].strip().lower():
tool_choice = None
else:
tool_choice = None
state["tool_choice"] = tool_choice
state["which_tool"] = 0
print(f"The chosen tools are (Retry): {tool_choice}")
return state
def loop_node(state: State) -> State:
'''
This node accepts the tool returns and decides if it needs to call another
tool or go on to the parser node.
Input: the tool returns.
Output: the next node to call.
'''
return state
def parser_node(state: State) -> State:
'''
This is the third node in the agent. It receives the output from the tool,
puts it into a prompt as CONTEXT, and asks the LLM to answer the original
query.
Input: the output from the tool.
Output: the answer to the original query.
'''
props_string = state["props_string"]
query_task = state["query_task"]
check_prompt = f'Determine if there is enough CONTEXT below to answer the original \
QUERY TASK. If there is, respond with "PROCEED #" . If there is not enough information \
to answer the QUERY TASK, respond with "LOOP #" \n \
CONTEXT: {props_string}.\n \
QUERY_TASK: {query_task}.\n'
res = chat_model.invoke(check_prompt)
# print('*'*50)
# print(res)
# print('*'*50)
if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
state["loop_again"] = "loop_again"
return state
elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
state["loop_again"] = None
prompt = f'Using the CONTEXT below, answer the original query, which \
was to answer the QUERY_TASK. End your answer with a "#" \
QUERY_TASK: {query_task}.\n \
CONTEXT: {props_string}.\n '
res = chat_model.invoke(prompt)
return {"messages": res}
def reflect_node(state: State) -> State:
'''
This is the fourth node of the agent. It recieves the LLMs previous answer and
tries to improve it.
Input: the LLMs last answer.
Output: the improved answer.
'''
previous_answer = state["messages"][-1].content
props_string = state["props_string"]
prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
TOOL RESULTS by adding additional clarifying and enriching information. End \
your new answer with a "#" \
PREVIOUS ANSWER: {previous_answer}.\n \
TOOL RESULTS: {props_string}. '
res = chat_model.invoke(prompt)
return {"messages": res}
def get_chemtool(state):
'''
'''
which_tool = state["which_tool"]
tool_choice = state["tool_choice"]
#print(tool_choice)
if tool_choice == None:
return None
if which_tool == 0 or which_tool == 1:
current_tool = tool_choice[which_tool]
if current_tool == "smiles_tool" and ("query_name" not in state.keys()):
current_tool = "name_tool"
print("Switching from smiles tool to name tool")
elif current_tool == "name_tool" and ("query_smiles" not in state.keys()):
current_tool = "smiles_tool"
print("Switching from name tool to smiles tool")
elif which_tool > 1:
current_tool = None
return current_tool
def loop_or_not(state):
'''
'''
print(f"Loop? {state['loop_again']}")
if state["loop_again"] == "loop_again":
return True
else:
return False
def pretty_print(answer):
final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
for i in range(0,len(final),100):
print(final[i:i+100])
def print_short(answer):
for i in range(0,len(answer),100):
print(answer[i:i+100])