Wildfire-FM / experiments /raw_reference /task_scripts /summarize_forced_meanstd_20260429.py
yx21e's picture
Initial FireWx-FM artifact release
80ef3b2 verified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import math
import re
import statistics
from pathlib import Path
from typing import Any
SLUG_LABELS = {
"reference": "Reference",
"prithvi_wxc": "Prithvi-WxC",
"stormcast": "StormCast",
"aurora": "Aurora",
"climax": "ClimaX",
"alphaearth": "AlphaEarth",
}
def load(path: Path) -> dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def stats(values: list[float]) -> dict[str, float | int]:
values = [float(v) for v in values if math.isfinite(float(v))]
if not values:
return {"n": 0, "mean": math.nan, "std": math.nan}
return {
"n": len(values),
"mean": float(statistics.fmean(values)),
"std": float(statistics.stdev(values)) if len(values) > 1 else 0.0,
}
def seed_from_path(path: Path) -> int | None:
match = re.search(r"_seed_(\d+)", str(path))
return int(match.group(1)) if match else None
def label_from_seed_dir(path: Path, prefix: str) -> str:
for part in path.parts:
if part.startswith(prefix) and "_seed_" in part:
slug = part[len(prefix) :].split("_seed_", 1)[0]
return SLUG_LABELS.get(slug, slug)
return "unknown"
def dedupe_rows(rows: list[dict[str, Any]], keys: tuple[str, ...]) -> list[dict[str, Any]]:
selected: dict[tuple[Any, ...], dict[str, Any]] = {}
for row in rows:
key = tuple(row.get(name) for name in keys)
old = selected.get(key)
if old is None:
selected[key] = row
continue
old_mtime = Path(str(old["path"])).stat().st_mtime
new_mtime = Path(str(row["path"])).stat().st_mtime
if new_mtime >= old_mtime:
selected[key] = row
return list(selected.values())
def best_val_threshold(data: dict[str, Any]) -> str:
entries = data["splits"]["val"]["threshold_metrics"]
return max(entries, key=lambda key: (float(entries[key]["f1"]), -float(entries[key]["threshold"])))
def collect_occupancy(run_root: Path) -> dict[str, Any]:
rows: list[dict[str, Any]] = []
for path in sorted(run_root.glob("table3_occupancy_*_seed_*/run_*/summary.json")):
data = load(path)
threshold_key = best_val_threshold(data)
test = data["splits"]["test"]
rows.append(
{
"label": data.get("fm_family") or label_from_seed_dir(path, "table3_occupancy_"),
"seed": seed_from_path(path),
"strict_f1": float(test["threshold_metrics"][threshold_key]["f1"]),
"tolerant_f1": float(test["tolerant_threshold_metrics"]["t0_s3"][threshold_key]["f1"]),
"union_f1": float(test["tolerant_threshold_metrics"]["t3_s3"][threshold_key]["f1"]),
"path": str(path),
}
)
return group(rows, ["strict_f1", "tolerant_f1", "union_f1"])
def collect_headcontrol(run_root: Path) -> dict[str, Any]:
rows: list[dict[str, Any]] = []
for path in sorted(run_root.glob("table2_prithvi_wxc_headcontrol_seed_*/run_*/summary.json")):
data = load(path)
seed = seed_from_path(path)
for row in data.get("selection_summary", {}).get("rows", []):
rows.append(
{
"label": "Prithvi-WxC",
"scope": row["scope"],
"seed": seed,
"ranking_selected_union_f1": float(row["ranking_selected_union_f1"]),
"decision_selected_union_f1": float(row["decision_selected_union_f1"]),
"decision_regret_union_f1": float(row["decision_regret_union_f1"]),
"selection_failure": bool(row.get("selection_failure", False)),
"path": str(path),
}
)
grouped: dict[str, Any] = {}
rows = dedupe_rows(rows, ("label", "scope", "seed"))
for scope in sorted({str(row["scope"]) for row in rows}):
selected = [row for row in rows if row["scope"] == scope]
grouped[scope] = {
"n": len(selected),
"failure_count": int(sum(1 for row in selected if row["selection_failure"])),
"ranking_selected_union_f1": stats([row["ranking_selected_union_f1"] for row in selected]),
"decision_selected_union_f1": stats([row["decision_selected_union_f1"] for row in selected]),
"decision_regret_union_f1": stats([row["decision_regret_union_f1"] for row in selected]),
}
return {"rows": rows, "summary": grouped}
def collect_spread(run_root: Path) -> dict[str, Any]:
rows: list[dict[str, Any]] = []
for pattern, prefix in [
("table3_spread_*_seed_*/run_*/summary.json", "table3_spread_"),
("table3_reference_spread_seed_*/run_*/summary.json", "table3_reference_spread_"),
]:
for path in sorted(run_root.glob(pattern)):
data = load(path)
headline = data["headline_metrics"]
label = data.get("fm_family") or ("Reference" if "reference_spread" in str(path) else label_from_seed_dir(path, prefix))
rows.append(
{
"label": label,
"seed": seed_from_path(path),
"strict_f1": float(headline["strict_f1"]),
"spatial_f1": float(headline["same_sample_spatial_tolerance_f1"]["s4"]),
"ap": float(headline["strict_AP"]),
"path": str(path),
}
)
return group(rows, ["strict_f1", "spatial_f1", "ap"])
def collect_task(run_root: Path, glob_pattern: str, prefix: str, metrics_path: list[str], metric_keys: list[str]) -> dict[str, Any]:
rows: list[dict[str, Any]] = []
for path in sorted(run_root.glob(glob_pattern)):
data = load(path)
label = data.get("fm_family") or label_from_seed_dir(path, prefix)
node: Any = data
for key in metrics_path:
node = node[key]
row = {"label": label, "seed": seed_from_path(path), "path": str(path)}
for key in metric_keys:
row[key] = float(node[key])
rows.append(row)
return group(rows, metric_keys)
def group(rows: list[dict[str, Any]], metric_keys: list[str]) -> dict[str, Any]:
if rows and "seed" in rows[0]:
rows = dedupe_rows(rows, ("label", "seed"))
summary: dict[str, Any] = {}
for label in sorted({str(row["label"]) for row in rows}):
selected = [row for row in rows if row["label"] == label]
summary[label] = {"n": len(selected)}
for key in metric_keys:
summary[label][key] = stats([row[key] for row in selected])
return {"rows": rows, "summary": summary}
def fmt(value: dict[str, Any], scale: float = 1.0, digits: int = 2) -> str:
if int(value["n"]) == 0:
return "missing"
return f"{float(value['mean']) * scale:.{digits}f} +/- {float(value['std']) * scale:.{digits}f} (n={int(value['n'])})"
def write_markdown(out: Path, summary: dict[str, Any]) -> None:
lines = ["# Forced Mean/Std Gap-Fill Summary", ""]
for section in [
"table2_headcontrol",
"table3_occupancy",
"table3_spread",
"table4_final_area",
"table4_analog",
"table4_smoke",
"table4_heat",
]:
lines += [f"## {section}", ""]
sec = summary.get(section, {}).get("summary", {})
if section == "table2_headcontrol":
for scope, row in sec.items():
lines.append(
f"- {scope}: regret {fmt(row['decision_regret_union_f1'], 100.0)}; "
f"ranking union {fmt(row['ranking_selected_union_f1'], 100.0)}; "
f"decision union {fmt(row['decision_selected_union_f1'], 100.0)}; "
f"failures {row['failure_count']}/{row['n']}"
)
else:
for label, row in sec.items():
pieces = [f"{key} {fmt(val, 100.0 if key.endswith('_f1') or key == 'ap' else 1.0)}" for key, val in row.items() if isinstance(val, dict)]
lines.append(f"- {label}: " + "; ".join(pieces))
lines.append("")
out.write_text("\n".join(lines), encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--run-root", type=Path, default=Path("${RUN_ROOT}"))
parser.add_argument("--out-json", type=Path, default=Path("${OUT_JSON}"))
parser.add_argument("--out-md", type=Path, default=Path("${OUT_MD}"))
args = parser.parse_args()
summary = {
"run_root": str(args.run_root),
"table2_headcontrol": collect_headcontrol(args.run_root),
"table3_occupancy": collect_occupancy(args.run_root),
"table3_spread": collect_spread(args.run_root),
"table4_final_area": collect_task(args.run_root, "table4_final_area_*_seed_*/run_*/summary.json", "table4_final_area_", ["headline_metrics"], ["log_rmse", "log_mae", "log_spearman"]),
"table4_analog": collect_task(args.run_root, "table4_analog_*_seed_*/run_*/summary.json", "table4_analog_", ["test_metrics"], ["ndcg_at_10", "log_rmse", "log_mae"]),
"table4_smoke": collect_task(args.run_root, "table4_smoke_*_seed_*/run_*/summary.json", "table4_smoke_", ["test_metrics"], ["rmse", "mae", "pearson_r"]),
"table4_heat": collect_task(args.run_root, "table4_heat_*_seed_*/run_*/summary.json", "table4_heat_", ["test_metrics"], ["rmse_c", "mae_c", "pearson_r"]),
}
args.out_json.parent.mkdir(parents=True, exist_ok=True)
args.out_json.write_text(json.dumps(summary, indent=2), encoding="utf-8")
write_markdown(args.out_md, summary)
print(f"wrote={args.out_json}")
print(f"wrote={args.out_md}")
if __name__ == "__main__":
main()