gaia-final-agent / agent_runner.py
Sanjayyy06's picture
Update agent_runner.py
aad078a verified
import os
import re
import json
import time
import logging
import requests
from typing import List, Dict
from dotenv import load_dotenv
from smolagents import (
CodeAgent,
Tool,
DuckDuckGoSearchTool,
FinalAnswerTool,
PythonInterpreterTool,
InferenceClientModel
)
load_dotenv()
GAIA_API_BASE = "https://agents-course-unit4-scoring.hf.space"
HF_TOKEN = os.getenv("HF_TOKEN")
HF_USERNAME = os.getenv("HF_USERNAME", "sanjayelango")
SPACE_ID = os.getenv("SPACE_ID", f"{HF_USERNAME}/gaia-final-agent")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("gaia_agent")
SKIP_EXT = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".mp3", ".wav", ".ogg"}
SYSTEM_PROMPT = """
You are a GAIA Level-1 agent.
CRITICAL RULES:
- NEVER try to call tools inside python code.
- Python is ONLY for math, parsing, data processing, and simple computations.
- Search must ONLY be done with DuckDuckGoSearchTool, NOT python.
- Do NOT write code like: DuckDuckGoSearchTool(...), web_search(...), requests.get(...)
- If you need to search → CALL the search tool directly.
- If file missing → FILE_SKIPPED.
- Give exact final answer.
"""
def ext_of(name: str):
return os.path.splitext(name)[1].lower() if name else ""
# ===============================
# Custom FetchFile Tool
# ===============================
class FetchFileTool(Tool):
name = "fetch_file"
description = "Fetch GAIA file attachment"
inputs = {
"task_id": {
"type": "string",
"description": "GAIA task id"
}
}
output_type = "string"
def forward(self, task_id: str) -> str:
url = f"{GAIA_API_BASE}/files/{task_id}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
try:
r = requests.get(url, headers=headers, timeout=10)
if r.status_code == 200:
try:
return r.text
except:
return "BINARY_FILE"
else:
return "FILE_MISSING"
except:
return "FILE_MISSING"
fetch_file_tool = FetchFileTool()
search_tool = DuckDuckGoSearchTool()
python_tool = PythonInterpreterTool()
final_tool = FinalAnswerTool()
TOOLS = [search_tool, python_tool, final_tool, fetch_file_tool]
# ===============================
# HF MODEL
# ===============================
def create_model():
if not HF_TOKEN:
raise RuntimeError("Missing HF_TOKEN")
return InferenceClientModel(
model_id="meta-llama/Llama-3.1-8B-Instruct",
api_key=HF_TOKEN,
temperature=0.0,
max_tokens=600
)
# ===============================
# Solve a single question
# ===============================
def solve_question(q: Dict) -> str:
task_id = q["task_id"]
question = q["question"]
file_name = q.get("file_name", "")
ext = ext_of(file_name)
if ext in SKIP_EXT:
file_text = "FILE_SKIPPED"
else:
file_text = fetch_file_tool(task_id=task_id)
prompt = f"""
SYSTEM:
{SYSTEM_PROMPT}
USER:
Task ID: {task_id}
File Content: {file_text}
Question:
{question}
Provide ONLY the final answer.
"""
model = create_model()
agent = CodeAgent(
model=model,
tools=TOOLS,
max_steps=6, # More steps helps reasoning, but still safe
restrict_code_execution=True, # <-- IMPORTANT FIX
allow_real_python=True,
verbosity_level=0
)
try:
result = agent.run(prompt)
text = result.strip()
m = re.search(r"final answer[:\s]*(.+)", text, re.I)
if m:
return m.group(1).strip()
return text.split("\n")[0].strip()
except Exception as e:
logger.error(f"Failed on {task_id}: {e}")
return "ERROR"
# ===============================
# GAIA API
# ===============================
def get_questions():
r = requests.get(f"{GAIA_API_BASE}/questions", timeout=20)
r.raise_for_status()
return r.json()
def submit_answers(answers):
payload = {
"username": HF_USERNAME,
"agent_code": SPACE_ID,
"answers": answers,
}
headers = {"Content-Type": "application/json"}
if HF_TOKEN:
headers["Authorization"] = f"Bearer {HF_TOKEN}"
r = requests.post(f"{GAIA_API_BASE}/submit", json=payload, headers=headers)
r.raise_for_status()
return r.json()
def run_submission():
qs = get_questions()
answers = []
for q in qs:
ans = solve_question(q)
answers.append({"task_id": q["task_id"], "submitted_answer": ans})
time.sleep(1.0)
return submit_answers(answers)