"""Local held-out evaluation and artifact generation for Counsel-Env.""" import argparse import csv import json import random from pathlib import Path from statistics import mean from typing import Callable, Dict, Iterable, List from counsel_env.models import CounselAction from counsel_env.server.counsel_env_environment import CounselEnvironment AgentPolicy = Callable[[CounselEnvironment, random.Random], None] def random_agent(env: CounselEnvironment, rng: random.Random) -> None: """Low-information baseline.""" for _ in range(8): if env.done: break if rng.random() < 0.7: env.step(CounselAction(tool="ask_question", text=f"Can you explain detail {rng.randint(1, 999)}?")) else: env.step(CounselAction(tool="present_evidence", exhibit_id=rng.choice(list(env.case["evidence"].keys())))) if not env.done: env.step(CounselAction(tool="rest_case")) def keyword_spam_agent(env: CounselEnvironment, rng: random.Random) -> None: """Reward-hacking probe: asks trigger-like terms but does not reason about exhibits.""" questions = [ "Where were you?", "What was your motive?", "Did you know the victim?", "What happened and why?", "Were you at the location?", ] for question in questions: if env.done: break env.step(CounselAction(tool="ask_question", text=question)) if not env.done: env.step(CounselAction(tool="rest_case")) def present_all_agent(env: CounselEnvironment, rng: random.Random) -> None: """Reward-hacking probe: blindly presents every exhibit.""" for exhibit_id in list(env.case["evidence"].keys()): if env.done: break env.step(CounselAction(tool="present_evidence", exhibit_id=exhibit_id)) if not env.done: env.step(CounselAction(tool="rest_case")) def oracle_scripted_agent(env: CounselEnvironment, rng: random.Random) -> None: """Upper-bound scripted strategy using hidden contradiction metadata.""" for contradiction in env.witness.contradictions: if env.done: break env.step(CounselAction(tool="ask_question", text=f"{contradiction.trigger_keywords[0]}?")) if env.done: break env.step(CounselAction(tool="present_evidence", exhibit_id=contradiction.disprover_evidence_id)) if not env.done: env.step(CounselAction(tool="rest_case")) AGENTS: Dict[str, AgentPolicy] = { "random": random_agent, "keyword_spam": keyword_spam_agent, "present_all": present_all_agent, "scripted_oracle": oracle_scripted_agent, } def make_eval_seeds(count: int = 30, start: int = 20260425) -> List[int]: return list(range(start, start + count)) def evaluate_agent( name: str, policy: AgentPolicy, seeds: Iterable[int], curriculum_stage: str = "mixed", transcript_limit: int = 3, ) -> tuple[List[dict], List[str]]: rows: List[dict] = [] markdown_samples: List[str] = [] for index, seed in enumerate(seeds): agent_offset = sum(ord(ch) for ch in name) rng = random.Random(seed + agent_offset) env = CounselEnvironment() obs = env.reset(seed=seed, curriculum_stage=curriculum_stage, episode_id=f"{name}_{seed}") assert obs.case_id == env.case["case_id"] policy(env, rng) if not env.done: env.step(CounselAction(tool="rest_case")) components = env._calculate_reward_components() row = { "agent": name, "seed": seed, "case_id": env.case["case_id"], "difficulty": env.case["difficulty"], "reward": components["total_reward"], "primary_reward": components["primary_reward"], "auxiliary_reward": components["auxiliary_reward_raw"], "contradictions_total": int(components["contradictions_total"]), "contradictions_triggered": int(components["contradictions_triggered"]), "contradictions_surfaced": int(components["contradictions_surfaced"]), "questions_used": env.questions_used, "evidence_presented": env.evidence_presented_count, "evidence_timing_successes": int(components["evidence_timing_successes"]), "blind_evidence_count": int(components["blind_evidence_count"]), "useless_questions_ratio": components["useless_questions_ratio"], "avg_question_length": components["avg_question_length"], } rows.append(row) if index < transcript_limit: markdown_samples.append(f"# Agent: {name}\n\n" + env.export_transcript_markdown()) return rows, markdown_samples def summarize(rows: List[dict]) -> List[dict]: summaries: List[dict] = [] for agent in sorted({row["agent"] for row in rows}): agent_rows = [row for row in rows if row["agent"] == agent] summaries.append( { "agent": agent, "episodes": len(agent_rows), "avg_reward": mean(row["reward"] for row in agent_rows), "avg_primary_reward": mean(row["primary_reward"] for row in agent_rows), "avg_trigger_rate": mean( row["contradictions_triggered"] / max(1, row["contradictions_total"]) for row in agent_rows ), "avg_surface_rate": mean( row["contradictions_surfaced"] / max(1, row["contradictions_total"]) for row in agent_rows ), "avg_evidence_timing": mean(row["evidence_timing_successes"] for row in agent_rows), "avg_useless_ratio": mean(row["useless_questions_ratio"] for row in agent_rows), } ) return summaries def write_jsonl(path: Path, rows: Iterable[dict]) -> None: with path.open("w", encoding="utf-8") as handle: for row in rows: handle.write(json.dumps(row, sort_keys=True) + "\n") def write_csv(path: Path, rows: List[dict]) -> None: if not rows: return with path.open("w", newline="", encoding="utf-8") as handle: writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def write_plots(plot_dir: Path, summaries: List[dict]) -> None: try: import matplotlib.pyplot as plt except Exception: write_csv(plot_dir / "summary_for_plots.csv", summaries) write_svg_bar_chart( plot_dir / "baseline_vs_oracle.svg", summaries, metric="avg_reward", title="Held-out evaluation reward by baseline", ) write_svg_multi_metric(plot_dir / "rubric_breakdown.svg", summaries) return agents = [row["agent"] for row in summaries] rewards = [row["avg_reward"] for row in summaries] primary = [row["avg_primary_reward"] for row in summaries] surface = [row["avg_surface_rate"] for row in summaries] trigger = [row["avg_trigger_rate"] for row in summaries] useless = [row["avg_useless_ratio"] for row in summaries] plt.figure(figsize=(8, 4.5)) plt.bar(agents, rewards, color=["#777777", "#ba5a31", "#4c78a8", "#2f855a"][: len(agents)]) plt.ylabel("average reward") plt.xlabel("agent") plt.title("Held-out evaluation reward by baseline") plt.xticks(rotation=20, ha="right") plt.tight_layout() plt.savefig(plot_dir / "baseline_vs_oracle.png", dpi=180) plt.close() x = range(len(agents)) plt.figure(figsize=(8, 4.5)) plt.plot(x, primary, marker="o", label="primary reward") plt.plot(x, trigger, marker="o", label="trigger rate") plt.plot(x, surface, marker="o", label="surface rate") plt.plot(x, useless, marker="o", label="useless question ratio") plt.xticks(list(x), agents, rotation=20, ha="right") plt.ylabel("rate") plt.xlabel("agent") plt.title("Reward-hacking audit metrics") plt.legend() plt.tight_layout() plt.savefig(plot_dir / "rubric_breakdown.png", dpi=180) plt.close() def write_svg_bar_chart(path: Path, summaries: List[dict], metric: str, title: str) -> None: width, height = 840, 460 margin_left, margin_bottom, margin_top = 80, 90, 60 chart_w = width - margin_left - 40 chart_h = height - margin_top - margin_bottom max_value = max(1.0, max(row[metric] for row in summaries)) bar_gap = 24 bar_w = (chart_w - bar_gap * (len(summaries) - 1)) / max(1, len(summaries)) colors = ["#777777", "#ba5a31", "#4c78a8", "#2f855a"] parts = [ f'', '', f'{title}', f'', f'', ] for idx, row in enumerate(summaries): value = row[metric] x = margin_left + idx * (bar_w + bar_gap) bar_h = chart_h * value / max_value y = height - margin_bottom - bar_h parts.append(f'') parts.append(f'{value:.3f}') parts.append(f'{row["agent"]}') parts.append(f'average reward') parts.append("") path.write_text("\n".join(parts), encoding="utf-8") def write_svg_multi_metric(path: Path, summaries: List[dict]) -> None: width, height = 880, 500 margin_left, margin_bottom, margin_top = 80, 100, 60 chart_w = width - margin_left - 50 chart_h = height - margin_top - margin_bottom metrics = [ ("avg_primary_reward", "primary", "#2f855a"), ("avg_trigger_rate", "trigger", "#4c78a8"), ("avg_surface_rate", "surface", "#805ad5"), ("avg_useless_ratio", "useless", "#ba5a31"), ] agents = [row["agent"] for row in summaries] x_step = chart_w / max(1, len(agents) - 1) parts = [ f'', '', f'Reward-hacking audit metrics', f'', f'', ] for value in [0.0, 0.25, 0.5, 0.75, 1.0]: y = height - margin_bottom - chart_h * value parts.append(f'') parts.append(f'{value:.2f}') for metric, label, color in metrics: points = [] for idx, row in enumerate(summaries): x = margin_left + idx * x_step y = height - margin_bottom - chart_h * max(0.0, min(1.0, row[metric])) points.append((x, y)) path_data = " ".join(f"{x:.1f},{y:.1f}" for x, y in points) parts.append(f'') for x, y in points: parts.append(f'') lx = width - 180 ly = margin_top + 22 * metrics.index((metric, label, color)) parts.append(f'') parts.append(f'{label}') for idx, agent in enumerate(agents): x = margin_left + idx * x_step parts.append(f'{agent}') parts.append("") path.write_text("\n".join(parts), encoding="utf-8") def write_before_after_pairs(path: Path, transcript_by_agent: Dict[str, List[str]]) -> None: sections = ["# Before / After Transcript Samples", ""] for agent in ["random", "keyword_spam", "present_all", "scripted_oracle"]: samples = transcript_by_agent.get(agent, []) if samples: sections.append(samples[0]) sections.append("") path.write_text("\n".join(sections), encoding="utf-8") def run_evaluation(output_dir: str | Path = "assets", episodes: int = 30) -> dict: output = Path(output_dir) plot_dir = output / "plots" transcript_dir = output / "transcripts" output.mkdir(exist_ok=True) plot_dir.mkdir(parents=True, exist_ok=True) transcript_dir.mkdir(parents=True, exist_ok=True) seeds = make_eval_seeds(episodes) all_rows: List[dict] = [] transcript_by_agent: Dict[str, List[str]] = {} for agent, policy in AGENTS.items(): rows, markdown_samples = evaluate_agent(agent, policy, seeds) all_rows.extend(rows) transcript_by_agent[agent] = markdown_samples summaries = summarize(all_rows) write_jsonl(output / "heldout_eval.jsonl", all_rows) write_csv(output / "heldout_eval_summary.csv", summaries) (output / "heldout_eval_summary.json").write_text(json.dumps(summaries, indent=2), encoding="utf-8") write_plots(plot_dir, summaries) write_before_after_pairs(transcript_dir / "before_after_pairs.md", transcript_by_agent) return {"rows": all_rows, "summaries": summaries} def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--output-dir", default="assets") parser.add_argument("--episodes", type=int, default=30) args = parser.parse_args() result = run_evaluation(args.output_dir, args.episodes) print(json.dumps(result["summaries"], indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())