Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """Step 5: Evaluate cascade predictions on 2025 test events. | |
| Two evaluation modes: | |
| - ``gold`` (default): predicted chain vs the test-side cascade chain | |
| (``data/processed/cascade_chains/{event_id}.json``) extracted by | |
| step 03. Produces precision / recall / F1 / domain-Jaccard / | |
| severity-match-rate. Per-event JSON cached under | |
| ``data/evaluation/gold/``. | |
| - ``judge``: the LLM-as-judge route, predicted chain graded against | |
| the news. Produces evidence_level / timing_alignment / missed | |
| cascades. Per-event JSON cached under ``data/evaluation/``. | |
| - ``both``: run gold first, then judge; both caches are populated. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| from collections import Counter | |
| from pathlib import Path | |
| from src.eval.evaluator import Evaluator | |
| from src.eval.gold_evaluator import GoldEvaluator | |
| from src.llm import create_llm_client | |
| from src.llm.client import load_config | |
| from src.models.schemas import FloodEvent | |
| from src.rag.embedder import Embedder | |
| from src.rag.predictor import CascadePredictor | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description=__doc__) | |
| p.add_argument( | |
| "--mode", choices=["gold", "judge", "both"], default="gold", | |
| help="Evaluation mode (default: gold).", | |
| ) | |
| p.add_argument( | |
| "--force", action="store_true", | |
| help="Ignore cache fingerprints and re-run for every event.", | |
| ) | |
| p.add_argument( | |
| "--event-id", dest="event_ids", action="append", default=None, | |
| help="Evaluate only this event id; repeat the flag for multiple ids.", | |
| ) | |
| p.add_argument( | |
| "--threshold", type=float, default=None, | |
| help=( | |
| "Override the cosine match threshold for gold mode. " | |
| "Falls back to config.evaluation.cosine_threshold, then " | |
| "src/eval/metrics.py::DEFAULT_COSINE_THRESHOLD. " | |
| "Note: a different threshold yields a different fingerprint, " | |
| "so the cache will be skipped automatically — but pass --force " | |
| "if you also want to bypass any cache that happens to share it." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--dump-match-debug", action="store_true", | |
| help=( | |
| "Write all same-domain (pred, gold) cosine pairs (including " | |
| "below-threshold near-misses) to data/evaluation/gold/" | |
| "{event_id}.diag.json. Used to support offline threshold sweeps; " | |
| "does not affect the main cache fingerprint." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--dump-bfs-full", action="store_true", | |
| help=( | |
| "Issue #12 diagnostic dump. When set, the predictor stores the " | |
| "full retrieved-edge payload + per-layer LLM call (template path, " | |
| "variables, raw response) in PredictionResult.trace, and the " | |
| "evaluator writes that trace to data/evaluation/diagnostics/" | |
| "{event_id}_bfs_full.json. Independent of the main cache " | |
| "fingerprint and of --dump-match-debug; combine freely. Only " | |
| "fires for events where a fresh predict() runs (cache hits " | |
| "produce no dump — pair with --force when you want to refresh)." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--seeds", | |
| type=lambda s: [int(x) for x in s.split(",")], | |
| default=None, | |
| help=( | |
| "Comma-separated list of integer seeds for n-seed averaging " | |
| "(local backend only). Defaults to config.evaluation.seed_list. " | |
| "When llm.temperature == 0 the local backend is greedy and seed " | |
| "is ignored, but the loop still runs n times to confirm " | |
| "determinism (per-event std should collapse to 0)." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--compare-to", | |
| dest="compare_to", | |
| default=None, | |
| help=( | |
| "Path to a previously-written aggregate_*.json file (see " | |
| "data/evaluation/aggregate_*.json). When set, the AGGREGATE " | |
| "block appends a CI OVERLAP/SEPARATED comparison against this " | |
| "baseline. Does not affect KEEP/ROLLBACK acceptance — " | |
| "telemetry only (spec §1.2 / §5.2)." | |
| ), | |
| ) | |
| return p.parse_args() | |
| # v0.7 issue A — config fingerprint + aggregate cache helpers. | |
| # Spec: docs/superpowers/specs/2026-06-07-v07-bootstrap-ci-design.md §5.3 / §5.2 | |
| _AGGREGATE_FORMAT_VERSION = 1 | |
| def _config_fingerprint(config: dict) -> str: | |
| """sha256 of the config fields that influence aggregate CI numbers. | |
| Includes the full rag.* + llm.* sections plus evaluation.cosine_threshold | |
| and evaluation.seed_list (the eval knobs that change retrieval / model / | |
| matching behaviour). Does NOT include paths or other I/O-only fields so | |
| that moving caches around does not invalidate fingerprints. | |
| """ | |
| import hashlib | |
| import json as _json | |
| payload = { | |
| "rag": config.get("rag", {}), | |
| "llm": config.get("llm", {}), | |
| "cosine_threshold": config.get("evaluation", {}).get("cosine_threshold"), | |
| "seed_list": config.get("evaluation", {}).get("seed_list"), | |
| } | |
| blob = _json.dumps(payload, sort_keys=True, default=str).encode("utf-8") | |
| return hashlib.sha256(blob).hexdigest() | |
| def _write_aggregate_artefact( | |
| *, | |
| config: dict, | |
| seeds: list, | |
| ci_per_metric: dict, | |
| per_event_seed_values_per_metric: dict, | |
| ci_metric_attrs: list[tuple[str, str]], | |
| ) -> Path: | |
| """Persist mean + CI + per-event detail to data/evaluation/aggregate_{fp}.json. | |
| `fp` is a 12-char prefix of the config fingerprint so the file is | |
| distinguishable in `ls` output but the full hash is still stored in | |
| the JSON body for `--compare-to` integrity. Format version 1 documented | |
| in spec §5.3. | |
| """ | |
| from datetime import datetime | |
| output_dir = Path(config["evaluation"]["output_dir"]) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| fp_full = _config_fingerprint(config) | |
| fp_short = fp_full[:12] | |
| path = output_dir / f"aggregate_{fp_short}.json" | |
| # Map internal driver keys (cat_r / cat_sev / f1 / cat_f1) → public | |
| # metric names that the spec / retro script use. Keep the public | |
| # names stable across format versions. | |
| public_name = { | |
| "cat_r": "category_recall", | |
| "cat_sev": "category_severity_match_rate", | |
| "f1": "cosine_f1", | |
| "cat_f1": "category_f1", | |
| } | |
| metrics_block: dict = {} | |
| for _, key in ci_metric_attrs: | |
| ci = ci_per_metric.get(key) | |
| if ci is None: | |
| continue | |
| seed_vals = per_event_seed_values_per_metric.get(key, {}) | |
| per_event_means = { | |
| eid: sum(vs) / len(vs) for eid, vs in seed_vals.items() | |
| } | |
| metrics_block[public_name[key]] = { | |
| **ci, | |
| "per_event_values": per_event_means, | |
| "per_event_seed_values": seed_vals, | |
| } | |
| payload = { | |
| "format_version": _AGGREGATE_FORMAT_VERSION, | |
| "config_fingerprint": fp_full, | |
| "evaluated_at": datetime.utcnow().isoformat() + "Z", | |
| "n_events": next(iter(ci_per_metric.values()))["n"], | |
| "n_seeds": len(seeds) if seeds else 0, | |
| "seed_list": [s for s in (seeds or []) if s is not None], | |
| "bootstrap": { | |
| "n_resamples": 1000, | |
| "seed": 42, | |
| "confidence": 0.95, | |
| }, | |
| "metrics": metrics_block, | |
| } | |
| path.write_text( | |
| json.dumps(payload, indent=2, ensure_ascii=False), | |
| encoding="utf-8", | |
| ) | |
| return path | |
| def _print_ci_comparison( | |
| *, | |
| current_ci: dict, | |
| ci_metric_attrs: list[tuple[str, str]], | |
| baseline_path: Path, | |
| ) -> None: | |
| """Print OVERLAP / SEPARATED comparison between current run and a baseline | |
| aggregate_*.json file. Telemetry only — does not change acceptance. | |
| """ | |
| if not baseline_path.is_file(): | |
| print(f"\n[--compare-to] baseline file not found: {baseline_path}") | |
| return | |
| try: | |
| baseline = json.loads(baseline_path.read_text(encoding="utf-8")) | |
| except (OSError, json.JSONDecodeError) as exc: | |
| print(f"\n[--compare-to] could not parse baseline: {exc}") | |
| return | |
| if baseline.get("format_version") != _AGGREGATE_FORMAT_VERSION: | |
| print( | |
| f"\n[--compare-to] baseline format_version=" | |
| f"{baseline.get('format_version')!r} != " | |
| f"{_AGGREGATE_FORMAT_VERSION}; comparison skipped" | |
| ) | |
| return | |
| public_name = { | |
| "cat_r": "category_recall", | |
| "cat_sev": "category_severity_match_rate", | |
| "f1": "cosine_f1", | |
| "cat_f1": "category_f1", | |
| } | |
| baseline_metrics = baseline.get("metrics", {}) | |
| print(f"\nCI OVERLAP vs baseline ({baseline_path.name}):") | |
| for label, key in ci_metric_attrs: | |
| cur = current_ci.get(key) | |
| base = baseline_metrics.get(public_name[key]) | |
| if cur is None or base is None: | |
| continue | |
| delta = cur["mean"] - base["mean"] | |
| overlap_lo = max(cur["ci_low"], base["ci_low"]) | |
| overlap_hi = min(cur["ci_high"], base["ci_high"]) | |
| verdict = "OVERLAP" if overlap_lo <= overlap_hi else "SEPARATED" | |
| print( | |
| f" {label:<18} " | |
| f"current {cur['mean']:.3f} " | |
| f"[{cur['ci_low']:.3f}, {cur['ci_high']:.3f}] " | |
| f"vs baseline {base['mean']:.3f} " | |
| f"[{base['ci_low']:.3f}, {base['ci_high']:.3f}] " | |
| f"→ {verdict} (Δ mean {delta:+.3f})" | |
| ) | |
| def _load_test_events(config: dict, wanted_ids: set[str] | None) -> list[FloodEvent]: | |
| test_events = [ | |
| FloodEvent(**e) | |
| for e in json.loads(Path(config["paths"]["test_events"]).read_text()) | |
| ] | |
| if wanted_ids: | |
| test_events = [e for e in test_events if e.event_id in wanted_ids] | |
| if not test_events: | |
| raise SystemExit( | |
| f"No matching events in test split for --event-id {sorted(wanted_ids)}" | |
| ) | |
| return test_events | |
| def run_gold( | |
| test_events: list[FloodEvent], | |
| config: dict, | |
| predictor: CascadePredictor, | |
| force: bool, | |
| threshold: float | None = None, | |
| dump_diagnostics: bool = False, | |
| dump_bfs_full: bool = False, | |
| seeds: list[int | None] | None = None, | |
| compare_to: str | None = None, | |
| ) -> None: | |
| """Per-event evaluation with optional multi-seed averaging (v0.5 issue A). | |
| seeds=[None] preserves legacy behaviour (one run, no _seed{seed} cache suffix). | |
| seeds=[42, 1337, 2026] runs predict 3× per event and reports mean ± std. | |
| """ | |
| import math | |
| if seeds is None or not seeds: | |
| seeds = [None] | |
| embedder = Embedder(config) | |
| eval_cfg = config.get("evaluation", {}) | |
| if threshold is None: | |
| threshold = eval_cfg.get("cosine_threshold") | |
| diagnostics_dir: Path | None = None | |
| if dump_bfs_full: | |
| diagnostics_dir = Path(eval_cfg["output_dir"]) / "diagnostics" | |
| diagnostics_dir.mkdir(parents=True, exist_ok=True) | |
| logger.info( | |
| "BFS full-trace dumps enabled — writing to %s/{event_id}_seed{seed}_bfs_full.json", | |
| diagnostics_dir, | |
| ) | |
| print("\n" + "=" * 96) | |
| print(f"GOLD-BASED EVALUATION — per event (seeds={seeds})") | |
| print("=" * 96) | |
| print(f"{'event_id':<18} {'pred':>5} {'gold':>5} {'matched':>8} " | |
| f"{'f1 (mean±std)':>16} {'cat_f1 (mean±std)':>18} {'sev_match':>10}") | |
| # Per-event accumulator: {event_id: {'f1':[...], 'cat_f1':[...], ...}} | |
| per_event: dict[str, dict[str, list]] = {} | |
| for seed in seeds: | |
| gold_kwargs: dict = { | |
| "predictor": predictor, | |
| "embedder": embedder, | |
| "config": config, | |
| "dump_diagnostics": dump_diagnostics, | |
| "seed": seed, | |
| } | |
| if threshold is not None: | |
| gold_kwargs["threshold"] = threshold | |
| gold = GoldEvaluator(**gold_kwargs) | |
| logger.info( | |
| "Seed %s: gold-eval fingerprint=%s threshold=%.4f", | |
| seed, gold.fingerprint, gold.threshold, | |
| ) | |
| for i, event in enumerate(test_events, start=1): | |
| logger.info("[seed=%s][%d/%d] %s (%s)", | |
| seed, i, len(test_events), event.event_id, event.country) | |
| result = gold.evaluate_event(event, write_cache=True, force_reeval=force) | |
| if diagnostics_dir is not None and gold.last_prediction is not None: | |
| stem = ( | |
| f"{event.event_id}_bfs_full.json" | |
| if seed is None | |
| else f"{event.event_id}_seed{seed}_bfs_full.json" | |
| ) | |
| (diagnostics_dir / stem).write_text( | |
| json.dumps( | |
| {"event_id": event.event_id, "seed": seed, | |
| "trace": gold.last_prediction.trace}, | |
| indent=2, ensure_ascii=False, default=str, | |
| ), | |
| encoding="utf-8", | |
| ) | |
| slot = per_event.setdefault(event.event_id, { | |
| "status": [], "pred_n": [], "gold_n": [], "matched_n": [], | |
| "f1": [], "cat_f1": [], "sev_match": [], "cat_sev": [], | |
| "p": [], "r": [], "cat_p": [], "cat_r": [], "dom_jac": [], | |
| }) | |
| slot["status"].append(gold.last_status) | |
| if result is None: | |
| continue | |
| slot["pred_n"].append(result.predicted_node_count) | |
| slot["gold_n"].append(result.gold_node_count) | |
| slot["matched_n"].append(len(result.matches)) | |
| slot["p"].append(result.precision) | |
| slot["r"].append(result.recall) | |
| slot["f1"].append(result.f1) | |
| slot["dom_jac"].append(result.domain_jaccard) | |
| if result.severity_match_rate is not None: | |
| slot["sev_match"].append(result.severity_match_rate) | |
| if result.category_f1 is not None: | |
| slot["cat_p"].append(result.category_precision or 0.0) | |
| slot["cat_r"].append(result.category_recall or 0.0) | |
| slot["cat_f1"].append(result.category_f1) | |
| if result.category_severity_match_rate is not None: | |
| slot["cat_sev"].append(result.category_severity_match_rate) | |
| # v0.5 issue A: count each event's status ONCE (last seed's status), not | |
| # n_seeds × n_events times. Counter is built after both loops finish so | |
| # the AGGREGATE summary lines up with the per-event row count. | |
| status_counts: Counter = Counter( | |
| slot["status"][-1] if slot["status"] else "unknown" | |
| for slot in per_event.values() | |
| ) | |
| def _mean(xs: list[float]) -> float | None: | |
| return sum(xs) / len(xs) if xs else None | |
| def _std(xs: list[float]) -> float | None: | |
| if len(xs) < 2: | |
| return 0.0 if xs else None | |
| m = _mean(xs) | |
| return math.sqrt(sum((x - m) ** 2 for x in xs) / len(xs)) | |
| def _fmt(m: float | None, s: float | None, w: int = 16) -> str: | |
| if m is None: | |
| return f"{'-':>{w}}" | |
| s_str = f"{s:.3f}" if s is not None else "-" | |
| return f"{m:.3f}±{s_str}".rjust(w) | |
| # --- Per-event rows --- | |
| for eid, slot in per_event.items(): | |
| if not slot["f1"]: # all seeds reported None (no_gold / outlier) | |
| label = "outlier" if "outlier" in slot["status"] else "no_gold" | |
| print(f"{eid:<18} {'-':>5} {'-':>5} {'-':>8} " | |
| f"{label:>16} {'-':>18} {'-':>10}") | |
| continue | |
| f1_m, f1_s = _mean(slot["f1"]), _std(slot["f1"]) | |
| cat_f1_m, cat_f1_s = _mean(slot["cat_f1"]), _std(slot["cat_f1"]) | |
| sev = _mean(slot["cat_sev"]) if slot["cat_sev"] else _mean(slot["sev_match"]) | |
| sev_str = f"{sev:.3f}" if sev is not None else "-" | |
| print( | |
| f"{eid:<18} " | |
| f"{int(_mean(slot['pred_n'])):>5} " | |
| f"{int(_mean(slot['gold_n'])):>5} " | |
| f"{int(_mean(slot['matched_n'])):>8} " | |
| f"{_fmt(f1_m, f1_s)} " | |
| f"{_fmt(cat_f1_m, cat_f1_s, w=18)} " | |
| f"{sev_str:>10}" | |
| ) | |
| # --- AGGREGATE block --- | |
| print("\n" + "=" * 96) | |
| print(f"AGGREGATE (n_seeds={len(seeds)})") | |
| print("=" * 96) | |
| evaluable = [s for s in per_event.values() if s["f1"]] | |
| n = len(evaluable) | |
| if not n: | |
| print("No events had a usable gold chain — nothing to aggregate.") | |
| print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/") | |
| return | |
| macro_f1_means = [_mean(s["f1"]) for s in evaluable] | |
| macro_cat_f1_means = [_mean(s["cat_f1"]) for s in evaluable if s["cat_f1"]] | |
| per_event_f1_stds = [_std(s["f1"]) for s in evaluable] | |
| print(f"Events evaluated: {n} " | |
| f"(no_gold/outlier excluded; status counts={dict(status_counts)})") | |
| print(f"Cosine Macro F1 {_mean(macro_f1_means):.3f} " | |
| f"(per-event std mean = {_mean(per_event_f1_stds):.3f}, " | |
| f"max = {max(per_event_f1_stds):.3f})") | |
| if macro_cat_f1_means: | |
| per_event_cat_f1_stds = [_std(s["cat_f1"]) for s in evaluable if s["cat_f1"]] | |
| print(f"Category Macro F1 {_mean(macro_cat_f1_means):.3f} " | |
| f"(per-event std mean = {_mean(per_event_cat_f1_stds):.3f}, " | |
| f"max = {max(per_event_cat_f1_stds):.3f})") | |
| # Severity (category-level if available) | |
| cat_sev_means = [_mean(s["cat_sev"]) for s in evaluable if s["cat_sev"]] | |
| if cat_sev_means: | |
| print(f"Category Severity {_mean(cat_sev_means):.3f}") | |
| else: | |
| sev_means = [_mean(s["sev_match"]) for s in evaluable if s["sev_match"]] | |
| if sev_means: | |
| print(f"Severity match {_mean(sev_means):.3f}") | |
| print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}/gold/") | |
| # --- v0.7 issue A: BOOTSTRAP 95% CI block (telemetry only) --- | |
| from src.eval.metrics import bootstrap_macro_ci | |
| # Per-event seed-means for each of the 4 CI metrics. Mirrors the | |
| # 4-tuple in aggregate_with_ci default but inlined here so the | |
| # driver can reuse `per_event` directly (built across seeds above). | |
| ci_metric_attrs = [ | |
| ("Cat Macro Recall", "cat_r"), | |
| ("Cat Macro Sev", "cat_sev"), | |
| ("Cos Macro F1", "f1"), | |
| ("Cat Macro F1", "cat_f1"), | |
| ] | |
| ci_per_metric: dict[str, dict] = {} | |
| per_event_seed_values_per_metric: dict[str, dict[str, list[float]]] = {} | |
| for _, key in ci_metric_attrs: | |
| # `per_event[eid][key]` is the list of seed-level values for this | |
| # event/metric. Empty for events with no usable gold (no_gold / | |
| # outlier) — those events contribute nothing. | |
| seed_values_per_event = { | |
| eid: slot[key] | |
| for eid, slot in per_event.items() | |
| if slot.get(key) | |
| } | |
| per_event_seed_values_per_metric[key] = seed_values_per_event | |
| means = [ | |
| sum(vs) / len(vs) for vs in seed_values_per_event.values() | |
| ] | |
| if not means: | |
| continue | |
| ci = bootstrap_macro_ci( | |
| means, | |
| n_resamples=1000, | |
| confidence=0.95, | |
| seed=42, | |
| ) | |
| ci_per_metric[key] = ci | |
| if ci_per_metric: | |
| print( | |
| "\nBOOTSTRAP 95% CI " | |
| "(B=1000, n={n} events, event-level resample, seed=42):".format( | |
| n=next(iter(ci_per_metric.values()))["n"] | |
| ) | |
| ) | |
| for label, key in ci_metric_attrs: | |
| ci = ci_per_metric.get(key) | |
| if ci is None: | |
| continue | |
| print( | |
| f" {label:<18} " | |
| f"{ci['mean']:.3f} [{ci['ci_low']:.3f}, {ci['ci_high']:.3f}]" | |
| ) | |
| # --- Write aggregate cache artefact --- | |
| aggregate_path = _write_aggregate_artefact( | |
| config=config, | |
| seeds=seeds, | |
| ci_per_metric=ci_per_metric, | |
| per_event_seed_values_per_metric=per_event_seed_values_per_metric, | |
| ci_metric_attrs=ci_metric_attrs, | |
| ) | |
| print(f"\nAggregate CI written to: {aggregate_path}") | |
| # --- Optional --compare-to OVERLAP/SEPARATED block --- | |
| if compare_to: | |
| _print_ci_comparison( | |
| current_ci=ci_per_metric, | |
| ci_metric_attrs=ci_metric_attrs, | |
| baseline_path=Path(compare_to), | |
| ) | |
| def run_judge( | |
| test_events: list[FloodEvent], | |
| config: dict, | |
| predictor: CascadePredictor, | |
| llm_client, | |
| force: bool, | |
| ) -> None: | |
| articles_dir = Path(config["paths"]["articles_dir"]) | |
| judge = Evaluator( | |
| llm_client=llm_client, | |
| predictor=predictor, | |
| articles_dir=articles_dir, | |
| config=config, | |
| ) | |
| logger.info("Judge fingerprint: %s", judge.judge_fingerprint_value) | |
| print("\n" + "=" * 72) | |
| print("JUDGE-BASED EVALUATION — per event") | |
| print("=" * 72) | |
| totals: Counter = Counter() | |
| missed_total = 0 | |
| status_counts: Counter = Counter() | |
| for i, event in enumerate(test_events, start=1): | |
| logger.info("[%d/%d] %s (%s)", i, len(test_events), | |
| event.event_id, event.country) | |
| result = judge.evaluate_event( | |
| event, write_cache=True, force_rejudge=force, | |
| ) | |
| status_counts[judge.last_cache_status or "unknown"] += 1 | |
| grades = Counter(n.evidence_level for n in result.node_evaluations) | |
| totals.update(grades) | |
| missed_total += len(result.missed_cascades) | |
| print( | |
| f"{event.event_id:<18} " | |
| f"verified:{grades.get('verified', 0):<2} " | |
| f"plausible:{grades.get('plausible', 0):<2} " | |
| f"unsupported:{grades.get('unsupported', 0):<2} " | |
| f"contradicted:{grades.get('contradicted', 0):<2} " | |
| f"missed:{len(result.missed_cascades)}" | |
| ) | |
| print("\n" + "=" * 72) | |
| print("TOTALS") | |
| print("=" * 72) | |
| total_nodes = sum(totals.values()) | |
| print(f"Nodes graded: {total_nodes}") | |
| for level in ("verified", "plausible", "unsupported", "contradicted"): | |
| print(f" {level:<13} {totals.get(level, 0)}") | |
| print(f"Missed cascades: {missed_total}") | |
| print( | |
| f"Events — rejudged: {status_counts.get('rejudge', 0)} / " | |
| f"cache-hit: {status_counts.get('hit', 0)} / " | |
| f"skipped-no-news: {status_counts.get('skip_no_news', 0)}" | |
| ) | |
| print(f"\nPer-event JSON cached to: {config['evaluation']['output_dir']}") | |
| def main() -> None: | |
| args = parse_args() | |
| config = load_config() | |
| seeds: list[int | None] = ( | |
| args.seeds | |
| if args.seeds is not None | |
| else config.get("evaluation", {}).get("seed_list", [None]) | |
| ) | |
| test_events = _load_test_events( | |
| config, set(args.event_ids) if args.event_ids else None | |
| ) | |
| logger.info("Loaded %d test events; mode=%s", len(test_events), args.mode) | |
| llm_client = create_llm_client(config) | |
| predictor = CascadePredictor( | |
| llm_client, config, dump_full_trace=args.dump_bfs_full, | |
| ) | |
| if args.mode in ("gold", "both"): | |
| run_gold( | |
| test_events, config, predictor, args.force, | |
| threshold=args.threshold, | |
| dump_diagnostics=args.dump_match_debug, | |
| dump_bfs_full=args.dump_bfs_full, | |
| seeds=seeds, | |
| compare_to=args.compare_to, | |
| ) | |
| if args.mode in ("judge", "both"): | |
| run_judge(test_events, config, predictor, llm_client, args.force) | |
| print("\nLaunch the Streamlit Evaluation tab to explore case studies.") | |
| if __name__ == "__main__": | |
| main() | |