Nigou Julien
Build routed GAIA agent v1
07fb471
from __future__ import annotations
import re
from pathlib import Path
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import END, StateGraph
from gaia_agent.answer import normalize_answer
from gaia_agent.config import settings
from gaia_agent.llms import create_chat_model
from gaia_agent.observability import traced_step
from gaia_agent.prompts import (
GAIA_AGENT_SYSTEM_PROMPT,
GAIA_QUERY_PROMPT,
GAIA_VERIFY_PROMPT,
)
from gaia_agent.state import GaiaState
from gaia_agent.tools.files import (
download_task_file,
read_text_file,
summarize_spreadsheet,
)
from gaia_agent.tools.media import image_data_url, transcribe_audio_file
from gaia_agent.tools.python_repl import run_python_file
from gaia_agent.tools.web import (
extract_urls,
fetch_url,
get_youtube_transcript,
web_search,
)
MAX_EVIDENCE_CHARS = 36_000
MAX_WEB_PAGES = 4
def build_graph(trace=None, llm=None):
graph = StateGraph(GaiaState)
chat_model = llm or create_chat_model()
def ingest_task(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
output: dict[str, Any] = {
"evidence": evidence,
"tool_outputs": list(state.get("tool_outputs", [])),
}
if state.get("file_path") or not state.get("file_name"):
return output
try:
path = download_task_file(
settings.gaia_api_url,
state["task_id"],
state.get("file_name"),
)
output["file_path"] = str(path)
evidence.append(f"Downloaded attached file to {path}.")
except Exception as exc:
evidence.append(f"Could not download attached file: {exc}")
output["error"] = str(exc)
return output
return traced_step(trace, "ingest_task", run)
def classify_task(state: GaiaState) -> dict[str, str]:
def run() -> dict[str, str]:
question = state["question"].lower()
file_name = state.get("file_name", "").lower()
if file_name.endswith((".xlsx", ".xls", ".csv")):
task_type = "spreadsheet"
elif file_name.endswith(".py"):
task_type = "python_file"
elif file_name.endswith((".mp3", ".wav", ".m4a", ".ogg", ".flac")):
task_type = "audio"
elif file_name.endswith((".png", ".jpg", ".jpeg", ".webp")):
task_type = "image"
elif "youtube.com" in question or "youtu.be" in question:
task_type = "youtube"
elif _looks_like_computation(question):
task_type = "compute"
elif _looks_like_direct(question):
task_type = "direct"
else:
task_type = "web"
return {"task_type": task_type}
return traced_step(trace, "classify_task", run)
def solve_direct(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
f"Question:\n{state['question']}",
)
return {"draft_answer": answer}
return traced_step(trace, "solve_direct", run)
def solve_compute(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
(
"Solve this question carefully. If it includes a table or "
"formal rule, compute the requested value exactly.\n\n"
f"Question:\n{state['question']}"
),
)
return {"draft_answer": answer}
return traced_step(trace, "solve_compute", run)
def solve_spreadsheet(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
path = state.get("file_path")
if not path:
evidence.append("Attached spreadsheet is unavailable.")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
summary = summarize_spreadsheet(path)
evidence.append(f"Spreadsheet summary:\n{summary}")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
return traced_step(trace, "solve_spreadsheet", run)
def solve_python_file(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
path = state.get("file_path")
if not path:
evidence.append("Attached Python file is unavailable.")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
source = read_text_file(path, max_chars=30_000)
result = run_python_file(path)
evidence.append(f"Attached Python source:\n{source}")
evidence.append(
"Python execution result:\n"
f"exit_code={result['exit_code']}\n"
f"stdout:\n{result['stdout']}\n"
f"stderr:\n{result['stderr']}"
)
stdout = str(result.get("stdout", "")).strip()
if stdout and not str(result.get("stderr", "")).strip():
draft = stdout.splitlines()[-1]
verified = draft
else:
draft = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
verified = ""
output = {"evidence": evidence, "draft_answer": draft}
if verified:
output["verified_answer"] = verified
return output
return traced_step(trace, "solve_python_file", run)
def solve_audio(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
path = state.get("file_path")
if not path:
evidence.append("Attached audio file is unavailable.")
else:
try:
transcript = transcribe_audio_file(path)
evidence.append(f"Audio transcript:\n{transcript}")
except Exception as exc:
evidence.append(f"Audio transcription failed: {exc}")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
return traced_step(trace, "solve_audio", run)
def solve_image(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
path = state.get("file_path")
if path:
try:
answer = _invoke_image(chat_model, state["question"], path)
evidence.append(f"Image analyzed from {path}.")
except Exception as exc:
evidence.append(f"Image analysis failed: {exc}")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
else:
evidence.append("Attached image file is unavailable.")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
return traced_step(trace, "solve_image", run)
def solve_youtube(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
urls = extract_urls(state["question"])
for url in urls:
if "youtube.com" not in url and "youtu.be" not in url:
continue
try:
transcript = get_youtube_transcript(url)
evidence.append(f"YouTube transcript for {url}:\n{transcript}")
except Exception as exc:
evidence.append(f"YouTube transcript failed for {url}: {exc}")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
return traced_step(trace, "solve_youtube", run)
def solve_web(state: GaiaState) -> dict[str, Any]:
def run() -> dict[str, Any]:
evidence = list(state.get("evidence", []))
queries = _build_search_queries(chat_model, state["question"])
seen_urls: set[str] = set()
for query in queries:
try:
results = web_search(query, max_results=5)
except Exception as exc:
evidence.append(f"Search failed for {query!r}: {exc}")
continue
if results:
evidence.append(
"Search results for "
f"{query!r}:\n"
+ "\n".join(f"- {item.title}: {item.url}" for item in results)
)
for result in results:
if len(seen_urls) >= MAX_WEB_PAGES:
break
if result.url in seen_urls:
continue
seen_urls.add(result.url)
try:
page_text = fetch_url(result.url, max_chars=12_000)
except Exception as exc:
evidence.append(f"Fetch failed for {result.url}: {exc}")
continue
evidence.append(f"Page: {result.title}\nURL: {result.url}\n{page_text}")
answer = _invoke_text(
chat_model,
GAIA_AGENT_SYSTEM_PROMPT,
_question_with_evidence(state["question"], evidence),
)
return {"evidence": evidence, "draft_answer": answer}
return traced_step(trace, "solve_web", run)
def verify_answer(state: GaiaState) -> dict[str, str]:
def run() -> dict[str, str]:
if state.get("verified_answer"):
return {"verified_answer": state["verified_answer"]}
evidence = _trim_evidence(state.get("evidence", []))
verified = _invoke_text(
chat_model,
GAIA_VERIFY_PROMPT,
(
f"Question:\n{state['question']}\n\n"
f"Evidence:\n{evidence}\n\n"
f"Draft answer:\n{state.get('draft_answer', '')}"
),
)
return {"verified_answer": verified}
return traced_step(trace, "verify_answer", run)
def normalize_final_answer(state: GaiaState) -> dict[str, str]:
def run() -> dict[str, str]:
answer = state.get("verified_answer") or state.get("draft_answer", "")
return {"final_answer": normalize_answer(answer)}
return traced_step(trace, "normalize_final_answer", run)
graph.add_node("ingest_task", ingest_task)
graph.add_node("classify_task", classify_task)
graph.add_node("solve_direct", solve_direct)
graph.add_node("solve_compute", solve_compute)
graph.add_node("solve_spreadsheet", solve_spreadsheet)
graph.add_node("solve_python_file", solve_python_file)
graph.add_node("solve_audio", solve_audio)
graph.add_node("solve_image", solve_image)
graph.add_node("solve_youtube", solve_youtube)
graph.add_node("solve_web", solve_web)
graph.add_node("verify_answer", verify_answer)
graph.add_node("normalize_final_answer", normalize_final_answer)
graph.set_entry_point("ingest_task")
graph.add_edge("ingest_task", "classify_task")
graph.add_conditional_edges(
"classify_task",
lambda state: state.get("task_type", "web"),
{
"direct": "solve_direct",
"compute": "solve_compute",
"spreadsheet": "solve_spreadsheet",
"python_file": "solve_python_file",
"audio": "solve_audio",
"image": "solve_image",
"youtube": "solve_youtube",
"web": "solve_web",
},
)
for node in (
"solve_direct",
"solve_compute",
"solve_spreadsheet",
"solve_python_file",
"solve_audio",
"solve_image",
"solve_youtube",
"solve_web",
):
graph.add_edge(node, "verify_answer")
graph.add_edge("verify_answer", "normalize_final_answer")
graph.add_edge("normalize_final_answer", END)
return graph.compile()
def _invoke_text(chat_model, system_prompt: str, user_prompt: str) -> str:
response = chat_model.invoke(
[
("system", system_prompt),
("user", user_prompt),
]
)
return str(response.content)
def _invoke_image(chat_model, question: str, path: str | Path) -> str:
response = chat_model.invoke(
[
SystemMessage(content=GAIA_AGENT_SYSTEM_PROMPT),
HumanMessage(
content=[
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {"url": image_data_url(path)},
},
]
),
]
)
return str(response.content)
def _build_search_queries(chat_model, question: str) -> list[str]:
raw_queries = _invoke_text(
chat_model,
GAIA_QUERY_PROMPT,
f"Question:\n{question}",
)
queries = [
re.sub(r"^\s*[-*\d.)]+\s*", "", line).strip()
for line in raw_queries.splitlines()
if line.strip()
]
queries = [query.strip("\"'") for query in queries if len(query.strip("\"'")) > 3]
if question not in queries:
queries.append(question)
return queries[:3]
def _question_with_evidence(question: str, evidence: list[str]) -> str:
return f"Question:\n{question}\n\nEvidence:\n{_trim_evidence(evidence)}"
def _trim_evidence(evidence: list[str]) -> str:
text = "\n\n---\n\n".join(evidence)
if len(text) <= MAX_EVIDENCE_CHARS:
return text
return f"{text[:MAX_EVIDENCE_CHARS]}\n\n[trimmed after {MAX_EVIDENCE_CHARS} chars]"
def _looks_like_computation(question: str) -> bool:
markers = (
"given this table",
"provide the subset",
"counter-examples",
"not commutative",
"calculate",
"numeric output",
)
return any(marker in question for marker in markers)
def _looks_like_direct(question: str) -> bool:
if question.count(" ") <= 8:
return True
if _looks_reversed(question):
return True
direct_markers = (
"grocery list",
"categorizing things",
"write the opposite",
)
return any(marker in question for marker in direct_markers)
def _looks_reversed(question: str) -> bool:
words = re.findall(r"[a-z]{4,}", question)
if len(words) < 3:
return False
reversed_common = {"rewsna", "drow", "etirw", "ecnetnes", "dnatsrednu"}
return len(reversed_common.intersection(words)) >= 2