#!/usr/bin/env python """Evaluate any policy / model on the held-out scenario set. Two modes: policy Use the deterministic POLICY_PLANS asker from inference.py — no LLM, free, deterministic, used as the floor baseline. api Use an OpenAI-compatible chat endpoint (the same path the submission validator uses on inference.py). Set: MODEL_NAME e.g. Qwen/Qwen3-0.6B API_BASE_URL e.g. https://router.huggingface.co/v1 HF_TOKEN write/read token Output: a single JSON file with per-scenario scores, breakdowns, question counts, and aggregate metrics — formatted exactly the way `scripts/make_plots.py` consumes. Usage: # baseline (deterministic policy) python scripts/run_eval.py --mode policy --out outputs/eval_policy.json --limit 100 # untrained Qwen3-0.6B via HF Inference router HF_TOKEN=hf_xxx MODEL_NAME=Qwen/Qwen3-0.6B \\ python scripts/run_eval.py --mode api --out outputs/eval_qwen3-0.6b_base.json --limit 100 # trained model via HF Inference Endpoints (you provided the URL) API_BASE_URL=https://my-endpoint.endpoints.huggingface.cloud/v1 \\ MODEL_NAME=clarify-rl-grpo-qwen3-0.6b HF_TOKEN=hf_xxx \\ python scripts/run_eval.py --mode api --out outputs/eval_qwen3-0.6b_trained.json --limit 100 """ from __future__ import annotations import argparse import asyncio import json import os import sys import time from pathlib import Path from typing import Any, Optional # Make the inference.py helpers importable without copy-paste. _HERE = Path(__file__).resolve().parent _REPO = _HERE.parent sys.path.insert(0, str(_REPO)) def _lazy_import_inference(): """Lazy-import inference.py so `--help` works without openai installed.""" import inference as _inf # type: ignore return _inf def _make_ws_url(base_url: str) -> str: return base_url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/") + "/ws" async def _ws_reset_with_seed(ws, task_id: str, seed: int) -> dict: """Reset env to a specific (task_id, seed) — exact replay of an eval scenario.""" await ws.send(json.dumps({"type": "reset", "data": {"task_id": task_id, "seed": seed}})) resp = json.loads(await ws.recv()) if resp.get("type") == "error": return {"observation": {}, "reward": 0.0, "done": False, "error": resp.get("data", {})} data = resp.get("data", {}) return { "observation": data.get("observation", {}), "reward": float(data.get("reward", 0.0)), "done": bool(data.get("done", False)), } def _parse_observation(obs: dict) -> dict: """Pull the canonical tool-result dict out of an MCP observation.""" result = obs.get("result") if isinstance(result, dict): if isinstance(result.get("structured_content"), dict): return result["structured_content"] if isinstance(result.get("data"), dict): return result["data"] content = result.get("content") if isinstance(content, list) and content: txt = content[0].get("text", "") try: parsed = json.loads(txt) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass return result if isinstance(result, str): try: parsed = json.loads(result) if isinstance(parsed, dict): return parsed except json.JSONDecodeError: pass return {} async def _eval_one_scenario( ws, scenario: dict, mode: str, llm_client, timeout_s: float, inf, ) -> dict: """Run a single scenario end-to-end. Returns a result row.""" seed = scenario["seed"] task_id = scenario["task_id"] family = scenario.get("family", "") t0 = time.time() reset = await _ws_reset_with_seed(ws, task_id, seed) if "error" in reset: return { "seed": seed, "task_id": task_id, "scenario_id": f"seed{seed:05d}_{family}_{task_id}", "family": family, "request": "", "final_score": 0.0, "score_breakdown": {}, "questions_asked": 0, "format_pass": False, "error": str(reset["error"]), "messages": [], "trace": [], "elapsed_s": time.time() - t0, } initial_data = _parse_observation(reset["observation"]) request_text = initial_data.get("request", "") max_steps = int(initial_data.get("max_steps", 10)) messages = [ {"role": "system", "content": inf.SYSTEM_PROMPT}, {"role": "user", "content": ( f"USER REQUEST:\n{request_text}\n\nYou have {max_steps} steps. " "Available tools: ask_question(question), propose_plan(plan), get_task_info().\n\n" "RESPONSE FORMAT: Reply with ONE function call only, no other text.\n" "Examples:\n" " ask_question(\"What is the date?\")\n" " propose_plan('{\"event_type\": \"birthday\", \"date\": \"2024-12-25\"}')\n" " get_task_info()\n" )}, ] trace: list[dict] = [] revealed: dict[str, Any] = {} questions_asked = 0 final_score = 0.0 score_breakdown: dict[str, float] = {} format_pass: Optional[bool] = None parse_error: Optional[str] = None llm_attempts = 0 used_policy_step = 0 done = False for step in range(max_steps): if time.time() - t0 > timeout_s: trace.append({"step": step, "error": "timeout"}) break if mode == "policy": tool_name, args = inf._next_policy_action( # type: ignore[attr-defined] task_id, used_policy_step, request_text, revealed ) used_policy_step += 1 else: # api tool_name, args, fellback, llm_attempts = inf._choose_action( # type: ignore[attr-defined] task_id, messages, llm_client, used_policy_step, llm_attempts, request_text, revealed ) if fellback: used_policy_step += 1 try: step_resp = await inf.ws_step(ws, tool_name, args) except Exception as exc: # noqa: BLE001 trace.append({"step": step, "error": f"ws_step exception: {exc}"}) break obs = step_resp.get("observation", {}) or {} result = _parse_observation(obs) done = bool(step_resp.get("done")) record = { "step": step, "tool": tool_name, "args": args, "reward": float(step_resp.get("reward", 0.0)), "done": done, "result": result, } trace.append(record) format_reminder = ( "\n\nReminder: Reply with ONE function call only " "(ask_question/propose_plan/get_task_info), no other text." ) if tool_name == "ask_question": questions_asked += 1 if isinstance(result, dict) and result.get("field_revealed"): fld = result["field_revealed"] ans = result.get("answer", "") revealed[fld] = ans messages.append({"role": "user", "content": json.dumps(result) + format_reminder}) elif tool_name == "get_task_info": messages.append({"role": "user", "content": json.dumps(result) + format_reminder}) elif tool_name == "propose_plan": if isinstance(result, dict): final_score = float(result.get("score", step_resp.get("reward", 0.0))) score_breakdown = result.get("breakdown", {}) or {} parse_error = result.get("parse_error") fmt = score_breakdown.get("FormatCheck") or score_breakdown.get("format_check") if fmt is not None: format_pass = fmt > 0 done = True if done: break return { "seed": seed, "task_id": task_id, "scenario_id": f"seed{seed:05d}_{family}_{task_id}", "family": family, "request": request_text, "final_score": final_score, "score_breakdown": score_breakdown, "questions_asked": questions_asked, "format_pass": format_pass, "parse_error": parse_error, "messages": messages, "trace": trace, "elapsed_s": time.time() - t0, } async def _run(args) -> dict: inf = _lazy_import_inference() eval_path = Path(args.scenarios) if not eval_path.exists(): raise FileNotFoundError(f"Scenario file not found: {eval_path}") scenarios = json.loads(eval_path.read_text()) if args.limit and args.limit < len(scenarios): scenarios = scenarios[: args.limit] print(f"Loaded {len(scenarios)} scenarios from {eval_path}") llm_client = None if args.mode == "api": if not inf.API_KEY: raise RuntimeError("api mode requires HF_TOKEN / OPENAI_API_KEY") llm_client = inf.create_client() if llm_client is None: raise RuntimeError("Failed to create OpenAI client (check API_BASE_URL/HF_TOKEN)") print(f"Using OpenAI client with base_url={inf.API_BASE_URL} model={inf.MODEL_NAME}") else: print("Mode: policy (deterministic, no LLM)") import websockets results: list[dict] = [] ws_url = _make_ws_url(args.env) print(f"Env WS: {ws_url}") print(f"Output to: {args.out}") print() overall_t0 = time.time() async with websockets.connect( ws_url, open_timeout=30, close_timeout=10, max_size=2**24 ) as ws: for i, scn in enumerate(scenarios): print(f"[{i+1}/{len(scenarios)}] family={scn.get('family','?')} task={scn['task_id']} seed={scn['seed']}", flush=True) row = await _eval_one_scenario(ws, scn, args.mode, llm_client, args.timeout, inf) results.append(row) print( f" score={row['final_score']:.3f} q={row['questions_asked']} fmt={row['format_pass']} " f"err={row.get('error') or row.get('parse_error') or ''}", flush=True, ) total_s = time.time() - overall_t0 scores = [r["final_score"] for r in results] fmt_passes = [r["format_pass"] for r in results if r["format_pass"] is not None] qs = [r["questions_asked"] for r in results] summary = { "model": inf.MODEL_NAME if args.mode == "api" else None, "mode": args.mode, "scenarios_total": len(results), "elapsed_s": total_s, "avg_score": sum(scores) / len(scores) if scores else 0.0, "avg_questions": sum(qs) / len(qs) if qs else 0.0, "format_pass_rate": (sum(1 for f in fmt_passes if f) / len(fmt_passes)) if fmt_passes else 0.0, "completion_rate": sum(1 for r in results if r["final_score"] > 0) / max(1, len(results)), } payload = { "summary": summary, "config": { "mode": args.mode, "model": inf.MODEL_NAME if args.mode == "api" else None, "api_base_url": inf.API_BASE_URL if args.mode == "api" else None, "env_base_url": args.env, "scenarios_file": str(eval_path), "limit": args.limit, }, "results": results, } out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(payload, indent=2)) print() print(f"Saved {len(results)} results to {out_path}") print(f"Avg score: {summary['avg_score']:.4f}") print(f"Format pass rate: {summary['format_pass_rate']:.4f}") print(f"Completion rate: {summary['completion_rate']:.4f}") print(f"Avg questions: {summary['avg_questions']:.2f}") print(f"Total elapsed: {total_s:.1f} s") return summary def main() -> None: parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--mode", choices=("policy", "api"), required=True) parser.add_argument( "--scenarios", default=str(_REPO / "scenarios" / "eval_held_out.json"), help="Path to eval scenario JSON (default: scenarios/eval_held_out.json)", ) parser.add_argument("--out", required=True, help="Output JSON file (e.g. outputs/eval_policy.json)") parser.add_argument("--limit", type=int, default=None, help="Cap to first N scenarios") parser.add_argument( "--env", default=os.environ.get("ENV_BASE_URL", "https://agarwalanu3103-clarify-rl.hf.space"), help="Env Space URL", ) parser.add_argument("--timeout", type=float, default=180.0, help="Per-scenario timeout in seconds") args = parser.parse_args() asyncio.run(_run(args)) if __name__ == "__main__": main()