File size: 3,308 Bytes
524e3cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python3
"""Fetch GAIA course questions, run GaiaAgent, save JSON — does not submit."""

from __future__ import annotations

import argparse
import json
import os
import sys
import tempfile
from pathlib import Path

import requests

ROOT = Path(__file__).resolve().parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from agent import GaiaAgent  # noqa: E402
from answer_normalize import normalize_answer  # noqa: E402

DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"


def download_file(api_url: str, task_id: str, file_name: str) -> str | None:
    if not file_name or not str(file_name).strip():
        return None
    url = f"{api_url}/files/{task_id}"
    r = requests.get(url, timeout=120)
    if r.status_code != 200:
        return None
    ctype = (r.headers.get("Content-Type") or "").lower()
    if "application/json" in ctype:
        try:
            data = r.json()
            if isinstance(data, dict) and data.get("detail"):
                return None
        except json.JSONDecodeError:
            pass
    suffix = Path(file_name).suffix or ""
    fd, path = tempfile.mkstemp(suffix=suffix, prefix=f"gaia_{task_id[:8]}_")
    with os.fdopen(fd, "wb") as f:
        f.write(r.content)
    return path


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--api-url",
        default=os.environ.get("GAIA_API_URL", DEFAULT_API_URL),
    )
    parser.add_argument(
        "-o",
        "--output",
        default=str(ROOT / "local_eval_answers.json"),
        help="Write answers JSON here",
    )
    args = parser.parse_args()

    q_url = f"{args.api_url.rstrip('/')}/questions"
    print(f"GET {q_url}")
    r = requests.get(q_url, timeout=60)
    r.raise_for_status()
    items = r.json()
    print(f"{len(items)} questions")

    token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    agent = GaiaAgent(hf_token=token) if token else None

    out: list[dict] = []
    for item in items:
        tid = item.get("task_id")
        q = item.get("question")
        fn = item.get("file_name") or ""
        if not tid or q is None:
            continue
        local = None
        try:
            if fn and str(fn).strip():
                local = download_file(args.api_url, str(tid), str(fn))
            if agent is not None:
                ans = agent(str(q), attachment_path=local, task_id=str(tid))
            else:
                from tools.registry import deterministic_attempt

                d = deterministic_attempt(str(q), local)
                ans = d if d is not None else "NO_HF_TOKEN"
        finally:
            if local and Path(local).is_file():
                Path(local).unlink(missing_ok=True)

        if isinstance(ans, (int, float)) and not isinstance(ans, bool):
            sub = ans
        else:
            sub = normalize_answer(ans)
        out.append(
            {
                "task_id": tid,
                "question": q,
                "submitted_answer": sub,
            }
        )
        print(f"--- {tid[:8]}… -> {out[-1]['submitted_answer']!r}")

    Path(args.output).write_text(json.dumps(out, indent=2), encoding="utf-8")
    print(f"Wrote {args.output}")


if __name__ == "__main__":
    main()