Spaces:
Runtime error
Runtime error
SpaceFozzy commited on
Commit Β·
7dd0b14
1
Parent(s): 66dc640
Add new gaia agent
Browse files- agent/__init__.py +0 -0
- agent/gaia.py +215 -0
agent/__init__.py
ADDED
|
File without changes
|
agent/gaia.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
from langchain_anthropic import ChatAnthropic, convert_to_anthropic_tool
|
| 10 |
+
from langchain_core.messages import ToolMessage
|
| 11 |
+
from langchain_core.tools import tool, InjectedToolCallId
|
| 12 |
+
from langchain_tavily import TavilySearch
|
| 13 |
+
|
| 14 |
+
from langgraph.graph import StateGraph, START, END
|
| 15 |
+
from langgraph.graph.message import add_messages
|
| 16 |
+
from langgraph.types import Command
|
| 17 |
+
from langgraph.prebuilt import InjectedState, ToolNode
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=os.getenv("LOGLEVEL", "WARNING"))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AgentState(BaseModel):
|
| 24 |
+
question: dict
|
| 25 |
+
final_agent_answer: dict | None
|
| 26 |
+
messages: Annotated[list, add_messages]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@tool
|
| 30 |
+
def add(x: float, y: float):
|
| 31 |
+
"""This function adds two numbers."""
|
| 32 |
+
logging.info(f"Added {x} and {y}")
|
| 33 |
+
return x + y
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@tool
|
| 37 |
+
def subtract(x: float, y: float):
|
| 38 |
+
"""This function subtracts two numbers."""
|
| 39 |
+
logging.info(f"Subtracting {y} from {x}")
|
| 40 |
+
return x - y
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@tool
|
| 44 |
+
def multiply(x: float, y: float):
|
| 45 |
+
"""This function multiplies two numbers."""
|
| 46 |
+
logging.info(f"Multiplying {x} and {y}")
|
| 47 |
+
return x * y
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@tool
|
| 51 |
+
def divide(x: float, y: float):
|
| 52 |
+
"""this function divides two numbers. handles division by zero."""
|
| 53 |
+
logging.info(f"dividing {x} by {y}")
|
| 54 |
+
if y == 0:
|
| 55 |
+
return "error: cannot divide by zero."
|
| 56 |
+
return x / y
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@tool
|
| 60 |
+
def submit_final_answer(
|
| 61 |
+
answer: str,
|
| 62 |
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
| 63 |
+
state: Annotated[AgentState, InjectedState],
|
| 64 |
+
):
|
| 65 |
+
"""This function should be called to submit your final answer only once you have tetermined it. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
|
| 66 |
+
|
| 67 |
+
print(f"Submitting final answer: {answer}")
|
| 68 |
+
|
| 69 |
+
answer_data = {
|
| 70 |
+
"task_id": state.question["task_id"],
|
| 71 |
+
"agent_answer": answer,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
logging.info("Final answer written, updating state with final answer...")
|
| 75 |
+
return Command(
|
| 76 |
+
update={
|
| 77 |
+
"final_agent_answer": answer_data,
|
| 78 |
+
"messages": [
|
| 79 |
+
ToolMessage(
|
| 80 |
+
"You have successfully submitted your final answer. There is nothing left to be done.",
|
| 81 |
+
tool_call_id=tool_call_id,
|
| 82 |
+
)
|
| 83 |
+
],
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
tavily = TavilySearch(max_results=2)
|
| 89 |
+
tools = [add, subtract, multiply, divide, tavily, submit_final_answer]
|
| 90 |
+
|
| 91 |
+
anthropic_tools = []
|
| 92 |
+
|
| 93 |
+
for raw_tool in tools:
|
| 94 |
+
anthropic_tool = convert_to_anthropic_tool(raw_tool)
|
| 95 |
+
anthropic_tools.append(anthropic_tool)
|
| 96 |
+
|
| 97 |
+
# To cache all tools we add the cache control block to the last tool
|
| 98 |
+
anthropic_tools[-1]["cache_control"] = {"type": "ephemeral"}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
llm = ChatAnthropic(
|
| 102 |
+
model_name="claude-sonnet-4-20250514",
|
| 103 |
+
max_tokens=5000,
|
| 104 |
+
timeout=None,
|
| 105 |
+
thinking={"type": "enabled", "budget_tokens": 4000},
|
| 106 |
+
model_kwargs={
|
| 107 |
+
"extra_headers": {"anthropic-beta": "token-efficient-tools-2025-02-19"}
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
llm_with_tools = llm.bind_tools(anthropic_tools)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class GaiaAgent:
|
| 114 |
+
def __init__(self):
|
| 115 |
+
self.llm = llm_with_tools
|
| 116 |
+
self.agent_graph = self.compile_graph()
|
| 117 |
+
|
| 118 |
+
def compile_graph(self):
|
| 119 |
+
graph = StateGraph(AgentState)
|
| 120 |
+
|
| 121 |
+
def should_continue(state):
|
| 122 |
+
logging.info(
|
| 123 |
+
"Checking for final answer in decide_next_node conditional edge"
|
| 124 |
+
)
|
| 125 |
+
logging.info(state.final_agent_answer)
|
| 126 |
+
if state.final_agent_answer:
|
| 127 |
+
logging.info("Final answer submitted. Ending agent flow.")
|
| 128 |
+
return END
|
| 129 |
+
else:
|
| 130 |
+
logging.info(
|
| 131 |
+
"No final answer submitted yet, proceed to the tool nodes."
|
| 132 |
+
)
|
| 133 |
+
return "tools"
|
| 134 |
+
|
| 135 |
+
graph.add_node(self.consider_question)
|
| 136 |
+
graph.add_node("tools", ToolNode(tools))
|
| 137 |
+
|
| 138 |
+
graph.add_edge(START, "consider_question")
|
| 139 |
+
graph.add_edge("tools", "consider_question")
|
| 140 |
+
graph.add_conditional_edges(
|
| 141 |
+
"consider_question", should_continue, ["tools", END]
|
| 142 |
+
)
|
| 143 |
+
return graph.compile()
|
| 144 |
+
|
| 145 |
+
async def consider_question(self, state: AgentState):
|
| 146 |
+
"""Home of the agent. Looks at all the messages so far, generates the next message."""
|
| 147 |
+
logging.info("Considering question...")
|
| 148 |
+
time.sleep(5)
|
| 149 |
+
if state.final_agent_answer is None:
|
| 150 |
+
messages = state.messages
|
| 151 |
+
response = await self.llm.ainvoke(messages)
|
| 152 |
+
if hasattr(response, "content"):
|
| 153 |
+
for message in response.content:
|
| 154 |
+
if "text" in message:
|
| 155 |
+
logging.info(message["text"])
|
| 156 |
+
return {"messages": [response]}
|
| 157 |
+
else:
|
| 158 |
+
# If a final answer has been determined no more consideration is required
|
| 159 |
+
logging.info(
|
| 160 |
+
"Skipping question consideration because final answer is available"
|
| 161 |
+
)
|
| 162 |
+
return state
|
| 163 |
+
|
| 164 |
+
async def answer_question(self, question):
|
| 165 |
+
if question["file_name"]:
|
| 166 |
+
return "I don't know - I can't handle files yet!"
|
| 167 |
+
question_text = question["question"]
|
| 168 |
+
|
| 169 |
+
logging.debug("Initializing agent state to answer question...")
|
| 170 |
+
system_prompt = """
|
| 171 |
+
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer by calling the submit_final_answer tool. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 172 |
+
To operate effectively, always remember:
|
| 173 |
+
1. Before using any math tools for operations, make sure you have thought about the math problem sufficiently and stated the equation that you will solve. Plan the equation first, then use the math tools to solve it precisely.
|
| 174 |
+
"""
|
| 175 |
+
initial_state = {
|
| 176 |
+
"question": question,
|
| 177 |
+
"final_agent_answer": None,
|
| 178 |
+
"messages": [
|
| 179 |
+
{
|
| 180 |
+
"role": "system",
|
| 181 |
+
"content": [
|
| 182 |
+
{
|
| 183 |
+
"type": "text",
|
| 184 |
+
"text": system_prompt,
|
| 185 |
+
"cache_control": {"type": "ephemeral"},
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": question_text,
|
| 192 |
+
},
|
| 193 |
+
],
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
async def get_final_answer(agent):
|
| 197 |
+
final_output: dict | None = None
|
| 198 |
+
async for mode, chunk in agent.astream(
|
| 199 |
+
initial_state,
|
| 200 |
+
stream_mode=["values", "messages"],
|
| 201 |
+
config={"recursion_limit": 30},
|
| 202 |
+
):
|
| 203 |
+
if mode == "values":
|
| 204 |
+
final_output = chunk
|
| 205 |
+
|
| 206 |
+
if final_output is None:
|
| 207 |
+
return "I don't know!"
|
| 208 |
+
|
| 209 |
+
return final_output["final_agent_answer"]["agent_answer"]
|
| 210 |
+
|
| 211 |
+
result = await get_final_answer(self.agent_graph)
|
| 212 |
+
return result
|
| 213 |
+
|
| 214 |
+
def __call__(self, question):
|
| 215 |
+
return asyncio.run(self.answer_question(question))
|