"""Step 5: Evaluate cascade predictions on 2025 test events. Two evaluation modes: - ``gold`` (default): predicted chain vs the test-side cascade chain (``data/processed/cascade_chains/{event_id}.json``) extracted by step 03. Produces precision / recall / F1 / domain-Jaccard / severity-match-rate. Per-event JSON cached under ``data/evaluation/gold/``. - ``judge``: the LLM-as-judge route, predicted chain graded against the news. Produces evidence_level / timing_alignment / missed cascades. Per-event JSON cached under ``data/evaluation/``. - ``both``: run gold first, then judge; both caches are populated. """ from __future__ import annotations import argparse import json import logging from collections import Counter from pathlib import Path from src.eval.evaluator import Evaluator from src.eval.gold_evaluator import GoldEvaluator from src.llm import create_llm_client from src.llm.client import load_config from src.models.schemas import FloodEvent from src.rag.embedder import Embedder from src.rag.predictor import CascadePredictor logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__) p.add_argument( "--mode", choices=["gold", "judge", "both"], default="gold", help="Evaluation mode (default: gold).", ) p.add_argument( "--force", action="store_true", help="Ignore cache fingerprints and re-run for every event.", ) p.add_argument( "--event-id", dest="event_ids", action="append", default=None, help="Evaluate only this event id; repeat the flag for multiple ids.", ) p.add_argument( "--threshold", type=float, default=None, help=( "Override the cosine match threshold for gold mode. " "Falls back to config.evaluation.cosine_threshold, then " "src/eval/metrics.py::DEFAULT_COSINE_THRESHOLD. " "Note: a different threshold yields a different fingerprint, " "so the cache will be skipped automatically — but pass --force " "if you also want to bypass any cache that happens to share it." ), ) p.add_argument( "--dump-match-debug", action="store_true", help=( "Write all same-domain (pred, gold) cosine pairs (including " "below-threshold near-misses) to data/evaluation/gold/" "{event_id}.diag.json. Used to support offline threshold sweeps; " "does not affect the main cache fingerprint." ), ) p.add_argument( "--dump-bfs-full", action="store_true", help=( "Issue #12 diagnostic dump. When set, the predictor stores the " "full retrieved-edge payload + per-layer LLM call (template path, " "variables, raw response) in PredictionResult.trace, and the " "evaluator writes that trace to data/evaluation/diagnostics/" "{event_id}_bfs_full.json. Independent of the main cache " "fingerprint and of --dump-match-debug; combine freely. Only " "fires for events where a fresh predict() runs (cache hits " "produce no dump — pair with --force when you want to refresh)." ), ) p.add_argument( "--seeds", type=lambda s: [int(x) for x in s.split(",")], default=None, help=( "Comma-separated list of integer seeds for n-seed averaging " "(local backend only). Defaults to config.evaluation.seed_list. " "When llm.temperature == 0 the local backend is greedy and seed " "is ignored, but the loop still runs n times to confirm " "determinism (per-event std should collapse to 0)." ), ) p.add_argument( "--compare-to", dest="compare_to", default=None, help=( "Path to a previously-written aggregate_*.json file (see " "data/evaluation/aggregate_*.json). When set, the AGGREGATE " "block appends a CI OVERLAP/SEPARATED comparison against this " "baseline. Does not affect KEEP/ROLLBACK acceptance — " "telemetry only (spec §1.2 / §5.2)." ), ) return p.parse_args() # v0.7 issue A — config fingerprint + aggregate cache helpers. # Spec: docs/superpowers/specs/2026-06-07-v07-bootstrap-ci-design.md §5.3 / §5.2 _AGGREGATE_FORMAT_VERSION = 1 def _config_fingerprint(config: dict) -> str: """sha256 of the config fields that influence aggregate CI numbers. Includes the full rag.* + llm.* sections plus evaluation.cosine_threshold and evaluation.seed_list (the eval knobs that change retrieval / model / matching behaviour). Does NOT include paths or other I/O-only fields so that moving caches around does not invalidate fingerprints. """ import hashlib import json as _json payload = { "rag": config.get("rag", {}), "llm": config.get("llm", {}), "cosine_threshold": config.get("evaluation", {}).get("cosine_threshold"), "seed_list": config.get("evaluation", {}).get("seed_list"), } blob = _json.dumps(payload, sort_keys=True, default=str).encode("utf-8") return hashlib.sha256(blob).hexdigest() def _write_aggregate_artefact( *, config: dict, seeds: list, ci_per_metric: dict, per_event_seed_values_per_metric: dict, ci_metric_attrs: list[tuple[str, str]], ) -> Path: """Persist mean + CI + per-event detail to data/evaluation/aggregate_{fp}.json. `fp` is a 12-char prefix of the config fingerprint so the file is distinguishable in `ls` output but the full hash is still stored in the JSON body for `--compare-to` integrity. Format version 1 documented in spec §5.3. """ from datetime import datetime output_dir = Path(config["evaluation"]["output_dir"]) output_dir.mkdir(parents=True, exist_ok=True) fp_full = _config_fingerprint(config) fp_short = fp_full[:12] path = output_dir / f"aggregate_{fp_short}.json" # Map internal driver keys (cat_r / cat_sev / f1 / cat_f1) → public # metric names that the spec / retro script use. Keep the public # names stable across format versions. public_name = { "cat_r": "category_recall", "cat_sev": "category_severity_match_rate", "f1": "cosine_f1", "cat_f1": "category_f1", } metrics_block: dict = {} for _, key in ci_metric_attrs: ci = ci_per_metric.get(key) if ci is None: continue seed_vals = per_event_seed_values_per_metric.get(key, {}) per_event_means = { eid: sum(vs) / len(vs) for eid, vs in seed_vals.items() } metrics_block[public_name[key]] = { **ci, "per_event_values": per_event_means, "per_event_seed_values": seed_vals, } payload = { "format_version": _AGGREGATE_FORMAT_VERSION, "config_fingerprint": fp_full, "evaluated_at": datetime.utcnow().isoformat() + "Z", "n_events": next(iter(ci_per_metric.values()))["n"], "n_seeds": len(seeds) if seeds else 0, "seed_list": [s for s in (seeds or []) if s is not None], "bootstrap": { "n_resamples": 1000, "seed": 42, "confidence": 0.95, }, "metrics": metrics_block, } path.write_text( json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8", ) return path def _print_ci_comparison( *, current_ci: dict, ci_metric_attrs: list[tuple[str, str]], baseline_path: Path, ) -> None: """Print OVERLAP / SEPARATED comparison between current run and a baseline aggregate_*.json file. Telemetry only — does not change acceptance. """ if not baseline_path.is_file(): print(f"\n[--compare-to] baseline file not found: {baseline_path}") return try: baseline = json.loads(baseline_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError) as exc: print(f"\n[--compare-to] could not parse baseline: {exc}") return if baseline.get("format_version") != _AGGREGATE_FORMAT_VERSION: print( f"\n[--compare-to] baseline format_version=" f"{baseline.get('format_version')!r} != " f"{_AGGREGATE_FORMAT_VERSION}; comparison skipped" ) return public_name = { "cat_r": "category_recall", "cat_sev": "category_severity_match_rate", "f1": "cosine_f1", "cat_f1": "category_f1", } baseline_metrics = baseline.get("metrics", {}) print(f"\nCI OVERLAP vs baseline ({baseline_path.name}):") for label, key in ci_metric_attrs: cur = current_ci.get(key) base = baseline_metrics.get(public_name[key]) if cur is None or base is None: continue delta = cur["mean"] - base["mean"] overlap_lo = max(cur["ci_low"], base["ci_low"]) overlap_hi = min(cur["ci_high"], base["ci_high"]) verdict = "OVERLAP" if overlap_lo <= overlap_hi else "SEPARATED" print( f" {label:<18} " f"current {cur['mean']:.3f} " f"[{cur['ci_low']:.3f}, {cur['ci_high']:.3f}] " f"vs baseline {base['mean']:.3f} " f"[{base['ci_low']:.3f}, {base['ci_high']:.3f}] " f"→ {verdict} (Δ mean {delta:+.3f})" ) def _load_test_events(config: dict, wanted_ids: set[str] | None) -> list[FloodEvent]: test_events = [ FloodEvent(**e) for e in json.loads(Path(config["paths"]["test_events"]).read_text()) ] if wanted_ids: test_events = [e for e in test_events if e.event_id in wanted_ids] if not test_events: raise SystemExit( f"No matching events in test split for --event-id {sorted(wanted_ids)}" ) return test_events def run_gold( test_events: list[FloodEvent], config: dict, predictor: CascadePredictor, force: bool, threshold: float | None = None, dump_diagnostics: bool = False, dump_bfs_full: bool = False, seeds: list[int | None] | None = None, compare_to: str | None = None, ) -> None: """Per-event evaluation with optional multi-seed averaging (v0.5 issue A). seeds=[None] preserves legacy behaviour (one run, no _seed{seed} cache suffix). seeds=[42, 1337, 2026] runs predict 3× per event and reports mean ± std. """ import math if seeds is None or not seeds: seeds = [None] embedder = Embedder(config) eval_cfg = config.get("evaluation", {}) if threshold is None: threshold = eval_cfg.get("cosine_threshold") diagnostics_dir: Path | None = None if dump_bfs_full: diagnostics_dir = Path(eval_cfg["output_dir"]) / "diagnostics" diagnostics_dir.mkdir(parents=True, exist_ok=True) logger.info( "BFS full-trace dumps enabled — writing to %s/{event_id}_seed{seed}_bfs_full.json", diagnostics_dir, ) print("\n" + "=" * 96) print(f"GOLD-BASED EVALUATION — per event (seeds={seeds})") print("=" * 96) print(f"{'event_id':<18} {'pred':>5} {'gold':>5} {'matched':>8} " f"{'f1 (mean±std)':>16} {'cat_f1 (mean±std)':>18} {'sev_match':>10}") # Per-event accumulator: {event_id: {'f1':[...], 'cat_f1':[...], ...}} per_event: dict[str, dict[str, list]] = {} for seed in seeds: gold_kwargs: dict = { "predictor": predictor, "embedder": embedder, "config": config, "dump_diagnostics": dump_diagnostics, "seed": seed, } if threshold is not None: gold_kwargs["threshold"] = threshold gold = GoldEvaluator(**gold_kwargs) logger.info( "Seed %s: gold-eval fingerprint=%s threshold=%.4f", seed, gold.fingerprint, gold.threshold, ) for i, event in enumerate(test_events, start=1): logger.info("[seed=%s][%d/%d] %s (%s)", seed, i, len(test_events), event.event_id, event.country) result = gold.evaluate_event(event, write_cache=True, force_reeval=force) if diagnostics_dir is not None and gold.last_prediction is not None: stem = ( f"{event.event_id}_bfs_full.json" if seed is None else f"{event.event_id}_seed{seed}_bfs_full.json" ) (diagnostics_dir / stem).write_text( json.dumps( {"event_id": event.event_id, "seed": seed, "trace": gold.last_prediction.trace}, indent=2, ensure_ascii=False, default=str, ), encoding="utf-8", ) slot = per_event.setdefault(event.event_id, { "status": [], "pred_n": [], "gold_n": [], "matched_n": [], "f1": [], "cat_f1": [], "sev_match": [], "cat_sev": [], "p": [], "r": [], "cat_p": [], "cat_r": [], "dom_jac": [], }) slot["status"].append(gold.last_status) if result is None: continue slot["pred_n"].append(result.predicted_node_count) slot["gold_n"].append(result.gold_node_count) slot["matched_n"].append(len(result.matches)) slot["p"].append(result.precision) slot["r"].append(result.recall) slot["f1"].append(result.f1) slot["dom_jac"].append(result.domain_jaccard) if result.severity_match_rate is not None: slot["sev_match"].append(result.severity_match_rate) if result.category_f1 is not None: slot["cat_p"].append(result.category_precision or 0.0) slot["cat_r"].append(result.category_recall or 0.0) slot["cat_f1"].append(result.category_f1) if result.category_severity_match_rate is not None: slot["cat_sev"].append(result.category_severity_match_rate) # v0.5 issue A: count each event's status ONCE (last seed's status), not # n_seeds × n_events times. Counter is built after both loops finish so # the AGGREGATE summary lines up with the per-event row count. status_counts: Counter = Counter( slot["status"][-1] if slot["status"] else "unknown" for slot in per_event.values() ) def _mean(xs: list[float]) -> float | None: return sum(xs) / len(xs) if xs else None def _std(xs: list[float]) -> float | None: if len(xs) < 2: return 0.0 if xs else None m = _mean(xs) return math.sqrt(sum((x - m) ** 2 for x in xs) / len(xs)) def _fmt(m: float | None, s: float | None, w: int = 16) -> str: if m is None: return f"{'-':>{w}}" s_str = f"{s:.3f}" if s is not None else "-" return f"{m:.3f}±{s_str}".rjust(w) # --- Per-event rows --- for eid, slot in per_event.items(): if not slot["f1"]: # all seeds reported None (no_gold / outlier) label = "outlier" if "outlier" in slot["status"] else "no_gold" print(f"{eid:<18} {'-':>5} {'-':>5} {'-':>8} " f"{label:>16} {'-':>18} {'-':>10}") continue f1_m, f1_s = _mean(slot["f1"]), _std(slot["f1"]) cat_f1_m, cat_f1_s = _mean(slot["cat_f1"]), _std(slot["cat_f1"]) sev = _mean(slot["cat_sev"]) if slot["cat_sev"] else _mean(slot["sev_match"]) sev_str = f"{sev:.3f}" if sev is not None else "-" print( f"{eid:<18} " f"{int(_mean(slot['pred_n'])):>5} " f"{int(_mean(slot['gold_n'])):>5} " f"{int(_mean(slot['matched_n'])):>8} " f"{_fmt(f1_m, f1_s)} " f"{_fmt(cat_f1_m, cat_f1_s, w=18)} " f"{sev_str:>10}" ) # --- AGGREGATE block --- print("\n" + "=" * 96) print(f"AGGREGATE (n_seeds={len(seeds)})") print("=" * 96) evaluable = [s for s in per_event.values() if s["f1"]] n = len(evaluable) if not n: print("No events had a usable gold chain — nothing to aggregate.") print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/") return macro_f1_means = [_mean(s["f1"]) for s in evaluable] macro_cat_f1_means = [_mean(s["cat_f1"]) for s in evaluable if s["cat_f1"]] per_event_f1_stds = [_std(s["f1"]) for s in evaluable] print(f"Events evaluated: {n} " f"(no_gold/outlier excluded; status counts={dict(status_counts)})") print(f"Cosine Macro F1 {_mean(macro_f1_means):.3f} " f"(per-event std mean = {_mean(per_event_f1_stds):.3f}, " f"max = {max(per_event_f1_stds):.3f})") if macro_cat_f1_means: per_event_cat_f1_stds = [_std(s["cat_f1"]) for s in evaluable if s["cat_f1"]] print(f"Category Macro F1 {_mean(macro_cat_f1_means):.3f} " f"(per-event std mean = {_mean(per_event_cat_f1_stds):.3f}, " f"max = {max(per_event_cat_f1_stds):.3f})") # Severity (category-level if available) cat_sev_means = [_mean(s["cat_sev"]) for s in evaluable if s["cat_sev"]] if cat_sev_means: print(f"Category Severity {_mean(cat_sev_means):.3f}") else: sev_means = [_mean(s["sev_match"]) for s in evaluable if s["sev_match"]] if sev_means: print(f"Severity match {_mean(sev_means):.3f}") print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/") # --- v0.7 issue A: BOOTSTRAP 95% CI block (telemetry only) --- from src.eval.metrics import bootstrap_macro_ci # Per-event seed-means for each of the 4 CI metrics. Mirrors the # 4-tuple in aggregate_with_ci default but inlined here so the # driver can reuse `per_event` directly (built across seeds above). ci_metric_attrs = [ ("Cat Macro Recall", "cat_r"), ("Cat Macro Sev", "cat_sev"), ("Cos Macro F1", "f1"), ("Cat Macro F1", "cat_f1"), ] ci_per_metric: dict[str, dict] = {} per_event_seed_values_per_metric: dict[str, dict[str, list[float]]] = {} for _, key in ci_metric_attrs: # `per_event[eid][key]` is the list of seed-level values for this # event/metric. Empty for events with no usable gold (no_gold / # outlier) — those events contribute nothing. seed_values_per_event = { eid: slot[key] for eid, slot in per_event.items() if slot.get(key) } per_event_seed_values_per_metric[key] = seed_values_per_event means = [ sum(vs) / len(vs) for vs in seed_values_per_event.values() ] if not means: continue ci = bootstrap_macro_ci( means, n_resamples=1000, confidence=0.95, seed=42, ) ci_per_metric[key] = ci if ci_per_metric: print( "\nBOOTSTRAP 95% CI " "(B=1000, n={n} events, event-level resample, seed=42):".format( n=next(iter(ci_per_metric.values()))["n"] ) ) for label, key in ci_metric_attrs: ci = ci_per_metric.get(key) if ci is None: continue print( f" {label:<18} " f"{ci['mean']:.3f} [{ci['ci_low']:.3f}, {ci['ci_high']:.3f}]" ) # --- Write aggregate cache artefact --- aggregate_path = _write_aggregate_artefact( config=config, seeds=seeds, ci_per_metric=ci_per_metric, per_event_seed_values_per_metric=per_event_seed_values_per_metric, ci_metric_attrs=ci_metric_attrs, ) print(f"\nAggregate CI written to: {aggregate_path}") # --- Optional --compare-to OVERLAP/SEPARATED block --- if compare_to: _print_ci_comparison( current_ci=ci_per_metric, ci_metric_attrs=ci_metric_attrs, baseline_path=Path(compare_to), ) def run_judge( test_events: list[FloodEvent], config: dict, predictor: CascadePredictor, llm_client, force: bool, ) -> None: articles_dir = Path(config["paths"]["articles_dir"]) judge = Evaluator( llm_client=llm_client, predictor=predictor, articles_dir=articles_dir, config=config, ) logger.info("Judge fingerprint: %s", judge.judge_fingerprint_value) print("\n" + "=" * 72) print("JUDGE-BASED EVALUATION — per event") print("=" * 72) totals: Counter = Counter() missed_total = 0 status_counts: Counter = Counter() for i, event in enumerate(test_events, start=1): logger.info("[%d/%d] %s (%s)", i, len(test_events), event.event_id, event.country) result = judge.evaluate_event( event, write_cache=True, force_rejudge=force, ) status_counts[judge.last_cache_status or "unknown"] += 1 grades = Counter(n.evidence_level for n in result.node_evaluations) totals.update(grades) missed_total += len(result.missed_cascades) print( f"{event.event_id:<18} " f"verified:{grades.get('verified', 0):<2} " f"plausible:{grades.get('plausible', 0):<2} " f"unsupported:{grades.get('unsupported', 0):<2} " f"contradicted:{grades.get('contradicted', 0):<2} " f"missed:{len(result.missed_cascades)}" ) print("\n" + "=" * 72) print("TOTALS") print("=" * 72) total_nodes = sum(totals.values()) print(f"Nodes graded: {total_nodes}") for level in ("verified", "plausible", "unsupported", "contradicted"): print(f" {level:<13} {totals.get(level, 0)}") print(f"Missed cascades: {missed_total}") print( f"Events — rejudged: {status_counts.get('rejudge', 0)} / " f"cache-hit: {status_counts.get('hit', 0)} / " f"skipped-no-news: {status_counts.get('skip_no_news', 0)}" ) print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}") def main() -> None: args = parse_args() config = load_config() seeds: list[int | None] = ( args.seeds if args.seeds is not None else config.get("evaluation", {}).get("seed_list", [None]) ) test_events = _load_test_events( config, set(args.event_ids) if args.event_ids else None ) logger.info("Loaded %d test events; mode=%s", len(test_events), args.mode) llm_client = create_llm_client(config) predictor = CascadePredictor( llm_client, config, dump_full_trace=args.dump_bfs_full, ) if args.mode in ("gold", "both"): run_gold( test_events, config, predictor, args.force, threshold=args.threshold, dump_diagnostics=args.dump_match_debug, dump_bfs_full=args.dump_bfs_full, seeds=seeds, compare_to=args.compare_to, ) if args.mode in ("judge", "both"): run_judge(test_events, config, predictor, llm_client, args.force) print("\nLaunch the Streamlit Evaluation tab to explore case studies.") if __name__ == "__main__": main()