File size: 3,220 Bytes
01aa6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb4047
01aa6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb4047
 
 
 
 
 
 
 
01aa6b9
 
 
 
afb4047
 
 
 
 
01aa6b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Dict, List, cast

from langchain_core.messages import AIMessage
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from src.config import Configuration
from src.model import GoogleModel
from src.state import InputState, State
from src.tools import TOOLS

class GaiaAgent:
    def __init__(self):
        self.graph = self._build_graph()
    
    def _build_graph(self) -> StateGraph:
        builder = StateGraph(State, input=InputState, config_schema=Configuration)

        # Define the two nodes we will cycle between
        builder.add_node("call_model", self._call_model)
        builder.add_node("tools", ToolNode(TOOLS))

        # Set the entrypoint as `call_model`
        # This means that this node is the first one called
        builder.add_edge("__start__", "call_model")
        builder.add_conditional_edges(
            "call_model",
            # If the latest message requires a tool, route to tools
            # Otherwise, provide a direct response
            tools_condition,
        )
        builder.add_edge("tools", "call_model")
        
        graph = builder.compile(name="GAIA Agent", debug=False)
        
        return graph
        
    def _call_model(self, state: State) -> Dict[str, List[AIMessage]]:
        """Call the LLM powering our "agent".

        This function prepares the prompt, initializes the model, and processes the response.

        Args:
            state (State): The current state of the conversation.
            config (RunnableConfig): Configuration for the model run.

        Returns:
            dict: A dictionary containing the model's response message.
        """
        configuration = Configuration.from_context()

        # Initialize the model with tool binding. Change the model or add more tools here.
        model = GoogleModel(
            model=configuration.google_model, 
            temperature=configuration.temperature,
            tools=TOOLS
        )

        # Format the system prompt. Customize this to change the agent's behavior.
        system_message = configuration.system_prompt

        if state.file_name:
            file_prompt = (
                f"\n\nThe task id is {state.task_id}.\n"
                f"Please use this to download the file."
            )

            system_message += file_prompt

        # Get the model's response
        response = cast(
            AIMessage,
            model.llm.invoke(
                [
                    {"role": "system", "content": system_message}, 
                    *state.messages,

                    ]
            ),
        )

        # Handle the case when it's the last step and the model still wants to use a tool
        if state.is_last_step and response.tool_calls:
            return {
                "messages": [
                    AIMessage(
                        id=response.id,
                        content="Sorry, I could not find an answer to your question in the specified number of steps.",
                    )
                ]
            }

        # Return the model's response as a list to be added to existing messages
        return {"messages": [response]}