File size: 6,273 Bytes
10e9b7d 1fc7c71 10e9b7d eccf8e4 3c4371f 1fc7c71 10e9b7d e80aab9 3db6293 e80aab9 1fc7c71 31243f4 1fc7c71 7d65c66 3c4371f 1fc7c71 e80aab9 1fc7c71 31243f4 1fc7c71 31243f4 3c4371f 1fc7c71 eccf8e4 1fc7c71 31243f4 7d65c66 1fc7c71 e80aab9 1fc7c71 7d65c66 1fc7c71 31243f4 1fc7c71 31243f4 1fc7c71 31243f4 1fc7c71 31243f4 1fc7c71 e80aab9 1fc7c71 e80aab9 1fc7c71 e80aab9 1fc7c71 e80aab9 1fc7c71 7d65c66 1fc7c71 e80aab9 1fc7c71 e80aab9 1fc7c71 0ee0419 1fc7c71 e80aab9 7e4a06b 1fc7c71 e80aab9 1fc7c71 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import os
import io
import gradio as gr
import requests
import pandas as pd
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
LiteLLMModel,
Tool,
tool,
)
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
# --- Custom Tool: Read task files from GAIA API ---
class TaskFileReaderTool(Tool):
name = "task_file_reader"
description = (
"Downloads and reads a file attached to a GAIA task by its task_id. "
"Use this when the question mentions an attached file, document, spreadsheet, or image."
)
inputs = {
"task_id": {
"type": "string",
"description": "The task_id to download the file for.",
}
}
output_type = "string"
def forward(self, task_id: str) -> str:
try:
r = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=30)
r.raise_for_status()
ct = r.headers.get("Content-Type", "")
if "text" in ct or "json" in ct or "csv" in ct:
return r.text[:10000]
elif "spreadsheet" in ct or "excel" in ct:
df = pd.read_excel(io.BytesIO(r.content))
return df.to_string()
else:
try:
return r.text[:10000]
except Exception:
return f"[Binary file, {len(r.content)} bytes, type: {ct}]"
except Exception as e:
return f"Error downloading file for task {task_id}: {e}"
# --- Agent Definition ---
class GAIAAgent:
def __init__(self):
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Set ANTHROPIC_API_KEY env var")
model = LiteLLMModel(
model_id="anthropic/claude-sonnet-4-20250514",
api_key=api_key,
)
self.agent = CodeAgent(
tools=[DuckDuckGoSearchTool(), TaskFileReaderTool()],
model=model,
max_steps=8,
verbosity_level=1,
additional_authorized_imports=[
"re", "json", "math", "collections",
"itertools", "statistics", "unicodedata",
],
)
print("GAIAAgent initialized with Claude Sonnet.")
def __call__(self, question: str, task_id: str = None) -> str:
prompt = (
f"Question: {question}\n\n"
f"INSTRUCTIONS:\n"
f"- If the question references an attached file, use task_file_reader with task_id='{task_id}'.\n"
f"- Use web_search to find factual information when needed.\n"
f"- Give ONLY the exact final answer. No explanation, no 'The answer is', no extra words.\n"
f"- For numbers: just the number. For names: just the name. For lists: comma-separated.\n"
)
try:
result = self.agent.run(prompt)
answer = str(result).strip()
for prefix in ["The answer is ", "Answer: ", "FINAL ANSWER: ", "Final answer: "]:
if answer.lower().startswith(prefix.lower()):
answer = answer[len(prefix):].strip()
return answer
except Exception as e:
print(f"Agent error: {e}")
return "Unable to determine answer"
def run_and_submit_all(profile: gr.OAuthProfile | None):
space_id = os.getenv("SPACE_ID")
if not profile:
return "Please Login to Hugging Face with the button.", None
username = profile.username
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
# 1. Init agent
try:
agent = GAIAAgent()
except Exception as e:
return f"Error initializing agent: {e}", None
# 2. Fetch questions
try:
resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
resp.raise_for_status()
questions_data = resp.json()
print(f"Fetched {len(questions_data)} questions.")
except Exception as e:
return f"Error fetching questions: {e}", None
# 3. Run agent
results_log = []
answers_payload = []
for i, item in enumerate(questions_data):
task_id = item.get("task_id")
question_text = item.get("question")
if not task_id or question_text is None:
continue
print(f"\n--- Q{i+1}/{len(questions_data)} [{task_id}] ---")
print(f"Q: {question_text[:120]}")
try:
answer = agent(question_text, task_id=task_id)
print(f"A: {answer}")
except Exception as e:
answer = f"ERROR: {e}"
print(f"Error: {e}")
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": answer})
if not answers_payload:
return "No answers produced.", pd.DataFrame(results_log)
# 4. Submit
submission = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
try:
resp = requests.post(f"{DEFAULT_API_URL}/submit", json=submission, timeout=120)
resp.raise_for_status()
data = resp.json()
status = (
f"Submission Successful!\n"
f"User: {data.get('username')}\n"
f"Score: {data.get('score', 'N/A')}% "
f"({data.get('correct_count', '?')}/{data.get('total_attempted', '?')} correct)\n"
f"Message: {data.get('message', '')}"
)
return status, pd.DataFrame(results_log)
except Exception as e:
return f"Submission Failed: {e}", pd.DataFrame(results_log)
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent — smolagents + Claude Sonnet")
gr.Markdown(
"1. Log in with HuggingFace\n"
"2. Click 'Run Evaluation & Submit'\n"
"3. Wait for the agent to answer all 20 questions"
)
gr.LoginButton()
run_btn = gr.Button("Run Evaluation & Submit All Answers")
status_box = gr.Textbox(label="Status", lines=5, interactive=False)
results_tbl = gr.DataFrame(label="Results", wrap=True)
run_btn.click(fn=run_and_submit_all, outputs=[status_box, results_tbl])
if __name__ == "__main__":
demo.launch(debug=True, share=False)
|