import os from dotenv import load_dotenv from langchain_core.messages import HumanMessage from langchain_chroma import Chroma from langchain_litellm import ChatLiteLLM from langchain_openai import OpenAIEmbeddings from langgraph.graph import START, StateGraph from langgraph.graph.message import MessagesState from langgraph.prebuilt import ToolNode, tools_condition from tools import * load_dotenv() class GaiaAgent: def __init__(self): self.llm = ChatLiteLLM( model="openai/gemini-2.5-pro", api_key=os.getenv("ITP_API_KEY"), api_base=os.getenv("TRELLIS_URL"), temperature=0.5, ) self.tools = [ web_search, wikipedia_search, arxiv_search, text_splitter, read_file, analyze_image, analyze_audio, analyze_youtube_video, multiply, add, subtract, divide, ] self.llm_with_tools = self.llm.bind_tools(self.tools) self.system_message = """ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. """ self.vectorstore = Chroma( embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")), persist_directory="chroma_db" ) self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3}) def build_graph(self): builder = StateGraph(MessagesState) builder.add_node("retriever", self.retrieve_node) builder.add_node("assistant", self.assistant_node) builder.add_node("tools", ToolNode(self.tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") return builder.compile() def retrieve_node(self, state: MessagesState): """Retriever node""" question = state["messages"][-1].content docs = self.retriever.invoke(question) if docs: context = "\n\n".join([d.page_content for d in docs]) else: context = "No relevant documents found" combined = f"Context:\n{context}\n\nQuestion:\n{question}" return {"messages": [HumanMessage(content=combined)]} def assistant_node(self, state: MessagesState): """Assistant node""" if not any(isinstance(m, HumanMessage) for m in state["messages"]): messages = [self.system_message] + state["messages"] else: messages = state["messages"] response = self.llm_with_tools.invoke(messages) return {"messages": [response]} @staticmethod def extract_answer(text: str): keyword = "FINAL ANSWER" index = text.find(keyword) if index != -1: return text[index + len(keyword):].strip() else: return text def run(self, task: dict): task_id, question, file_name = task["task_id"], task["question"], task["file_name"] if file_name != "" or file_name is not None: question = f"{question} with task_id {task_id}" graph = self.build_graph() messages: list[HumanMessage] = [HumanMessage(content=question)] result = graph.invoke({"messages": messages}) last_message = self.extract_answer(result["messages"][-1].content) return self.extract_answer(last_message)