hf-papers / scripts /score_tool_routing_confusion.py
evalstate's picture
evalstate HF Staff
sync: promote hf_hub_community prompt v3 + add prompt/coverage harness
bba4fab verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import re
import subprocess
from collections import defaultdict
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parents[1]
CARDS_DIR = ROOT / '.fast-agent' / 'tool-cards'
PROMPTS_FILE = ROOT / 'scripts' / 'tool_routing_challenges.txt'
EXPECTED_FILE = ROOT / 'scripts' / 'tool_routing_expected.json'
OUT_DIR = ROOT / 'docs' / 'tool_routing_eval'
ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")
def strip_ansi(text: str) -> str:
return ANSI_RE.sub('', text)
def load_prompts(path: Path) -> list[str]:
lines = [ln.strip() for ln in path.read_text(encoding='utf-8').splitlines()]
return [ln for ln in lines if ln]
def load_expected(path: Path) -> dict[int, dict[str, Any]]:
rows = json.loads(path.read_text(encoding='utf-8'))
out: dict[int, dict[str, Any]] = {}
for row in rows:
out[int(row['id'])] = row
return out
def _extract_session_observations(result_path: Path) -> dict[str, Any]:
data = json.loads(result_path.read_text(encoding='utf-8'))
messages = data.get('messages', []) if isinstance(data, dict) else []
tool_calls: list[str] = []
merged_parts: list[str] = []
for msg in messages:
if not isinstance(msg, dict):
continue
if msg.get('role') == 'assistant':
for item in msg.get('content', []) or []:
if isinstance(item, dict) and item.get('type') == 'text' and item.get('text'):
merged_parts.append(str(item['text']))
channels = msg.get('channels') or {}
for ch_name in ('reasoning',):
for item in channels.get(ch_name, []) or []:
if isinstance(item, dict) and item.get('text'):
merged_parts.append(str(item['text']))
tc_map = msg.get('tool_calls') or {}
if isinstance(tc_map, dict):
for tc in tc_map.values():
params = (tc or {}).get('params', {}) if isinstance(tc, dict) else {}
name = params.get('name') if isinstance(params, dict) else None
if isinstance(name, str):
tool_calls.append(name)
merged_parts.append(f'tool call - {name}')
args = params.get('arguments') if isinstance(params, dict) else None
if isinstance(args, dict):
merged_parts.append(json.dumps(args, ensure_ascii=False))
if msg.get('role') == 'user':
tr_map = msg.get('tool_results') or {}
if isinstance(tr_map, dict):
for tr in tr_map.values():
for item in (tr or {}).get('content', []) or []:
if isinstance(item, dict) and item.get('type') == 'text' and item.get('text'):
merged_parts.append(str(item['text']))
called_tools = list(dict.fromkeys(tool_calls))
return {
'tool_calls': tool_calls,
'called_tools': called_tools,
'merged_from_result': '\n'.join(merged_parts).strip(),
}
def run_prompt(
prompt: str,
model: str,
agent: str,
cards_dir: Path,
timeout_sec: int,
result_path: Path,
) -> dict[str, Any]:
result_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
'fast-agent', 'go',
'--no-env',
'--model', model,
'--agent-cards', str(cards_dir),
'--agent', agent,
'--results', str(result_path),
'-m', prompt,
]
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout_sec)
out = strip_ansi(proc.stdout or '')
err = strip_ansi(proc.stderr or '')
merged_console = (out + '\n' + err).strip()
if not result_path.exists():
raise RuntimeError(f'Expected --results file not written: {result_path}')
parsed = _extract_session_observations(result_path)
tool_calls = parsed['tool_calls']
called_tools = parsed['called_tools']
merged = parsed['merged_from_result']
return {
'returncode': proc.returncode,
'stdout': out,
'stderr': err,
'merged': merged,
'merged_console': merged_console,
'tool_calls': tool_calls,
'called_tools': called_tools,
'result_file': str(result_path),
}
def _match_any(observed: str | None, expected_any: list[str] | None) -> bool | None:
if expected_any is None:
return None
if observed is None:
return False
return observed in expected_any
def evaluate_case(obs: dict[str, Any], exp: dict[str, Any]) -> dict[str, Any]:
tool_calls: list[str] = obs['tool_calls']
called_tools: list[str] = obs['called_tools']
first_tool = tool_calls[0] if tool_calls else None
primary_tool = None
if called_tools:
primary_tool = max(called_tools, key=lambda t: tool_calls.count(t))
expect_no_tool = bool(exp.get('expect_no_tool_call', False))
expected_first = exp.get('expected_first_any')
expected_primary = exp.get('expected_primary_any')
allowed_tools = exp.get('allowed_tools')
success = (obs['returncode'] == 0 and 'Traceback' not in obs['merged'])
if expect_no_tool:
first_ok = (first_tool is None)
primary_ok = (primary_tool is None)
else:
first_ok = _match_any(first_tool, expected_first)
primary_ok = _match_any(primary_tool, expected_primary)
if allowed_tools is None:
chain_ok = True
else:
chain_ok = all(t in allowed_tools for t in called_tools)
# simple /10 routing score
route_first = 2 if first_ok else 0
route_primary = 2 if primary_ok else 0
route_chain = 2 if chain_ok else 0
route_success = 2 if success else 0
# efficiency heuristic by bucket
calls = len(tool_calls)
bucket = exp.get('bucket', 'other')
if bucket == 'distractor_positive':
efficiency = 2 if calls <= 2 else (1 if calls <= 4 else 0)
elif bucket == 'mixed_chain':
efficiency = 2 if calls <= 4 else (1 if calls <= 6 else 0)
elif exp.get('expect_no_tool_call', False):
efficiency = 2 if calls == 0 else (1 if calls == 1 else 0)
else:
efficiency = 2 if calls <= 5 else (1 if calls <= 8 else 0)
total = route_first + route_primary + route_chain + route_success + efficiency
return {
'first_tool': first_tool,
'primary_tool': primary_tool,
'tool_calls_count': calls,
'first_ok': first_ok,
'primary_ok': primary_ok,
'chain_ok': chain_ok,
'success': success,
'bucket': bucket,
'score': {
'first': route_first,
'primary': route_primary,
'chain': route_chain,
'success': route_success,
'efficiency': efficiency,
'total': total,
},
}
def summarize(rows: list[dict[str, Any]]) -> dict[str, Any]:
n = len(rows)
first_acc = sum(1 for r in rows if r['eval']['first_ok']) / n if n else 0.0
primary_acc = sum(1 for r in rows if r['eval']['primary_ok']) / n if n else 0.0
chain_acc = sum(1 for r in rows if r['eval']['chain_ok']) / n if n else 0.0
success_rate = sum(1 for r in rows if r['eval']['success']) / n if n else 0.0
avg_calls = sum(r['eval']['tool_calls_count'] for r in rows) / n if n else 0.0
avg_score = sum(r['eval']['score']['total'] for r in rows) / n if n else 0.0
by_bucket = defaultdict(list)
for r in rows:
by_bucket[r['eval']['bucket']].append(r)
bucket_summary = {}
for b, items in by_bucket.items():
m = len(items)
bucket_summary[b] = {
'n': m,
'first_acc': round(sum(1 for r in items if r['eval']['first_ok']) / m, 4),
'primary_acc': round(sum(1 for r in items if r['eval']['primary_ok']) / m, 4),
'avg_calls': round(sum(r['eval']['tool_calls_count'] for r in items) / m, 3),
'avg_score': round(sum(r['eval']['score']['total'] for r in items) / m, 3),
}
return {
'n_cases': n,
'first_accuracy': round(first_acc, 4),
'primary_accuracy': round(primary_acc, 4),
'chain_accuracy': round(chain_acc, 4),
'success_rate': round(success_rate, 4),
'avg_tool_calls': round(avg_calls, 3),
'avg_score_total': round(avg_score, 3),
'bucket_summary': bucket_summary,
}
def render_md(rows: list[dict[str, Any]], summary: dict[str, Any], model: str, agent: str) -> str:
out = [
'# Tool Routing/Confusion Evaluation Report',
'',
f'- Model: `{model}`',
f'- Agent: `{agent}`',
f"- Cases: **{summary['n_cases']}**",
'',
'## Overall metrics',
'',
f"- First-tool accuracy: **{summary['first_accuracy']}**",
f"- Primary-tool accuracy: **{summary['primary_accuracy']}**",
f"- Allowed-chain accuracy: **{summary['chain_accuracy']}**",
f"- Success rate: **{summary['success_rate']}**",
f"- Avg tool calls: **{summary['avg_tool_calls']}**",
f"- Avg score (/10): **{summary['avg_score_total']}**",
'',
'## By bucket',
'',
'| Bucket | N | First acc | Primary acc | Avg calls | Avg score |',
'|---|---:|---:|---:|---:|---:|',
]
for b, s in sorted(summary['bucket_summary'].items()):
out.append(f"| {b} | {s['n']} | {s['first_acc']} | {s['primary_acc']} | {s['avg_calls']} | {s['avg_score']} |")
out += [
'',
'## Case details',
'',
'| # | Bucket | First tool | Primary tool | Calls | First OK | Primary OK | Chain OK | Success | Score |',
'|---|---|---|---|---:|---:|---:|---:|---:|---:|',
]
for r in rows:
e = r['eval']
s = e['score']
out.append(
f"| {r['id']} | {e['bucket']} | {e['first_tool'] or '-'} | {e['primary_tool'] or '-'} | {e['tool_calls_count']} | {int(bool(e['first_ok']))} | {int(bool(e['primary_ok']))} | {int(bool(e['chain_ok']))} | {int(bool(e['success']))} | {s['total']} |"
)
return '\n'.join(out) + '\n'
def main() -> None:
ap = argparse.ArgumentParser(description='Score tool-routing/confusion benchmark')
ap.add_argument('--model', required=True, help='Model ID')
ap.add_argument('--agent', default='hf_hub_community', help='Agent name to run')
ap.add_argument('--agent-cards', type=Path, default=CARDS_DIR)
ap.add_argument('--prompts', type=Path, default=PROMPTS_FILE)
ap.add_argument('--expected', type=Path, default=EXPECTED_FILE)
ap.add_argument('--start', type=int, default=1)
ap.add_argument('--end', type=int, default=20)
ap.add_argument('--timeout', type=int, default=240)
ap.add_argument('--out-dir', type=Path, default=OUT_DIR)
ap.add_argument('--raw-results-dir', type=Path, default=None, help='Where to store fast-agent --results JSON files')
args = ap.parse_args()
raw_results_dir = args.raw_results_dir or (args.out_dir / 'raw_results')
prompts = load_prompts(args.prompts)
expected = load_expected(args.expected)
subset = [(i, p) for i, p in enumerate(prompts, start=1) if args.start <= i <= args.end]
rows: list[dict[str, Any]] = []
for i, prompt in subset:
safe_model = args.model.replace('/', '_')
result_path = raw_results_dir / safe_model / f'case_{i:02d}.json'
obs = run_prompt(
prompt,
model=args.model,
agent=args.agent,
cards_dir=args.agent_cards,
timeout_sec=args.timeout,
result_path=result_path,
)
exp = expected.get(i, {'id': i, 'bucket': 'other'})
ev = evaluate_case(obs, exp)
row = {
'id': i,
'prompt': prompt,
'expected': exp,
'observed': {
'returncode': obs['returncode'],
'tool_calls': obs['tool_calls'],
'called_tools': obs['called_tools'],
'result_file': obs.get('result_file'),
},
'eval': ev,
'merged': obs['merged'],
}
rows.append(row)
print(f"[{i}] score={ev['score']['total']}/10 first={ev['first_tool']} primary={ev['primary_tool']} calls={ev['tool_calls_count']}")
summary = summarize(rows)
args.out_dir.mkdir(parents=True, exist_ok=True)
stem = f"tool_routing_{args.model.replace('/', '_')}"
json_path = args.out_dir / f"{stem}.json"
md_path = args.out_dir / f"{stem}.md"
json_path.write_text(json.dumps({'summary': summary, 'rows': rows}, indent=2), encoding='utf-8')
md_path.write_text(render_md(rows, summary, model=args.model, agent=args.agent), encoding='utf-8')
print(f"\nWrote:\n- {json_path}\n- {md_path}")
if __name__ == '__main__':
main()