"""GAIA agent: smolagents CodeAgent + HF Inference Providers backend. Pipeline per question: 1. Eager-download `/files/{task_id}` so attachments are never missed. 2. Run the CodeAgent N times (self-consistency). 3. Majority-vote across normalized candidate answers. 4. Verifier LLM pass: re-asks the model to pick the best answer in the canonical GAIA `FINAL ANSWER:` format given the candidates. 5. Final string normalization for exact-match. """ from __future__ import annotations import os import re from collections import Counter from typing import Optional from huggingface_hub import InferenceClient from smolagents import CodeAgent, InferenceClientModel import config from tools import ( analyze_image, download_task_file, read_table, read_webpage, transcribe_audio, web_search, wikipedia_search, youtube_transcript, ) SYSTEM_PROMPT = """You are a general AI assistant solving questions from the GAIA benchmark. I will ask you a question. Think step by step using the tools available. When you are confident, finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER] YOUR FINAL ANSWER must obey these rules exactly: - It is a number, OR as few words as possible, OR a comma separated list of numbers and/or strings. - If a number: digits only. No commas inside numbers. No units ($, %, kg, etc.) unless the question explicitly asks for them. - If a string: no articles (a, an, the), no abbreviations (write "Saint Louis" not "St. Louis", "New York City" not "NYC"), and digits in plain text unless the question asks for numerals. - If a comma separated list: items separated by ", " (comma + single space), in the order requested, each item following the number/string rules above. - Do NOT include explanations, units, quotes, or trailing punctuation in the FINAL ANSWER line itself. Tool playbook: - If the question mentions an attached file ("the attached", "this Excel / CSV / audio / image / Python file", "the file"), call `download_task_file(task_id)` FIRST to get a local path. - Spreadsheets / CSV: `read_table(path)` then pandas in Python. - Audio: `transcribe_audio(path)`. - Image: `analyze_image(path, question)` — phrase the question precisely. - YouTube link: `youtube_transcript(url)`. - Open web: `web_search(query)` then `read_webpage(url)` on the best hit. - Encyclopedic facts: `wikipedia_search(topic)`. - Use Python freely for math, parsing, sorting, set operations, etc. Self-check before answering: - Did you answer the literal question asked? (Not a related one.) - Does the format match the rules above? - If unsure, do one more verification search. """ VERIFIER_SYSTEM = """You are a strict GAIA answer formatter. You will see a GAIA question and several candidate answers produced by an agent. Pick the single best answer and reformat it to obey the GAIA answer format: - A number, a short string, or a comma-separated list. - Numbers: digits only, no commas, no units unless the question requires them. - Strings: no articles, no abbreviations, plain-text digits unless asked. - Comma-separated lists: ", " between items, in the order requested. - No explanations. No quotes. No trailing punctuation. Respond on a SINGLE LINE in exactly this template: FINAL ANSWER: """ class GaiaAgent: """CodeAgent + self-consistency + verifier.""" def __init__(self, model_id: Optional[str] = None, provider: Optional[str] = None): token = os.getenv("HF_TOKEN") if not token: raise RuntimeError( "HF_TOKEN is not set. Add it as a Space secret to use HF Inference." ) self._token = token self._model_id = model_id or config.AGENT_MODEL_ID self._provider = provider or config.AGENT_PROVIDER self.model = InferenceClientModel( model_id=self._model_id, provider=self._provider, token=token, max_tokens=config.AGENT_MAX_TOKENS, ) self.tools = [ web_search, read_webpage, wikipedia_search, youtube_transcript, download_task_file, read_table, transcribe_audio, analyze_image, ] self._verifier = InferenceClient(token=token, provider=self._provider) self._n = config.SELF_CONSISTENCY_N # -------------------- public -------------------- def __call__(self, question: str, task_id: Optional[str] = None) -> str: # 1) Pre-download attachment. attachment_line = "" if task_id: try: path = download_task_file(task_id) if path and not path.startswith( ("NO_FILE", "Download error", "Download failed") ): attachment_line = ( f"An attached file for this task has already been " f"downloaded to: {path}\nUse the right tool on this " f"path (read_table for spreadsheets/CSV, " f"transcribe_audio for audio, analyze_image for " f"images, or open it with Python).\n" ) except Exception as e: print(f" pre-download skipped: {e}") # 2) Self-consistency: N independent CodeAgent runs. candidates: list[str] = [] last_error: Optional[str] = None for i in range(self._n): try: raw = self._single_run(question, task_id, attachment_line) except Exception as e: last_error = f"{type(e).__name__}: {e}" print(f" attempt {i + 1} errored: {last_error}") continue norm = self._normalize(str(raw)) print(f" attempt {i + 1}: {norm!r}") if norm and not norm.startswith("AGENT ERROR"): candidates.append(norm) if not candidates: return f"AGENT ERROR: all {self._n} attempts failed. Last error: {last_error}" # 3) Majority vote on normalized strings (case-insensitive bucket). voted = self._vote(candidates) # 4) Verifier pass — pick + reformat using the LLM. try: verified = self._verify(question, candidates, voted) except Exception as e: print(f" verifier errored, falling back to vote: {e}") verified = voted return self._normalize(verified) # -------------------- helpers -------------------- def _build_agent(self) -> CodeAgent: kwargs = dict( tools=self.tools, model=self.model, max_steps=config.MAX_STEPS, additional_authorized_imports=[ "pandas", "numpy", "json", "re", "math", "statistics", "itertools", "collections", "datetime", "csv", "io", "pathlib", "string", "base64", "urllib", "unicodedata", ], ) if config.PLANNING_INTERVAL > 0: try: return CodeAgent(planning_interval=config.PLANNING_INTERVAL, **kwargs) except TypeError: pass return CodeAgent(**kwargs) def _single_run( self, question: str, task_id: Optional[str], attachment_line: str, ) -> str: agent = self._build_agent() prompt = SYSTEM_PROMPT if task_id: prompt += f"\nThe current task_id is: {task_id}\n" if attachment_line: prompt += attachment_line prompt += f"\nQuestion:\n{question}\n" return agent.run(prompt) @staticmethod def _vote(candidates: list[str]) -> str: """Majority-vote with case-insensitive bucketing; ties → first seen.""" buckets: dict[str, list[str]] = {} for c in candidates: key = c.lower().strip() buckets.setdefault(key, []).append(c) ranked = sorted( buckets.items(), key=lambda kv: (-len(kv[1]), candidates.index(kv[1][0])), ) return ranked[0][1][0] def _verify( self, question: str, candidates: list[str], voted: str, ) -> str: """Ask the LLM to choose + reformat the best candidate.""" cand_block = "\n".join(f"- {c}" for c in candidates) user = ( f"Question:\n{question}\n\n" f"Candidate answers from independent attempts:\n{cand_block}\n\n" f"Most common candidate: {voted}\n\n" f"Return the best final answer in the canonical GAIA format." ) resp = self._verifier.chat.completions.create( model=self._model_id, messages=[ {"role": "system", "content": VERIFIER_SYSTEM}, {"role": "user", "content": user}, ], max_tokens=256, temperature=0.0, ) text = (resp.choices[0].message.content or "").strip() m = re.search(r"final answer\s*[:\-]\s*(.+)", text, flags=re.IGNORECASE) return m.group(1).strip() if m else text @staticmethod def _normalize(text: str) -> str: """Strip every common LLM cruft pattern so exact-match has a chance.""" if not text: return text t = str(text).strip() # 1) If the model emitted "FINAL ANSWER: X" anywhere, keep only X. m = re.search(r"final answer\s*[:\-]\s*(.+)", t, flags=re.IGNORECASE | re.DOTALL) if m: t = m.group(1).strip() t = t.split("\n")[0].strip() # 2) Strip wrapping markdown bold/italic/code. t = re.sub(r"^[*_`]+|[*_`]+$", "", t).strip() # 3) Strip wrapping quotes (incl. smart quotes). if len(t) >= 2 and t[0] == t[-1] and t[0] in "\"'`“”‘’": t = t[1:-1].strip() # 4) Drop a single trailing period if the answer is short (not a sentence). if t.count(" ") < 6: t = t.rstrip(".") # 5) Collapse internal whitespace. t = re.sub(r"[ \t]+", " ", t).strip() return t