nehaMfiles's picture
Update app.py
3fe8e7c verified
Raw
History Blame
13.2 kB
"""
Final Assignment — GAIA agent for the HF Agents Course (Unit 4).
Built on smolagents. It:
- fetches the filtered GAIA questions,
- downloads any attached file and extracts its text,
- runs a CodeAgent with web search / webpage / Wikipedia tools,
- prompts with the OFFICIAL GAIA system prompt and extracts the text
after "FINAL ANSWER:",
- submits the bare answer to the course /submit API,
- also writes an official GAIA json-lines file (model_answer +
reasoning_trace) you can download and upload to the real leaderboard.
Set these as *Space secrets* (Settings -> Variables and secrets):
- HF_TOKEN (always needed; raises your inference rate limit)
- MODEL_PROVIDER "hf" (default) or "litellm"
- MODEL_ID e.g. "Qwen/Qwen2.5-Coder-32B-Instruct" (hf)
or "gpt-4o" / "anthropic/claude-sonnet-4-5" (litellm)
- LITELLM_API_KEY only if MODEL_PROVIDER=litellm
A GPT-4-level model follows the format prompt easily and scores much higher.
"""
import os
import io
import re
import json
import tempfile
import requests
import pandas as pd
import gradio as gr
from smolagents import (
CodeAgent,
InferenceClientModel,
LiteLLMModel,
DuckDuckGoSearchTool,
WikipediaSearchTool,
VisitWebpageTool,
)
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
JSONL_PATH = "gaia_submission.jsonl"
# Official GAIA system prompt (from the paper / leaderboard).
GAIA_SYSTEM_PROMPT = (
"You are a general AI assistant. I will ask you a question. Report your "
"thoughts, and finish your answer with the following template: FINAL ANSWER: "
"[YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as "
"possible OR a comma separated list of numbers and/or strings. If you are asked "
"for a number, don't use comma to write your number neither use units such as $ "
"or percent sign unless specified otherwise. If you are asked for a string, "
"don't use articles, neither abbreviations (e.g. for cities), and write the "
"digits in plain text unless specified otherwise. If you are asked for a comma "
"separated list, apply the above rules depending of whether the element to be "
"put in the list is a number or a string."
)
# --------------------------------------------------------------------------- #
# Model selection
# --------------------------------------------------------------------------- #
def build_model():
provider = os.getenv("MODEL_PROVIDER", "hf").lower()
model_id = os.getenv("MODEL_ID", "Qwen/Qwen2.5-Coder-32B-Instruct")
if provider == "litellm":
return LiteLLMModel(
model_id=model_id,
api_key=os.getenv("LITELLM_API_KEY") or os.getenv("OPENAI_API_KEY"),
temperature=0.0,
)
kwargs = {"model_id": model_id, "temperature": 0.0}
hf_provider = os.getenv("HF_INFERENCE_PROVIDER") # e.g. "together", "sambanova"
if hf_provider:
kwargs["provider"] = hf_provider
token = os.getenv("HF_TOKEN")
if token:
kwargs["token"] = token
return InferenceClientModel(**kwargs)
# --------------------------------------------------------------------------- #
# File handling: download a task's attachment and extract usable text
# --------------------------------------------------------------------------- #
def fetch_file_text(api_url: str, task_id: str, file_name: str) -> str:
url = f"{api_url}/files/{task_id}"
try:
r = requests.get(url, timeout=60)
r.raise_for_status()
except Exception as e:
return f"[Could not download attached file '{file_name}': {e}]"
data = r.content
ext = file_name.lower().rsplit(".", 1)[-1] if "." in file_name else ""
try:
if ext in ("txt", "py", "md", "json", "xml", "csv", "tsv"):
text = data.decode("utf-8", errors="replace")
if ext == "csv":
df = pd.read_csv(io.StringIO(text))
return f"CSV file '{file_name}' content:\n{df.to_string()}"
if ext == "tsv":
df = pd.read_csv(io.StringIO(text), sep="\t")
return f"TSV file '{file_name}' content:\n{df.to_string()}"
return f"File '{file_name}' content:\n{text}"
if ext in ("xlsx", "xls"):
sheets = pd.read_excel(io.BytesIO(data), sheet_name=None)
parts = [f"Excel file '{file_name}':"]
for name, df in sheets.items():
parts.append(f"--- sheet: {name} ---\n{df.to_string()}")
return "\n".join(parts)
if ext == "pdf":
import pdfplumber
with pdfplumber.open(io.BytesIO(data)) as pdf:
pages = [p.extract_text() or "" for p in pdf.pages]
return f"PDF file '{file_name}' text:\n" + "\n".join(pages)
if ext == "docx":
import docx
tmp = os.path.join(tempfile.gettempdir(), file_name)
with open(tmp, "wb") as f:
f.write(data)
doc = docx.Document(tmp)
return f"Word file '{file_name}':\n" + "\n".join(
p.text for p in doc.paragraphs
)
tmp = os.path.join(tempfile.gettempdir(), file_name)
with open(tmp, "wb") as f:
f.write(data)
return (
f"[A file named '{file_name}' is attached and saved locally at '{tmp}'. "
f"Use your tools / Python to inspect it if the question needs it.]"
)
except Exception as e:
return f"[Attached file '{file_name}' could not be parsed: {e}]"
# --------------------------------------------------------------------------- #
# Answer extraction / normalization
# --------------------------------------------------------------------------- #
def extract_answer(raw: str) -> str:
"""Take the text after the last 'FINAL ANSWER:' if present, then normalize."""
text = str(raw).strip()
matches = list(re.finditer(r"final answer\s*:", text, flags=re.IGNORECASE))
if matches:
text = text[matches[-1].end():].strip()
# collapse to first line (the answer should be a single line)
text = text.splitlines()[0].strip() if text else text
# strip wrapping quotes / brackets
if len(text) >= 2 and text[0] == text[-1] and text[0] in ("'", '"'):
text = text[1:-1].strip()
# drop a trailing period unless it is part of a number
if text.endswith(".") and not re.fullmatch(r"[\d.]+", text):
text = text[:-1].strip()
return text
# --------------------------------------------------------------------------- #
# The agent
# --------------------------------------------------------------------------- #
class GaiaAgent:
def __init__(self, api_url: str = DEFAULT_API_URL):
self.api_url = api_url
model = build_model()
tools = [
DuckDuckGoSearchTool(),
VisitWebpageTool(),
WikipediaSearchTool(user_agent="GAIA-course-agent (student@example.com)"),
]
self.agent = CodeAgent(
tools=tools,
model=model,
add_base_tools=True, # python interpreter + transcriber
additional_authorized_imports=[
"pandas", "numpy", "math", "statistics",
"json", "re", "datetime", "itertools",
],
max_steps=10,
verbosity_level=1,
)
print("GaiaAgent ready.")
def _reasoning_trace(self) -> str:
"""Reconstruct a compact trace from the agent's memory of the last run."""
try:
lines = []
for step in getattr(self.agent.memory, "steps", []):
out = getattr(step, "model_output", None)
if out:
lines.append(str(out).strip())
obs = getattr(step, "observations", None)
if obs:
lines.append("Observation: " + str(obs).strip()[:400])
return "\n".join(lines)[:6000]
except Exception:
return ""
def __call__(self, question: str, task_id: str = "", file_name: str = ""):
"""Returns (answer, reasoning_trace)."""
prompt = (
GAIA_SYSTEM_PROMPT
+ "\n\nWhen you call final_answer, pass ONLY the value that should "
"follow 'FINAL ANSWER:', formatted by the rules above.\n\nQUESTION:\n"
+ question
)
if file_name:
prompt += "\n\n" + fetch_file_text(self.api_url, task_id, file_name)
try:
result = self.agent.run(prompt)
return extract_answer(result), self._reasoning_trace()
except Exception as e:
print(f"Agent error on task {task_id}: {e}")
return "unknown", f"error: {e}"
# --------------------------------------------------------------------------- #
# Fetch -> run -> submit
# --------------------------------------------------------------------------- #
def run_and_submit_all(profile: gr.OAuthProfile | None):
space_id = os.getenv("SPACE_ID")
if profile:
username = profile.username
print(f"User logged in: {username}")
else:
return "Please Login to Hugging Face with the button.", None, None
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
submit_url = f"{api_url}/submit"
try:
agent = GaiaAgent(api_url)
except Exception as e:
return f"Error initializing agent: {e}", None, None
agent_code = (
f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "local"
)
try:
resp = requests.get(questions_url, timeout=30)
resp.raise_for_status()
questions = resp.json()
if not questions:
return "Fetched questions list is empty.", None, None
except Exception as e:
return f"Error fetching questions: {e}", None, None
results_log = []
answers_payload = [] # for the course /submit API
jsonl_records = [] # for the official GAIA leaderboard file
for item in questions:
task_id = item.get("task_id")
question = item.get("question")
file_name = item.get("file_name", "") or ""
if not task_id or question is None:
continue
print(f"Running task {task_id} ...")
answer, trace = agent(question, task_id, file_name)
answers_payload.append(
{"task_id": task_id, "submitted_answer": answer}
)
jsonl_records.append(
{"task_id": task_id, "model_answer": answer, "reasoning_trace": trace}
)
results_log.append(
{"Task ID": task_id, "Question": question, "Submitted Answer": answer}
)
# Write the official GAIA json-lines file for download.
try:
with open(JSONL_PATH, "w", encoding="utf-8") as f:
for rec in jsonl_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
jsonl_file = JSONL_PATH
except Exception as e:
print(f"Could not write jsonl: {e}")
jsonl_file = None
if not answers_payload:
return "Agent produced no answers.", pd.DataFrame(results_log), jsonl_file
submission = {
"username": username.strip(),
"agent_code": agent_code,
"answers": answers_payload,
}
try:
resp = requests.post(submit_url, json=submission, timeout=120)
resp.raise_for_status()
data = resp.json()
status = (
f"Submission Successful!\n"
f"User: {data.get('username')}\n"
f"Score: {data.get('score', 'N/A')}% "
f"({data.get('correct_count', '?')}/"
f"{data.get('total_attempted', '?')} correct)\n"
f"Message: {data.get('message', '')}"
)
return status, pd.DataFrame(results_log), jsonl_file
except Exception as e:
return f"Submission Failed: {e}", pd.DataFrame(results_log), jsonl_file
# --------------------------------------------------------------------------- #
# Gradio UI
# --------------------------------------------------------------------------- #
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent — Final Assignment")
gr.Markdown(
"1. Log in with Hugging Face below.\n"
"2. Click **Run Evaluation & Submit All Answers**.\n\n"
"This submits to the course leaderboard AND produces a "
"`gaia_submission.jsonl` file in the official GAIA format for download. "
"Running all questions can take several minutes."
)
gr.LoginButton()
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(
label="Run Status / Submission Result", lines=5, interactive=False
)
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
jsonl_download = gr.File(label="Official GAIA submission (.jsonl)")
run_button.click(
fn=run_and_submit_all,
outputs=[status_output, results_table, jsonl_download],
)
if __name__ == "__main__":
demo.launch(debug=True, share=False)