Spaces:
Runtime error
Runtime error
File size: 4,280 Bytes
49ab10c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langchain_chroma import Chroma
from langchain_litellm import ChatLiteLLM
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import START, StateGraph
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from tools import *
load_dotenv()
class GaiaAgent:
def __init__(self):
self.llm = ChatLiteLLM(
model="openai/gemini-2.5-pro",
api_key=os.getenv("ITP_API_KEY"),
api_base=os.getenv("TRELLIS_URL"),
temperature=0.5,
)
self.tools = [
web_search,
wikipedia_search,
arxiv_search,
text_splitter,
read_file,
analyze_image,
analyze_audio,
analyze_youtube_video,
multiply,
add,
subtract,
divide,
]
self.llm_with_tools = self.llm.bind_tools(self.tools)
self.system_message = """
You are a general AI assistant. I will ask you a question.
Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
self.vectorstore = Chroma(
embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")),
persist_directory="chroma_db"
)
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
def build_graph(self):
builder = StateGraph(MessagesState)
builder.add_node("retriever", self.retrieve_node)
builder.add_node("assistant", self.assistant_node)
builder.add_node("tools", ToolNode(self.tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
return builder.compile()
def retrieve_node(self, state: MessagesState):
"""Retriever node"""
question = state["messages"][-1].content
docs = self.retriever.invoke(question)
if docs:
context = "\n\n".join([d.page_content for d in docs])
else:
context = "No relevant documents found"
combined = f"Context:\n{context}\n\nQuestion:\n{question}"
return {"messages": [HumanMessage(content=combined)]}
def assistant_node(self, state: MessagesState):
"""Assistant node"""
if not any(isinstance(m, HumanMessage) for m in state["messages"]):
messages = [self.system_message] + state["messages"]
else:
messages = state["messages"]
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
@staticmethod
def extract_answer(text: str):
keyword = "FINAL ANSWER"
index = text.find(keyword)
if index != -1:
return text[index + len(keyword):].strip()
else:
return text
def run(self, task: dict):
task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
if file_name != "" or file_name is not None:
question = f"{question} with task_id {task_id}"
graph = self.build_graph()
messages: list[HumanMessage] = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
last_message = self.extract_answer(result["messages"][-1].content)
return self.extract_answer(last_message) |