msanton's picture
Add Agent and Tools
49ab10c verified
import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langchain_chroma import Chroma
from langchain_litellm import ChatLiteLLM
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import START, StateGraph
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from tools import *
load_dotenv()
class GaiaAgent:
def __init__(self):
self.llm = ChatLiteLLM(
model="openai/gemini-2.5-pro",
api_key=os.getenv("ITP_API_KEY"),
api_base=os.getenv("TRELLIS_URL"),
temperature=0.5,
)
self.tools = [
web_search,
wikipedia_search,
arxiv_search,
text_splitter,
read_file,
analyze_image,
analyze_audio,
analyze_youtube_video,
multiply,
add,
subtract,
divide,
]
self.llm_with_tools = self.llm.bind_tools(self.tools)
self.system_message = """
You are a general AI assistant. I will ask you a question.
Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
self.vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")),
persist_directory="chroma_db"
)
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
def build_graph(self):
builder = StateGraph(MessagesState)
builder.add_node("retriever", self.retrieve_node)
builder.add_node("assistant", self.assistant_node)
builder.add_node("tools", ToolNode(self.tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
return builder.compile()
def retrieve_node(self, state: MessagesState):
"""Retriever node"""
question = state["messages"][-1].content
docs = self.retriever.invoke(question)
if docs:
context = "\n\n".join([d.page_content for d in docs])
else:
context = "No relevant documents found"
combined = f"Context:\n{context}\n\nQuestion:\n{question}"
return {"messages": [HumanMessage(content=combined)]}
def assistant_node(self, state: MessagesState):
"""Assistant node"""
if not any(isinstance(m, HumanMessage) for m in state["messages"]):
messages = [self.system_message] + state["messages"]
else:
messages = state["messages"]
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
@staticmethod
def extract_answer(text: str):
keyword = "FINAL ANSWER"
index = text.find(keyword)
if index != -1:
return text[index + len(keyword):].strip()
else:
return text
def run(self, task: dict):
task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
if file_name != "" or file_name is not None:
question = f"{question} with task_id {task_id}"
graph = self.build_graph()
messages: list[HumanMessage] = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
last_message = self.extract_answer(result["messages"][-1].content)
return self.extract_answer(last_message)