import os import requests from datasets import load_dataset from basic_agent import BasicAgent, DEFAULT_API_URL from langchain_core.messages import HumanMessage # Server-side filtering parameters (keep in sync with scoring API) TOOL_THRESHOLD = 3 STEP_THRESHOLD = 6 def build_ground_truth_mapping(): """Return a dict mapping task_id -> final_answer for the filtered GAIA validation set.""" print("Loading GAIA benchmark validation split… (this may take a minute)") ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True) mapping = {} for item in ds: meta = item.get("Annotator Metadata") or {} try: n_tools = int(meta.get("Number of tools", 99)) n_steps = int(meta.get("Number of steps", 99)) except ValueError: continue # skip malformed counts if n_tools < TOOL_THRESHOLD and n_steps < STEP_THRESHOLD: task_id = str(item["task_id"]) mapping[task_id] = str(item["Final answer"]) print(f"Ground-truth map built for {len(mapping)} tasks.") return mapping def fetch_random_question(api_base: str = DEFAULT_API_URL): """GET /random-question from the scoring API and return its JSON. Raises requests.HTTPError on failure.""" url = f"{api_base}/random-question" resp = requests.get(url, timeout=30) resp.raise_for_status() return resp.json() def main(): gt = build_ground_truth_mapping() q = fetch_random_question() task_id = str(q["task_id"]) question_text = q["question"] print("\n=== Random Question ===") print(f"Task ID : {task_id}") print(f"Question: {question_text}\n") agent = BasicAgent() answer = agent.agent.invoke({"messages": [HumanMessage(content=question_text)]}) # sync invoke for simplicity # If agent.aquery exists (async), you could run via asyncio.run; here we keep it simple. if isinstance(answer, dict) and "messages" in answer and answer["messages"]: answer_str = answer["messages"][-1].content.strip() else: print("Agent returned unexpected structure – treating as error.") answer_str = "" print(f"Agent answer: {answer_str}") gt_answer = gt.get(task_id) if gt_answer is None: print("Ground-truth answer not found for this task ID. Filtering may be out of sync with the server.") return print(f"Ground truth: {gt_answer}") if answer_str.strip().lower() == gt_answer.strip().lower(): print("✅ Correct!") else: print("❌ Incorrect.") if __name__ == "__main__": main()