Spaces:
Sleeping
Sleeping
| """GAIA agent: smolagents CodeAgent + HF Inference Providers backend.""" | |
| from __future__ import annotations | |
| import os | |
| import re | |
| from typing import Optional | |
| from smolagents import CodeAgent, InferenceClientModel | |
| from tools import ( | |
| analyze_image, | |
| download_task_file, | |
| read_table, | |
| read_webpage, | |
| transcribe_audio, | |
| web_search, | |
| wikipedia_search, | |
| youtube_transcript, | |
| ) | |
| # Qwen2.5-72B is a much stronger tool-using / code-writing model than Llama-3.3 | |
| # for GAIA. Override via AGENT_MODEL_ID env var if you want to test others. | |
| DEFAULT_MODEL_ID = "Qwen/Qwen2.5-72B-Instruct" | |
| DEFAULT_PROVIDER = "auto" | |
| # Canonical GAIA system prompt — the exact answer-format spec the leaderboard | |
| # uses. Deviating from this format is the #1 cause of "right reasoning, wrong | |
| # score". | |
| SYSTEM_PROMPT = """You are a general AI assistant solving questions from the GAIA benchmark. | |
| I will ask you a question. Think step by step using the tools available. When you | |
| are confident, finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER] | |
| YOUR FINAL ANSWER must obey these rules exactly: | |
| - It is a number, OR as few words as possible, OR a comma separated list of | |
| numbers and/or strings. | |
| - If a number: digits only. No commas inside numbers. No units ($, %, kg, etc.) | |
| unless the question explicitly asks for them. | |
| - If a string: no articles (a, an, the), no abbreviations (write "Saint Louis" | |
| not "St. Louis", "New York City" not "NYC"), and digits in plain text unless | |
| the question asks for numerals. | |
| - If a comma separated list: items separated by ", " (comma + single space), | |
| in the order requested, each item following the number/string rules above. | |
| - Do NOT include explanations, units, quotes, or trailing punctuation in the | |
| FINAL ANSWER line itself. | |
| Tool playbook: | |
| - If the question mentions an attached file ("the attached", "this Excel / | |
| CSV / audio / image / Python file", "the file"), call | |
| `download_task_file(task_id)` FIRST to get a local path. | |
| - Spreadsheets / CSV: `read_table(path)` then pandas in Python. | |
| - Audio: `transcribe_audio(path)`. | |
| - Image: `analyze_image(path, question)` — phrase the question precisely. | |
| - YouTube link: `youtube_transcript(url)`. | |
| - Open web: `web_search(query)` then `read_webpage(url)` on the best hit. | |
| - Encyclopedic facts: `wikipedia_search(topic)`. | |
| - Use Python freely for math, parsing, sorting, set operations, etc. | |
| Self-check before answering: | |
| - Did you answer the literal question asked? (Not a related one.) | |
| - Does the format match the rules above? | |
| - If unsure, do one more verification search. | |
| """ | |
| class GaiaAgent: | |
| """Stateless wrapper around a smolagents CodeAgent.""" | |
| def __init__(self, model_id: Optional[str] = None, provider: Optional[str] = None): | |
| token = os.getenv("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError( | |
| "HF_TOKEN is not set. Add it as a Space secret to use HF Inference." | |
| ) | |
| self.model = InferenceClientModel( | |
| model_id=model_id or os.getenv("AGENT_MODEL_ID", DEFAULT_MODEL_ID), | |
| provider=provider or os.getenv("AGENT_PROVIDER", DEFAULT_PROVIDER), | |
| token=token, | |
| max_tokens=4096, | |
| ) | |
| self.tools = [ | |
| web_search, | |
| read_webpage, | |
| wikipedia_search, | |
| youtube_transcript, | |
| download_task_file, | |
| read_table, | |
| transcribe_audio, | |
| analyze_image, | |
| ] | |
| def _build_agent(self) -> CodeAgent: | |
| # Fresh agent per question — keeps memory clean between tasks. | |
| kwargs = dict( | |
| tools=self.tools, | |
| model=self.model, | |
| max_steps=15, | |
| additional_authorized_imports=[ | |
| "pandas", "numpy", "json", "re", "math", "statistics", | |
| "itertools", "collections", "datetime", "csv", "io", | |
| "pathlib", "string", "base64", "urllib", "unicodedata", | |
| ], | |
| ) | |
| # planning_interval was added in smolagents 1.x — guard for older installs. | |
| try: | |
| return CodeAgent(planning_interval=4, **kwargs) | |
| except TypeError: | |
| return CodeAgent(**kwargs) | |
| def __call__(self, question: str, task_id: Optional[str] = None) -> str: | |
| agent = self._build_agent() | |
| prompt = SYSTEM_PROMPT | |
| if task_id: | |
| prompt += f"\nThe current task_id is: {task_id}\n" | |
| # Eager-download the attachment so the agent never misses it. | |
| try: | |
| path = download_task_file(task_id) | |
| if path and not path.startswith("NO_FILE") and not path.startswith( | |
| ("Download error", "Download failed") | |
| ): | |
| prompt += ( | |
| f"An attached file for this task has already been " | |
| f"downloaded to: {path}\n" | |
| f"Use the right tool on this path (read_table for " | |
| f"spreadsheets/CSV, transcribe_audio for audio, " | |
| f"analyze_image for images, or open it with Python).\n" | |
| ) | |
| except Exception as e: | |
| print(f" pre-download skipped: {e}") | |
| prompt += f"\nQuestion:\n{question}\n" | |
| try: | |
| raw = agent.run(prompt) | |
| except Exception as e: | |
| return f"AGENT ERROR: {e}" | |
| return self._normalize(str(raw)) | |
| def _normalize(text: str) -> str: | |
| """Strip every common LLM cruft pattern so exact-match has a chance.""" | |
| if not text: | |
| return text | |
| t = str(text).strip() | |
| # 1) If the model emitted "FINAL ANSWER: X" anywhere, keep only X. | |
| m = re.search(r"final answer\s*[:\-]\s*(.+)", t, flags=re.IGNORECASE | re.DOTALL) | |
| if m: | |
| t = m.group(1).strip() | |
| # Trim at first newline / closing markdown. | |
| t = t.split("\n")[0].strip() | |
| # 2) Strip wrapping markdown bold/italic/code. | |
| t = re.sub(r"^[*_`]+|[*_`]+$", "", t).strip() | |
| # 3) Strip wrapping quotes. | |
| if len(t) >= 2 and t[0] == t[-1] and t[0] in "\"'`“”‘’": | |
| t = t[1:-1].strip() | |
| # 4) Drop a single trailing period if the answer is short (not a sentence). | |
| if t.count(" ") < 6: | |
| t = t.rstrip(".") | |
| # 5) Collapse internal whitespace. | |
| t = re.sub(r"[ \t]+", " ", t).strip() | |
| return t | |