gaia-agent / agent.py
mrtom17's picture
Update agent.py
6030af2 verified
# agent.py
import os
import logging
from typing import TypedDict, Annotated, Any
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from dotenv import load_dotenv
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
from tools import TOOLS # Your tools list should be defined here
import requests
import re
import json
# --- Logging Setup ---
load_dotenv()
LOG_FILE = os.path.join(os.path.dirname(__file__), "agent.log")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(LOG_FILE, mode="w", encoding="utf-8"),
],
)
logger = logging.getLogger("agent_logger")
# --- Token Counting Helper ---
def count_tokens(messages):
try:
import tiktoken
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
total = 0
for msg in messages:
if hasattr(msg, "content") and msg.content:
total += len(enc.encode(str(msg.content)))
return total
except ImportError:
logger.warning("tiktoken not installed, skipping token count.")
return -1
except Exception as e:
logger.warning(f"Token counting error: {e}")
return -1
# LLM definition using GPT‑o3
system_prompt = (
"You are a helpful assistant. The current year is 2025. When answering, output ONLY the answer to the question, with no extra text, explanation, or formatting. "
"If you call a tool and receive its output, use the tool output as the main source for your answer. "
"You may analyze, summarize, or combine tool outputs if needed to answer the question, but do not ignore tool outputs or say you cannot access files or images. "
"Do not include phrases like 'Final answer', 'The answer is', or any commentary. Output only the answer string. "
"If a question involves a file, audio, or image, use the appropriate tool to access or process the file. Do not say you cannot access files—always attempt a tool call first. "
"If a tool result contains the answer, output the answer immediately. Do not make additional tool calls if the answer is already present in the tool result. "
)
chat = ChatOpenAI(
model="o3", # GPT‑o3 model
temperature=1,
openai_api_key=os.getenv("OPENAI_API_KEY"),
)
# Bind tools with the LLM
chat_with_tools = chat.bind_tools(TOOLS)
# Agent state: tracks conversation history
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Assistant node: single chat invocation, LLM always decides
def assistant(state: AgentState) -> dict[str, list[AnyMessage]]:
logger.info("[Agent] Thinking...")
# Only log message types and contents, skip system prompt
logger.info("[Agent] Messages so far:")
for m in state['messages']:
# Skip system prompt
if hasattr(m, 'content') and isinstance(m.content, str) and m.content.startswith("You are a helpful assistant."):
continue
logger.info(f"{type(m).__name__}: {getattr(m, 'content', str(m))}")
logger.info("-" * 40)
# Track tool call attempts in state
if 'tool_call_attempts' not in state:
state['tool_call_attempts'] = 0
# If tool call limit reached, inject special ToolMessage and force answer
if state['tool_call_attempts'] >= 2:
logger.info("[Agent] Tool call limit reached. Injecting tool limit message.")
state['messages'].append(ToolMessage(content= "YOU CAN NO LONGER MAKE ANY TOOL CALL, PLEASE ANSWER WITH THE CONTEXT YOU HAVE, OR CLEARLY STATE THAT YOU DO NOT HAVE ENOUGH DATA."
, tool_call_id="tool_limit"))
next_msg = chat_with_tools.invoke(state["messages"])
logger.info(f"[Agent] LLM response: {next_msg.content}")
if getattr(next_msg, "tool_calls", None):
for tc in next_msg.tool_calls:
logger.info(f"[Tool Call] {tc['name']} | Args: {tc['args']}")
return {"messages": state["messages"] + [next_msg], "tool_call_attempts": state['tool_call_attempts']}
next_msg = chat_with_tools.invoke(state["messages"])
logger.info(f"[Agent] LLM response: {next_msg.content}")
# If the LLM wants to call a tool, increment the counter
if getattr(next_msg, "tool_calls", None):
state['tool_call_attempts'] += 1
for tc in next_msg.tool_calls:
logger.info(f"[Tool Call] {tc['name']} | Args: {tc['args']}")
return {"messages": state["messages"] + [next_msg], "tool_call_attempts": state['tool_call_attempts']}
# Condition: check if the assistant wants to use a tool again
def needs_tool(state: AgentState) -> str:
last = state["messages"][-1]
# If the LLM called a tool, we route to the tool node
if getattr(last, "tool_calls", None):
return "tools"
# Else, stop at END
return "end"
# Build the graph
def build_langgraph():
builder = StateGraph(AgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(TOOLS))
builder.set_entry_point("assistant")
builder.add_conditional_edges(
"assistant",
needs_tool,
{"tools": "tools", "end": END}
)
builder.add_edge("tools", "assistant")
return builder.compile()
# High-level solve function with logging and token counting
def solve(question: str) -> str:
logger.info(f"[User] {question}")
graph = build_langgraph()
state = {"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]}
step = 0
all_messages = list(state["messages"])
google_search_calls = 0
MAX_GOOGLE_SEARCH_CALLS = 10
tool_call_counts = {}
GIVE_UP_THRESHOLD = 5
fallback_answer = "Unable to determine from available data."
recursion_fallback = "Unable to find the answer with the given data."
try:
while True:
step += 1
logger.info(f"\n========== Step {step} ==========")
# Run one step of the graph with recursion_limit set to 25
result = graph.invoke(state, {"recursion_limit": 13})
new_msgs = result["messages"][len(state["messages"]):]
for msg in new_msgs:
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tool_call in msg.tool_calls:
logger.info(f"[Tool Call] {tool_call['name']} | Args: {tool_call.get('args', tool_call.get('function', {}).get('arguments', ''))}")
if isinstance(msg, ToolMessage):
logger.info(f"[Tool Result] {msg.content}")
if isinstance(msg, AIMessage):
logger.info(f"[Agent Thinking] {msg.content}")
# Intercept tool calls and block google_search_tool after limit
if hasattr(msg, "tool_call_id") and hasattr(msg, "name") and msg.name == "google_search_tool":
google_search_calls += 1
if google_search_calls > MAX_GOOGLE_SEARCH_CALLS:
refusal = ToolMessage(
content="Google search tool call refused: limit of 10 calls per question reached.",
tool_call_id=msg.tool_call_id
)
result["messages"][result["messages"].index(msg)] = refusal
logger.info("[ToolMessage] Google search tool call refused: limit reached.")
if hasattr(msg, "name") and hasattr(msg, "tool_call_id"):
tool_args = ""
if hasattr(msg, "additional_kwargs") and msg.additional_kwargs and "tool_calls" in msg.additional_kwargs:
tool_calls = msg.additional_kwargs["tool_calls"]
if tool_calls and isinstance(tool_calls, list):
tool_args = tool_calls[0].get("function", {}).get("arguments", "")
tool_key = (msg.name, tool_args.strip().lower())
tool_call_counts[tool_key] = tool_call_counts.get(tool_key, 0) + 1
if tool_call_counts[tool_key] > GIVE_UP_THRESHOLD:
logger.info(f"[Agent] Give up condition met for tool {msg.name} with similar arguments: {tool_args}")
return fallback_answer
all_messages.extend(new_msgs)
state["messages"] = result["messages"]
# Next action logging
last_msg = state["messages"][-1]
if getattr(last_msg, "tool_calls", None):
logger.info("[Next Action] Agent will call a tool.")
else:
logger.info("[Next Action] Agent will answer.")
if not getattr(last_msg, "tool_calls", None):
break
logger.info(f"[Agent] Final answer: {state['messages'][-1].content}")
token_count = count_tokens(all_messages)
if token_count >= 0:
logger.info(f"[Stats] Total tokens used: {token_count}")
return state["messages"][-1].content
except Exception as e:
import langgraph.errors
if isinstance(e, langgraph.errors.GraphRecursionError):
logger.info("[Agent] Recursion limit reached, returning fallback answer.")
return recursion_fallback
else:
logger.error(f"[Agent] Unexpected error: {e}")
raise
def download_file(url, dest_path):
response = requests.get(url, stream=True)
response.raise_for_status()
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {url} to {dest_path}")