Spaces:
Sleeping
Sleeping
| """ | |
| CounterFeint — Before/After Evaluation Lane. | |
| This module provides a reproducible, held-out eval harness that compares two | |
| Investigator policies ("before training" vs "after training") on a fixed set | |
| of (task_id, seed) tuples which are **distinct** from any seeds used during | |
| training. The outputs are the three artefacts the README's headline section | |
| needs: | |
| eval_results.json — full per-episode metrics for every (tag, task, seed) | |
| eval_summary.md — markdown delta table, human-readable | |
| eval_plot.png — bar-chart visualisation (matplotlib, dev-only) | |
| Per-episode metrics tracked: | |
| * ``grader_score`` — Investigator triage/calibration score from | |
| :func:`counterfeint.graders.grade_episode` | |
| * ``track_a_score`` — reasoning-audit score from the Auditor's | |
| submitted :class:`AuditReport` | |
| * ``track_b_score`` — Fraudster plausibility score (sanity signal: | |
| higher = realistic adversary) | |
| * ``n_fraud_leaks`` — false approvals on ground-truth fraud ads | |
| (the single most visible failure mode) | |
| * ``budget_used_pct`` — fraction of the Investigator's action budget | |
| actually consumed (efficiency proxy) | |
| * ``fallback_count`` — LLM → scripted fallback count for the | |
| Investigator in this episode (0 means the | |
| LLM answered every turn cleanly) | |
| * ``rewards_by_role`` — raw cumulative rewards per role | |
| Design notes | |
| ------------ | |
| * **Factories, not instances.** The caller provides zero-argument callables | |
| that construct a fresh Investigator (and optional Fraudster / Auditor). | |
| This guarantees clean per-episode state — especially important for the | |
| :class:`counterfeint.agents.base.LLMPolicyBase` ``fallback_count``. | |
| * **Live server required.** The eval sweep drives real episodes via | |
| :func:`counterfeint.inference.run_three_agent_episode`, so the CounterFeint | |
| FraudArena server must be running at ``COUNTERFEINT_ENV_URL``. | |
| * **Sweep modes.** The primary sweep uses :class:`ReactiveFraudster` (stable, | |
| headline numbers). The caller can swap in :class:`LLMFraudster` for a | |
| harder adversarial-curriculum eval, or set ``use_real_world_ads=True`` to | |
| route through the Meta-CIB-modeled holdout set (see | |
| :mod:`counterfeint.data.real_world_loader`) once that asset lands. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List, Optional | |
| try: | |
| from .data.real_world_loader import ( | |
| HoldoutAd, | |
| count_by_ring, | |
| list_case_studies, | |
| load_real_world_holdout, | |
| ) | |
| from .inference import ENV_URL, run_three_agent_episode | |
| from .scripted import HeuristicAuditor, ReactiveFraudster, ScriptedInvestigator | |
| except ImportError: # pragma: no cover - script execution fallback | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from counterfeint.data.real_world_loader import ( # type: ignore[no-redef] | |
| HoldoutAd, | |
| count_by_ring, | |
| list_case_studies, | |
| load_real_world_holdout, | |
| ) | |
| from counterfeint.inference import ENV_URL, run_three_agent_episode # type: ignore[no-redef] | |
| from counterfeint.scripted import ( # type: ignore[no-redef] | |
| HeuristicAuditor, | |
| ReactiveFraudster, | |
| ScriptedInvestigator, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Public constants | |
| # --------------------------------------------------------------------------- | |
| EVAL_SEEDS: Dict[str, List[int]] = { | |
| "task_1": list(range(1001, 1011)), # 1001..1010 | |
| "task_2": list(range(2001, 2011)), # 2001..2010 | |
| "task_3": list(range(3001, 3011)), # 3001..3010 | |
| "task_3_unseen": list(range(4001, 4006)), # 4001..4005 — held-out budget regime | |
| } | |
| """Held-out seeds — 10 per training-tier task, plus 5 generalisation seeds | |
| on ``task_3_unseen``. All seeds are deliberately disjoint from the training | |
| seed range (which uses seed=42 for the scripted baseline and small integers | |
| for self-play rollouts). ``task_3_unseen`` shares the template universe of | |
| ``task_3`` but ships a tighter action budget (~1.2/ad) and one extra ring, | |
| so a model that over-fit to ``task_3``'s budget distribution should regress | |
| visibly on these seeds. Five seeds (vs ten on the training-tier tasks) | |
| keeps the eval-loop wallclock from doubling for what is purely a | |
| generalisation probe. | |
| Judges can reproduce any eval number with | |
| ``python -m counterfeint.eval_suite --seed 1001 --task task_1``. | |
| """ | |
| PolicyFactory = Callable[[], Any] | |
| """Zero-arg callable returning a fresh Policy (one per episode).""" | |
| # --------------------------------------------------------------------------- | |
| # Metric dataclasses | |
| # --------------------------------------------------------------------------- | |
| class EpisodeMetrics: | |
| """Flat, JSON-serialisable record of a single eval episode.""" | |
| tag: str | |
| task_id: str | |
| seed: int | |
| grader_score: float | |
| track_a_score: float | |
| track_b_score: float | |
| n_fraud_leaks: int | |
| n_ground_truth_fraud: int | |
| budget_used_pct: float | |
| fallback_count: int | |
| steps: int | |
| end_reason: Optional[str] | |
| rewards_by_role: Dict[str, float] = field(default_factory=dict) | |
| error: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| return asdict(self) | |
| class AggregatedMetrics: | |
| """Averages over the 10 episodes of a single task (under one tag).""" | |
| tag: str | |
| task_id: str | |
| n_episodes: int | |
| grader_score_mean: float | |
| track_a_score_mean: float | |
| n_fraud_leaks_mean: float | |
| budget_used_pct_mean: float | |
| fallback_count_total: int | |
| errors: int | |
| def to_dict(self) -> Dict[str, Any]: | |
| return asdict(self) | |
| # --------------------------------------------------------------------------- | |
| # Metric extraction (pure, unit-testable) | |
| # --------------------------------------------------------------------------- | |
| def _parse_episode_metrics( | |
| tag: str, | |
| task_id: str, | |
| seed: int, | |
| episode_result: Dict[str, Any], | |
| ) -> EpisodeMetrics: | |
| """Distill a ``run_three_agent_episode`` return dict into flat metrics. | |
| Pure function — no I/O, no server calls. Kept separate from | |
| :func:`run_eval` so it can be exercised directly from unit tests | |
| without spinning up a live environment. | |
| """ | |
| final_state = episode_result.get("final_state") or {} | |
| audit_report = final_state.get("audit_report") or {} | |
| track_a_score = float(audit_report.get("investigator_audit_score", 1.0)) | |
| track_b_score = float(audit_report.get("fraudster_plausibility_score", 1.0)) | |
| investigator_state = final_state.get("investigator_state") or {} | |
| verdicts = investigator_state.get("verdicts") or {} | |
| n_fraud_leaks = 0 | |
| n_ground_truth_fraud = 0 | |
| for v in verdicts.values(): | |
| ground_truth = v.get("ground_truth") | |
| if ground_truth == "fraud": | |
| n_ground_truth_fraud += 1 | |
| if v.get("verdict") == "approve": | |
| n_fraud_leaks += 1 | |
| total_ads = int(investigator_state.get("total_ads") or 0) | |
| remaining_budget = int(investigator_state.get("remaining_budget") or 0) | |
| budget_used_pct = 0.0 | |
| if total_ads > 0 and remaining_budget >= 0: | |
| # action_budget per task is proportional to total_ads; we express | |
| # budget consumption as "1 - remaining/total" as a stable proxy. | |
| budget_used_pct = max(0.0, min(1.0, 1.0 - (remaining_budget / max(total_ads, 1)))) | |
| fallback_counts = episode_result.get("fallback_counts") or {} | |
| fallback_count = int(fallback_counts.get("investigator", 0)) | |
| return EpisodeMetrics( | |
| tag=tag, | |
| task_id=task_id, | |
| seed=seed, | |
| grader_score=float(episode_result.get("grader_score", 0.0)), | |
| track_a_score=track_a_score, | |
| track_b_score=track_b_score, | |
| n_fraud_leaks=n_fraud_leaks, | |
| n_ground_truth_fraud=n_ground_truth_fraud, | |
| budget_used_pct=budget_used_pct, | |
| fallback_count=fallback_count, | |
| steps=int(episode_result.get("steps", 0)), | |
| end_reason=episode_result.get("end_reason"), | |
| rewards_by_role=dict(episode_result.get("rewards_by_role") or {}), | |
| error=episode_result.get("error"), | |
| ) | |
| def _aggregate_per_task( | |
| tag: str, | |
| task_id: str, | |
| episodes: List[EpisodeMetrics], | |
| ) -> AggregatedMetrics: | |
| """Compute mean-over-episodes for the summary markdown + plot.""" | |
| valid = [m for m in episodes if m.error is None] | |
| n = len(valid) | |
| if n == 0: | |
| return AggregatedMetrics( | |
| tag=tag, | |
| task_id=task_id, | |
| n_episodes=0, | |
| grader_score_mean=0.0, | |
| track_a_score_mean=0.0, | |
| n_fraud_leaks_mean=0.0, | |
| budget_used_pct_mean=0.0, | |
| fallback_count_total=sum(m.fallback_count for m in episodes), | |
| errors=len(episodes) - n, | |
| ) | |
| return AggregatedMetrics( | |
| tag=tag, | |
| task_id=task_id, | |
| n_episodes=n, | |
| grader_score_mean=sum(m.grader_score for m in valid) / n, | |
| track_a_score_mean=sum(m.track_a_score for m in valid) / n, | |
| n_fraud_leaks_mean=sum(m.n_fraud_leaks for m in valid) / n, | |
| budget_used_pct_mean=sum(m.budget_used_pct for m in valid) / n, | |
| fallback_count_total=sum(m.fallback_count for m in episodes), | |
| errors=len(episodes) - n, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Core sweep | |
| # --------------------------------------------------------------------------- | |
| def run_eval( | |
| *, | |
| tag: str, | |
| investigator_factory: PolicyFactory, | |
| fraudster_factory: PolicyFactory = lambda: ReactiveFraudster(seed=42), | |
| auditor_factory: PolicyFactory = lambda: HeuristicAuditor(), | |
| seeds: Optional[Dict[str, List[int]]] = None, | |
| env_base_url: str = ENV_URL, | |
| max_steps: int = 200, | |
| log: bool = False, | |
| ) -> Dict[str, List[EpisodeMetrics]]: | |
| """Run the held-out sweep under a single named policy configuration. | |
| Parameters | |
| ---------- | |
| tag | |
| Short string tagging this run in outputs (e.g. ``"before"``, | |
| ``"after_grpo_v1"``). Shown in the summary table and plot legend. | |
| investigator_factory | |
| Zero-arg callable building the Investigator to evaluate. | |
| fraudster_factory, auditor_factory | |
| Zero-arg callables for the opponent and auditor. Default to the | |
| scripted ``ReactiveFraudster`` / ``HeuristicAuditor`` so eval | |
| numbers are stable across runs. | |
| seeds | |
| Override :data:`EVAL_SEEDS`. Useful for a smoke test with a | |
| single seed per task. | |
| env_base_url | |
| CounterFeint server URL. Default: ``COUNTERFEINT_ENV_URL`` env | |
| var or ``http://localhost:8000``. | |
| max_steps, log | |
| Forwarded to :func:`run_three_agent_episode`. | |
| Returns | |
| ------- | |
| dict | |
| ``{task_id: [EpisodeMetrics, ...]}`` with one entry per seed. | |
| """ | |
| seeds = seeds or EVAL_SEEDS | |
| results: Dict[str, List[EpisodeMetrics]] = {} | |
| for task_id, task_seeds in seeds.items(): | |
| task_metrics: List[EpisodeMetrics] = [] | |
| for seed in task_seeds: | |
| logger.info( | |
| "[%s] running %s seed=%d (fraudster=%s, investigator=%s)...", | |
| tag, | |
| task_id, | |
| seed, | |
| type(fraudster_factory()).__name__, | |
| type(investigator_factory()).__name__, | |
| ) | |
| try: | |
| episode_result = run_three_agent_episode( | |
| task_id, | |
| seed=seed, | |
| env_base_url=env_base_url, | |
| max_steps=max_steps, | |
| fraudster_policy=fraudster_factory(), | |
| investigator_policy=investigator_factory(), | |
| auditor_policy=auditor_factory(), | |
| log=log, | |
| ) | |
| except Exception as exc: # noqa: BLE001 - eval must never raise | |
| logger.exception( | |
| "[%s] %s seed=%d raised: %s", tag, task_id, seed, exc | |
| ) | |
| episode_result = { | |
| "task_id": task_id, | |
| "grader_score": 0.0, | |
| "steps": 0, | |
| "end_reason": "eval_exception", | |
| "rewards_by_role": {}, | |
| "fallback_counts": {}, | |
| "final_state": {}, | |
| "error": str(exc), | |
| } | |
| metrics = _parse_episode_metrics(tag, task_id, seed, episode_result) | |
| task_metrics.append(metrics) | |
| logger.info( | |
| "[%s] %s seed=%d score=%.3f leaks=%d fallbacks=%d", | |
| tag, | |
| task_id, | |
| seed, | |
| metrics.grader_score, | |
| metrics.n_fraud_leaks, | |
| metrics.fallback_count, | |
| ) | |
| results[task_id] = task_metrics | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Before/After artefact writers | |
| # --------------------------------------------------------------------------- | |
| def summarize_real_world_holdout() -> Dict[str, Any]: | |
| """Summary of the Meta-CIB-modeled holdout dataset (no opt-in needed). | |
| Pulls counts via :func:`counterfeint.data.real_world_loader.count_by_ring` | |
| so we can render "evaluated against N synthetic ads grounded in M | |
| Meta CIB case studies" without ever loading the ad text itself. | |
| Suitable for embedding in :func:`run_before_after` outputs and the | |
| README "Evaluated against Meta-CIB-modeled ads" subsection. | |
| """ | |
| counts = count_by_ring() | |
| case_studies = list_case_studies() | |
| return { | |
| "n_ads_total": sum(counts.values()), | |
| "n_case_studies": len(case_studies), | |
| "case_studies": case_studies, | |
| "ads_per_case_study": counts, | |
| } | |
| def _write_eval_json( | |
| before: Dict[str, List[EpisodeMetrics]], | |
| after: Dict[str, List[EpisodeMetrics]], | |
| before_tag: str, | |
| after_tag: str, | |
| path: Path, | |
| *, | |
| holdout_summary: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| payload = { | |
| "schema": "counterfeint.eval_suite.v1", | |
| "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "tags": {"before": before_tag, "after": after_tag}, | |
| "before": { | |
| task_id: [m.to_dict() for m in metrics] | |
| for task_id, metrics in before.items() | |
| }, | |
| "after": { | |
| task_id: [m.to_dict() for m in metrics] | |
| for task_id, metrics in after.items() | |
| }, | |
| } | |
| if holdout_summary is not None: | |
| payload["real_world_holdout"] = holdout_summary | |
| path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| def _format_delta(delta: float, *, higher_is_better: bool = True) -> str: | |
| sign = "+" if delta >= 0 else "" | |
| marker = "" | |
| if higher_is_better and delta > 0: | |
| marker = " ↑" | |
| elif higher_is_better and delta < 0: | |
| marker = " ↓" | |
| elif not higher_is_better and delta < 0: | |
| marker = " ↑" # lower is better, so a decrease is an improvement | |
| elif not higher_is_better and delta > 0: | |
| marker = " ↓" | |
| return f"{sign}{delta:.3f}{marker}" | |
| def _write_eval_summary_md( | |
| before_agg: Dict[str, AggregatedMetrics], | |
| after_agg: Dict[str, AggregatedMetrics], | |
| before_tag: str, | |
| after_tag: str, | |
| path: Path, | |
| ) -> None: | |
| """Render a human-readable delta table to ``eval_summary.md``.""" | |
| lines: List[str] = [] | |
| lines.append(f"# CounterFeint Eval: `{before_tag}` → `{after_tag}`") | |
| lines.append("") | |
| lines.append( | |
| f"Held-out sweep across {len(EVAL_SEEDS)} tasks × 10 seeds each, " | |
| f"evaluated against `ReactiveFraudster` (stable adversary)." | |
| ) | |
| lines.append("") | |
| lines.append( | |
| "| Task | Metric | " | |
| f"{before_tag} | {after_tag} | Delta |" | |
| ) | |
| lines.append("|------|--------|--------|-------|-------|") | |
| task_ids = list(before_agg.keys() or after_agg.keys()) | |
| for task_id in task_ids: | |
| b = before_agg.get(task_id) | |
| a = after_agg.get(task_id) | |
| if b is None or a is None: | |
| continue | |
| def _row(label: str, b_val: float, a_val: float, *, higher_is_better: bool = True) -> str: | |
| delta = a_val - b_val | |
| return ( | |
| f"| {task_id} | {label} | " | |
| f"{b_val:.3f} | {a_val:.3f} | " | |
| f"{_format_delta(delta, higher_is_better=higher_is_better)} |" | |
| ) | |
| lines.append(_row("grader_score (↑)", b.grader_score_mean, a.grader_score_mean)) | |
| lines.append(_row("track_a_score (↑)", b.track_a_score_mean, a.track_a_score_mean)) | |
| lines.append( | |
| _row( | |
| "n_fraud_leaks (↓)", | |
| b.n_fraud_leaks_mean, | |
| a.n_fraud_leaks_mean, | |
| higher_is_better=False, | |
| ) | |
| ) | |
| lines.append( | |
| _row( | |
| "budget_used_pct (↓)", | |
| b.budget_used_pct_mean, | |
| a.budget_used_pct_mean, | |
| higher_is_better=False, | |
| ) | |
| ) | |
| lines.append("") | |
| if any(a.fallback_count_total for a in after_agg.values()): | |
| lines.append("### Fallback activity (LLM → scripted, per tag/task)") | |
| lines.append("") | |
| lines.append(f"| Task | {before_tag} fallbacks | {after_tag} fallbacks |") | |
| lines.append("|------|----------------------|----------------------|") | |
| for task_id in task_ids: | |
| b = before_agg.get(task_id) | |
| a = after_agg.get(task_id) | |
| if b is None or a is None: | |
| continue | |
| lines.append( | |
| f"| {task_id} | {b.fallback_count_total} | {a.fallback_count_total} |" | |
| ) | |
| path.write_text("\n".join(lines) + "\n", encoding="utf-8") | |
| def _write_eval_plot( | |
| before_agg: Dict[str, AggregatedMetrics], | |
| after_agg: Dict[str, AggregatedMetrics], | |
| before_tag: str, | |
| after_tag: str, | |
| path: Path, | |
| ) -> None: | |
| """Render a 2x2 bar-chart PNG comparing the four headline metrics. | |
| matplotlib is a **dev-only** dependency (see ``requirements-dev.txt``) | |
| to keep the runtime Docker image slim. If matplotlib is not available, | |
| this writes a small text stub next to ``path`` instead of raising. | |
| """ | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: WPS433 - lazy import by design | |
| except ImportError: | |
| stub = path.with_suffix(".txt") | |
| stub.write_text( | |
| "matplotlib not installed — install requirements-dev.txt to generate the PNG.\n", | |
| encoding="utf-8", | |
| ) | |
| logger.warning("matplotlib unavailable; wrote stub to %s", stub) | |
| return | |
| task_ids = list(before_agg.keys() or after_agg.keys()) | |
| if not task_ids: | |
| logger.warning("No tasks to plot; skipping %s", path) | |
| return | |
| n_tasks = len(task_ids) | |
| fig, axes = plt.subplots(2, 2, figsize=(10, 7)) | |
| fig.suptitle( | |
| f"CounterFeint Eval: {before_tag} vs {after_tag}\n" | |
| f"{n_tasks} tasks × 10 held-out seeds" | |
| ) | |
| metric_specs = [ | |
| ("grader_score_mean", "Grader score (↑ better)", axes[0][0]), | |
| ("track_a_score_mean", "Track A reasoning (↑)", axes[0][1]), | |
| ("n_fraud_leaks_mean", "Mean fraud leaks (↓)", axes[1][0]), | |
| ("budget_used_pct_mean", "Budget used pct (↓)", axes[1][1]), | |
| ] | |
| for attr, title, ax in metric_specs: | |
| before_vals = [getattr(before_agg[t], attr) for t in task_ids] | |
| after_vals = [getattr(after_agg[t], attr) for t in task_ids] | |
| x = list(range(n_tasks)) | |
| width = 0.35 | |
| ax.bar([i - width / 2 for i in x], before_vals, width, label=before_tag) | |
| ax.bar([i + width / 2 for i in x], after_vals, width, label=after_tag) | |
| ax.set_title(title) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(task_ids) | |
| ax.legend(fontsize="x-small") | |
| ax.grid(True, axis="y", linestyle="--", alpha=0.4) | |
| fig.tight_layout(rect=(0, 0, 1, 0.94)) | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| logger.info("Wrote %s", path) | |
| def run_before_after( | |
| *, | |
| before_tag: str, | |
| after_tag: str, | |
| before_investigator_factory: PolicyFactory, | |
| after_investigator_factory: PolicyFactory, | |
| fraudster_factory: PolicyFactory = lambda: ReactiveFraudster(seed=42), | |
| auditor_factory: PolicyFactory = lambda: HeuristicAuditor(), | |
| out_dir: Path, | |
| seeds: Optional[Dict[str, List[int]]] = None, | |
| env_base_url: str = ENV_URL, | |
| include_real_world_summary: bool = True, | |
| ) -> Dict[str, Any]: | |
| """Run the full before/after comparison and write all three artefacts. | |
| Writes: | |
| * ``{out_dir}/eval_results.json`` | |
| * ``{out_dir}/eval_summary.md`` | |
| * ``{out_dir}/eval_plot.png`` (stub .txt if matplotlib is missing) | |
| Returns the aggregated metrics dict for programmatic use (e.g. the | |
| final cell of the Colab training notebook). | |
| """ | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info("=== Running BEFORE sweep: %s ===", before_tag) | |
| before = run_eval( | |
| tag=before_tag, | |
| investigator_factory=before_investigator_factory, | |
| fraudster_factory=fraudster_factory, | |
| auditor_factory=auditor_factory, | |
| seeds=seeds, | |
| env_base_url=env_base_url, | |
| ) | |
| logger.info("=== Running AFTER sweep: %s ===", after_tag) | |
| after = run_eval( | |
| tag=after_tag, | |
| investigator_factory=after_investigator_factory, | |
| fraudster_factory=fraudster_factory, | |
| auditor_factory=auditor_factory, | |
| seeds=seeds, | |
| env_base_url=env_base_url, | |
| ) | |
| before_agg = { | |
| task_id: _aggregate_per_task(before_tag, task_id, eps) | |
| for task_id, eps in before.items() | |
| } | |
| after_agg = { | |
| task_id: _aggregate_per_task(after_tag, task_id, eps) | |
| for task_id, eps in after.items() | |
| } | |
| holdout_summary: Optional[Dict[str, Any]] = None | |
| if include_real_world_summary: | |
| try: | |
| holdout_summary = summarize_real_world_holdout() | |
| except Exception as exc: # noqa: BLE001 | |
| logger.warning("Could not summarise holdout dataset: %s", exc) | |
| json_path = out_dir / "eval_results.json" | |
| md_path = out_dir / "eval_summary.md" | |
| png_path = out_dir / "eval_plot.png" | |
| _write_eval_json( | |
| before, | |
| after, | |
| before_tag, | |
| after_tag, | |
| json_path, | |
| holdout_summary=holdout_summary, | |
| ) | |
| _write_eval_summary_md(before_agg, after_agg, before_tag, after_tag, md_path) | |
| _write_eval_plot(before_agg, after_agg, before_tag, after_tag, png_path) | |
| logger.info("Wrote %s, %s, %s", json_path, md_path, png_path) | |
| return { | |
| "before": {t: m.to_dict() for t, m in before_agg.items()}, | |
| "after": {t: m.to_dict() for t, m in after_agg.items()}, | |
| "real_world_holdout": holdout_summary, | |
| "out_dir": str(out_dir), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _default_out_dir() -> Path: | |
| return Path(os.environ.get("COUNTERFEINT_EVAL_OUT_DIR", "eval_outputs")) | |
| def main() -> None: # pragma: no cover - thin CLI glue | |
| import argparse | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") | |
| parser = argparse.ArgumentParser( | |
| description="CounterFeint before/after eval sweep (scripted vs scripted by default)." | |
| ) | |
| parser.add_argument("--out-dir", type=Path, default=_default_out_dir()) | |
| parser.add_argument("--before-tag", default="scripted") | |
| parser.add_argument("--after-tag", default="scripted_rerun") | |
| parser.add_argument( | |
| "--smoke", | |
| action="store_true", | |
| help="Run one seed per task for a quick smoke test.", | |
| ) | |
| args = parser.parse_args() | |
| seeds: Optional[Dict[str, List[int]]] = None | |
| if args.smoke: | |
| seeds = {task: [seeds_list[0]] for task, seeds_list in EVAL_SEEDS.items()} | |
| run_before_after( | |
| before_tag=args.before_tag, | |
| after_tag=args.after_tag, | |
| before_investigator_factory=lambda: ScriptedInvestigator(), | |
| after_investigator_factory=lambda: ScriptedInvestigator(), | |
| out_dir=args.out_dir, | |
| seeds=seeds, | |
| ) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |