| | import os |
| | import requests |
| | from datasets import load_dataset |
| | from basic_agent import BasicAgent, DEFAULT_API_URL |
| | from langchain_core.messages import HumanMessage |
| |
|
| | |
| | TOOL_THRESHOLD = 3 |
| | STEP_THRESHOLD = 6 |
| |
|
| |
|
| | def build_ground_truth_mapping(): |
| | """Return a dict mapping task_id -> final_answer for the filtered GAIA validation set.""" |
| | print("Loading GAIA benchmark validation split… (this may take a minute)") |
| | ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True) |
| |
|
| | mapping = {} |
| | for item in ds: |
| | meta = item.get("Annotator Metadata") or {} |
| | try: |
| | n_tools = int(meta.get("Number of tools", 99)) |
| | n_steps = int(meta.get("Number of steps", 99)) |
| | except ValueError: |
| | continue |
| |
|
| | if n_tools < TOOL_THRESHOLD and n_steps < STEP_THRESHOLD: |
| | task_id = str(item["task_id"]) |
| | mapping[task_id] = str(item["Final answer"]) |
| | print(f"Ground-truth map built for {len(mapping)} tasks.") |
| | return mapping |
| |
|
| |
|
| | def fetch_random_question(api_base: str = DEFAULT_API_URL): |
| | """GET /random-question from the scoring API and return its JSON. |
| | Raises requests.HTTPError on failure.""" |
| | url = f"{api_base}/random-question" |
| | resp = requests.get(url, timeout=30) |
| | resp.raise_for_status() |
| | return resp.json() |
| |
|
| |
|
| | def main(): |
| | gt = build_ground_truth_mapping() |
| |
|
| | q = fetch_random_question() |
| | task_id = str(q["task_id"]) |
| | question_text = q["question"] |
| |
|
| | print("\n=== Random Question ===") |
| | print(f"Task ID : {task_id}") |
| | print(f"Question: {question_text}\n") |
| |
|
| | agent = BasicAgent() |
| | answer = agent.agent.invoke({"messages": [HumanMessage(content=question_text)]}) |
| | |
| |
|
| | if isinstance(answer, dict) and "messages" in answer and answer["messages"]: |
| | answer_str = answer["messages"][-1].content.strip() |
| | else: |
| | print("Agent returned unexpected structure – treating as error.") |
| | answer_str = "" |
| |
|
| | print(f"Agent answer: {answer_str}") |
| |
|
| | gt_answer = gt.get(task_id) |
| | if gt_answer is None: |
| | print("Ground-truth answer not found for this task ID. Filtering may be out of sync with the server.") |
| | return |
| |
|
| | print(f"Ground truth: {gt_answer}") |
| |
|
| | if answer_str.strip().lower() == gt_answer.strip().lower(): |
| | print("✅ Correct!") |
| | else: |
| | print("❌ Incorrect.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |