Raj
Upload 2 files
91f0524 verified
Raw
History Blame Contribute Delete
6.48 kB
"""GAIA agent: smolagents CodeAgent + HF Inference Providers backend."""
from __future__ import annotations
import os
import re
from typing import Optional
from smolagents import CodeAgent, InferenceClientModel
from tools import (
analyze_image,
download_task_file,
read_table,
read_webpage,
transcribe_audio,
web_search,
wikipedia_search,
youtube_transcript,
)
# Qwen2.5-72B is a much stronger tool-using / code-writing model than Llama-3.3
# for GAIA. Override via AGENT_MODEL_ID env var if you want to test others.
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-72B-Instruct"
DEFAULT_PROVIDER = "auto"
# Canonical GAIA system prompt — the exact answer-format spec the leaderboard
# uses. Deviating from this format is the #1 cause of "right reasoning, wrong
# score".
SYSTEM_PROMPT = """You are a general AI assistant solving questions from the GAIA benchmark.
I will ask you a question. Think step by step using the tools available. When you
are confident, finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER]
YOUR FINAL ANSWER must obey these rules exactly:
- It is a number, OR as few words as possible, OR a comma separated list of
numbers and/or strings.
- If a number: digits only. No commas inside numbers. No units ($, %, kg, etc.)
unless the question explicitly asks for them.
- If a string: no articles (a, an, the), no abbreviations (write "Saint Louis"
not "St. Louis", "New York City" not "NYC"), and digits in plain text unless
the question asks for numerals.
- If a comma separated list: items separated by ", " (comma + single space),
in the order requested, each item following the number/string rules above.
- Do NOT include explanations, units, quotes, or trailing punctuation in the
FINAL ANSWER line itself.
Tool playbook:
- If the question mentions an attached file ("the attached", "this Excel /
CSV / audio / image / Python file", "the file"), call
`download_task_file(task_id)` FIRST to get a local path.
- Spreadsheets / CSV: `read_table(path)` then pandas in Python.
- Audio: `transcribe_audio(path)`.
- Image: `analyze_image(path, question)` — phrase the question precisely.
- YouTube link: `youtube_transcript(url)`.
- Open web: `web_search(query)` then `read_webpage(url)` on the best hit.
- Encyclopedic facts: `wikipedia_search(topic)`.
- Use Python freely for math, parsing, sorting, set operations, etc.
Self-check before answering:
- Did you answer the literal question asked? (Not a related one.)
- Does the format match the rules above?
- If unsure, do one more verification search.
"""
class GaiaAgent:
"""Stateless wrapper around a smolagents CodeAgent."""
def __init__(self, model_id: Optional[str] = None, provider: Optional[str] = None):
token = os.getenv("HF_TOKEN")
if not token:
raise RuntimeError(
"HF_TOKEN is not set. Add it as a Space secret to use HF Inference."
)
self.model = InferenceClientModel(
model_id=model_id or os.getenv("AGENT_MODEL_ID", DEFAULT_MODEL_ID),
provider=provider or os.getenv("AGENT_PROVIDER", DEFAULT_PROVIDER),
token=token,
max_tokens=4096,
)
self.tools = [
web_search,
read_webpage,
wikipedia_search,
youtube_transcript,
download_task_file,
read_table,
transcribe_audio,
analyze_image,
]
def _build_agent(self) -> CodeAgent:
# Fresh agent per question — keeps memory clean between tasks.
kwargs = dict(
tools=self.tools,
model=self.model,
max_steps=15,
additional_authorized_imports=[
"pandas", "numpy", "json", "re", "math", "statistics",
"itertools", "collections", "datetime", "csv", "io",
"pathlib", "string", "base64", "urllib", "unicodedata",
],
)
# planning_interval was added in smolagents 1.x — guard for older installs.
try:
return CodeAgent(planning_interval=4, **kwargs)
except TypeError:
return CodeAgent(**kwargs)
def __call__(self, question: str, task_id: Optional[str] = None) -> str:
agent = self._build_agent()
prompt = SYSTEM_PROMPT
if task_id:
prompt += f"\nThe current task_id is: {task_id}\n"
# Eager-download the attachment so the agent never misses it.
try:
path = download_task_file(task_id)
if path and not path.startswith("NO_FILE") and not path.startswith(
("Download error", "Download failed")
):
prompt += (
f"An attached file for this task has already been "
f"downloaded to: {path}\n"
f"Use the right tool on this path (read_table for "
f"spreadsheets/CSV, transcribe_audio for audio, "
f"analyze_image for images, or open it with Python).\n"
)
except Exception as e:
print(f" pre-download skipped: {e}")
prompt += f"\nQuestion:\n{question}\n"
try:
raw = agent.run(prompt)
except Exception as e:
return f"AGENT ERROR: {e}"
return self._normalize(str(raw))
@staticmethod
def _normalize(text: str) -> str:
"""Strip every common LLM cruft pattern so exact-match has a chance."""
if not text:
return text
t = str(text).strip()
# 1) If the model emitted "FINAL ANSWER: X" anywhere, keep only X.
m = re.search(r"final answer\s*[:\-]\s*(.+)", t, flags=re.IGNORECASE | re.DOTALL)
if m:
t = m.group(1).strip()
# Trim at first newline / closing markdown.
t = t.split("\n")[0].strip()
# 2) Strip wrapping markdown bold/italic/code.
t = re.sub(r"^[*_`]+|[*_`]+$", "", t).strip()
# 3) Strip wrapping quotes.
if len(t) >= 2 and t[0] == t[-1] and t[0] in "\"'`“”‘’":
t = t[1:-1].strip()
# 4) Drop a single trailing period if the answer is short (not a sentence).
if t.count(" ") < 6:
t = t.rstrip(".")
# 5) Collapse internal whitespace.
t = re.sub(r"[ \t]+", " ", t).strip()
return t