File size: 4,280 Bytes
49ab10c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langchain_chroma import Chroma
from langchain_litellm import ChatLiteLLM
from langchain_openai import OpenAIEmbeddings
from langgraph.graph import START, StateGraph
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from tools import *

load_dotenv()

class GaiaAgent:
    def __init__(self):
        self.llm = ChatLiteLLM(
            model="openai/gemini-2.5-pro", 
            api_key=os.getenv("ITP_API_KEY"), 
            api_base=os.getenv("TRELLIS_URL"), 
            temperature=0.5,
        )
        self.tools = [
            web_search,
            wikipedia_search,
            arxiv_search,
            text_splitter,
            read_file,
            analyze_image,
            analyze_audio,
            analyze_youtube_video,
            multiply,
            add,
            subtract,
            divide,
        ]
        self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.system_message = """
        You are a general AI assistant. I will ask you a question. 
        Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. 
        YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. 
        If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. 
        If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. 
        If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
        """
        self.vectorstore = Chroma(
            embedding_function=OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY")),
            persist_directory="chroma_db"
        )
        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 3})
        

    def build_graph(self):
        builder = StateGraph(MessagesState)
        builder.add_node("retriever", self.retrieve_node)
        builder.add_node("assistant", self.assistant_node)
        builder.add_node("tools", ToolNode(self.tools))

        builder.add_edge(START, "retriever")
        builder.add_edge("retriever", "assistant")

        builder.add_conditional_edges(
            "assistant",
            tools_condition,
        )
        builder.add_edge("tools", "assistant")
        return builder.compile()


    def retrieve_node(self, state: MessagesState):
        """Retriever node"""
        question = state["messages"][-1].content
        docs = self.retriever.invoke(question)
        
        if docs:
            context = "\n\n".join([d.page_content for d in docs])
        else:
            context = "No relevant documents found"

        combined = f"Context:\n{context}\n\nQuestion:\n{question}"
        return {"messages": [HumanMessage(content=combined)]}

    def assistant_node(self, state: MessagesState):
        """Assistant node"""
        if not any(isinstance(m, HumanMessage) for m in state["messages"]):
            messages = [self.system_message] + state["messages"]
        else:
            messages = state["messages"]

        response = self.llm_with_tools.invoke(messages)
        return {"messages": [response]}

    @staticmethod
    def extract_answer(text: str):
        keyword = "FINAL ANSWER"
        index = text.find(keyword)
        if index != -1:
            return text[index + len(keyword):].strip()
        else:
            return text

    def run(self, task: dict):
        task_id, question, file_name = task["task_id"], task["question"], task["file_name"]

        if file_name != "" or file_name is not None:
            question = f"{question} with task_id {task_id}"
            
        graph = self.build_graph()

        messages: list[HumanMessage] = [HumanMessage(content=question)]
        result = graph.invoke({"messages": messages})

        last_message = self.extract_answer(result["messages"][-1].content)
        return self.extract_answer(last_message)