| """GAIA agent: smolagents CodeAgent + HF Inference Providers backend. |
| |
| Pipeline per question: |
| 1. Eager-download `/files/{task_id}` so attachments are never missed. |
| 2. Run the CodeAgent N times (self-consistency). |
| 3. Majority-vote across normalized candidate answers. |
| 4. Verifier LLM pass: re-asks the model to pick the best answer in the |
| canonical GAIA `FINAL ANSWER:` format given the candidates. |
| 5. Final string normalization for exact-match. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import re |
| from collections import Counter |
| from typing import Optional |
|
|
| from huggingface_hub import InferenceClient |
| from smolagents import CodeAgent, InferenceClientModel |
|
|
| import config |
| from tools import ( |
| analyze_image, |
| download_task_file, |
| read_table, |
| read_webpage, |
| transcribe_audio, |
| web_search, |
| wikipedia_search, |
| youtube_transcript, |
| ) |
|
|
| SYSTEM_PROMPT = """You are a general AI assistant solving questions from the GAIA benchmark. |
| |
| I will ask you a question. Think step by step using the tools available. When you |
| are confident, finish your answer with the following template: |
| |
| FINAL ANSWER: [YOUR FINAL ANSWER] |
| |
| YOUR FINAL ANSWER must obey these rules exactly: |
| - It is a number, OR as few words as possible, OR a comma separated list of |
| numbers and/or strings. |
| - If a number: digits only. No commas inside numbers. No units ($, %, kg, etc.) |
| unless the question explicitly asks for them. |
| - If a string: no articles (a, an, the), no abbreviations (write "Saint Louis" |
| not "St. Louis", "New York City" not "NYC"), and digits in plain text unless |
| the question asks for numerals. |
| - If a comma separated list: items separated by ", " (comma + single space), |
| in the order requested, each item following the number/string rules above. |
| - Do NOT include explanations, units, quotes, or trailing punctuation in the |
| FINAL ANSWER line itself. |
| |
| Tool playbook: |
| - If the question mentions an attached file ("the attached", "this Excel / |
| CSV / audio / image / Python file", "the file"), call |
| `download_task_file(task_id)` FIRST to get a local path. |
| - Spreadsheets / CSV: `read_table(path)` then pandas in Python. |
| - Audio: `transcribe_audio(path)`. |
| - Image: `analyze_image(path, question)` β phrase the question precisely. |
| - YouTube link: `youtube_transcript(url)`. |
| - Open web: `web_search(query)` then `read_webpage(url)` on the best hit. |
| - Encyclopedic facts: `wikipedia_search(topic)`. |
| - Use Python freely for math, parsing, sorting, set operations, etc. |
| |
| Self-check before answering: |
| - Did you answer the literal question asked? (Not a related one.) |
| - Does the format match the rules above? |
| - If unsure, do one more verification search. |
| """ |
|
|
| VERIFIER_SYSTEM = """You are a strict GAIA answer formatter. |
| |
| You will see a GAIA question and several candidate answers produced by an agent. |
| Pick the single best answer and reformat it to obey the GAIA answer format: |
| |
| - A number, a short string, or a comma-separated list. |
| - Numbers: digits only, no commas, no units unless the question requires them. |
| - Strings: no articles, no abbreviations, plain-text digits unless asked. |
| - Comma-separated lists: ", " between items, in the order requested. |
| - No explanations. No quotes. No trailing punctuation. |
| |
| Respond on a SINGLE LINE in exactly this template: |
| FINAL ANSWER: <the answer> |
| """ |
|
|
|
|
| class GaiaAgent: |
| """CodeAgent + self-consistency + verifier.""" |
|
|
| def __init__(self, model_id: Optional[str] = None, provider: Optional[str] = None): |
| token = os.getenv("HF_TOKEN") |
| if not token: |
| raise RuntimeError( |
| "HF_TOKEN is not set. Add it as a Space secret to use HF Inference." |
| ) |
| self._token = token |
| self._model_id = model_id or config.AGENT_MODEL_ID |
| self._provider = provider or config.AGENT_PROVIDER |
| self.model = InferenceClientModel( |
| model_id=self._model_id, |
| provider=self._provider, |
| token=token, |
| max_tokens=config.AGENT_MAX_TOKENS, |
| ) |
| self.tools = [ |
| web_search, |
| read_webpage, |
| wikipedia_search, |
| youtube_transcript, |
| download_task_file, |
| read_table, |
| transcribe_audio, |
| analyze_image, |
| ] |
| self._verifier = InferenceClient(token=token, provider=self._provider) |
| self._n = config.SELF_CONSISTENCY_N |
|
|
| |
| def __call__(self, question: str, task_id: Optional[str] = None) -> str: |
| |
| attachment_line = "" |
| if task_id: |
| try: |
| path = download_task_file(task_id) |
| if path and not path.startswith( |
| ("NO_FILE", "Download error", "Download failed") |
| ): |
| attachment_line = ( |
| f"An attached file for this task has already been " |
| f"downloaded to: {path}\nUse the right tool on this " |
| f"path (read_table for spreadsheets/CSV, " |
| f"transcribe_audio for audio, analyze_image for " |
| f"images, or open it with Python).\n" |
| ) |
| except Exception as e: |
| print(f" pre-download skipped: {e}") |
|
|
| |
| candidates: list[str] = [] |
| last_error: Optional[str] = None |
| for i in range(self._n): |
| try: |
| raw = self._single_run(question, task_id, attachment_line) |
| except Exception as e: |
| last_error = f"{type(e).__name__}: {e}" |
| print(f" attempt {i + 1} errored: {last_error}") |
| continue |
| norm = self._normalize(str(raw)) |
| print(f" attempt {i + 1}: {norm!r}") |
| if norm and not norm.startswith("AGENT ERROR"): |
| candidates.append(norm) |
|
|
| if not candidates: |
| return f"AGENT ERROR: all {self._n} attempts failed. Last error: {last_error}" |
|
|
| |
| voted = self._vote(candidates) |
|
|
| |
| try: |
| verified = self._verify(question, candidates, voted) |
| except Exception as e: |
| print(f" verifier errored, falling back to vote: {e}") |
| verified = voted |
|
|
| return self._normalize(verified) |
|
|
| |
| def _build_agent(self) -> CodeAgent: |
| kwargs = dict( |
| tools=self.tools, |
| model=self.model, |
| max_steps=config.MAX_STEPS, |
| additional_authorized_imports=[ |
| "pandas", "numpy", "json", "re", "math", "statistics", |
| "itertools", "collections", "datetime", "csv", "io", |
| "pathlib", "string", "base64", "urllib", "unicodedata", |
| ], |
| ) |
| if config.PLANNING_INTERVAL > 0: |
| try: |
| return CodeAgent(planning_interval=config.PLANNING_INTERVAL, **kwargs) |
| except TypeError: |
| pass |
| return CodeAgent(**kwargs) |
|
|
| def _single_run( |
| self, |
| question: str, |
| task_id: Optional[str], |
| attachment_line: str, |
| ) -> str: |
| agent = self._build_agent() |
| prompt = SYSTEM_PROMPT |
| if task_id: |
| prompt += f"\nThe current task_id is: {task_id}\n" |
| if attachment_line: |
| prompt += attachment_line |
| prompt += f"\nQuestion:\n{question}\n" |
| return agent.run(prompt) |
|
|
| @staticmethod |
| def _vote(candidates: list[str]) -> str: |
| """Majority-vote with case-insensitive bucketing; ties β first seen.""" |
| buckets: dict[str, list[str]] = {} |
| for c in candidates: |
| key = c.lower().strip() |
| buckets.setdefault(key, []).append(c) |
| ranked = sorted( |
| buckets.items(), |
| key=lambda kv: (-len(kv[1]), candidates.index(kv[1][0])), |
| ) |
| return ranked[0][1][0] |
|
|
| def _verify( |
| self, |
| question: str, |
| candidates: list[str], |
| voted: str, |
| ) -> str: |
| """Ask the LLM to choose + reformat the best candidate.""" |
| cand_block = "\n".join(f"- {c}" for c in candidates) |
| user = ( |
| f"Question:\n{question}\n\n" |
| f"Candidate answers from independent attempts:\n{cand_block}\n\n" |
| f"Most common candidate: {voted}\n\n" |
| f"Return the best final answer in the canonical GAIA format." |
| ) |
| resp = self._verifier.chat.completions.create( |
| model=self._model_id, |
| messages=[ |
| {"role": "system", "content": VERIFIER_SYSTEM}, |
| {"role": "user", "content": user}, |
| ], |
| max_tokens=256, |
| temperature=0.0, |
| ) |
| text = (resp.choices[0].message.content or "").strip() |
| m = re.search(r"final answer\s*[:\-]\s*(.+)", text, flags=re.IGNORECASE) |
| return m.group(1).strip() if m else text |
|
|
| @staticmethod |
| def _normalize(text: str) -> str: |
| """Strip every common LLM cruft pattern so exact-match has a chance.""" |
| if not text: |
| return text |
| t = str(text).strip() |
|
|
| |
| m = re.search(r"final answer\s*[:\-]\s*(.+)", t, flags=re.IGNORECASE | re.DOTALL) |
| if m: |
| t = m.group(1).strip() |
| t = t.split("\n")[0].strip() |
|
|
| |
| t = re.sub(r"^[*_`]+|[*_`]+$", "", t).strip() |
|
|
| |
| if len(t) >= 2 and t[0] == t[-1] and t[0] in "\"'`ββββ": |
| t = t[1:-1].strip() |
|
|
| |
| if t.count(" ") < 6: |
| t = t.rstrip(".") |
|
|
| |
| t = re.sub(r"[ \t]+", " ", t).strip() |
|
|
| return t |
|
|