Raj
Upload 2 files
d2c906a verified
Raw
History Blame Contribute Delete
10.2 kB
"""GAIA agent: smolagents CodeAgent + HF Inference Providers backend.
Pipeline per question:
1. Eager-download `/files/{task_id}` so attachments are never missed.
2. Run the CodeAgent N times (self-consistency).
3. Majority-vote across normalized candidate answers.
4. Verifier LLM pass: re-asks the model to pick the best answer in the
canonical GAIA `FINAL ANSWER:` format given the candidates.
5. Final string normalization for exact-match.
"""
from __future__ import annotations
import os
import re
from collections import Counter
from typing import Optional
from huggingface_hub import InferenceClient
from smolagents import CodeAgent, InferenceClientModel
import config
from tools import (
analyze_image,
download_task_file,
read_table,
read_webpage,
transcribe_audio,
web_search,
wikipedia_search,
youtube_transcript,
)
SYSTEM_PROMPT = """You are a general AI assistant solving questions from the GAIA benchmark.
I will ask you a question. Think step by step using the tools available. When you
are confident, finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER]
YOUR FINAL ANSWER must obey these rules exactly:
- It is a number, OR as few words as possible, OR a comma separated list of
numbers and/or strings.
- If a number: digits only. No commas inside numbers. No units ($, %, kg, etc.)
unless the question explicitly asks for them.
- If a string: no articles (a, an, the), no abbreviations (write "Saint Louis"
not "St. Louis", "New York City" not "NYC"), and digits in plain text unless
the question asks for numerals.
- If a comma separated list: items separated by ", " (comma + single space),
in the order requested, each item following the number/string rules above.
- Do NOT include explanations, units, quotes, or trailing punctuation in the
FINAL ANSWER line itself.
Tool playbook:
- If the question mentions an attached file ("the attached", "this Excel /
CSV / audio / image / Python file", "the file"), call
`download_task_file(task_id)` FIRST to get a local path.
- Spreadsheets / CSV: `read_table(path)` then pandas in Python.
- Audio: `transcribe_audio(path)`.
- Image: `analyze_image(path, question)` β€” phrase the question precisely.
- YouTube link: `youtube_transcript(url)`.
- Open web: `web_search(query)` then `read_webpage(url)` on the best hit.
- Encyclopedic facts: `wikipedia_search(topic)`.
- Use Python freely for math, parsing, sorting, set operations, etc.
Self-check before answering:
- Did you answer the literal question asked? (Not a related one.)
- Does the format match the rules above?
- If unsure, do one more verification search.
"""
VERIFIER_SYSTEM = """You are a strict GAIA answer formatter.
You will see a GAIA question and several candidate answers produced by an agent.
Pick the single best answer and reformat it to obey the GAIA answer format:
- A number, a short string, or a comma-separated list.
- Numbers: digits only, no commas, no units unless the question requires them.
- Strings: no articles, no abbreviations, plain-text digits unless asked.
- Comma-separated lists: ", " between items, in the order requested.
- No explanations. No quotes. No trailing punctuation.
Respond on a SINGLE LINE in exactly this template:
FINAL ANSWER: <the answer>
"""
class GaiaAgent:
"""CodeAgent + self-consistency + verifier."""
def __init__(self, model_id: Optional[str] = None, provider: Optional[str] = None):
token = os.getenv("HF_TOKEN")
if not token:
raise RuntimeError(
"HF_TOKEN is not set. Add it as a Space secret to use HF Inference."
)
self._token = token
self._model_id = model_id or config.AGENT_MODEL_ID
self._provider = provider or config.AGENT_PROVIDER
self.model = InferenceClientModel(
model_id=self._model_id,
provider=self._provider,
token=token,
max_tokens=config.AGENT_MAX_TOKENS,
)
self.tools = [
web_search,
read_webpage,
wikipedia_search,
youtube_transcript,
download_task_file,
read_table,
transcribe_audio,
analyze_image,
]
self._verifier = InferenceClient(token=token, provider=self._provider)
self._n = config.SELF_CONSISTENCY_N
# -------------------- public --------------------
def __call__(self, question: str, task_id: Optional[str] = None) -> str:
# 1) Pre-download attachment.
attachment_line = ""
if task_id:
try:
path = download_task_file(task_id)
if path and not path.startswith(
("NO_FILE", "Download error", "Download failed")
):
attachment_line = (
f"An attached file for this task has already been "
f"downloaded to: {path}\nUse the right tool on this "
f"path (read_table for spreadsheets/CSV, "
f"transcribe_audio for audio, analyze_image for "
f"images, or open it with Python).\n"
)
except Exception as e:
print(f" pre-download skipped: {e}")
# 2) Self-consistency: N independent CodeAgent runs.
candidates: list[str] = []
last_error: Optional[str] = None
for i in range(self._n):
try:
raw = self._single_run(question, task_id, attachment_line)
except Exception as e:
last_error = f"{type(e).__name__}: {e}"
print(f" attempt {i + 1} errored: {last_error}")
continue
norm = self._normalize(str(raw))
print(f" attempt {i + 1}: {norm!r}")
if norm and not norm.startswith("AGENT ERROR"):
candidates.append(norm)
if not candidates:
return f"AGENT ERROR: all {self._n} attempts failed. Last error: {last_error}"
# 3) Majority vote on normalized strings (case-insensitive bucket).
voted = self._vote(candidates)
# 4) Verifier pass β€” pick + reformat using the LLM.
try:
verified = self._verify(question, candidates, voted)
except Exception as e:
print(f" verifier errored, falling back to vote: {e}")
verified = voted
return self._normalize(verified)
# -------------------- helpers --------------------
def _build_agent(self) -> CodeAgent:
kwargs = dict(
tools=self.tools,
model=self.model,
max_steps=config.MAX_STEPS,
additional_authorized_imports=[
"pandas", "numpy", "json", "re", "math", "statistics",
"itertools", "collections", "datetime", "csv", "io",
"pathlib", "string", "base64", "urllib", "unicodedata",
],
)
if config.PLANNING_INTERVAL > 0:
try:
return CodeAgent(planning_interval=config.PLANNING_INTERVAL, **kwargs)
except TypeError:
pass
return CodeAgent(**kwargs)
def _single_run(
self,
question: str,
task_id: Optional[str],
attachment_line: str,
) -> str:
agent = self._build_agent()
prompt = SYSTEM_PROMPT
if task_id:
prompt += f"\nThe current task_id is: {task_id}\n"
if attachment_line:
prompt += attachment_line
prompt += f"\nQuestion:\n{question}\n"
return agent.run(prompt)
@staticmethod
def _vote(candidates: list[str]) -> str:
"""Majority-vote with case-insensitive bucketing; ties β†’ first seen."""
buckets: dict[str, list[str]] = {}
for c in candidates:
key = c.lower().strip()
buckets.setdefault(key, []).append(c)
ranked = sorted(
buckets.items(),
key=lambda kv: (-len(kv[1]), candidates.index(kv[1][0])),
)
return ranked[0][1][0]
def _verify(
self,
question: str,
candidates: list[str],
voted: str,
) -> str:
"""Ask the LLM to choose + reformat the best candidate."""
cand_block = "\n".join(f"- {c}" for c in candidates)
user = (
f"Question:\n{question}\n\n"
f"Candidate answers from independent attempts:\n{cand_block}\n\n"
f"Most common candidate: {voted}\n\n"
f"Return the best final answer in the canonical GAIA format."
)
resp = self._verifier.chat.completions.create(
model=self._model_id,
messages=[
{"role": "system", "content": VERIFIER_SYSTEM},
{"role": "user", "content": user},
],
max_tokens=256,
temperature=0.0,
)
text = (resp.choices[0].message.content or "").strip()
m = re.search(r"final answer\s*[:\-]\s*(.+)", text, flags=re.IGNORECASE)
return m.group(1).strip() if m else text
@staticmethod
def _normalize(text: str) -> str:
"""Strip every common LLM cruft pattern so exact-match has a chance."""
if not text:
return text
t = str(text).strip()
# 1) If the model emitted "FINAL ANSWER: X" anywhere, keep only X.
m = re.search(r"final answer\s*[:\-]\s*(.+)", t, flags=re.IGNORECASE | re.DOTALL)
if m:
t = m.group(1).strip()
t = t.split("\n")[0].strip()
# 2) Strip wrapping markdown bold/italic/code.
t = re.sub(r"^[*_`]+|[*_`]+$", "", t).strip()
# 3) Strip wrapping quotes (incl. smart quotes).
if len(t) >= 2 and t[0] == t[-1] and t[0] in "\"'`β€œβ€β€˜β€™":
t = t[1:-1].strip()
# 4) Drop a single trailing period if the answer is short (not a sentence).
if t.count(" ") < 6:
t = t.rstrip(".")
# 5) Collapse internal whitespace.
t = re.sub(r"[ \t]+", " ", t).strip()
return t