File size: 8,332 Bytes
277cbcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/env python3
"""Fit quantile-mapping MOS for GraphCast at Dar es Salaam.

Reads cached GraphCast forecasts from the lastmile-bench adapter cache
and pairs them with ERA5 actuals to build a quantile mapping for
temperature and humidity. Saves the mapping to models/graphcast_mos_dar.json.

Usage:
    python3 scripts/calibrate_graphcast_mos.py
    python3 scripts/calibrate_graphcast_mos.py --cv   # leave-one-season-out CV
"""
from __future__ import annotations

import argparse
import json
import sys
from collections import defaultdict
from datetime import date, timedelta
from pathlib import Path

import numpy as np

PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))

from src.calibration.graphcast_mos import (
    apply_mapping,
    correct_forecast_wbgt,
    fit_quantile_mapping,
    save_mapping,
)
from src.indexing.heat_index import calculate_wbgt
from src.prediction.forecast_trigger import forecast_trigger_decision

# Paths
BENCH_ROOT = Path.home() / "lastmile-bench"
CACHE_DIR = BENCH_ROOT / "adapters" / "climate_risk_engine" / ".cache" / "graphcast_insurance"
ERA5_PATH = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json"
OUTPUT_PATH = PROJECT_ROOT / "models" / "graphcast_mos_dar.json"

# Benchmark thresholds (from insurance_benchmark.py)
WINDOW_THRESHOLD_C = 35.1
PAYOUT_SEVERITY_C = 30.7
ALERT_DURATION_DAYS = 2
PAYOUT_DURATION_DAYS = 5


