| | |
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import re |
| | import subprocess |
| | from collections import defaultdict |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | ROOT = Path(__file__).resolve().parents[1] |
| | CARDS_DIR = ROOT / '.fast-agent' / 'tool-cards' |
| | PROMPTS_FILE = ROOT / 'scripts' / 'tool_routing_challenges.txt' |
| | EXPECTED_FILE = ROOT / 'scripts' / 'tool_routing_expected.json' |
| | OUT_DIR = ROOT / 'docs' / 'tool_routing_eval' |
| |
|
| | ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") |
| |
|
| |
|
| | def strip_ansi(text: str) -> str: |
| | return ANSI_RE.sub('', text) |
| |
|
| |
|
| | def load_prompts(path: Path) -> list[str]: |
| | lines = [ln.strip() for ln in path.read_text(encoding='utf-8').splitlines()] |
| | return [ln for ln in lines if ln] |
| |
|
| |
|
| | def load_expected(path: Path) -> dict[int, dict[str, Any]]: |
| | rows = json.loads(path.read_text(encoding='utf-8')) |
| | out: dict[int, dict[str, Any]] = {} |
| | for row in rows: |
| | out[int(row['id'])] = row |
| | return out |
| |
|
| |
|
| | def _extract_session_observations(result_path: Path) -> dict[str, Any]: |
| | data = json.loads(result_path.read_text(encoding='utf-8')) |
| | messages = data.get('messages', []) if isinstance(data, dict) else [] |
| |
|
| | tool_calls: list[str] = [] |
| | merged_parts: list[str] = [] |
| |
|
| | for msg in messages: |
| | if not isinstance(msg, dict): |
| | continue |
| |
|
| | if msg.get('role') == 'assistant': |
| | for item in msg.get('content', []) or []: |
| | if isinstance(item, dict) and item.get('type') == 'text' and item.get('text'): |
| | merged_parts.append(str(item['text'])) |
| |
|
| | channels = msg.get('channels') or {} |
| | for ch_name in ('reasoning',): |
| | for item in channels.get(ch_name, []) or []: |
| | if isinstance(item, dict) and item.get('text'): |
| | merged_parts.append(str(item['text'])) |
| |
|
| | tc_map = msg.get('tool_calls') or {} |
| | if isinstance(tc_map, dict): |
| | for tc in tc_map.values(): |
| | params = (tc or {}).get('params', {}) if isinstance(tc, dict) else {} |
| | name = params.get('name') if isinstance(params, dict) else None |
| | if isinstance(name, str): |
| | tool_calls.append(name) |
| | merged_parts.append(f'tool call - {name}') |
| | args = params.get('arguments') if isinstance(params, dict) else None |
| | if isinstance(args, dict): |
| | merged_parts.append(json.dumps(args, ensure_ascii=False)) |
| |
|
| | if msg.get('role') == 'user': |
| | tr_map = msg.get('tool_results') or {} |
| | if isinstance(tr_map, dict): |
| | for tr in tr_map.values(): |
| | for item in (tr or {}).get('content', []) or []: |
| | if isinstance(item, dict) and item.get('type') == 'text' and item.get('text'): |
| | merged_parts.append(str(item['text'])) |
| |
|
| | called_tools = list(dict.fromkeys(tool_calls)) |
| | return { |
| | 'tool_calls': tool_calls, |
| | 'called_tools': called_tools, |
| | 'merged_from_result': '\n'.join(merged_parts).strip(), |
| | } |
| |
|
| |
|
| | def run_prompt( |
| | prompt: str, |
| | model: str, |
| | agent: str, |
| | cards_dir: Path, |
| | timeout_sec: int, |
| | result_path: Path, |
| | ) -> dict[str, Any]: |
| | result_path.parent.mkdir(parents=True, exist_ok=True) |
| | cmd = [ |
| | 'fast-agent', 'go', |
| | '--no-env', |
| | '--model', model, |
| | '--agent-cards', str(cards_dir), |
| | '--agent', agent, |
| | '--results', str(result_path), |
| | '-m', prompt, |
| | ] |
| |
|
| | proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout_sec) |
| | out = strip_ansi(proc.stdout or '') |
| | err = strip_ansi(proc.stderr or '') |
| | merged_console = (out + '\n' + err).strip() |
| |
|
| | if not result_path.exists(): |
| | raise RuntimeError(f'Expected --results file not written: {result_path}') |
| |
|
| | parsed = _extract_session_observations(result_path) |
| | tool_calls = parsed['tool_calls'] |
| | called_tools = parsed['called_tools'] |
| | merged = parsed['merged_from_result'] |
| |
|
| | return { |
| | 'returncode': proc.returncode, |
| | 'stdout': out, |
| | 'stderr': err, |
| | 'merged': merged, |
| | 'merged_console': merged_console, |
| | 'tool_calls': tool_calls, |
| | 'called_tools': called_tools, |
| | 'result_file': str(result_path), |
| | } |
| |
|
| |
|
| | def _match_any(observed: str | None, expected_any: list[str] | None) -> bool | None: |
| | if expected_any is None: |
| | return None |
| | if observed is None: |
| | return False |
| | return observed in expected_any |
| |
|
| |
|
| | def evaluate_case(obs: dict[str, Any], exp: dict[str, Any]) -> dict[str, Any]: |
| | tool_calls: list[str] = obs['tool_calls'] |
| | called_tools: list[str] = obs['called_tools'] |
| |
|
| | first_tool = tool_calls[0] if tool_calls else None |
| | primary_tool = None |
| | if called_tools: |
| | primary_tool = max(called_tools, key=lambda t: tool_calls.count(t)) |
| |
|
| | expect_no_tool = bool(exp.get('expect_no_tool_call', False)) |
| | expected_first = exp.get('expected_first_any') |
| | expected_primary = exp.get('expected_primary_any') |
| | allowed_tools = exp.get('allowed_tools') |
| |
|
| | success = (obs['returncode'] == 0 and 'Traceback' not in obs['merged']) |
| |
|
| | if expect_no_tool: |
| | first_ok = (first_tool is None) |
| | primary_ok = (primary_tool is None) |
| | else: |
| | first_ok = _match_any(first_tool, expected_first) |
| | primary_ok = _match_any(primary_tool, expected_primary) |
| |
|
| | if allowed_tools is None: |
| | chain_ok = True |
| | else: |
| | chain_ok = all(t in allowed_tools for t in called_tools) |
| |
|
| | |
| | route_first = 2 if first_ok else 0 |
| | route_primary = 2 if primary_ok else 0 |
| | route_chain = 2 if chain_ok else 0 |
| | route_success = 2 if success else 0 |
| |
|
| | |
| | calls = len(tool_calls) |
| | bucket = exp.get('bucket', 'other') |
| | if bucket == 'distractor_positive': |
| | efficiency = 2 if calls <= 2 else (1 if calls <= 4 else 0) |
| | elif bucket == 'mixed_chain': |
| | efficiency = 2 if calls <= 4 else (1 if calls <= 6 else 0) |
| | elif exp.get('expect_no_tool_call', False): |
| | efficiency = 2 if calls == 0 else (1 if calls == 1 else 0) |
| | else: |
| | efficiency = 2 if calls <= 5 else (1 if calls <= 8 else 0) |
| |
|
| | total = route_first + route_primary + route_chain + route_success + efficiency |
| |
|
| | return { |
| | 'first_tool': first_tool, |
| | 'primary_tool': primary_tool, |
| | 'tool_calls_count': calls, |
| | 'first_ok': first_ok, |
| | 'primary_ok': primary_ok, |
| | 'chain_ok': chain_ok, |
| | 'success': success, |
| | 'bucket': bucket, |
| | 'score': { |
| | 'first': route_first, |
| | 'primary': route_primary, |
| | 'chain': route_chain, |
| | 'success': route_success, |
| | 'efficiency': efficiency, |
| | 'total': total, |
| | }, |
| | } |
| |
|
| |
|
| | def summarize(rows: list[dict[str, Any]]) -> dict[str, Any]: |
| | n = len(rows) |
| | first_acc = sum(1 for r in rows if r['eval']['first_ok']) / n if n else 0.0 |
| | primary_acc = sum(1 for r in rows if r['eval']['primary_ok']) / n if n else 0.0 |
| | chain_acc = sum(1 for r in rows if r['eval']['chain_ok']) / n if n else 0.0 |
| | success_rate = sum(1 for r in rows if r['eval']['success']) / n if n else 0.0 |
| | avg_calls = sum(r['eval']['tool_calls_count'] for r in rows) / n if n else 0.0 |
| | avg_score = sum(r['eval']['score']['total'] for r in rows) / n if n else 0.0 |
| |
|
| | by_bucket = defaultdict(list) |
| | for r in rows: |
| | by_bucket[r['eval']['bucket']].append(r) |
| |
|
| | bucket_summary = {} |
| | for b, items in by_bucket.items(): |
| | m = len(items) |
| | bucket_summary[b] = { |
| | 'n': m, |
| | 'first_acc': round(sum(1 for r in items if r['eval']['first_ok']) / m, 4), |
| | 'primary_acc': round(sum(1 for r in items if r['eval']['primary_ok']) / m, 4), |
| | 'avg_calls': round(sum(r['eval']['tool_calls_count'] for r in items) / m, 3), |
| | 'avg_score': round(sum(r['eval']['score']['total'] for r in items) / m, 3), |
| | } |
| |
|
| | return { |
| | 'n_cases': n, |
| | 'first_accuracy': round(first_acc, 4), |
| | 'primary_accuracy': round(primary_acc, 4), |
| | 'chain_accuracy': round(chain_acc, 4), |
| | 'success_rate': round(success_rate, 4), |
| | 'avg_tool_calls': round(avg_calls, 3), |
| | 'avg_score_total': round(avg_score, 3), |
| | 'bucket_summary': bucket_summary, |
| | } |
| |
|
| |
|
| | def render_md(rows: list[dict[str, Any]], summary: dict[str, Any], model: str, agent: str) -> str: |
| | out = [ |
| | '# Tool Routing/Confusion Evaluation Report', |
| | '', |
| | f'- Model: `{model}`', |
| | f'- Agent: `{agent}`', |
| | f"- Cases: **{summary['n_cases']}**", |
| | '', |
| | '## Overall metrics', |
| | '', |
| | f"- First-tool accuracy: **{summary['first_accuracy']}**", |
| | f"- Primary-tool accuracy: **{summary['primary_accuracy']}**", |
| | f"- Allowed-chain accuracy: **{summary['chain_accuracy']}**", |
| | f"- Success rate: **{summary['success_rate']}**", |
| | f"- Avg tool calls: **{summary['avg_tool_calls']}**", |
| | f"- Avg score (/10): **{summary['avg_score_total']}**", |
| | '', |
| | '## By bucket', |
| | '', |
| | '| Bucket | N | First acc | Primary acc | Avg calls | Avg score |', |
| | '|---|---:|---:|---:|---:|---:|', |
| | ] |
| |
|
| | for b, s in sorted(summary['bucket_summary'].items()): |
| | out.append(f"| {b} | {s['n']} | {s['first_acc']} | {s['primary_acc']} | {s['avg_calls']} | {s['avg_score']} |") |
| |
|
| | out += [ |
| | '', |
| | '## Case details', |
| | '', |
| | '| # | Bucket | First tool | Primary tool | Calls | First OK | Primary OK | Chain OK | Success | Score |', |
| | '|---|---|---|---|---:|---:|---:|---:|---:|---:|', |
| | ] |
| |
|
| | for r in rows: |
| | e = r['eval'] |
| | s = e['score'] |
| | out.append( |
| | f"| {r['id']} | {e['bucket']} | {e['first_tool'] or '-'} | {e['primary_tool'] or '-'} | {e['tool_calls_count']} | {int(bool(e['first_ok']))} | {int(bool(e['primary_ok']))} | {int(bool(e['chain_ok']))} | {int(bool(e['success']))} | {s['total']} |" |
| | ) |
| |
|
| | return '\n'.join(out) + '\n' |
| |
|
| |
|
| | def main() -> None: |
| | ap = argparse.ArgumentParser(description='Score tool-routing/confusion benchmark') |
| | ap.add_argument('--model', required=True, help='Model ID') |
| | ap.add_argument('--agent', default='hf_hub_community', help='Agent name to run') |
| | ap.add_argument('--agent-cards', type=Path, default=CARDS_DIR) |
| | ap.add_argument('--prompts', type=Path, default=PROMPTS_FILE) |
| | ap.add_argument('--expected', type=Path, default=EXPECTED_FILE) |
| | ap.add_argument('--start', type=int, default=1) |
| | ap.add_argument('--end', type=int, default=20) |
| | ap.add_argument('--timeout', type=int, default=240) |
| | ap.add_argument('--out-dir', type=Path, default=OUT_DIR) |
| | ap.add_argument('--raw-results-dir', type=Path, default=None, help='Where to store fast-agent --results JSON files') |
| | args = ap.parse_args() |
| |
|
| | raw_results_dir = args.raw_results_dir or (args.out_dir / 'raw_results') |
| |
|
| | prompts = load_prompts(args.prompts) |
| | expected = load_expected(args.expected) |
| |
|
| | subset = [(i, p) for i, p in enumerate(prompts, start=1) if args.start <= i <= args.end] |
| |
|
| | rows: list[dict[str, Any]] = [] |
| | for i, prompt in subset: |
| | safe_model = args.model.replace('/', '_') |
| | result_path = raw_results_dir / safe_model / f'case_{i:02d}.json' |
| | obs = run_prompt( |
| | prompt, |
| | model=args.model, |
| | agent=args.agent, |
| | cards_dir=args.agent_cards, |
| | timeout_sec=args.timeout, |
| | result_path=result_path, |
| | ) |
| | exp = expected.get(i, {'id': i, 'bucket': 'other'}) |
| | ev = evaluate_case(obs, exp) |
| |
|
| | row = { |
| | 'id': i, |
| | 'prompt': prompt, |
| | 'expected': exp, |
| | 'observed': { |
| | 'returncode': obs['returncode'], |
| | 'tool_calls': obs['tool_calls'], |
| | 'called_tools': obs['called_tools'], |
| | 'result_file': obs.get('result_file'), |
| | }, |
| | 'eval': ev, |
| | 'merged': obs['merged'], |
| | } |
| | rows.append(row) |
| | print(f"[{i}] score={ev['score']['total']}/10 first={ev['first_tool']} primary={ev['primary_tool']} calls={ev['tool_calls_count']}") |
| |
|
| | summary = summarize(rows) |
| |
|
| | args.out_dir.mkdir(parents=True, exist_ok=True) |
| | stem = f"tool_routing_{args.model.replace('/', '_')}" |
| | json_path = args.out_dir / f"{stem}.json" |
| | md_path = args.out_dir / f"{stem}.md" |
| |
|
| | json_path.write_text(json.dumps({'summary': summary, 'rows': rows}, indent=2), encoding='utf-8') |
| | md_path.write_text(render_md(rows, summary, model=args.model, agent=args.agent), encoding='utf-8') |
| |
|
| | print(f"\nWrote:\n- {json_path}\n- {md_path}") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|