| |
| """Fit quantile-mapping MOS for GraphCast at Dar es Salaam. |
| |
| Reads cached GraphCast forecasts from the lastmile-bench adapter cache |
| and pairs them with ERA5 actuals to build a quantile mapping for |
| temperature and humidity. Saves the mapping to models/graphcast_mos_dar.json. |
| |
| Usage: |
| python3 scripts/calibrate_graphcast_mos.py |
| python3 scripts/calibrate_graphcast_mos.py --cv # leave-one-season-out CV |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from collections import defaultdict |
| from datetime import date, timedelta |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.calibration.graphcast_mos import ( |
| apply_mapping, |
| correct_forecast_wbgt, |
| fit_quantile_mapping, |
| save_mapping, |
| ) |
| from src.indexing.heat_index import calculate_wbgt |
| from src.prediction.forecast_trigger import forecast_trigger_decision |
|
|
| |
| BENCH_ROOT = Path.home() / "lastmile-bench" |
| CACHE_DIR = BENCH_ROOT / "adapters" / "climate_risk_engine" / ".cache" / "graphcast_insurance" |
| ERA5_PATH = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json" |
| OUTPUT_PATH = PROJECT_ROOT / "models" / "graphcast_mos_dar.json" |
|
|
| |
| WINDOW_THRESHOLD_C = 35.1 |
| PAYOUT_SEVERITY_C = 30.7 |
| ALERT_DURATION_DAYS = 2 |
| PAYOUT_DURATION_DAYS = 5 |
|
|
|
|
| def load_paired_observations() -> list[dict]: |
| """Build (GraphCast, ERA5) paired observations from cached forecasts.""" |
| era5 = json.loads(ERA5_PATH.read_text()) |
| era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]} |
|
|
| pairs = [] |
| for f in sorted(CACHE_DIR.glob("gc_*.json")): |
| date_str = f.stem.split("_")[-1] |
| forecast = json.loads(f.read_text()) |
| ws = date.fromisoformat(date_str) |
|
|
| |
| season = f"{ws.year}-{ws.year+1:02d}" if ws.month == 12 else f"{ws.year-1}-{ws.year:02d}" |
|
|
| for day_idx in range(min(5, len(forecast) // 4)): |
| lead = day_idx + 1 |
| steps = forecast[day_idx * 4 : (day_idx + 1) * 4] |
| gc_temp = max(s.get("temperature", 0) for s in steps) |
| gc_hum = sum(s.get("humidity", 0) for s in steps) / len(steps) |
|
|
| forecast_date = (ws + timedelta(days=lead)).isoformat() |
| era5_rec = era5_by_date.get(forecast_date) |
| if era5_rec is None: |
| continue |
|
|
| pairs.append({ |
| "window_start": date_str, |
| "forecast_date": forecast_date, |
| "season": season, |
| "lead": lead, |
| "month": (ws + timedelta(days=lead)).month, |
| "gc_temp": gc_temp, |
| "gc_hum": gc_hum, |
| "era5_temp": era5_rec["temp_max_c"], |
| "era5_hum": era5_rec["humidity_pct"], |
| }) |
|
|
| return pairs |
|
|
|
|
| def evaluate_trigger( |
| pairs_by_window: dict[str, list[dict]], |
| t_mapping: dict, |
| h_mapping: dict, |
| era5_by_date: dict[str, dict], |
| ) -> dict: |
| """Evaluate trigger accuracy with corrected forecasts.""" |
| correct = 0 |
| total = 0 |
| from collections import Counter |
| pred_dist = Counter() |
| opt_dist = Counter() |
|
|
| for window_start, day_pairs in pairs_by_window.items(): |
| if len(day_pairs) < 5: |
| continue |
| day_pairs.sort(key=lambda p: p["lead"]) |
|
|
| |
| gc_temps = [p["gc_temp"] for p in day_pairs[:5]] |
| gc_hums = [p["gc_hum"] for p in day_pairs[:5]] |
| corrected_wbgt = correct_forecast_wbgt( |
| gc_temps, gc_hums, t_mapping, h_mapping, calculate_wbgt, |
| ) |
|
|
| |
| pred = forecast_trigger_decision( |
| corrected_wbgt, |
| alert_duration_days=ALERT_DURATION_DAYS, |
| payout_duration_days=PAYOUT_DURATION_DAYS, |
| window_threshold_c=WINDOW_THRESHOLD_C, |
| payout_severity_c=PAYOUT_SEVERITY_C, |
| ) |
|
|
| |
| era5_wbgt = [] |
| ws = date.fromisoformat(window_start) |
| for lead in range(1, 6): |
| fd = (ws + timedelta(days=lead)).isoformat() |
| rec = era5_by_date.get(fd) |
| if rec: |
| era5_wbgt.append(calculate_wbgt(rec["temp_max_c"], rec["humidity_pct"])) |
| if len(era5_wbgt) < 5: |
| continue |
|
|
| opt = forecast_trigger_decision( |
| era5_wbgt, |
| alert_duration_days=ALERT_DURATION_DAYS, |
| payout_duration_days=PAYOUT_DURATION_DAYS, |
| window_threshold_c=WINDOW_THRESHOLD_C, |
| payout_severity_c=PAYOUT_SEVERITY_C, |
| ) |
|
|
| pred_dist[pred] += 1 |
| opt_dist[opt] += 1 |
| if pred == opt: |
| correct += 1 |
| total += 1 |
|
|
| return { |
| "hit_rate": correct / max(total, 1), |
| "n": total, |
| "pred_dist": dict(pred_dist), |
| "opt_dist": dict(opt_dist), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cv", action="store_true", help="Run leave-one-season-out CV") |
| args = parser.parse_args() |
|
|
| print("Loading paired observations...") |
| pairs = load_paired_observations() |
| print(f" {len(pairs)} paired observations from {len(set(p['window_start'] for p in pairs))} forecasts") |
|
|
| |
| by_window = defaultdict(list) |
| for p in pairs: |
| by_window[p["window_start"]].append(p) |
|
|
| |
| era5 = json.loads(ERA5_PATH.read_text()) |
| era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]} |
|
|
| if args.cv: |
| |
| seasons = sorted(set(p["season"] for p in pairs)) |
| print(f"\n Leave-one-season-out CV across {len(seasons)} seasons...") |
|
|
| all_correct = 0 |
| all_total = 0 |
|
|
| for held_out in seasons: |
| train = [p for p in pairs if p["season"] != held_out] |
| test_windows = { |
| p["window_start"] for p in pairs if p["season"] == held_out |
| } |
| test_by_window = {ws: by_window[ws] for ws in test_windows} |
|
|
| t_map = fit_quantile_mapping( |
| [p["gc_temp"] for p in train], |
| [p["era5_temp"] for p in train], |
| ) |
| h_map = fit_quantile_mapping( |
| [p["gc_hum"] for p in train], |
| [p["era5_hum"] for p in train], |
| ) |
|
|
| result = evaluate_trigger(test_by_window, t_map, h_map, era5_by_date) |
| all_correct += int(result["hit_rate"] * result["n"]) |
| all_total += result["n"] |
|
|
| if result["n"] > 0: |
| print(f" {held_out}: {result['hit_rate']:.1%} ({result['n']} windows)") |
|
|
| cv_hit = all_correct / max(all_total, 1) |
| print(f"\n CV hit rate: {cv_hit:.1%} ({all_correct}/{all_total})") |
|
|
| |
| print("\nFitting final quantile mapping on all data...") |
| t_mapping = fit_quantile_mapping( |
| [p["gc_temp"] for p in pairs], |
| [p["era5_temp"] for p in pairs], |
| ) |
| h_mapping = fit_quantile_mapping( |
| [p["gc_hum"] for p in pairs], |
| [p["era5_hum"] for p in pairs], |
| ) |
|
|
| print(f" Temperature: bias {t_mapping['mean_bias']:+.2f}°C " |
| f"(GC mean={t_mapping['gc_mean']:.1f}, ERA5 mean={t_mapping['era5_mean']:.1f})") |
| print(f" Humidity: bias {h_mapping['mean_bias']:+.2f}% " |
| f"(GC mean={h_mapping['gc_mean']:.1f}, ERA5 mean={h_mapping['era5_mean']:.1f})") |
|
|
| |
| result = evaluate_trigger(by_window, t_mapping, h_mapping, era5_by_date) |
| print(f"\n In-sample hit rate: {result['hit_rate']:.1%} ({result['n']} windows)") |
| print(f" Predicted: {result['pred_dist']}") |
| print(f" Optimal: {result['opt_dist']}") |
|
|
| |
| print(f"\n Comparison:") |
| print(f" always_no_trigger: 60.6%") |
| print(f" crude threshold (34.0): 67.4%") |
| print(f" quantile-mapped MOS: {result['hit_rate']:.1%}") |
|
|
| |
| OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) |
| save_mapping(t_mapping, h_mapping, OUTPUT_PATH) |
| print(f"\n Saved to {OUTPUT_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|