#!/usr/bin/env python3 from __future__ import annotations import argparse import json import math import re import statistics from pathlib import Path from typing import Any SLUG_LABELS = { "reference": "Reference", "prithvi_wxc": "Prithvi-WxC", "stormcast": "StormCast", "aurora": "Aurora", "climax": "ClimaX", "alphaearth": "AlphaEarth", } def load(path: Path) -> dict[str, Any]: return json.loads(path.read_text(encoding="utf-8")) def stats(values: list[float]) -> dict[str, float | int]: values = [float(v) for v in values if math.isfinite(float(v))] if not values: return {"n": 0, "mean": math.nan, "std": math.nan} return { "n": len(values), "mean": float(statistics.fmean(values)), "std": float(statistics.stdev(values)) if len(values) > 1 else 0.0, } def seed_from_path(path: Path) -> int | None: match = re.search(r"_seed_(\d+)", str(path)) return int(match.group(1)) if match else None def label_from_seed_dir(path: Path, prefix: str) -> str: for part in path.parts: if part.startswith(prefix) and "_seed_" in part: slug = part[len(prefix) :].split("_seed_", 1)[0] return SLUG_LABELS.get(slug, slug) return "unknown" def dedupe_rows(rows: list[dict[str, Any]], keys: tuple[str, ...]) -> list[dict[str, Any]]: selected: dict[tuple[Any, ...], dict[str, Any]] = {} for row in rows: key = tuple(row.get(name) for name in keys) old = selected.get(key) if old is None: selected[key] = row continue old_mtime = Path(str(old["path"])).stat().st_mtime new_mtime = Path(str(row["path"])).stat().st_mtime if new_mtime >= old_mtime: selected[key] = row return list(selected.values()) def best_val_threshold(data: dict[str, Any]) -> str: entries = data["splits"]["val"]["threshold_metrics"] return max(entries, key=lambda key: (float(entries[key]["f1"]), -float(entries[key]["threshold"]))) def collect_occupancy(run_root: Path) -> dict[str, Any]: rows: list[dict[str, Any]] = [] for path in sorted(run_root.glob("table3_occupancy_*_seed_*/run_*/summary.json")): data = load(path) threshold_key = best_val_threshold(data) test = data["splits"]["test"] rows.append( { "label": data.get("fm_family") or label_from_seed_dir(path, "table3_occupancy_"), "seed": seed_from_path(path), "strict_f1": float(test["threshold_metrics"][threshold_key]["f1"]), "tolerant_f1": float(test["tolerant_threshold_metrics"]["t0_s3"][threshold_key]["f1"]), "union_f1": float(test["tolerant_threshold_metrics"]["t3_s3"][threshold_key]["f1"]), "path": str(path), } ) return group(rows, ["strict_f1", "tolerant_f1", "union_f1"]) def collect_headcontrol(run_root: Path) -> dict[str, Any]: rows: list[dict[str, Any]] = [] for path in sorted(run_root.glob("table2_prithvi_wxc_headcontrol_seed_*/run_*/summary.json")): data = load(path) seed = seed_from_path(path) for row in data.get("selection_summary", {}).get("rows", []): rows.append( { "label": "Prithvi-WxC", "scope": row["scope"], "seed": seed, "ranking_selected_union_f1": float(row["ranking_selected_union_f1"]), "decision_selected_union_f1": float(row["decision_selected_union_f1"]), "decision_regret_union_f1": float(row["decision_regret_union_f1"]), "selection_failure": bool(row.get("selection_failure", False)), "path": str(path), } ) grouped: dict[str, Any] = {} rows = dedupe_rows(rows, ("label", "scope", "seed")) for scope in sorted({str(row["scope"]) for row in rows}): selected = [row for row in rows if row["scope"] == scope] grouped[scope] = { "n": len(selected), "failure_count": int(sum(1 for row in selected if row["selection_failure"])), "ranking_selected_union_f1": stats([row["ranking_selected_union_f1"] for row in selected]), "decision_selected_union_f1": stats([row["decision_selected_union_f1"] for row in selected]), "decision_regret_union_f1": stats([row["decision_regret_union_f1"] for row in selected]), } return {"rows": rows, "summary": grouped} def collect_spread(run_root: Path) -> dict[str, Any]: rows: list[dict[str, Any]] = [] for pattern, prefix in [ ("table3_spread_*_seed_*/run_*/summary.json", "table3_spread_"), ("table3_reference_spread_seed_*/run_*/summary.json", "table3_reference_spread_"), ]: for path in sorted(run_root.glob(pattern)): data = load(path) headline = data["headline_metrics"] label = data.get("fm_family") or ("Reference" if "reference_spread" in str(path) else label_from_seed_dir(path, prefix)) rows.append( { "label": label, "seed": seed_from_path(path), "strict_f1": float(headline["strict_f1"]), "spatial_f1": float(headline["same_sample_spatial_tolerance_f1"]["s4"]), "ap": float(headline["strict_AP"]), "path": str(path), } ) return group(rows, ["strict_f1", "spatial_f1", "ap"]) def collect_task(run_root: Path, glob_pattern: str, prefix: str, metrics_path: list[str], metric_keys: list[str]) -> dict[str, Any]: rows: list[dict[str, Any]] = [] for path in sorted(run_root.glob(glob_pattern)): data = load(path) label = data.get("fm_family") or label_from_seed_dir(path, prefix) node: Any = data for key in metrics_path: node = node[key] row = {"label": label, "seed": seed_from_path(path), "path": str(path)} for key in metric_keys: row[key] = float(node[key]) rows.append(row) return group(rows, metric_keys) def group(rows: list[dict[str, Any]], metric_keys: list[str]) -> dict[str, Any]: if rows and "seed" in rows[0]: rows = dedupe_rows(rows, ("label", "seed")) summary: dict[str, Any] = {} for label in sorted({str(row["label"]) for row in rows}): selected = [row for row in rows if row["label"] == label] summary[label] = {"n": len(selected)} for key in metric_keys: summary[label][key] = stats([row[key] for row in selected]) return {"rows": rows, "summary": summary} def fmt(value: dict[str, Any], scale: float = 1.0, digits: int = 2) -> str: if int(value["n"]) == 0: return "missing" return f"{float(value['mean']) * scale:.{digits}f} +/- {float(value['std']) * scale:.{digits}f} (n={int(value['n'])})" def write_markdown(out: Path, summary: dict[str, Any]) -> None: lines = ["# Forced Mean/Std Gap-Fill Summary", ""] for section in [ "table2_headcontrol", "table3_occupancy", "table3_spread", "table4_final_area", "table4_analog", "table4_smoke", "table4_heat", ]: lines += [f"## {section}", ""] sec = summary.get(section, {}).get("summary", {}) if section == "table2_headcontrol": for scope, row in sec.items(): lines.append( f"- {scope}: regret {fmt(row['decision_regret_union_f1'], 100.0)}; " f"ranking union {fmt(row['ranking_selected_union_f1'], 100.0)}; " f"decision union {fmt(row['decision_selected_union_f1'], 100.0)}; " f"failures {row['failure_count']}/{row['n']}" ) else: for label, row in sec.items(): pieces = [f"{key} {fmt(val, 100.0 if key.endswith('_f1') or key == 'ap' else 1.0)}" for key, val in row.items() if isinstance(val, dict)] lines.append(f"- {label}: " + "; ".join(pieces)) lines.append("") out.write_text("\n".join(lines), encoding="utf-8") def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--run-root", type=Path, default=Path("${RUN_ROOT}")) parser.add_argument("--out-json", type=Path, default=Path("${OUT_JSON}")) parser.add_argument("--out-md", type=Path, default=Path("${OUT_MD}")) args = parser.parse_args() summary = { "run_root": str(args.run_root), "table2_headcontrol": collect_headcontrol(args.run_root), "table3_occupancy": collect_occupancy(args.run_root), "table3_spread": collect_spread(args.run_root), "table4_final_area": collect_task(args.run_root, "table4_final_area_*_seed_*/run_*/summary.json", "table4_final_area_", ["headline_metrics"], ["log_rmse", "log_mae", "log_spearman"]), "table4_analog": collect_task(args.run_root, "table4_analog_*_seed_*/run_*/summary.json", "table4_analog_", ["test_metrics"], ["ndcg_at_10", "log_rmse", "log_mae"]), "table4_smoke": collect_task(args.run_root, "table4_smoke_*_seed_*/run_*/summary.json", "table4_smoke_", ["test_metrics"], ["rmse", "mae", "pearson_r"]), "table4_heat": collect_task(args.run_root, "table4_heat_*_seed_*/run_*/summary.json", "table4_heat_", ["test_metrics"], ["rmse_c", "mae_c", "pearson_r"]), } args.out_json.parent.mkdir(parents=True, exist_ok=True) args.out_json.write_text(json.dumps(summary, indent=2), encoding="utf-8") write_markdown(args.out_md, summary) print(f"wrote={args.out_json}") print(f"wrote={args.out_md}") if __name__ == "__main__": main()