climate-risk-engine / scripts /calibrate_graphcast_mos.py
jtlevine's picture
Replace neural actuarial model with GraphCast triggers + empirical burn analysis
277cbcf
#!/usr/bin/env python3
"""Fit quantile-mapping MOS for GraphCast at Dar es Salaam.
Reads cached GraphCast forecasts from the lastmile-bench adapter cache
and pairs them with ERA5 actuals to build a quantile mapping for
temperature and humidity. Saves the mapping to models/graphcast_mos_dar.json.
Usage:
python3 scripts/calibrate_graphcast_mos.py
python3 scripts/calibrate_graphcast_mos.py --cv # leave-one-season-out CV
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import defaultdict
from datetime import date, timedelta
from pathlib import Path
import numpy as np
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
from src.calibration.graphcast_mos import (
apply_mapping,
correct_forecast_wbgt,
fit_quantile_mapping,
save_mapping,
)
from src.indexing.heat_index import calculate_wbgt
from src.prediction.forecast_trigger import forecast_trigger_decision
# Paths
BENCH_ROOT = Path.home() / "lastmile-bench"
CACHE_DIR = BENCH_ROOT / "adapters" / "climate_risk_engine" / ".cache" / "graphcast_insurance"
ERA5_PATH = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json"
OUTPUT_PATH = PROJECT_ROOT / "models" / "graphcast_mos_dar.json"
# Benchmark thresholds (from insurance_benchmark.py)
WINDOW_THRESHOLD_C = 35.1
PAYOUT_SEVERITY_C = 30.7
ALERT_DURATION_DAYS = 2
PAYOUT_DURATION_DAYS = 5
def load_paired_observations() -> list[dict]:
"""Build (GraphCast, ERA5) paired observations from cached forecasts."""
era5 = json.loads(ERA5_PATH.read_text())
era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]}
pairs = []
for f in sorted(CACHE_DIR.glob("gc_*.json")):
date_str = f.stem.split("_")[-1]
forecast = json.loads(f.read_text())
ws = date.fromisoformat(date_str)
# Determine season (Dec-Mar → season year = Dec's year)
season = f"{ws.year}-{ws.year+1:02d}" if ws.month == 12 else f"{ws.year-1}-{ws.year:02d}"
for day_idx in range(min(5, len(forecast) // 4)):
lead = day_idx + 1
steps = forecast[day_idx * 4 : (day_idx + 1) * 4]
gc_temp = max(s.get("temperature", 0) for s in steps)
gc_hum = sum(s.get("humidity", 0) for s in steps) / len(steps)
forecast_date = (ws + timedelta(days=lead)).isoformat()
era5_rec = era5_by_date.get(forecast_date)
if era5_rec is None:
continue
pairs.append({
"window_start": date_str,
"forecast_date": forecast_date,
"season": season,
"lead": lead,
"month": (ws + timedelta(days=lead)).month,
"gc_temp": gc_temp,
"gc_hum": gc_hum,
"era5_temp": era5_rec["temp_max_c"],
"era5_hum": era5_rec["humidity_pct"],
})
return pairs
def evaluate_trigger(
pairs_by_window: dict[str, list[dict]],
t_mapping: dict,
h_mapping: dict,
era5_by_date: dict[str, dict],
) -> dict:
"""Evaluate trigger accuracy with corrected forecasts."""
correct = 0
total = 0
from collections import Counter
pred_dist = Counter()
opt_dist = Counter()
for window_start, day_pairs in pairs_by_window.items():
if len(day_pairs) < 5:
continue
day_pairs.sort(key=lambda p: p["lead"])
# Corrected WBGT
gc_temps = [p["gc_temp"] for p in day_pairs[:5]]
gc_hums = [p["gc_hum"] for p in day_pairs[:5]]
corrected_wbgt = correct_forecast_wbgt(
gc_temps, gc_hums, t_mapping, h_mapping, calculate_wbgt,
)
# Predicted action
pred = forecast_trigger_decision(
corrected_wbgt,
alert_duration_days=ALERT_DURATION_DAYS,
payout_duration_days=PAYOUT_DURATION_DAYS,
window_threshold_c=WINDOW_THRESHOLD_C,
payout_severity_c=PAYOUT_SEVERITY_C,
)
# Optimal action (from ERA5 actuals)
era5_wbgt = []
ws = date.fromisoformat(window_start)
for lead in range(1, 6):
fd = (ws + timedelta(days=lead)).isoformat()
rec = era5_by_date.get(fd)
if rec:
era5_wbgt.append(calculate_wbgt(rec["temp_max_c"], rec["humidity_pct"]))
if len(era5_wbgt) < 5:
continue
opt = forecast_trigger_decision(
era5_wbgt,
alert_duration_days=ALERT_DURATION_DAYS,
payout_duration_days=PAYOUT_DURATION_DAYS,
window_threshold_c=WINDOW_THRESHOLD_C,
payout_severity_c=PAYOUT_SEVERITY_C,
)
pred_dist[pred] += 1
opt_dist[opt] += 1
if pred == opt:
correct += 1
total += 1
return {
"hit_rate": correct / max(total, 1),
"n": total,
"pred_dist": dict(pred_dist),
"opt_dist": dict(opt_dist),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cv", action="store_true", help="Run leave-one-season-out CV")
args = parser.parse_args()
print("Loading paired observations...")
pairs = load_paired_observations()
print(f" {len(pairs)} paired observations from {len(set(p['window_start'] for p in pairs))} forecasts")
# Group by window_start for trigger evaluation
by_window = defaultdict(list)
for p in pairs:
by_window[p["window_start"]].append(p)
# Load ERA5 for evaluation
era5 = json.loads(ERA5_PATH.read_text())
era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]}
if args.cv:
# Leave-one-season-out cross-validation
seasons = sorted(set(p["season"] for p in pairs))
print(f"\n Leave-one-season-out CV across {len(seasons)} seasons...")
all_correct = 0
all_total = 0
for held_out in seasons:
train = [p for p in pairs if p["season"] != held_out]
test_windows = {
p["window_start"] for p in pairs if p["season"] == held_out
}
test_by_window = {ws: by_window[ws] for ws in test_windows}
t_map = fit_quantile_mapping(
[p["gc_temp"] for p in train],
[p["era5_temp"] for p in train],
)
h_map = fit_quantile_mapping(
[p["gc_hum"] for p in train],
[p["era5_hum"] for p in train],
)
result = evaluate_trigger(test_by_window, t_map, h_map, era5_by_date)
all_correct += int(result["hit_rate"] * result["n"])
all_total += result["n"]
if result["n"] > 0:
print(f" {held_out}: {result['hit_rate']:.1%} ({result['n']} windows)")
cv_hit = all_correct / max(all_total, 1)
print(f"\n CV hit rate: {cv_hit:.1%} ({all_correct}/{all_total})")
# Fit final mapping on all data
print("\nFitting final quantile mapping on all data...")
t_mapping = fit_quantile_mapping(
[p["gc_temp"] for p in pairs],
[p["era5_temp"] for p in pairs],
)
h_mapping = fit_quantile_mapping(
[p["gc_hum"] for p in pairs],
[p["era5_hum"] for p in pairs],
)
print(f" Temperature: bias {t_mapping['mean_bias']:+.2f}°C "
f"(GC mean={t_mapping['gc_mean']:.1f}, ERA5 mean={t_mapping['era5_mean']:.1f})")
print(f" Humidity: bias {h_mapping['mean_bias']:+.2f}% "
f"(GC mean={h_mapping['gc_mean']:.1f}, ERA5 mean={h_mapping['era5_mean']:.1f})")
# Evaluate on all data (in-sample — use CV result for honest estimate)
result = evaluate_trigger(by_window, t_mapping, h_mapping, era5_by_date)
print(f"\n In-sample hit rate: {result['hit_rate']:.1%} ({result['n']} windows)")
print(f" Predicted: {result['pred_dist']}")
print(f" Optimal: {result['opt_dist']}")
# Compare to baselines
print(f"\n Comparison:")
print(f" always_no_trigger: 60.6%")
print(f" crude threshold (34.0): 67.4%")
print(f" quantile-mapped MOS: {result['hit_rate']:.1%}")
# Save
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
save_mapping(t_mapping, h_mapping, OUTPUT_PATH)
print(f"\n Saved to {OUTPUT_PATH}")
if __name__ == "__main__":
main()