cascade_risk / scripts /05_evaluate.py
Lucasoppem's picture
Sync from GitHub main (part 2)
36f9d47 verified
Raw
History Blame Contribute Delete
23.7 kB
"""Step 5: Evaluate cascade predictions on 2025 test events.
Two evaluation modes:
- ``gold`` (default): predicted chain vs the test-side cascade chain
(``data/processed/cascade_chains/{event_id}.json``) extracted by
step 03. Produces precision / recall / F1 / domain-Jaccard /
severity-match-rate. Per-event JSON cached under
``data/evaluation/gold/``.
- ``judge``: the LLM-as-judge route, predicted chain graded against
the news. Produces evidence_level / timing_alignment / missed
cascades. Per-event JSON cached under ``data/evaluation/``.
- ``both``: run gold first, then judge; both caches are populated.
"""
from __future__ import annotations
import argparse
import json
import logging
from collections import Counter
from pathlib import Path
from src.eval.evaluator import Evaluator
from src.eval.gold_evaluator import GoldEvaluator
from src.llm import create_llm_client
from src.llm.client import load_config
from src.models.schemas import FloodEvent
from src.rag.embedder import Embedder
from src.rag.predictor import CascadePredictor
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument(
"--mode", choices=["gold", "judge", "both"], default="gold",
help="Evaluation mode (default: gold).",
)
p.add_argument(
"--force", action="store_true",
help="Ignore cache fingerprints and re-run for every event.",
)
p.add_argument(
"--event-id", dest="event_ids", action="append", default=None,
help="Evaluate only this event id; repeat the flag for multiple ids.",
)
p.add_argument(
"--threshold", type=float, default=None,
help=(
"Override the cosine match threshold for gold mode. "
"Falls back to config.evaluation.cosine_threshold, then "
"src/eval/metrics.py::DEFAULT_COSINE_THRESHOLD. "
"Note: a different threshold yields a different fingerprint, "
"so the cache will be skipped automatically — but pass --force "
"if you also want to bypass any cache that happens to share it."
),
)
p.add_argument(
"--dump-match-debug", action="store_true",
help=(
"Write all same-domain (pred, gold) cosine pairs (including "
"below-threshold near-misses) to data/evaluation/gold/"
"{event_id}.diag.json. Used to support offline threshold sweeps; "
"does not affect the main cache fingerprint."
),
)
p.add_argument(
"--dump-bfs-full", action="store_true",
help=(
"Issue #12 diagnostic dump. When set, the predictor stores the "
"full retrieved-edge payload + per-layer LLM call (template path, "
"variables, raw response) in PredictionResult.trace, and the "
"evaluator writes that trace to data/evaluation/diagnostics/"
"{event_id}_bfs_full.json. Independent of the main cache "
"fingerprint and of --dump-match-debug; combine freely. Only "
"fires for events where a fresh predict() runs (cache hits "
"produce no dump — pair with --force when you want to refresh)."
),
)
p.add_argument(
"--seeds",
type=lambda s: [int(x) for x in s.split(",")],
default=None,
help=(
"Comma-separated list of integer seeds for n-seed averaging "
"(local backend only). Defaults to config.evaluation.seed_list. "
"When llm.temperature == 0 the local backend is greedy and seed "
"is ignored, but the loop still runs n times to confirm "
"determinism (per-event std should collapse to 0)."
),
)
p.add_argument(
"--compare-to",
dest="compare_to",
default=None,
help=(
"Path to a previously-written aggregate_*.json file (see "
"data/evaluation/aggregate_*.json). When set, the AGGREGATE "
"block appends a CI OVERLAP/SEPARATED comparison against this "
"baseline. Does not affect KEEP/ROLLBACK acceptance — "
"telemetry only (spec §1.2 / §5.2)."
),
)
return p.parse_args()
# v0.7 issue A — config fingerprint + aggregate cache helpers.
# Spec: docs/superpowers/specs/2026-06-07-v07-bootstrap-ci-design.md §5.3 / §5.2
_AGGREGATE_FORMAT_VERSION = 1
def _config_fingerprint(config: dict) -> str:
"""sha256 of the config fields that influence aggregate CI numbers.
Includes the full rag.* + llm.* sections plus evaluation.cosine_threshold
and evaluation.seed_list (the eval knobs that change retrieval / model /
matching behaviour). Does NOT include paths or other I/O-only fields so
that moving caches around does not invalidate fingerprints.
"""
import hashlib
import json as _json
payload = {
"rag": config.get("rag", {}),
"llm": config.get("llm", {}),
"cosine_threshold": config.get("evaluation", {}).get("cosine_threshold"),
"seed_list": config.get("evaluation", {}).get("seed_list"),
}
blob = _json.dumps(payload, sort_keys=True, default=str).encode("utf-8")
return hashlib.sha256(blob).hexdigest()
def _write_aggregate_artefact(
*,
config: dict,
seeds: list,
ci_per_metric: dict,
per_event_seed_values_per_metric: dict,
ci_metric_attrs: list[tuple[str, str]],
) -> Path:
"""Persist mean + CI + per-event detail to data/evaluation/aggregate_{fp}.json.
`fp` is a 12-char prefix of the config fingerprint so the file is
distinguishable in `ls` output but the full hash is still stored in
the JSON body for `--compare-to` integrity. Format version 1 documented
in spec §5.3.
"""
from datetime import datetime
output_dir = Path(config["evaluation"]["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
fp_full = _config_fingerprint(config)
fp_short = fp_full[:12]
path = output_dir / f"aggregate_{fp_short}.json"
# Map internal driver keys (cat_r / cat_sev / f1 / cat_f1) → public
# metric names that the spec / retro script use. Keep the public
# names stable across format versions.
public_name = {
"cat_r": "category_recall",
"cat_sev": "category_severity_match_rate",
"f1": "cosine_f1",
"cat_f1": "category_f1",
}
metrics_block: dict = {}
for _, key in ci_metric_attrs:
ci = ci_per_metric.get(key)
if ci is None:
continue
seed_vals = per_event_seed_values_per_metric.get(key, {})
per_event_means = {
eid: sum(vs) / len(vs) for eid, vs in seed_vals.items()
}
metrics_block[public_name[key]] = {
**ci,
"per_event_values": per_event_means,
"per_event_seed_values": seed_vals,
}
payload = {
"format_version": _AGGREGATE_FORMAT_VERSION,
"config_fingerprint": fp_full,
"evaluated_at": datetime.utcnow().isoformat() + "Z",
"n_events": next(iter(ci_per_metric.values()))["n"],
"n_seeds": len(seeds) if seeds else 0,
"seed_list": [s for s in (seeds or []) if s is not None],
"bootstrap": {
"n_resamples": 1000,
"seed": 42,
"confidence": 0.95,
},
"metrics": metrics_block,
}
path.write_text(
json.dumps(payload, indent=2, ensure_ascii=False),
encoding="utf-8",
)
return path
def _print_ci_comparison(
*,
current_ci: dict,
ci_metric_attrs: list[tuple[str, str]],
baseline_path: Path,
) -> None:
"""Print OVERLAP / SEPARATED comparison between current run and a baseline
aggregate_*.json file. Telemetry only — does not change acceptance.
"""
if not baseline_path.is_file():
print(f"\n[--compare-to] baseline file not found: {baseline_path}")
return
try:
baseline = json.loads(baseline_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError) as exc:
print(f"\n[--compare-to] could not parse baseline: {exc}")
return
if baseline.get("format_version") != _AGGREGATE_FORMAT_VERSION:
print(
f"\n[--compare-to] baseline format_version="
f"{baseline.get('format_version')!r} != "
f"{_AGGREGATE_FORMAT_VERSION}; comparison skipped"
)
return
public_name = {
"cat_r": "category_recall",
"cat_sev": "category_severity_match_rate",
"f1": "cosine_f1",
"cat_f1": "category_f1",
}
baseline_metrics = baseline.get("metrics", {})
print(f"\nCI OVERLAP vs baseline ({baseline_path.name}):")
for label, key in ci_metric_attrs:
cur = current_ci.get(key)
base = baseline_metrics.get(public_name[key])
if cur is None or base is None:
continue
delta = cur["mean"] - base["mean"]
overlap_lo = max(cur["ci_low"], base["ci_low"])
overlap_hi = min(cur["ci_high"], base["ci_high"])
verdict = "OVERLAP" if overlap_lo <= overlap_hi else "SEPARATED"
print(
f" {label:<18} "
f"current {cur['mean']:.3f} "
f"[{cur['ci_low']:.3f}, {cur['ci_high']:.3f}] "
f"vs baseline {base['mean']:.3f} "
f"[{base['ci_low']:.3f}, {base['ci_high']:.3f}] "
f"→ {verdict} (Δ mean {delta:+.3f})"
)
def _load_test_events(config: dict, wanted_ids: set[str] | None) -> list[FloodEvent]:
test_events = [
FloodEvent(**e)
for e in json.loads(Path(config["paths"]["test_events"]).read_text())
]
if wanted_ids:
test_events = [e for e in test_events if e.event_id in wanted_ids]
if not test_events:
raise SystemExit(
f"No matching events in test split for --event-id {sorted(wanted_ids)}"
)
return test_events
def run_gold(
test_events: list[FloodEvent],
config: dict,
predictor: CascadePredictor,
force: bool,
threshold: float | None = None,
dump_diagnostics: bool = False,
dump_bfs_full: bool = False,
seeds: list[int | None] | None = None,
compare_to: str | None = None,
) -> None:
"""Per-event evaluation with optional multi-seed averaging (v0.5 issue A).
seeds=[None] preserves legacy behaviour (one run, no _seed{seed} cache suffix).
seeds=[42, 1337, 2026] runs predict 3× per event and reports mean ± std.
"""
import math
if seeds is None or not seeds:
seeds = [None]
embedder = Embedder(config)
eval_cfg = config.get("evaluation", {})
if threshold is None:
threshold = eval_cfg.get("cosine_threshold")
diagnostics_dir: Path | None = None
if dump_bfs_full:
diagnostics_dir = Path(eval_cfg["output_dir"]) / "diagnostics"
diagnostics_dir.mkdir(parents=True, exist_ok=True)
logger.info(
"BFS full-trace dumps enabled — writing to %s/{event_id}_seed{seed}_bfs_full.json",
diagnostics_dir,
)
print("\n" + "=" * 96)
print(f"GOLD-BASED EVALUATION — per event (seeds={seeds})")
print("=" * 96)
print(f"{'event_id':<18} {'pred':>5} {'gold':>5} {'matched':>8} "
f"{'f1 (mean±std)':>16} {'cat_f1 (mean±std)':>18} {'sev_match':>10}")
# Per-event accumulator: {event_id: {'f1':[...], 'cat_f1':[...], ...}}
per_event: dict[str, dict[str, list]] = {}
for seed in seeds:
gold_kwargs: dict = {
"predictor": predictor,
"embedder": embedder,
"config": config,
"dump_diagnostics": dump_diagnostics,
"seed": seed,
}
if threshold is not None:
gold_kwargs["threshold"] = threshold
gold = GoldEvaluator(**gold_kwargs)
logger.info(
"Seed %s: gold-eval fingerprint=%s threshold=%.4f",
seed, gold.fingerprint, gold.threshold,
)
for i, event in enumerate(test_events, start=1):
logger.info("[seed=%s][%d/%d] %s (%s)",
seed, i, len(test_events), event.event_id, event.country)
result = gold.evaluate_event(event, write_cache=True, force_reeval=force)
if diagnostics_dir is not None and gold.last_prediction is not None:
stem = (
f"{event.event_id}_bfs_full.json"
if seed is None
else f"{event.event_id}_seed{seed}_bfs_full.json"
)
(diagnostics_dir / stem).write_text(
json.dumps(
{"event_id": event.event_id, "seed": seed,
"trace": gold.last_prediction.trace},
indent=2, ensure_ascii=False, default=str,
),
encoding="utf-8",
)
slot = per_event.setdefault(event.event_id, {
"status": [], "pred_n": [], "gold_n": [], "matched_n": [],
"f1": [], "cat_f1": [], "sev_match": [], "cat_sev": [],
"p": [], "r": [], "cat_p": [], "cat_r": [], "dom_jac": [],
})
slot["status"].append(gold.last_status)
if result is None:
continue
slot["pred_n"].append(result.predicted_node_count)
slot["gold_n"].append(result.gold_node_count)
slot["matched_n"].append(len(result.matches))
slot["p"].append(result.precision)
slot["r"].append(result.recall)
slot["f1"].append(result.f1)
slot["dom_jac"].append(result.domain_jaccard)
if result.severity_match_rate is not None:
slot["sev_match"].append(result.severity_match_rate)
if result.category_f1 is not None:
slot["cat_p"].append(result.category_precision or 0.0)
slot["cat_r"].append(result.category_recall or 0.0)
slot["cat_f1"].append(result.category_f1)
if result.category_severity_match_rate is not None:
slot["cat_sev"].append(result.category_severity_match_rate)
# v0.5 issue A: count each event's status ONCE (last seed's status), not
# n_seeds × n_events times. Counter is built after both loops finish so
# the AGGREGATE summary lines up with the per-event row count.
status_counts: Counter = Counter(
slot["status"][-1] if slot["status"] else "unknown"
for slot in per_event.values()
)
def _mean(xs: list[float]) -> float | None:
return sum(xs) / len(xs) if xs else None
def _std(xs: list[float]) -> float | None:
if len(xs) < 2:
return 0.0 if xs else None
m = _mean(xs)
return math.sqrt(sum((x - m) ** 2 for x in xs) / len(xs))
def _fmt(m: float | None, s: float | None, w: int = 16) -> str:
if m is None:
return f"{'-':>{w}}"
s_str = f"{s:.3f}" if s is not None else "-"
return f"{m:.3f}±{s_str}".rjust(w)
# --- Per-event rows ---
for eid, slot in per_event.items():
if not slot["f1"]: # all seeds reported None (no_gold / outlier)
label = "outlier" if "outlier" in slot["status"] else "no_gold"
print(f"{eid:<18} {'-':>5} {'-':>5} {'-':>8} "
f"{label:>16} {'-':>18} {'-':>10}")
continue
f1_m, f1_s = _mean(slot["f1"]), _std(slot["f1"])
cat_f1_m, cat_f1_s = _mean(slot["cat_f1"]), _std(slot["cat_f1"])
sev = _mean(slot["cat_sev"]) if slot["cat_sev"] else _mean(slot["sev_match"])
sev_str = f"{sev:.3f}" if sev is not None else "-"
print(
f"{eid:<18} "
f"{int(_mean(slot['pred_n'])):>5} "
f"{int(_mean(slot['gold_n'])):>5} "
f"{int(_mean(slot['matched_n'])):>8} "
f"{_fmt(f1_m, f1_s)} "
f"{_fmt(cat_f1_m, cat_f1_s, w=18)} "
f"{sev_str:>10}"
)
# --- AGGREGATE block ---
print("\n" + "=" * 96)
print(f"AGGREGATE (n_seeds={len(seeds)})")
print("=" * 96)
evaluable = [s for s in per_event.values() if s["f1"]]
n = len(evaluable)
if not n:
print("No events had a usable gold chain — nothing to aggregate.")
print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/")
return
macro_f1_means = [_mean(s["f1"]) for s in evaluable]
macro_cat_f1_means = [_mean(s["cat_f1"]) for s in evaluable if s["cat_f1"]]
per_event_f1_stds = [_std(s["f1"]) for s in evaluable]
print(f"Events evaluated: {n} "
f"(no_gold/outlier excluded; status counts={dict(status_counts)})")
print(f"Cosine Macro F1 {_mean(macro_f1_means):.3f} "
f"(per-event std mean = {_mean(per_event_f1_stds):.3f}, "
f"max = {max(per_event_f1_stds):.3f})")
if macro_cat_f1_means:
per_event_cat_f1_stds = [_std(s["cat_f1"]) for s in evaluable if s["cat_f1"]]
print(f"Category Macro F1 {_mean(macro_cat_f1_means):.3f} "
f"(per-event std mean = {_mean(per_event_cat_f1_stds):.3f}, "
f"max = {max(per_event_cat_f1_stds):.3f})")
# Severity (category-level if available)
cat_sev_means = [_mean(s["cat_sev"]) for s in evaluable if s["cat_sev"]]
if cat_sev_means:
print(f"Category Severity {_mean(cat_sev_means):.3f}")
else:
sev_means = [_mean(s["sev_match"]) for s in evaluable if s["sev_match"]]
if sev_means:
print(f"Severity match {_mean(sev_means):.3f}")
print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/")
# --- v0.7 issue A: BOOTSTRAP 95% CI block (telemetry only) ---
from src.eval.metrics import bootstrap_macro_ci
# Per-event seed-means for each of the 4 CI metrics. Mirrors the
# 4-tuple in aggregate_with_ci default but inlined here so the
# driver can reuse `per_event` directly (built across seeds above).
ci_metric_attrs = [
("Cat Macro Recall", "cat_r"),
("Cat Macro Sev", "cat_sev"),
("Cos Macro F1", "f1"),
("Cat Macro F1", "cat_f1"),
]
ci_per_metric: dict[str, dict] = {}
per_event_seed_values_per_metric: dict[str, dict[str, list[float]]] = {}
for _, key in ci_metric_attrs:
# `per_event[eid][key]` is the list of seed-level values for this
# event/metric. Empty for events with no usable gold (no_gold /
# outlier) — those events contribute nothing.
seed_values_per_event = {
eid: slot[key]
for eid, slot in per_event.items()
if slot.get(key)
}
per_event_seed_values_per_metric[key] = seed_values_per_event
means = [
sum(vs) / len(vs) for vs in seed_values_per_event.values()
]
if not means:
continue
ci = bootstrap_macro_ci(
means,
n_resamples=1000,
confidence=0.95,
seed=42,
)
ci_per_metric[key] = ci
if ci_per_metric:
print(
"\nBOOTSTRAP 95% CI "
"(B=1000, n={n} events, event-level resample, seed=42):".format(
n=next(iter(ci_per_metric.values()))["n"]
)
)
for label, key in ci_metric_attrs:
ci = ci_per_metric.get(key)
if ci is None:
continue
print(
f" {label:<18} "
f"{ci['mean']:.3f} [{ci['ci_low']:.3f}, {ci['ci_high']:.3f}]"
)
# --- Write aggregate cache artefact ---
aggregate_path = _write_aggregate_artefact(
config=config,
seeds=seeds,
ci_per_metric=ci_per_metric,
per_event_seed_values_per_metric=per_event_seed_values_per_metric,
ci_metric_attrs=ci_metric_attrs,
)
print(f"\nAggregate CI written to: {aggregate_path}")
# --- Optional --compare-to OVERLAP/SEPARATED block ---
if compare_to:
_print_ci_comparison(
current_ci=ci_per_metric,
ci_metric_attrs=ci_metric_attrs,
baseline_path=Path(compare_to),
)
def run_judge(
test_events: list[FloodEvent],
config: dict,
predictor: CascadePredictor,
llm_client,
force: bool,
) -> None:
articles_dir = Path(config["paths"]["articles_dir"])
judge = Evaluator(
llm_client=llm_client,
predictor=predictor,
articles_dir=articles_dir,
config=config,
)
logger.info("Judge fingerprint: %s", judge.judge_fingerprint_value)
print("\n" + "=" * 72)
print("JUDGE-BASED EVALUATION — per event")
print("=" * 72)
totals: Counter = Counter()
missed_total = 0
status_counts: Counter = Counter()
for i, event in enumerate(test_events, start=1):
logger.info("[%d/%d] %s (%s)", i, len(test_events),
event.event_id, event.country)
result = judge.evaluate_event(
event, write_cache=True, force_rejudge=force,
)
status_counts[judge.last_cache_status or "unknown"] += 1
grades = Counter(n.evidence_level for n in result.node_evaluations)
totals.update(grades)
missed_total += len(result.missed_cascades)
print(
f"{event.event_id:<18} "
f"verified:{grades.get('verified', 0):<2} "
f"plausible:{grades.get('plausible', 0):<2} "
f"unsupported:{grades.get('unsupported', 0):<2} "
f"contradicted:{grades.get('contradicted', 0):<2} "
f"missed:{len(result.missed_cascades)}"
)
print("\n" + "=" * 72)
print("TOTALS")
print("=" * 72)
total_nodes = sum(totals.values())
print(f"Nodes graded: {total_nodes}")
for level in ("verified", "plausible", "unsupported", "contradicted"):
print(f" {level:<13} {totals.get(level, 0)}")
print(f"Missed cascades: {missed_total}")
print(
f"Events — rejudged: {status_counts.get('rejudge', 0)} / "
f"cache-hit: {status_counts.get('hit', 0)} / "
f"skipped-no-news: {status_counts.get('skip_no_news', 0)}"
)
print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}")
def main() -> None:
args = parse_args()
config = load_config()
seeds: list[int | None] = (
args.seeds
if args.seeds is not None
else config.get("evaluation", {}).get("seed_list", [None])
)
test_events = _load_test_events(
config, set(args.event_ids) if args.event_ids else None
)
logger.info("Loaded %d test events; mode=%s", len(test_events), args.mode)
llm_client = create_llm_client(config)
predictor = CascadePredictor(
llm_client, config, dump_full_trace=args.dump_bfs_full,
)
if args.mode in ("gold", "both"):
run_gold(
test_events, config, predictor, args.force,
threshold=args.threshold,
dump_diagnostics=args.dump_match_debug,
dump_bfs_full=args.dump_bfs_full,
seeds=seeds,
compare_to=args.compare_to,
)
if args.mode in ("judge", "both"):
run_judge(test_events, config, predictor, llm_client, args.force)
print("\nLaunch the Streamlit Evaluation tab to explore case studies.")
if __name__ == "__main__":
main()