| |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import re |
| import statistics |
| from pathlib import Path |
| from typing import Any |
|
|
|
|
| SLUG_LABELS = { |
| "reference": "Reference", |
| "prithvi_wxc": "Prithvi-WxC", |
| "stormcast": "StormCast", |
| "aurora": "Aurora", |
| "climax": "ClimaX", |
| "alphaearth": "AlphaEarth", |
| } |
|
|
|
|
| def load(path: Path) -> dict[str, Any]: |
| return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
| def stats(values: list[float]) -> dict[str, float | int]: |
| values = [float(v) for v in values if math.isfinite(float(v))] |
| if not values: |
| return {"n": 0, "mean": math.nan, "std": math.nan} |
| return { |
| "n": len(values), |
| "mean": float(statistics.fmean(values)), |
| "std": float(statistics.stdev(values)) if len(values) > 1 else 0.0, |
| } |
|
|
|
|
| def seed_from_path(path: Path) -> int | None: |
| match = re.search(r"_seed_(\d+)", str(path)) |
| return int(match.group(1)) if match else None |
|
|
|
|
| def label_from_seed_dir(path: Path, prefix: str) -> str: |
| for part in path.parts: |
| if part.startswith(prefix) and "_seed_" in part: |
| slug = part[len(prefix) :].split("_seed_", 1)[0] |
| return SLUG_LABELS.get(slug, slug) |
| return "unknown" |
|
|
|
|
| def dedupe_rows(rows: list[dict[str, Any]], keys: tuple[str, ...]) -> list[dict[str, Any]]: |
| selected: dict[tuple[Any, ...], dict[str, Any]] = {} |
| for row in rows: |
| key = tuple(row.get(name) for name in keys) |
| old = selected.get(key) |
| if old is None: |
| selected[key] = row |
| continue |
| old_mtime = Path(str(old["path"])).stat().st_mtime |
| new_mtime = Path(str(row["path"])).stat().st_mtime |
| if new_mtime >= old_mtime: |
| selected[key] = row |
| return list(selected.values()) |
|
|
|
|
| def best_val_threshold(data: dict[str, Any]) -> str: |
| entries = data["splits"]["val"]["threshold_metrics"] |
| return max(entries, key=lambda key: (float(entries[key]["f1"]), -float(entries[key]["threshold"]))) |
|
|
|
|
| def collect_occupancy(run_root: Path) -> dict[str, Any]: |
| rows: list[dict[str, Any]] = [] |
| for path in sorted(run_root.glob("table3_occupancy_*_seed_*/run_*/summary.json")): |
| data = load(path) |
| threshold_key = best_val_threshold(data) |
| test = data["splits"]["test"] |
| rows.append( |
| { |
| "label": data.get("fm_family") or label_from_seed_dir(path, "table3_occupancy_"), |
| "seed": seed_from_path(path), |
| "strict_f1": float(test["threshold_metrics"][threshold_key]["f1"]), |
| "tolerant_f1": float(test["tolerant_threshold_metrics"]["t0_s3"][threshold_key]["f1"]), |
| "union_f1": float(test["tolerant_threshold_metrics"]["t3_s3"][threshold_key]["f1"]), |
| "path": str(path), |
| } |
| ) |
| return group(rows, ["strict_f1", "tolerant_f1", "union_f1"]) |
|
|
|
|
| def collect_headcontrol(run_root: Path) -> dict[str, Any]: |
| rows: list[dict[str, Any]] = [] |
| for path in sorted(run_root.glob("table2_prithvi_wxc_headcontrol_seed_*/run_*/summary.json")): |
| data = load(path) |
| seed = seed_from_path(path) |
| for row in data.get("selection_summary", {}).get("rows", []): |
| rows.append( |
| { |
| "label": "Prithvi-WxC", |
| "scope": row["scope"], |
| "seed": seed, |
| "ranking_selected_union_f1": float(row["ranking_selected_union_f1"]), |
| "decision_selected_union_f1": float(row["decision_selected_union_f1"]), |
| "decision_regret_union_f1": float(row["decision_regret_union_f1"]), |
| "selection_failure": bool(row.get("selection_failure", False)), |
| "path": str(path), |
| } |
| ) |
| grouped: dict[str, Any] = {} |
| rows = dedupe_rows(rows, ("label", "scope", "seed")) |
| for scope in sorted({str(row["scope"]) for row in rows}): |
| selected = [row for row in rows if row["scope"] == scope] |
| grouped[scope] = { |
| "n": len(selected), |
| "failure_count": int(sum(1 for row in selected if row["selection_failure"])), |
| "ranking_selected_union_f1": stats([row["ranking_selected_union_f1"] for row in selected]), |
| "decision_selected_union_f1": stats([row["decision_selected_union_f1"] for row in selected]), |
| "decision_regret_union_f1": stats([row["decision_regret_union_f1"] for row in selected]), |
| } |
| return {"rows": rows, "summary": grouped} |
|
|
|
|
| def collect_spread(run_root: Path) -> dict[str, Any]: |
| rows: list[dict[str, Any]] = [] |
| for pattern, prefix in [ |
| ("table3_spread_*_seed_*/run_*/summary.json", "table3_spread_"), |
| ("table3_reference_spread_seed_*/run_*/summary.json", "table3_reference_spread_"), |
| ]: |
| for path in sorted(run_root.glob(pattern)): |
| data = load(path) |
| headline = data["headline_metrics"] |
| label = data.get("fm_family") or ("Reference" if "reference_spread" in str(path) else label_from_seed_dir(path, prefix)) |
| rows.append( |
| { |
| "label": label, |
| "seed": seed_from_path(path), |
| "strict_f1": float(headline["strict_f1"]), |
| "spatial_f1": float(headline["same_sample_spatial_tolerance_f1"]["s4"]), |
| "ap": float(headline["strict_AP"]), |
| "path": str(path), |
| } |
| ) |
| return group(rows, ["strict_f1", "spatial_f1", "ap"]) |
|
|
|
|
| def collect_task(run_root: Path, glob_pattern: str, prefix: str, metrics_path: list[str], metric_keys: list[str]) -> dict[str, Any]: |
| rows: list[dict[str, Any]] = [] |
| for path in sorted(run_root.glob(glob_pattern)): |
| data = load(path) |
| label = data.get("fm_family") or label_from_seed_dir(path, prefix) |
| node: Any = data |
| for key in metrics_path: |
| node = node[key] |
| row = {"label": label, "seed": seed_from_path(path), "path": str(path)} |
| for key in metric_keys: |
| row[key] = float(node[key]) |
| rows.append(row) |
| return group(rows, metric_keys) |
|
|
|
|
| def group(rows: list[dict[str, Any]], metric_keys: list[str]) -> dict[str, Any]: |
| if rows and "seed" in rows[0]: |
| rows = dedupe_rows(rows, ("label", "seed")) |
| summary: dict[str, Any] = {} |
| for label in sorted({str(row["label"]) for row in rows}): |
| selected = [row for row in rows if row["label"] == label] |
| summary[label] = {"n": len(selected)} |
| for key in metric_keys: |
| summary[label][key] = stats([row[key] for row in selected]) |
| return {"rows": rows, "summary": summary} |
|
|
|
|
| def fmt(value: dict[str, Any], scale: float = 1.0, digits: int = 2) -> str: |
| if int(value["n"]) == 0: |
| return "missing" |
| return f"{float(value['mean']) * scale:.{digits}f} +/- {float(value['std']) * scale:.{digits}f} (n={int(value['n'])})" |
|
|
|
|
| def write_markdown(out: Path, summary: dict[str, Any]) -> None: |
| lines = ["# Forced Mean/Std Gap-Fill Summary", ""] |
| for section in [ |
| "table2_headcontrol", |
| "table3_occupancy", |
| "table3_spread", |
| "table4_final_area", |
| "table4_analog", |
| "table4_smoke", |
| "table4_heat", |
| ]: |
| lines += [f"## {section}", ""] |
| sec = summary.get(section, {}).get("summary", {}) |
| if section == "table2_headcontrol": |
| for scope, row in sec.items(): |
| lines.append( |
| f"- {scope}: regret {fmt(row['decision_regret_union_f1'], 100.0)}; " |
| f"ranking union {fmt(row['ranking_selected_union_f1'], 100.0)}; " |
| f"decision union {fmt(row['decision_selected_union_f1'], 100.0)}; " |
| f"failures {row['failure_count']}/{row['n']}" |
| ) |
| else: |
| for label, row in sec.items(): |
| pieces = [f"{key} {fmt(val, 100.0 if key.endswith('_f1') or key == 'ap' else 1.0)}" for key, val in row.items() if isinstance(val, dict)] |
| lines.append(f"- {label}: " + "; ".join(pieces)) |
| lines.append("") |
| out.write_text("\n".join(lines), encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--run-root", type=Path, default=Path("${RUN_ROOT}")) |
| parser.add_argument("--out-json", type=Path, default=Path("${OUT_JSON}")) |
| parser.add_argument("--out-md", type=Path, default=Path("${OUT_MD}")) |
| args = parser.parse_args() |
|
|
| summary = { |
| "run_root": str(args.run_root), |
| "table2_headcontrol": collect_headcontrol(args.run_root), |
| "table3_occupancy": collect_occupancy(args.run_root), |
| "table3_spread": collect_spread(args.run_root), |
| "table4_final_area": collect_task(args.run_root, "table4_final_area_*_seed_*/run_*/summary.json", "table4_final_area_", ["headline_metrics"], ["log_rmse", "log_mae", "log_spearman"]), |
| "table4_analog": collect_task(args.run_root, "table4_analog_*_seed_*/run_*/summary.json", "table4_analog_", ["test_metrics"], ["ndcg_at_10", "log_rmse", "log_mae"]), |
| "table4_smoke": collect_task(args.run_root, "table4_smoke_*_seed_*/run_*/summary.json", "table4_smoke_", ["test_metrics"], ["rmse", "mae", "pearson_r"]), |
| "table4_heat": collect_task(args.run_root, "table4_heat_*_seed_*/run_*/summary.json", "table4_heat_", ["test_metrics"], ["rmse_c", "mae_c", "pearson_r"]), |
| } |
| args.out_json.parent.mkdir(parents=True, exist_ok=True) |
| args.out_json.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
| write_markdown(args.out_md, summary) |
| print(f"wrote={args.out_json}") |
| print(f"wrote={args.out_md}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|