def load_paired_observations() -> list[dict]:
    """Build (GraphCast, ERA5) paired observations from cached forecasts."""
    era5 = json.loads(ERA5_PATH.read_text())
    era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]}

    pairs = []
    for f in sorted(CACHE_DIR.glob("gc_*.json")):
        date_str = f.stem.split("_")[-1]
        forecast = json.loads(f.read_text())
        ws = date.fromisoformat(date_str)

        # Determine season (Dec-Mar → season year = Dec's year)
        season = f"{ws.year}-{ws.year+1:02d}" if ws.month == 12 else f"{ws.year-1}-{ws.year:02d}"

        for day_idx in range(min(5, len(forecast) // 4)):
            lead = day_idx + 1
            steps = forecast[day_idx * 4 : (day_idx + 1) * 4]
            gc_temp = max(s.get("temperature", 0) for s in steps)
            gc_hum = sum(s.get("humidity", 0) for s in steps) / len(steps)

            forecast_date = (ws + timedelta(days=lead)).isoformat()
            era5_rec = era5_by_date.get(forecast_date)
            if era5_rec is None:
                continue

            pairs.append({
                "window_start": date_str,
                "forecast_date": forecast_date,
                "season": season,
                "lead": lead,
                "month": (ws + timedelta(days=lead)).month,
                "gc_temp": gc_temp,
                "gc_hum": gc_hum,
                "era5_temp": era5_rec["temp_max_c"],
                "era5_hum": era5_rec["humidity_pct"],
            })

    return pairs


def evaluate_trigger(
    pairs_by_window: dict[str, list[dict]],
    t_mapping: dict,
    h_mapping: dict,
    era5_by_date: dict[str, dict],
) -> dict:
    """Evaluate trigger accuracy with corrected forecasts."""
    correct = 0
    total = 0
    from collections import Counter
    pred_dist = Counter()
    opt_dist = Counter()

    for window_start, day_pairs in pairs_by_window.items():
        if len(day_pairs) < 5:
            continue
        day_pairs.sort(key=lambda p: p["lead"])

        # Corrected WBGT
        gc_temps = [p["gc_temp"] for p in day_pairs[:5]]
        gc_hums = [p["gc_hum"] for p in day_pairs[:5]]
        corrected_wbgt = correct_forecast_wbgt(
            gc_temps, gc_hums, t_mapping, h_mapping, calculate_wbgt,
        )

        # Predicted action
        pred = forecast_trigger_decision(
            corrected_wbgt,
            alert_duration_days=ALERT_DURATION_DAYS,
            payout_duration_days=PAYOUT_DURATION_DAYS,
            window_threshold_c=WINDOW_THRESHOLD_C,
            payout_severity_c=PAYOUT_SEVERITY_C,
        )

        # Optimal action (from ERA5 actuals)
        era5_wbgt = []
        ws = date.fromisoformat(window_start)
        for lead in range(1, 6):
            fd = (ws + timedelta(days=lead)).isoformat()
            rec = era5_by_date.get(fd)
            if rec:
                era5_wbgt.append(calculate_wbgt(rec["temp_max_c"], rec["humidity_pct"]))
        if len(era5_wbgt) < 5:
            continue

        opt = forecast_trigger_decision(
            era5_wbgt,
            alert_duration_days=ALERT_DURATION_DAYS,
            payout_duration_days=PAYOUT_DURATION_DAYS,
            window_threshold_c=WINDOW_THRESHOLD_C,
            payout_severity_c=PAYOUT_SEVERITY_C,
        )

        pred_dist[pred] += 1
        opt_dist[opt] += 1
        if pred == opt:
            correct += 1
        total += 1

    return {
        "hit_rate": correct / max(total, 1),
        "n": total,
        "pred_dist": dict(pred_dist),
        "opt_dist": dict(opt_dist),
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cv", action="store_true", help="Run leave-one-season-out CV")
    args = parser.parse_args()

    print("Loading paired observations...")
    pairs = load_paired_observations()
    print(f"  {len(pairs)} paired observations from {len(set(p['window_start'] for p in pairs))} forecasts")

    # Group by window_start for trigger evaluation
    by_window = defaultdict(list)
    for p in pairs:
        by_window[p["window_start"]].append(p)

    # Load ERA5 for evaluation
    era5 = json.loads(ERA5_PATH.read_text())
    era5_by_date = {r["date"]: r for r in era5["DAR-JAN"]}

    if args.cv:
        # Leave-one-season-out cross-validation
        seasons = sorted(set(p["season"] for p in pairs))
        print(f"\n  Leave-one-season-out CV across {len(seasons)} seasons...")

        all_correct = 0
        all_total = 0

        for held_out in seasons:
            train = [p for p in pairs if p["season"] != held_out]
            test_windows = {
                p["window_start"] for p in pairs if p["season"] == held_out
            }
            test_by_window = {ws: by_window[ws] for ws in test_windows}

            t_map = fit_quantile_mapping(
                [p["gc_temp"] for p in train],
                [p["era5_temp"] for p in train],
            )
            h_map = fit_quantile_mapping(
                [p["gc_hum"] for p in train],
                [p["era5_hum"] for p in train],
            )

            result = evaluate_trigger(test_by_window, t_map, h_map, era5_by_date)
            all_correct += int(result["hit_rate"] * result["n"])
            all_total += result["n"]

            if result["n"] > 0:
                print(f"    {held_out}: {result['hit_rate']:.1%} ({result['n']} windows)")

        cv_hit = all_correct / max(all_total, 1)
        print(f"\n  CV hit rate: {cv_hit:.1%} ({all_correct}/{all_total})")

    # Fit final mapping on all data
    print("\nFitting final quantile mapping on all data...")
    t_mapping = fit_quantile_mapping(
        [p["gc_temp"] for p in pairs],
        [p["era5_temp"] for p in pairs],
    )
    h_mapping = fit_quantile_mapping(
        [p["gc_hum"] for p in pairs],
        [p["era5_hum"] for p in pairs],
    )

    print(f"  Temperature: bias {t_mapping['mean_bias']:+.2f}°C "
          f"(GC mean={t_mapping['gc_mean']:.1f}, ERA5 mean={t_mapping['era5_mean']:.1f})")
    print(f"  Humidity:    bias {h_mapping['mean_bias']:+.2f}% "
          f"(GC mean={h_mapping['gc_mean']:.1f}, ERA5 mean={h_mapping['era5_mean']:.1f})")

    # Evaluate on all data (in-sample — use CV result for honest estimate)
    result = evaluate_trigger(by_window, t_mapping, h_mapping, era5_by_date)
    print(f"\n  In-sample hit rate: {result['hit_rate']:.1%} ({result['n']} windows)")
    print(f"  Predicted:  {result['pred_dist']}")
    print(f"  Optimal:    {result['opt_dist']}")

    # Compare to baselines
    print(f"\n  Comparison:")
    print(f"    always_no_trigger:      60.6%")
    print(f"    crude threshold (34.0): 67.4%")
    print(f"    quantile-mapped MOS:    {result['hit_rate']:.1%}")

    # Save
    OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    save_mapping(t_mapping, h_mapping, OUTPUT_PATH)
    print(f"\n  Saved to {OUTPUT_PATH}")


if __name__ == "__main__":
    main()