#!/usr/bin/env python3 """Fit quantile-mapping MOS for GraphCast at Dar es Salaam. Reads cached GraphCast forecasts from the lastmile-bench adapter cache and pairs them with ERA5 actuals to build a quantile mapping for temperature and humidity. Saves the mapping to models/graphcast_mos_dar.json. Usage: python3 scripts/calibrate_graphcast_mos.py python3 scripts/calibrate_graphcast_mos.py --cv # leave-one-season-out CV """ from __future__ import annotations import argparse import json import sys from collections import defaultdict from datetime import date, timedelta from pathlib import Path import numpy as np PROJECT_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(PROJECT_ROOT)) from src.calibration.graphcast_mos import ( apply_mapping, correct_forecast_wbgt, fit_quantile_mapping, save_mapping, ) from src.indexing.heat_index import calculate_wbgt from src.prediction.forecast_trigger import forecast_trigger_decision # Paths BENCH_ROOT = Path.home() / "lastmile-bench" CACHE_DIR = BENCH_ROOT / "adapters" / "climate_risk_engine" / ".cache" / "graphcast_insurance" ERA5_PATH = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json" OUTPUT_PATH = PROJECT_ROOT / "models" / "graphcast_mos_dar.json" # Benchmark thresholds (from insurance_benchmark.py) WINDOW_THRESHOLD_C = 35.1 PAYOUT_SEVERITY_C = 30.7 ALERT_DURATION_DAYS = 2 PAYOUT_DURATION_DAYS = 5 def load_paired_observations() -> list[dict]: """Build (GraphCast, ERA5) paired observations from cached forecasts.""" era5 = json.loads(ERA5_PATH.read_text()) era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]} pairs = [] for f in sorted(CACHE_DIR.glob("gc_*.json")): date_str = f.stem.split("_")[-1] forecast = json.loads(f.read_text()) ws = date.fromisoformat(date_str) # Determine season (Dec-Mar → season year = Dec's year) season = f"{ws.year}-{ws.year+1:02d}" if ws.month == 12 else f"{ws.year-1}-{ws.year:02d}" for day_idx in range(min(5, len(forecast) // 4)): lead = day_idx + 1 steps = forecast[day_idx * 4 : (day_idx + 1) * 4] gc_temp = max(s.get("temperature", 0) for s in steps) gc_hum = sum(s.get("humidity", 0) for s in steps) / len(steps) forecast_date = (ws + timedelta(days=lead)).isoformat() era5_rec = era5_by_date.get(forecast_date) if era5_rec is None: continue pairs.append({ "window_start": date_str, "forecast_date": forecast_date, "season": season, "lead": lead, "month": (ws + timedelta(days=lead)).month, "gc_temp": gc_temp, "gc_hum": gc_hum, "era5_temp": era5_rec["temp_max_c"], "era5_hum": era5_rec["humidity_pct"], }) return pairs def evaluate_trigger( pairs_by_window: dict[str, list[dict]], t_mapping: dict, h_mapping: dict, era5_by_date: dict[str, dict], ) -> dict: """Evaluate trigger accuracy with corrected forecasts.""" correct = 0 total = 0 from collections import Counter pred_dist = Counter() opt_dist = Counter() for window_start, day_pairs in pairs_by_window.items(): if len(day_pairs) < 5: continue day_pairs.sort(key=lambda p: p["lead"]) # Corrected WBGT gc_temps = [p["gc_temp"] for p in day_pairs[:5]] gc_hums = [p["gc_hum"] for p in day_pairs[:5]] corrected_wbgt = correct_forecast_wbgt( gc_temps, gc_hums, t_mapping, h_mapping, calculate_wbgt, ) # Predicted action pred = forecast_trigger_decision( corrected_wbgt, alert_duration_days=ALERT_DURATION_DAYS, payout_duration_days=PAYOUT_DURATION_DAYS, window_threshold_c=WINDOW_THRESHOLD_C, payout_severity_c=PAYOUT_SEVERITY_C, ) # Optimal action (from ERA5 actuals) era5_wbgt = [] ws = date.fromisoformat(window_start) for lead in range(1, 6): fd = (ws + timedelta(days=lead)).isoformat() rec = era5_by_date.get(fd) if rec: era5_wbgt.append(calculate_wbgt(rec["temp_max_c"], rec["humidity_pct"])) if len(era5_wbgt) < 5: continue opt = forecast_trigger_decision( era5_wbgt, alert_duration_days=ALERT_DURATION_DAYS, payout_duration_days=PAYOUT_DURATION_DAYS, window_threshold_c=WINDOW_THRESHOLD_C, payout_severity_c=PAYOUT_SEVERITY_C, ) pred_dist[pred] += 1 opt_dist[opt] += 1 if pred == opt: correct += 1 total += 1 return { "hit_rate": correct / max(total, 1), "n": total, "pred_dist": dict(pred_dist), "opt_dist": dict(opt_dist), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--cv", action="store_true", help="Run leave-one-season-out CV") args = parser.parse_args() print("Loading paired observations...") pairs = load_paired_observations() print(f" {len(pairs)} paired observations from {len(set(p['window_start'] for p in pairs))} forecasts") # Group by window_start for trigger evaluation by_window = defaultdict(list) for p in pairs: by_window[p["window_start"]].append(p) # Load ERA5 for evaluation era5 = json.loads(ERA5_PATH.read_text()) era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]} if args.cv: # Leave-one-season-out cross-validation seasons = sorted(set(p["season"] for p in pairs)) print(f"\n Leave-one-season-out CV across {len(seasons)} seasons...") all_correct = 0 all_total = 0 for held_out in seasons: train = [p for p in pairs if p["season"] != held_out] test_windows = { p["window_start"] for p in pairs if p["season"] == held_out } test_by_window = {ws: by_window[ws] for ws in test_windows} t_map = fit_quantile_mapping( [p["gc_temp"] for p in train], [p["era5_temp"] for p in train], ) h_map = fit_quantile_mapping( [p["gc_hum"] for p in train], [p["era5_hum"] for p in train], ) result = evaluate_trigger(test_by_window, t_map, h_map, era5_by_date) all_correct += int(result["hit_rate"] * result["n"]) all_total += result["n"] if result["n"] > 0: print(f" {held_out}: {result['hit_rate']:.1%} ({result['n']} windows)") cv_hit = all_correct / max(all_total, 1) print(f"\n CV hit rate: {cv_hit:.1%} ({all_correct}/{all_total})") # Fit final mapping on all data print("\nFitting final quantile mapping on all data...") t_mapping = fit_quantile_mapping( [p["gc_temp"] for p in pairs], [p["era5_temp"] for p in pairs], ) h_mapping = fit_quantile_mapping( [p["gc_hum"] for p in pairs], [p["era5_hum"] for p in pairs], ) print(f" Temperature: bias {t_mapping['mean_bias']:+.2f}°C " f"(GC mean={t_mapping['gc_mean']:.1f}, ERA5 mean={t_mapping['era5_mean']:.1f})") print(f" Humidity: bias {h_mapping['mean_bias']:+.2f}% " f"(GC mean={h_mapping['gc_mean']:.1f}, ERA5 mean={h_mapping['era5_mean']:.1f})") # Evaluate on all data (in-sample — use CV result for honest estimate) result = evaluate_trigger(by_window, t_mapping, h_mapping, era5_by_date) print(f"\n In-sample hit rate: {result['hit_rate']:.1%} ({result['n']} windows)") print(f" Predicted: {result['pred_dist']}") print(f" Optimal: {result['opt_dist']}") # Compare to baselines print(f"\n Comparison:") print(f" always_no_trigger: 60.6%") print(f" crude threshold (34.0): 67.4%") print(f" quantile-mapped MOS: {result['hit_rate']:.1%}") # Save OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) save_mapping(t_mapping, h_mapping, OUTPUT_PATH) print(f"\n Saved to {OUTPUT_PATH}") if __name__ == "__main__": main()