SpaceFozzy commited on
Commit
7dd0b14
Β·
1 Parent(s): 66dc640

Add new gaia agent

Browse files
Files changed (2) hide show
  1. agent/__init__.py +0 -0
  2. agent/gaia.py +215 -0
agent/__init__.py ADDED
File without changes
agent/gaia.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import asyncio
4
+ import time
5
+
6
+ from pydantic import BaseModel
7
+ from typing import Annotated
8
+
9
+ from langchain_anthropic import ChatAnthropic, convert_to_anthropic_tool
10
+ from langchain_core.messages import ToolMessage
11
+ from langchain_core.tools import tool, InjectedToolCallId
12
+ from langchain_tavily import TavilySearch
13
+
14
+ from langgraph.graph import StateGraph, START, END
15
+ from langgraph.graph.message import add_messages
16
+ from langgraph.types import Command
17
+ from langgraph.prebuilt import InjectedState, ToolNode
18
+
19
+
20
+ logging.basicConfig(level=os.getenv("LOGLEVEL", "WARNING"))
21
+
22
+
23
+ class AgentState(BaseModel):
24
+ question: dict
25
+ final_agent_answer: dict | None
26
+ messages: Annotated[list, add_messages]
27
+
28
+
29
+ @tool
30
+ def add(x: float, y: float):
31
+ """This function adds two numbers."""
32
+ logging.info(f"Added {x} and {y}")
33
+ return x + y
34
+
35
+
36
+ @tool
37
+ def subtract(x: float, y: float):
38
+ """This function subtracts two numbers."""
39
+ logging.info(f"Subtracting {y} from {x}")
40
+ return x - y
41
+
42
+
43
+ @tool
44
+ def multiply(x: float, y: float):
45
+ """This function multiplies two numbers."""
46
+ logging.info(f"Multiplying {x} and {y}")
47
+ return x * y
48
+
49
+
50
+ @tool
51
+ def divide(x: float, y: float):
52
+ """this function divides two numbers. handles division by zero."""
53
+ logging.info(f"dividing {x} by {y}")
54
+ if y == 0:
55
+ return "error: cannot divide by zero."
56
+ return x / y
57
+
58
+
59
+ @tool
60
+ def submit_final_answer(
61
+ answer: str,
62
+ tool_call_id: Annotated[str, InjectedToolCallId],
63
+ state: Annotated[AgentState, InjectedState],
64
+ ):
65
+ """This function should be called to submit your final answer only once you have tetermined it. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
66
+
67
+ print(f"Submitting final answer: {answer}")
68
+
69
+ answer_data = {
70
+ "task_id": state.question["task_id"],
71
+ "agent_answer": answer,
72
+ }
73
+
74
+ logging.info("Final answer written, updating state with final answer...")
75
+ return Command(
76
+ update={
77
+ "final_agent_answer": answer_data,
78
+ "messages": [
79
+ ToolMessage(
80
+ "You have successfully submitted your final answer. There is nothing left to be done.",
81
+ tool_call_id=tool_call_id,
82
+ )
83
+ ],
84
+ }
85
+ )
86
+
87
+
88
+ tavily = TavilySearch(max_results=2)
89
+ tools = [add, subtract, multiply, divide, tavily, submit_final_answer]
90
+
91
+ anthropic_tools = []
92
+
93
+ for raw_tool in tools:
94
+ anthropic_tool = convert_to_anthropic_tool(raw_tool)
95
+ anthropic_tools.append(anthropic_tool)
96
+
97
+ # To cache all tools we add the cache control block to the last tool
98
+ anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"}
99
+
100
+
101
+ llm = ChatAnthropic(
102
+ model_name="claude-sonnet-4-20250514",
103
+ max_tokens=5000,
104
+ timeout=None,
105
+ thinking={"type": "enabled", "budget_tokens": 4000},
106
+ model_kwargs={
107
+ "extra_headers": {"anthropic-beta": "token-efficient-tools-2025-02-19"}
108
+ },
109
+ )
110
+ llm_with_tools = llm.bind_tools(anthropic_tools)
111
+
112
+
113
+ class GaiaAgent:
114
+ def __init__(self):
115
+ self.llm = llm_with_tools
116
+ self.agent_graph = self.compile_graph()
117
+
118
+ def compile_graph(self):
119
+ graph = StateGraph(AgentState)
120
+
121
+ def should_continue(state):
122
+ logging.info(
123
+ "Checking for final answer in decide_next_node conditional edge"
124
+ )
125
+ logging.info(state.final_agent_answer)
126
+ if state.final_agent_answer:
127
+ logging.info("Final answer submitted. Ending agent flow.")
128
+ return END
129
+ else:
130
+ logging.info(
131
+ "No final answer submitted yet, proceed to the tool nodes."
132
+ )
133
+ return "tools"
134
+
135
+ graph.add_node(self.consider_question)
136
+ graph.add_node("tools", ToolNode(tools))
137
+
138
+ graph.add_edge(START, "consider_question")
139
+ graph.add_edge("tools", "consider_question")
140
+ graph.add_conditional_edges(
141
+ "consider_question", should_continue, ["tools", END]
142
+ )
143
+ return graph.compile()
144
+
145
+ async def consider_question(self, state: AgentState):
146
+ """Home of the agent. Looks at all the messages so far, generates the next message."""
147
+ logging.info("Considering question...")
148
+ time.sleep(5)
149
+ if state.final_agent_answer is None:
150
+ messages = state.messages
151
+ response = await self.llm.ainvoke(messages)
152
+ if hasattr(response, "content"):
153
+ for message in response.content:
154
+ if "text" in message:
155
+ logging.info(message["text"])
156
+ return {"messages": [response]}
157
+ else:
158
+ # If a final answer has been determined no more consideration is required
159
+ logging.info(
160
+ "Skipping question consideration because final answer is available"
161
+ )
162
+ return state
163
+
164
+ async def answer_question(self, question):
165
+ if question["file_name"]:
166
+ return "I don't know - I can't handle files yet!"
167
+ question_text = question["question"]
168
+
169
+ logging.debug("Initializing agent state to answer question...")
170
+ system_prompt = """
171
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer by calling the submit_final_answer tool. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
172
+ To operate effectively, always remember:
173
+ 1. Before using any math tools for operations, make sure you have thought about the math problem sufficiently and stated the equation that you will solve. Plan the equation first, then use the math tools to solve it precisely.
174
+ """
175
+ initial_state = {
176
+ "question": question,
177
+ "final_agent_answer": None,
178
+ "messages": [
179
+ {
180
+ "role": "system",
181
+ "content": [
182
+ {
183
+ "type": "text",
184
+ "text": system_prompt,
185
+ "cache_control": {"type": "ephemeral"},
186
+ }
187
+ ],
188
+ },
189
+ {
190
+ "role": "user",
191
+ "content": question_text,
192
+ },
193
+ ],
194
+ }
195
+
196
+ async def get_final_answer(agent):
197
+ final_output: dict | None = None
198
+ async for mode, chunk in agent.astream(
199
+ initial_state,
200
+ stream_mode=["values", "messages"],
201
+ config={"recursion_limit": 30},
202
+ ):
203
+ if mode == "values":
204
+ final_output = chunk
205
+
206
+ if final_output is None:
207
+ return "I don't know!"
208
+
209
+ return final_output["final_agent_answer"]["agent_answer"]
210
+
211
+ result = await get_final_answer(self.agent_graph)
212
+ return result
213
+
214
+ def __call__(self, question):
215
+ return asyncio.run(self.answer_question(question))