GAIA-Langgraph / agent.py
jash0803's picture
feat: langgraph agents
95bd81e
import re
import logging
from datetime import datetime
from agents.supervisor import build_supervisor_graph
logger = logging.getLogger("gaia_agent")
_log_handler = logging.FileHandler("gaia_agent.log", mode="a")
_log_handler.setFormatter(logging.Formatter("%(asctime)s | %(message)s", datefmt="%H:%M:%S"))
logger.addHandler(_log_handler)
logger.setLevel(logging.INFO)
INTERNAL_ROUTING_PATTERNS = re.compile(
r"^transfer_to_\w+$|^handoff_to_\w+$|^route_to_\w+$", re.IGNORECASE
)
def _extract_answer(text: str) -> str:
if not text:
return ""
if INTERNAL_ROUTING_PATTERNS.match(text.strip()):
return ""
# Look for "FINAL ANSWER: ..." pattern anywhere in the text
fa_match = re.search(r"(?i)FINAL\s*ANSWER\s*:\s*(.+)", text)
if fa_match:
return fa_match.group(1).strip()
# Fallback: strip common prefixes from the last non-empty line
prefixes_to_strip = [
r"(?i)^the\s+answer\s+is\s*:\s*",
r"(?i)^answer\s*:\s*",
]
cleaned = text.strip()
for pattern in prefixes_to_strip:
cleaned = re.sub(pattern, "", cleaned).strip()
lines = cleaned.strip().split("\n")
if lines:
last_non_empty = ""
for line in reversed(lines):
stripped = line.strip()
if stripped and not INTERNAL_ROUTING_PATTERNS.match(stripped):
last_non_empty = stripped
break
for pattern in prefixes_to_strip:
last_non_empty = re.sub(pattern, "", last_non_empty).strip()
if last_non_empty:
cleaned = last_non_empty
return cleaned.strip()
def _extract_trace(messages) -> tuple[list[str], list[str]]:
"""Walk the message list and collect which agents and tools were invoked."""
agents_used = []
tools_used = []
for msg in messages:
msg_type = type(msg).__name__
name = getattr(msg, "name", None)
if msg_type == "AIMessage" and name and name != "supervisor":
if name not in agents_used:
agents_used.append(name)
if msg_type == "ToolMessage" and name:
if name not in tools_used:
tools_used.append(name)
return agents_used, tools_used
class GAIAAgent:
def __init__(self):
print("Initializing GAIAAgent with multi-agent supervisor...")
self.graph = build_supervisor_graph()
logger.info("--- Session started ---")
print("GAIAAgent initialized successfully.")
def __call__(self, question: str, task_id: str | None = None, file_name: str = "") -> str:
print(f"\n{'='*60}")
print(f"Question (first 100 chars): {question[:100]}...")
print(f"Task ID: {task_id}")
has_file = bool(file_name)
print(f"Associated file: {'yes (' + file_name + ')' if has_file else 'no'}")
prompt = question
if has_file and task_id:
prompt = (
f"{question}\n\n"
f"[IMPORTANT CONTEXT: This question has an associated file named '{file_name}'. "
f"You MUST use the download_gaia_file tool with task_id='{task_id}' and "
f"file_name='{file_name}' to download and process this file before answering.]"
)
elif task_id:
prompt = (
f"{question}\n\n"
f"[Context: Task ID is '{task_id}'. If you need to download an associated file, "
f"use the download_gaia_file tool with this task_id.]"
)
messages = [{"role": "user", "content": prompt}]
try:
result = self.graph.invoke(
{"messages": messages},
config={"recursion_limit": 50},
)
response_messages = result.get("messages", [])
agents_used, tools_used = _extract_trace(response_messages)
if response_messages:
final_msg = response_messages[-1]
raw_answer = (
final_msg.content
if hasattr(final_msg, "content")
else str(final_msg)
)
else:
raw_answer = str(result)
answer = _extract_answer(raw_answer)
logger.info(
f"Q: {question[:80]}... | "
f"file={'yes' if has_file else 'no'} | "
f"agents: {', '.join(agents_used) or 'none'} | "
f"tools: {', '.join(tools_used) or 'none'} | "
f"answer: {answer[:80]}"
)
print(f"Agents used: {agents_used}")
print(f"Tools used: {tools_used}")
print(f"Final answer: {answer}")
print(f"{'='*60}\n")
return answer
except Exception as e:
print(f"Error running agent: {e}")
logger.info(f"Q: {question[:80]}... | ERROR: {e}")
import traceback
traceback.print_exc()
return f"Error: {e}"