File size: 2,632 Bytes
8417a25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import os
import requests
from datasets import load_dataset
from basic_agent import BasicAgent, DEFAULT_API_URL
from langchain_core.messages import HumanMessage
# Server-side filtering parameters (keep in sync with scoring API)
TOOL_THRESHOLD = 3
STEP_THRESHOLD = 6
def build_ground_truth_mapping():
"""Return a dict mapping task_id -> final_answer for the filtered GAIA validation set."""
print("Loading GAIA benchmark validation split… (this may take a minute)")
ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
mapping = {}
for item in ds:
meta = item.get("Annotator Metadata") or {}
try:
n_tools = int(meta.get("Number of tools", 99))
n_steps = int(meta.get("Number of steps", 99))
except ValueError:
continue # skip malformed counts
if n_tools < TOOL_THRESHOLD and n_steps < STEP_THRESHOLD:
task_id = str(item["task_id"])
mapping[task_id] = str(item["Final answer"])
print(f"Ground-truth map built for {len(mapping)} tasks.")
return mapping
def fetch_random_question(api_base: str = DEFAULT_API_URL):
"""GET /random-question from the scoring API and return its JSON.
Raises requests.HTTPError on failure."""
url = f"{api_base}/random-question"
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.json()
def main():
gt = build_ground_truth_mapping()
q = fetch_random_question()
task_id = str(q["task_id"])
question_text = q["question"]
print("\n=== Random Question ===")
print(f"Task ID : {task_id}")
print(f"Question: {question_text}\n")
agent = BasicAgent()
answer = agent.agent.invoke({"messages": [HumanMessage(content=question_text)]}) # sync invoke for simplicity
# If agent.aquery exists (async), you could run via asyncio.run; here we keep it simple.
if isinstance(answer, dict) and "messages" in answer and answer["messages"]:
answer_str = answer["messages"][-1].content.strip()
else:
print("Agent returned unexpected structure – treating as error.")
answer_str = ""
print(f"Agent answer: {answer_str}")
gt_answer = gt.get(task_id)
if gt_answer is None:
print("Ground-truth answer not found for this task ID. Filtering may be out of sync with the server.")
return
print(f"Ground truth: {gt_answer}")
if answer_str.strip().lower() == gt_answer.strip().lower():
print("✅ Correct!")
else:
print("❌ Incorrect.")
if __name__ == "__main__":
main() |