Final_Assignment_Template / run_local_eval.py
W01fAI's picture
Upload 7 files
524e3cf verified
raw
history blame
3.31 kB
#!/usr/bin/env python3
"""Fetch GAIA course questions, run GaiaAgent, save JSON — does not submit."""
from __future__ import annotations
import argparse
import json
import os
import sys
import tempfile
from pathlib import Path
import requests
ROOT = Path(__file__).resolve().parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from agent import GaiaAgent # noqa: E402
from answer_normalize import normalize_answer # noqa: E402
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
def download_file(api_url: str, task_id: str, file_name: str) -> str | None:
if not file_name or not str(file_name).strip():
return None
url = f"{api_url}/files/{task_id}"
r = requests.get(url, timeout=120)
if r.status_code != 200:
return None
ctype = (r.headers.get("Content-Type") or "").lower()
if "application/json" in ctype:
try:
data = r.json()
if isinstance(data, dict) and data.get("detail"):
return None
except json.JSONDecodeError:
pass
suffix = Path(file_name).suffix or ""
fd, path = tempfile.mkstemp(suffix=suffix, prefix=f"gaia_{task_id[:8]}_")
with os.fdopen(fd, "wb") as f:
f.write(r.content)
return path
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--api-url",
default=os.environ.get("GAIA_API_URL", DEFAULT_API_URL),
)
parser.add_argument(
"-o",
"--output",
default=str(ROOT / "local_eval_answers.json"),
help="Write answers JSON here",
)
args = parser.parse_args()
q_url = f"{args.api_url.rstrip('/')}/questions"
print(f"GET {q_url}")
r = requests.get(q_url, timeout=60)
r.raise_for_status()
items = r.json()
print(f"{len(items)} questions")
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
agent = GaiaAgent(hf_token=token) if token else None
out: list[dict] = []
for item in items:
tid = item.get("task_id")
q = item.get("question")
fn = item.get("file_name") or ""
if not tid or q is None:
continue
local = None
try:
if fn and str(fn).strip():
local = download_file(args.api_url, str(tid), str(fn))
if agent is not None:
ans = agent(str(q), attachment_path=local, task_id=str(tid))
else:
from tools.registry import deterministic_attempt
d = deterministic_attempt(str(q), local)
ans = d if d is not None else "NO_HF_TOKEN"
finally:
if local and Path(local).is_file():
Path(local).unlink(missing_ok=True)
if isinstance(ans, (int, float)) and not isinstance(ans, bool):
sub = ans
else:
sub = normalize_answer(ans)
out.append(
{
"task_id": tid,
"question": q,
"submitted_answer": sub,
}
)
print(f"--- {tid[:8]}… -> {out[-1]['submitted_answer']!r}")
Path(args.output).write_text(json.dumps(out, indent=2), encoding="utf-8")
print(f"Wrote {args.output}")
if __name__ == "__main__":
main()