jtlevine Claude Opus 4.7 (1M context) commited on
Commit
64c4f2a
·
1 Parent(s): 202abe1

Cut dead neural/LSTM pricing + predictor paths

Browse files

These were left behind after the earlier XGBoost heat-wave predictor
cut (93f565b) and the move to empirical burn analysis for pricing.
Nothing in src/pipeline.py or src/api.py imports them; the eval tests
they ship with couldn't pass either (neural_pricer_dar.pt never
existed on disk, so test_neural_model_loads fails by construction).

Removed:
- src/prediction/heat_forecast.py (HeatWavePredictor)
- src/prediction/lstm_model.py (LSTM trainer + CITY_THRESHOLDS)
- src/pricing/neural_actuarial.py (neural pricer; replaced by burn analysis)
- src/notification/sender.py + __init__.py (notify step was cut in a
prior pipeline pass; no one imports this anymore)
- scripts/train_neural_pricer.py, train_on_era5.py, train_on_nasa_power.py,
train_lstm.py, backtest_pricing.py (all drove the dead paths)
- tests/eval_heat_predictor.py, eval_neural_pricer.py (evaluated dead code)
- models/heat_predictor_xgb.json, trigger_head_retrained.pt
- (untracked) models/heat_lstm.pt, lstm_norm.json,
scripts/retrain_trigger_heads.py

Kept: BurnAnalysisPricer, UHI XGBoost correction (still live),
GraphCast forecast_trigger, RAG index, basis_risk assessor.

Smoke-tested: src.pipeline + src.api import clean post-cut.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

models/heat_predictor_xgb.json DELETED
The diff for this file is too large to render. See raw diff
 
models/trigger_head_retrained.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6fb225a2eabbb94697c1c4a3f772d0850c54f48c83855a318282c54d77a1f186
3
- size 168777
 
 
 
 
scripts/backtest_pricing.py DELETED
@@ -1,508 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Actuarial backtesting: predicted frequency vs. actual trigger events.
4
-
5
- Runs the NeuralActuarialPricer (or GLM fallback) against 19 historical
6
- hot seasons (Dec-Mar 2005-06 through 2023-24) using ERA5-Land data for
7
- all 15 Dar es Salaam zones. Counts actual alert and payout events
8
- using the same thresholds as the benchmark panel, then compares to the
9
- model's learned_frequency (lambda).
10
-
11
- Usage:
12
- python scripts/backtest_pricing.py
13
- python scripts/backtest_pricing.py --zone DAR-JAN # single zone
14
- python scripts/backtest_pricing.py --all-zones # all 15 zones
15
- python scripts/backtest_pricing.py --no-uhi # skip UHI correction
16
- """
17
-
18
- from __future__ import annotations
19
-
20
- import argparse
21
- import json
22
- import sys
23
- from datetime import date, timedelta
24
- from pathlib import Path
25
-
26
- import numpy as np
27
-
28
- # ── Project imports ──────────────────────────────────────────────────────
29
- PROJECT_ROOT = Path.home() / "climate-risk-engine"
30
- sys.path.insert(0, str(PROJECT_ROOT))
31
-
32
- from config import ZONE_MAP, ZONES
33
- from src.indexing.heat_index import calculate_wbgt
34
- from src.downscaling.uhi_model import UHI_RANGES
35
-
36
- # ── Constants ────────────────────────────────────────────────────────────
37
- # Hot season: December through March (Dar es Salaam)
38
- HOT_SEASON_MONTHS = {12, 1, 2, 3}
39
-
40
- # Benchmark trigger thresholds (from neural_actuarial.py and CLAUDE.md)
41
- WINDOW_ENTRY_WBGT = 35.1 # approximate window-entry threshold
42
- ALERT_MIN_DAYS = 2 # consecutive days for alert tier
43
- PAYOUT_MIN_DAYS = 5 # consecutive days for payout tier
44
- PAYOUT_PEAK_WBGT = 30.7 # peak WBGT for full payout qualification
45
-
46
- # Climate history window fed to the pricer
47
- HISTORY_DAYS = 90
48
-
49
- # Default payout per event per worker (USD)
50
- PAYOUT_PER_EVENT = 10.0
51
-
52
-
53
- # ── Helpers ──────────────────────────────────────────────────────────────
54
-
55
- def load_era5_data() -> dict[str, list[dict]]:
56
- """Load ERA5-Land daily records keyed by zone_id."""
57
- path = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json"
58
- with open(path) as f:
59
- return json.load(f)
60
-
61
-
62
- def build_date_index(records: list[dict]) -> dict[str, int]:
63
- """Map date string -> index in records list."""
64
- return {r["date"]: i for i, r in enumerate(records)}
65
-
66
-
67
- def season_label(year: int) -> str:
68
- """e.g. season_label(2005) -> '2005-06'"""
69
- return f"{year}-{str(year + 1)[-2:]}"
70
-
71
-
72
- def get_season_dates(start_year: int) -> tuple[date, date]:
73
- """Return (first day of Dec, last day of Mar) for a hot season."""
74
- season_start = date(start_year, 12, 1)
75
- season_end = date(start_year + 1, 3, 31)
76
- return season_start, season_end
77
-
78
-
79
- def get_history_window(
80
- records: list[dict],
81
- date_idx: dict[str, int],
82
- season_start: date,
83
- ) -> list[dict]:
84
- """Get the 90-day climate history ending just before the season start."""
85
- end = season_start - timedelta(days=1)
86
- start = end - timedelta(days=HISTORY_DAYS - 1)
87
- window = []
88
- d = start
89
- while d <= end:
90
- ds = d.isoformat()
91
- if ds in date_idx:
92
- window.append(records[date_idx[ds]])
93
- d += timedelta(days=1)
94
- return window
95
-
96
-
97
- def get_season_records(
98
- records: list[dict],
99
- date_idx: dict[str, int],
100
- season_start: date,
101
- season_end: date,
102
- ) -> list[dict]:
103
- """Get all daily records within the hot season window."""
104
- out = []
105
- d = season_start
106
- while d <= season_end:
107
- ds = d.isoformat()
108
- if ds in date_idx:
109
- out.append(records[date_idx[ds]])
110
- d += timedelta(days=1)
111
- return out
112
-
113
-
114
- def count_trigger_events(
115
- season_records: list[dict],
116
- uhi_delta: float,
117
- apply_uhi: bool = True,
118
- ) -> dict:
119
- """
120
- Count alert and payout events during a season.
121
-
122
- A "window" = consecutive days where WBGT >= WINDOW_ENTRY_WBGT.
123
- - Alert event: window duration >= ALERT_MIN_DAYS
124
- - Payout event: window duration >= PAYOUT_MIN_DAYS AND peak WBGT >= PAYOUT_PEAK_WBGT
125
-
126
- Returns counts with and without UHI correction.
127
- """
128
- results = {}
129
- for label, use_uhi in [("uhi", True), ("grid", False)]:
130
- if label == "uhi" and not apply_uhi:
131
- continue
132
-
133
- delta = uhi_delta if use_uhi else 0.0
134
- wbgts = []
135
- for rec in season_records:
136
- t_max = (rec.get("temp_max_c") or 30.0) + delta
137
- hum = rec.get("humidity_pct") or 75.0
138
- wbgts.append(calculate_wbgt(t_max, hum))
139
-
140
- # Find consecutive windows above threshold
141
- alert_events = 0
142
- payout_events = 0
143
- run_length = 0
144
- run_peak = 0.0
145
-
146
- for w in wbgts:
147
- if w >= WINDOW_ENTRY_WBGT:
148
- run_length += 1
149
- run_peak = max(run_peak, w)
150
- else:
151
- if run_length >= ALERT_MIN_DAYS:
152
- alert_events += 1
153
- if run_length >= PAYOUT_MIN_DAYS and run_peak >= PAYOUT_PEAK_WBGT:
154
- payout_events += 1
155
- run_length = 0
156
- run_peak = 0.0
157
- # Close trailing run
158
- if run_length >= ALERT_MIN_DAYS:
159
- alert_events += 1
160
- if run_length >= PAYOUT_MIN_DAYS and run_peak >= PAYOUT_PEAK_WBGT:
161
- payout_events += 1
162
-
163
- results[label] = {
164
- "alert_events": alert_events,
165
- "payout_events": payout_events,
166
- "mean_wbgt": round(float(np.mean(wbgts)), 2) if wbgts else 0.0,
167
- "max_wbgt": round(float(np.max(wbgts)), 2) if wbgts else 0.0,
168
- "days_above_threshold": sum(1 for w in wbgts if w >= WINDOW_ENTRY_WBGT),
169
- }
170
-
171
- return results
172
-
173
-
174
- # ── Main backtest ────────────────────────────────────────────────────────
175
-
176
- def run_backtest(
177
- zone_ids: list[str] | None = None,
178
- apply_uhi: bool = True,
179
- verbose: bool = True,
180
- ) -> list[dict]:
181
- """
182
- Run the backtest across seasons and zones.
183
-
184
- Returns list of row dicts with season, zone, predicted, and actual counts.
185
- """
186
- # Load pricer (neural if available, else GLM fallback)
187
- try:
188
- from src.pricing.neural_actuarial import NeuralActuarialPricer
189
- pricer = NeuralActuarialPricer()
190
- pricer_type = "chronos" if pricer._encoder_type == "chronos" else (
191
- "lstm" if pricer._encoder_type == "lstm" else "glm"
192
- )
193
- except Exception as e:
194
- print(f"[warn] Neural pricer unavailable ({e}), using GLM fallback")
195
- from src.pricing.actuarial import ActuarialPricer as _AP
196
- pricer = _AP()
197
- pricer_type = "glm"
198
-
199
- if verbose:
200
- print(f"Pricer: {pricer_type}")
201
- print(f"UHI correction: {'ON' if apply_uhi else 'OFF'}")
202
- print()
203
-
204
- # Load ERA5 data
205
- era5 = load_era5_data()
206
-
207
- # Resolve zone list
208
- dar_zones = [z for z in ZONES if z.city == "Dar es Salaam"]
209
- if zone_ids:
210
- dar_zones = [z for z in dar_zones if z.zone_id in zone_ids]
211
- if not dar_zones:
212
- print("No matching zones found.")
213
- return []
214
-
215
- if verbose:
216
- print(f"Zones: {[z.zone_id for z in dar_zones]}")
217
- print()
218
-
219
- # Seasons: Dec 2005 through Mar 2024 -> start years 2005..2023
220
- season_years = list(range(2005, 2024))
221
- rows = []
222
-
223
- for zone in dar_zones:
224
- records = era5.get(zone.zone_id, [])
225
- if not records:
226
- print(f"[warn] No ERA5 data for {zone.zone_id}, skipping")
227
- continue
228
-
229
- date_idx = build_date_index(records)
230
- uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
231
- mean_uhi = (uhi_lo + uhi_hi) / 2.0
232
-
233
- for sy in season_years:
234
- season_start, season_end = get_season_dates(sy)
235
-
236
- # 1. Get 90-day history ending before season start
237
- history = get_history_window(records, date_idx, season_start)
238
- if len(history) < 30:
239
- continue # not enough history
240
-
241
- # 2. Run pricer to get predicted frequency (lambda)
242
- predicted_freq = None
243
- try:
244
- if pricer_type == "glm":
245
- # GLM needs a frequency estimate -- use historical rate
246
- # from the previous season as a naive baseline
247
- result = pricer.price_zone(
248
- zone=zone,
249
- predicted_frequency=10.0, # placeholder
250
- basis_risk_score=0.3,
251
- payout_per_event=PAYOUT_PER_EVENT,
252
- enrolled=zone.worker_population_est,
253
- )
254
- predicted_freq = 10.0 # GLM doesn't learn frequency
255
- else:
256
- result = pricer.price_zone(
257
- zone=zone,
258
- predicted_frequency=10.0,
259
- basis_risk_score=0.3,
260
- payout_per_event=PAYOUT_PER_EVENT,
261
- enrolled=zone.worker_population_est,
262
- climate_history=history,
263
- )
264
- predicted_freq = result.cost_breakdown.get(
265
- "learned_frequency", 10.0
266
- )
267
- except Exception as e:
268
- if verbose:
269
- print(f" [warn] Pricer failed for {zone.zone_id} "
270
- f"season {season_label(sy)}: {e}")
271
- continue
272
-
273
- # 3. Count actual events during the season
274
- season_recs = get_season_records(
275
- records, date_idx, season_start, season_end
276
- )
277
- if not season_recs:
278
- continue
279
-
280
- actuals = count_trigger_events(
281
- season_recs, mean_uhi, apply_uhi=apply_uhi
282
- )
283
-
284
- # Build row -- use UHI-corrected actuals for primary comparison
285
- # if UHI is on, otherwise use grid actuals
286
- primary = actuals.get("uhi", actuals.get("grid", {}))
287
- grid = actuals.get("grid", {})
288
-
289
- row = {
290
- "season": season_label(sy),
291
- "zone_id": zone.zone_id,
292
- "zone_name": zone.name,
293
- "settlement_type": zone.settlement_type,
294
- "predicted_freq": round(predicted_freq, 1),
295
- "actual_alert_uhi": primary.get("alert_events", 0),
296
- "actual_payout_uhi": primary.get("payout_events", 0),
297
- "actual_alert_grid": grid.get("alert_events", 0),
298
- "actual_payout_grid": grid.get("payout_events", 0),
299
- "mean_wbgt_uhi": primary.get("mean_wbgt", 0),
300
- "max_wbgt_uhi": primary.get("max_wbgt", 0),
301
- "days_above_uhi": primary.get("days_above_threshold", 0),
302
- "mean_wbgt_grid": grid.get("mean_wbgt", 0),
303
- "max_wbgt_grid": grid.get("max_wbgt", 0),
304
- "days_above_grid": grid.get("days_above_threshold", 0),
305
- "season_days": len(season_recs),
306
- }
307
- rows.append(row)
308
-
309
- return rows
310
-
311
-
312
- def print_summary(rows: list[dict], apply_uhi: bool = True) -> None:
313
- """Print formatted summary tables."""
314
- if not rows:
315
- print("No backtest results to display.")
316
- return
317
-
318
- # ── Per-season detail table ──────────────────────────────────────
319
- zones_in_data = sorted(set(r["zone_id"] for r in rows))
320
- single_zone = len(zones_in_data) == 1
321
-
322
- print("=" * 95)
323
- print("ACTUARIAL BACKTEST: Predicted Frequency vs. Actual Trigger Events")
324
- print("=" * 95)
325
- print()
326
-
327
- if single_zone:
328
- zone_id = zones_in_data[0]
329
- zone_name = rows[0]["zone_name"]
330
- print(f"Zone: {zone_id} ({zone_name})")
331
- print(f"Threshold: WBGT >= {WINDOW_ENTRY_WBGT} C (window entry)")
332
- print(f"Alert: >= {ALERT_MIN_DAYS} consecutive days | "
333
- f"Payout: >= {PAYOUT_MIN_DAYS} days AND peak WBGT >= {PAYOUT_PEAK_WBGT}")
334
- print()
335
-
336
- header = (
337
- f"{'Season':<10} {'Pred_Freq':>9} "
338
- f"{'Alert_UHI':>9} {'Payout_UHI':>10} "
339
- f"{'Alert_Grid':>10} {'Payout_Grid':>11} "
340
- f"{'MeanWBGT':>8} {'MaxWBGT':>7} {'Days>Thr':>8}"
341
- )
342
- print(header)
343
- print("-" * len(header))
344
-
345
- for r in sorted(rows, key=lambda x: x["season"]):
346
- print(
347
- f"{r['season']:<10} {r['predicted_freq']:>9.1f} "
348
- f"{r['actual_alert_uhi']:>9} {r['actual_payout_uhi']:>10} "
349
- f"{r['actual_alert_grid']:>10} {r['actual_payout_grid']:>11} "
350
- f"{r['mean_wbgt_uhi']:>8.1f} {r['max_wbgt_uhi']:>7.1f} "
351
- f"{r['days_above_uhi']:>8}"
352
- )
353
- else:
354
- # Multi-zone: aggregate by zone
355
- print(f"Zones: {len(zones_in_data)} | Seasons: "
356
- f"{sorted(set(r['season'] for r in rows))[0]} to "
357
- f"{sorted(set(r['season'] for r in rows))[-1]}")
358
- print(f"Threshold: WBGT >= {WINDOW_ENTRY_WBGT} C")
359
- print()
360
-
361
- header = (
362
- f"{'Zone':<10} {'Type':<11} {'Seasons':>7} "
363
- f"{'MeanPred':>8} {'MeanAlertU':>10} {'MeanPayU':>8} "
364
- f"{'MeanAlertG':>10} {'MeanPayG':>8} {'MeanWBGT':>8}"
365
- )
366
- print(header)
367
- print("-" * len(header))
368
-
369
- for zid in zones_in_data:
370
- zrows = [r for r in rows if r["zone_id"] == zid]
371
- n = len(zrows)
372
- print(
373
- f"{zid:<10} {zrows[0]['settlement_type']:<11} {n:>7} "
374
- f"{np.mean([r['predicted_freq'] for r in zrows]):>8.1f} "
375
- f"{np.mean([r['actual_alert_uhi'] for r in zrows]):>10.1f} "
376
- f"{np.mean([r['actual_payout_uhi'] for r in zrows]):>8.1f} "
377
- f"{np.mean([r['actual_alert_grid'] for r in zrows]):>10.1f} "
378
- f"{np.mean([r['actual_payout_grid'] for r in zrows]):>8.1f} "
379
- f"{np.mean([r['mean_wbgt_uhi'] for r in zrows]):>8.1f}"
380
- )
381
-
382
- # ── Aggregate metrics ────────────────────────────────────────────
383
- print()
384
- print("=" * 95)
385
- print("AGGREGATE METRICS")
386
- print("=" * 95)
387
- print()
388
-
389
- pred = np.array([r["predicted_freq"] for r in rows])
390
- # Annualize actuals: season is ~121 days (Dec-Mar), so scale up
391
- # But predicted_freq is already annual. We compare predicted annual
392
- # to actual season counts directly -- the model should predict events
393
- # per year, but events in Dec-Mar ARE the bulk of the hot season.
394
- # So the comparison is: does the model's annual lambda match the
395
- # observed rate when we look at the actual hot season?
396
-
397
- for label_suffix, alert_key, payout_key in [
398
- ("(UHI-corrected)", "actual_alert_uhi", "actual_payout_uhi"),
399
- ("(grid / no UHI)", "actual_alert_grid", "actual_payout_grid"),
400
- ]:
401
- if not apply_uhi and label_suffix == "(UHI-corrected)":
402
- continue
403
-
404
- actual_alert = np.array([r[alert_key] for r in rows], dtype=float)
405
- actual_payout = np.array([r[payout_key] for r in rows], dtype=float)
406
-
407
- print(f"--- {label_suffix} ---")
408
- print(f" Mean predicted frequency (annual lambda): {pred.mean():.2f}")
409
- print(f" Mean actual alert events per season: {actual_alert.mean():.2f}")
410
- print(f" Mean actual payout events per season: {actual_payout.mean():.2f}")
411
- print()
412
- print(f" Predicted/Actual alert ratio: "
413
- f"{pred.mean() / max(actual_alert.mean(), 0.01):.2f}x")
414
- print(f" Predicted/Actual payout ratio: "
415
- f"{pred.mean() / max(actual_payout.mean(), 0.01):.2f}x")
416
-
417
- # Correlation (only meaningful if there's variance)
418
- if actual_alert.std() > 0 and pred.std() > 0:
419
- corr_alert = float(np.corrcoef(pred, actual_alert)[0, 1])
420
- print(f" Pearson correlation (pred vs alert): {corr_alert:.3f}")
421
- else:
422
- print(f" Pearson correlation (pred vs alert): N/A (no variance)")
423
-
424
- if actual_payout.std() > 0 and pred.std() > 0:
425
- corr_payout = float(np.corrcoef(pred, actual_payout)[0, 1])
426
- print(f" Pearson correlation (pred vs payout): {corr_payout:.3f}")
427
- else:
428
- print(f" Pearson correlation (pred vs payout): N/A (no variance)")
429
-
430
- # RMSE
431
- rmse_alert = float(np.sqrt(np.mean((pred - actual_alert) ** 2)))
432
- rmse_payout = float(np.sqrt(np.mean((pred - actual_payout) ** 2)))
433
- print(f" RMSE (pred vs alert): {rmse_alert:.2f}")
434
- print(f" RMSE (pred vs payout): {rmse_payout:.2f}")
435
-
436
- # Per-season hit rate: how often does predicted > 0 match actual > 0
437
- pred_positive = pred > 0.5
438
- actual_alert_positive = actual_alert > 0
439
- actual_payout_positive = actual_payout > 0
440
- if len(pred) > 0:
441
- alert_hit = float(
442
- np.mean(pred_positive == actual_alert_positive) * 100
443
- )
444
- payout_hit = float(
445
- np.mean(pred_positive == actual_payout_positive) * 100
446
- )
447
- print(f" Direction accuracy (alert): {alert_hit:.0f}%")
448
- print(f" Direction accuracy (payout): {payout_hit:.0f}%")
449
- print()
450
-
451
- # ── By settlement type ───────────────────────────────────────────
452
- stypes = sorted(set(r["settlement_type"] for r in rows))
453
- if len(stypes) > 1:
454
- print("--- By settlement type ---")
455
- for st in stypes:
456
- st_rows = [r for r in rows if r["settlement_type"] == st]
457
- st_pred = np.mean([r["predicted_freq"] for r in st_rows])
458
- st_alert = np.mean([r["actual_alert_uhi"] for r in st_rows])
459
- st_payout = np.mean([r["actual_payout_uhi"] for r in st_rows])
460
- print(f" {st:<12} pred={st_pred:.1f} alert={st_alert:.1f} "
461
- f"payout={st_payout:.1f} "
462
- f"ratio={st_pred / max(st_alert, 0.01):.1f}x")
463
- print()
464
-
465
-
466
- def main():
467
- parser = argparse.ArgumentParser(
468
- description="Backtest pricing model against historical ERA5-Land data"
469
- )
470
- parser.add_argument(
471
- "--zone", type=str, default="DAR-JAN",
472
- help="Zone ID to test (default: DAR-JAN)"
473
- )
474
- parser.add_argument(
475
- "--all-zones", action="store_true",
476
- help="Run backtest across all 15 Dar es Salaam zones"
477
- )
478
- parser.add_argument(
479
- "--no-uhi", action="store_true",
480
- help="Disable UHI correction (grid-only WBGT)"
481
- )
482
- parser.add_argument(
483
- "--quiet", action="store_true",
484
- help="Suppress per-season detail, show only aggregate metrics"
485
- )
486
- args = parser.parse_args()
487
-
488
- apply_uhi = not args.no_uhi
489
-
490
- if args.all_zones:
491
- zone_ids = None # all Dar zones
492
- else:
493
- zone_ids = [args.zone]
494
-
495
- rows = run_backtest(
496
- zone_ids=zone_ids,
497
- apply_uhi=apply_uhi,
498
- verbose=not args.quiet,
499
- )
500
-
501
- if not args.quiet:
502
- print()
503
-
504
- print_summary(rows, apply_uhi=apply_uhi)
505
-
506
-
507
- if __name__ == "__main__":
508
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_lstm.py DELETED
@@ -1,58 +0,0 @@
1
- """Train the LSTM heat wave predictor.
2
-
3
- Tries ERA5 data first, falls back to synthetic data generation
4
- (same seasonal + AR(1) approach as the existing XGBoost trainer).
5
-
6
- Usage:
7
- python3 scripts/train_lstm.py
8
- """
9
-
10
- import sys
11
- import time
12
-
13
- sys.path.insert(0, ".")
14
-
15
- from src.prediction.lstm_model import LSTMTrainer, generate_synthetic_zone_data
16
- from config import ZONES
17
-
18
- print("=" * 60)
19
- print("LSTM Heat Wave Predictor -- Training")
20
- print("=" * 60)
21
-
22
- # Try ERA5 data first, fall back to synthetic
23
- zone_data = None
24
-
25
- try:
26
- from src.ingestion.era5_fetcher import fetch_era5_sync
27
-
28
- print("\nFetching ERA5 data for training...")
29
- raw = fetch_era5_sync(ZONES, days_back=365)
30
- # Convert to training format if fetch succeeds
31
- if raw and len(raw) > 0:
32
- zone_data = raw
33
- print(f" Loaded ERA5 data for {len(zone_data)} zones")
34
- except Exception as e:
35
- print(f"\nERA5 unavailable ({e}), using synthetic training data")
36
-
37
- if zone_data is None:
38
- print("\nGenerating synthetic training data (2 years x 20 zones)...")
39
- t0 = time.time()
40
- zone_data = generate_synthetic_zone_data(ZONES, n_days=730, seed=42)
41
- elapsed = time.time() - t0
42
- total_days = sum(len(v) for v in zone_data.values())
43
- print(f" Generated {total_days:,} zone-days in {elapsed:.1f}s")
44
-
45
- # Train
46
- print("\nTraining LSTM...")
47
- t0 = time.time()
48
- trainer = LSTMTrainer(epochs=50, patience=5)
49
- metrics = trainer.train(zone_data)
50
- elapsed = time.time() - t0
51
-
52
- print(f"\nTraining complete in {elapsed:.1f}s")
53
- print(f" Epochs trained: {metrics.get('epochs_trained', '?')}")
54
- print(f" Train loss: {metrics.get('train_loss', '?')}")
55
- print(f" Val loss: {metrics.get('val_loss', '?')}")
56
- print(f" Val AUROC: {metrics.get('val_auroc', '?')}")
57
- print(f" Samples: {metrics.get('samples', '?')}")
58
- print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_neural_pricer.py DELETED
@@ -1,142 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Train the Neural Actuarial Pricing Engine on real climate data.
4
-
5
- Usage:
6
- python3 scripts/train_neural_pricer.py # default: config.PRIMARY_CITY
7
- python3 scripts/train_neural_pricer.py --city Kampala # Other city
8
-
9
- Requires: data/era5land_{slug}.json where slug is derived from the primary city
10
- (config.PRIMARY_CITY) or the --city flag.
11
- """
12
-
13
- import argparse
14
- import json
15
- import sys
16
- from pathlib import Path
17
-
18
- # Add project root to path
19
- sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
20
-
21
- from config import ZONES, ZONE_MAP, slug_for
22
- from src.pricing.neural_actuarial import (
23
- build_training_samples,
24
- load_climate_data,
25
- NeuralPricerTrainer,
26
- )
27
-
28
-
29
- def main():
30
- parser = argparse.ArgumentParser(description="Train neural actuarial pricer")
31
- parser.add_argument("--city", default="Dar es Salaam", help="City to train for")
32
- parser.add_argument("--data", default=None, help="Path to climate data JSON")
33
- parser.add_argument("--epochs", type=int, default=80)
34
- parser.add_argument("--lr", type=float, default=1e-3)
35
- parser.add_argument("--patience", type=int, default=10)
36
- parser.add_argument("--encoder", choices=["chronos", "lstm"], default="chronos",
37
- help="Encoder type: chronos (foundation model) or lstm (legacy)")
38
- args = parser.parse_args()
39
-
40
- data_dir = Path(__file__).resolve().parents[1] / "data"
41
-
42
- # Find climate data file
43
- if args.data:
44
- data_path = Path(args.data)
45
- else:
46
- city_slug = slug_for(args.city)
47
- # Try ERA5-Land first, fall back to NASA POWER
48
- era5_path = data_dir / f"era5land_{city_slug}.json"
49
- nasa_path = data_dir / f"nasa_power_{city_slug}.json"
50
- if era5_path.exists():
51
- data_path = era5_path
52
- print(f"Using ERA5-Land data: {data_path}")
53
- elif nasa_path.exists():
54
- data_path = nasa_path
55
- print(f"Using NASA POWER data: {data_path}")
56
- else:
57
- print(f"ERROR: No climate data found for {args.city}")
58
- print(f" Looked for: {era5_path}")
59
- print(f" Looked for: {nasa_path}")
60
- print(f" Run scripts/fetch_nasa_power_dar.py first")
61
- sys.exit(1)
62
-
63
- # Filter zones for the target city
64
- city_zones = [z for z in ZONES if z.city == args.city]
65
- if not city_zones:
66
- print(f"ERROR: No zones found for city '{args.city}'")
67
- print(f" Available cities: {sorted(set(z.city for z in ZONES))}")
68
- sys.exit(1)
69
-
70
- print(f"\n{'='*60}")
71
- print(f"Neural Actuarial Pricer — {args.city}")
72
- print(f"{'='*60}")
73
- print(f" Zones: {len(city_zones)} ({', '.join(z.name for z in city_zones)})")
74
- print(f" Data: {data_path}")
75
-
76
- # Load climate data
77
- climate_data = load_climate_data(data_path)
78
- zone_ids = {z.zone_id for z in city_zones}
79
- climate_data = {k: v for k, v in climate_data.items() if k in zone_ids}
80
-
81
- if not climate_data:
82
- print(f"ERROR: No matching zone data found in {data_path}")
83
- print(f" Expected zone IDs: {zone_ids}")
84
- print(f" Found zone IDs: {set(load_climate_data(data_path).keys())}")
85
- sys.exit(1)
86
-
87
- print(f" Records: {sum(len(v) for v in climate_data.values()):,} zone-days")
88
-
89
- # Build training samples with city-specific WBGT threshold
90
- from src.pricing.neural_actuarial import CITY_WBGT_THRESHOLDS
91
- wbgt_thresh = CITY_WBGT_THRESHOLDS.get(args.city, 35.0)
92
- print(f"\nBuilding training samples (90-day windows, stride=7, WBGT threshold={wbgt_thresh}°C)...")
93
- X, targets = build_training_samples(climate_data, city_zones, wbgt_threshold=wbgt_thresh)
94
- print(f" Samples: {len(X):,}")
95
- print(f" Shape: {X.shape}")
96
- print(f" Target ranges:")
97
- for k, v in targets.items():
98
- print(f" {k}: [{v.min():.3f}, {v.max():.3f}], mean={v.mean():.3f}")
99
-
100
- # Train
101
- print(f"\nTraining (encoder={args.encoder}, epochs={args.epochs}, lr={args.lr}, patience={args.patience})...")
102
- trainer = NeuralPricerTrainer(
103
- lr=args.lr, epochs=args.epochs, patience=args.patience,
104
- encoder=args.encoder,
105
- )
106
- metrics = trainer.train(X, targets)
107
-
108
- print(f"\n{'='*60}")
109
- print(f"Training complete")
110
- print(f"{'='*60}")
111
- for k, v in metrics.items():
112
- print(f" {k}: {v}")
113
-
114
- # Quick inference test
115
- print(f"\nInference test:")
116
- from src.pricing.neural_actuarial import NeuralActuarialPricer
117
- pricer = NeuralActuarialPricer()
118
- print(f" Neural model loaded: {pricer.is_neural}")
119
-
120
- if pricer.is_neural:
121
- for zone in city_zones:
122
- history = climate_data[zone.zone_id][-90:]
123
- result = pricer.price_zone(
124
- zone=zone,
125
- predicted_frequency=10.0, # ignored by neural model
126
- basis_risk_score=0.2, # ignored by neural model
127
- payout_per_event=10.0,
128
- enrolled=zone.worker_population_est,
129
- climate_history=history,
130
- )
131
- cb = result.cost_breakdown
132
- print(
133
- f" {zone.name:12s} ({zone.settlement_type:9s}): "
134
- f"${result.cost_per_worker_year:.2f}/worker/yr "
135
- f"(λ={cb.get('learned_frequency', '?')}, "
136
- f"basis_risk={cb.get('learned_basis_risk', '?')}, "
137
- f"δ_NN={cb.get('neural_correction_pct', '?'):+.1f}%)"
138
- )
139
-
140
-
141
- if __name__ == "__main__":
142
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_on_era5.py DELETED
@@ -1,491 +0,0 @@
1
- """
2
- Train all ML models on real ERA5 reanalysis data.
3
-
4
- Steps:
5
- 1. Fetch 2 years of ERA5 data for all 20 zones via Google ARCO Zarr store
6
- 2. Validate data quality (coverage, temp ranges, nulls)
7
- 3. Retrain XGBoost heat predictor on real data
8
- 4. Retrain LSTM on real data
9
- 5. Verify UHI model works with real ERA5 temps
10
- """
11
-
12
- import sys
13
- import os
14
- import time
15
- import logging
16
- import math
17
-
18
- import numpy as np
19
-
20
- # Project root on sys.path
21
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
-
23
- from config import ZONES, ZONE_MAP
24
- from src.ingestion.era5_fetcher import fetch_era5_sync
25
- from src.ingestion.models import DailyReading
26
- from src.indexing.heat_index import calculate_wbgt
27
-
28
- logging.basicConfig(
29
- level=logging.INFO,
30
- format="%(asctime)s %(name)s %(levelname)s %(message)s",
31
- datefmt="%H:%M:%S",
32
- )
33
- log = logging.getLogger("train_era5")
34
-
35
- # Expected temp ranges per city (max daily temps, deg C)
36
- EXPECTED_RANGES = {
37
- "Nairobi": (18, 35),
38
- "Dar es Salaam": (25, 40),
39
- "Kampala": (22, 36),
40
- "Kigali": (20, 34),
41
- }
42
-
43
-
44
- # ======================================================================
45
- # Step 1: Fetch ERA5 data
46
- # ======================================================================
47
-
48
- def fetch_data():
49
- log.info("=" * 60)
50
- log.info("STEP 1: Fetching 2 years of ERA5 data for %d zones", len(ZONES))
51
- log.info("=" * 60)
52
- t0 = time.time()
53
- data = fetch_era5_sync(ZONES, days_back=730)
54
- elapsed = time.time() - t0
55
- log.info("Fetch complete in %.1f seconds", elapsed)
56
- return data
57
-
58
-
59
- # ======================================================================
60
- # Step 2: Validate data quality
61
- # ======================================================================
62
-
63
- def validate_data(data: dict[str, list[DailyReading]]):
64
- log.info("=" * 60)
65
- log.info("STEP 2: Validating ERA5 data quality")
66
- log.info("=" * 60)
67
-
68
- issues = []
69
- stats = {}
70
-
71
- for zone in ZONES:
72
- zid = zone.zone_id
73
- readings = data.get(zid, [])
74
-
75
- if not readings:
76
- issues.append(f"{zid}: NO DATA")
77
- stats[zid] = {"days": 0, "issue": "no data"}
78
- continue
79
-
80
- temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
81
- humids = [r.humidity_pct for r in readings if r.humidity_pct is not None]
82
- winds = [r.wind_speed_ms for r in readings if r.wind_speed_ms is not None]
83
-
84
- if not temps:
85
- issues.append(f"{zid}: all temps are null")
86
- stats[zid] = {"days": len(readings), "issue": "all null temps"}
87
- continue
88
-
89
- t_min, t_max = min(temps), max(temps)
90
- t_mean = sum(temps) / len(temps)
91
-
92
- # Check physical reasonableness
93
- exp_lo, exp_hi = EXPECTED_RANGES.get(zone.city, (15, 42))
94
- if t_min < exp_lo - 5 or t_max > exp_hi + 5:
95
- issues.append(
96
- f"{zid} ({zone.city}): temp range [{t_min:.1f}, {t_max:.1f}] "
97
- f"outside expected [{exp_lo-5}, {exp_hi+5}]"
98
- )
99
-
100
- null_count = sum(1 for r in readings if r.temp_max_c is None)
101
-
102
- stats[zid] = {
103
- "days": len(readings),
104
- "temp_days": len(temps),
105
- "temp_min": round(t_min, 1),
106
- "temp_max": round(t_max, 1),
107
- "temp_mean": round(t_mean, 1),
108
- "humidity_mean": round(sum(humids)/len(humids), 1) if humids else None,
109
- "wind_mean": round(sum(winds)/len(winds), 1) if winds else None,
110
- "null_temps": null_count,
111
- }
112
-
113
- # Print summary
114
- print("\n--- ERA5 Data Summary ---")
115
- print(f"{'Zone':<12} {'City':<16} {'Days':>5} {'Temp min':>9} {'Temp max':>9} {'Temp mean':>10} {'Humidity':>9} {'Nulls':>6}")
116
- print("-" * 90)
117
-
118
- by_city = {}
119
- for zone in ZONES:
120
- s = stats.get(zone.zone_id, {})
121
- days = s.get("days", 0)
122
- t_lo = s.get("temp_min", "N/A")
123
- t_hi = s.get("temp_max", "N/A")
124
- t_mn = s.get("temp_mean", "N/A")
125
- hum = s.get("humidity_mean", "N/A")
126
- nulls = s.get("null_temps", "N/A")
127
- print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {t_lo:>9} {t_hi:>9} {t_mn:>10} {hum:>9} {nulls:>6}")
128
-
129
- city = zone.city
130
- if city not in by_city:
131
- by_city[city] = []
132
- by_city[city].append(s)
133
-
134
- print("\n--- Per-city aggregated temp ranges ---")
135
- for city, zone_stats in by_city.items():
136
- all_mins = [s["temp_min"] for s in zone_stats if s.get("temp_min") is not None]
137
- all_maxs = [s["temp_max"] for s in zone_stats if s.get("temp_max") is not None]
138
- if all_mins and all_maxs:
139
- print(f" {city:<16}: {min(all_mins):.1f} - {max(all_maxs):.1f} C")
140
-
141
- if issues:
142
- print(f"\n ISSUES ({len(issues)}):")
143
- for issue in issues:
144
- print(f" - {issue}")
145
- else:
146
- print("\n No data quality issues found.")
147
-
148
- zones_with_data = sum(1 for s in stats.values() if s.get("days", 0) > 0)
149
- assert zones_with_data == len(ZONES), f"Only {zones_with_data}/{len(ZONES)} zones have data"
150
- print(f"\n All {zones_with_data} zones have data.\n")
151
-
152
- return stats
153
-
154
-
155
- # ======================================================================
156
- # Step 3: Retrain XGBoost heat predictor on real data
157
- # ======================================================================
158
-
159
- def retrain_xgboost(data: dict[str, list[DailyReading]]):
160
- log.info("=" * 60)
161
- log.info("STEP 3: Retraining XGBoost heat predictor on real ERA5 data")
162
- log.info("=" * 60)
163
-
164
- from src.prediction.heat_forecast import HeatWavePredictor, CITY_THRESHOLDS, CITY_CLIMATE
165
- from src.prediction.lstm_model import CITY_CLIMATE as _ # ensure import works
166
-
167
- import xgboost as xgb
168
-
169
- # We replicate the training logic from HeatWavePredictor.train() but
170
- # use real ERA5 temps/humidity instead of synthetic series.
171
-
172
- all_X = []
173
- all_y = []
174
-
175
- for zone in ZONES:
176
- zid = zone.zone_id
177
- readings = data.get(zid, [])
178
- if len(readings) < 40:
179
- log.warning("Zone %s has only %d readings, skipping for XGBoost training", zid, len(readings))
180
- continue
181
-
182
- city = zone.city
183
- threshold = CITY_THRESHOLDS.get(city, 33.0)
184
-
185
- # Extract time series from real data
186
- temps = []
187
- humidity = []
188
- for r in readings:
189
- t = r.temp_max_c
190
- h = r.humidity_pct
191
- if t is None:
192
- continue
193
- temps.append(t)
194
- humidity.append(h if h is not None else 65.0)
195
-
196
- n_days = len(temps)
197
- if n_days < 40:
198
- log.warning("Zone %s has only %d valid temp readings, skipping", zid, n_days)
199
- continue
200
-
201
- # Compute WBGT series
202
- wbgt_series = [calculate_wbgt(t, h) for t, h in zip(temps, humidity)]
203
-
204
- # Labels: trigger within next 7 days (2+ consecutive above threshold)
205
- labels = [0] * n_days
206
- for day in range(n_days - 7):
207
- window = temps[day + 1:day + 8]
208
- consec = 0
209
- triggered = False
210
- for t in window:
211
- if t > threshold:
212
- consec += 1
213
- if consec >= 2:
214
- triggered = True
215
- break
216
- else:
217
- consec = 0
218
- labels[day] = 1 if triggered else 0
219
-
220
- # Vulnerability encoding
221
- vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
222
- zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
223
-
224
- rng = np.random.default_rng(42)
225
-
226
- # Build features (need 30-day lookback)
227
- for day in range(30, n_days - 7):
228
- t_window = temps[day - 30:day + 1]
229
- h_window = humidity[day - 30:day + 1]
230
- w_window = wbgt_series[day - 30:day + 1]
231
-
232
- current_temp = t_window[-1]
233
- current_wbgt = w_window[-1]
234
- current_humidity = h_window[-1]
235
-
236
- # Trend: slope of last 7 days
237
- x7 = np.arange(7, dtype=np.float64)
238
- y7 = np.array(t_window[-7:], dtype=np.float64)
239
- temp_trend = float(np.polyfit(x7, y7, 1)[0])
240
-
241
- # Anomaly: current vs 30-day mean
242
- temp_anomaly = current_temp - float(np.mean(t_window))
243
-
244
- # Soil moisture proxy
245
- soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
246
-
247
- # Rolling error (use neutral prior for training data)
248
- rolling_err = rng.uniform(0.1, 0.5)
249
-
250
- # Day-of-year encoding (use day index within 365-day cycle)
251
- doy = day % 365
252
- doy_sin = np.sin(2 * np.pi * doy / 365.0)
253
- doy_cos = np.cos(2 * np.pi * doy / 365.0)
254
-
255
- # Random hour for variety
256
- hour = rng.integers(6, 19)
257
- hour_sin = np.sin(2 * np.pi * hour / 24.0)
258
- hour_cos = np.cos(2 * np.pi * hour / 24.0)
259
-
260
- row = [
261
- current_temp,
262
- current_wbgt,
263
- current_humidity,
264
- temp_trend,
265
- temp_anomaly,
266
- soil_proxy,
267
- rolling_err,
268
- doy_sin,
269
- doy_cos,
270
- hour_sin,
271
- hour_cos,
272
- zone_vuln,
273
- ]
274
-
275
- all_X.append(row)
276
- all_y.append(labels[day])
277
-
278
- X = np.array(all_X, dtype=np.float32)
279
- y = np.array(all_y, dtype=np.int32)
280
-
281
- pos_rate = y.sum() / len(y) if len(y) > 0 else 0
282
- log.info(
283
- "XGBoost training data: %d samples, %.1f%% positive rate",
284
- len(X), pos_rate * 100,
285
- )
286
-
287
- # Create a fresh predictor to get the model object, then retrain
288
- predictor = HeatWavePredictor.__new__(HeatWavePredictor)
289
- predictor.model_path = HeatWavePredictor.__init__.__defaults__[0] # fallback
290
- from pathlib import Path
291
- predictor.model_path = Path(__file__).resolve().parents[1] / "models" / "heat_predictor_xgb.json"
292
- predictor._rolling_errors = []
293
-
294
- model = xgb.XGBClassifier(
295
- n_estimators=150,
296
- max_depth=5,
297
- learning_rate=0.1,
298
- eval_metric="logloss",
299
- random_state=42,
300
- )
301
-
302
- # Train/validation split (temporal: first 75% train, last 25% val)
303
- split = int(len(X) * 0.75)
304
- X_train, X_val = X[:split], X[split:]
305
- y_train, y_val = y[:split], y[split:]
306
-
307
- model.fit(
308
- X_train, y_train,
309
- eval_set=[(X_val, y_val)],
310
- verbose=False,
311
- )
312
-
313
- # Evaluate on validation set
314
- from sklearn.metrics import roc_auc_score, precision_score, recall_score
315
-
316
- val_probs = model.predict_proba(X_val)[:, 1]
317
- val_preds = (val_probs > 0.5).astype(int)
318
-
319
- if len(set(y_val)) > 1:
320
- auroc = roc_auc_score(y_val, val_probs)
321
- precision = precision_score(y_val, val_preds, zero_division=0)
322
- recall = recall_score(y_val, val_preds, zero_division=0)
323
- else:
324
- auroc, precision, recall = 0.5, 0.0, 0.0
325
-
326
- print(f"\n--- XGBoost Results (real ERA5 data) ---")
327
- print(f" Training samples: {len(X_train)}")
328
- print(f" Validation samples: {len(X_val)}")
329
- print(f" Positive rate: {pos_rate:.1%}")
330
- print(f" Val AUROC: {auroc:.4f}")
331
- print(f" Val Precision: {precision:.4f}")
332
- print(f" Val Recall: {recall:.4f}")
333
-
334
- # Save model
335
- predictor.model_path.parent.mkdir(parents=True, exist_ok=True)
336
- model.save_model(str(predictor.model_path))
337
- log.info("XGBoost model saved to %s", predictor.model_path)
338
-
339
- return {
340
- "train_samples": len(X_train),
341
- "val_samples": len(X_val),
342
- "positive_rate": round(pos_rate, 4),
343
- "val_auroc": round(auroc, 4),
344
- "val_precision": round(precision, 4),
345
- "val_recall": round(recall, 4),
346
- }
347
-
348
-
349
- # ======================================================================
350
- # Step 4: Retrain LSTM on real data
351
- # ======================================================================
352
-
353
- def retrain_lstm(data: dict[str, list[DailyReading]]):
354
- log.info("=" * 60)
355
- log.info("STEP 4: Retraining LSTM on real ERA5 data")
356
- log.info("=" * 60)
357
-
358
- from src.prediction.lstm_model import LSTMTrainer
359
-
360
- # Convert ERA5 DailyReading objects into the format the LSTM trainer expects:
361
- # dict of zone_id -> list of dicts with keys: temp_max_c, humidity_pct, wind_speed_ms, city
362
- zone_readings = {}
363
- for zone in ZONES:
364
- zid = zone.zone_id
365
- readings = data.get(zid, [])
366
- days = []
367
- for r in readings:
368
- if r.temp_max_c is None:
369
- continue
370
- days.append({
371
- "temp_max_c": r.temp_max_c,
372
- "humidity_pct": r.humidity_pct if r.humidity_pct is not None else 65.0,
373
- "wind_speed_ms": r.wind_speed_ms if r.wind_speed_ms is not None else 3.0,
374
- "city": zone.city,
375
- })
376
- if len(days) > 30:
377
- zone_readings[zid] = days
378
- log.info("Zone %s: %d valid readings for LSTM", zid, len(days))
379
- else:
380
- log.warning("Zone %s: only %d valid readings, skipping LSTM", zid, len(days))
381
-
382
- log.info("Training LSTM on %d zones", len(zone_readings))
383
- trainer = LSTMTrainer(epochs=50, patience=5)
384
- metrics = trainer.train(zone_readings)
385
-
386
- print(f"\n--- LSTM Results (real ERA5 data) ---")
387
- for k, v in metrics.items():
388
- print(f" {k}: {v}")
389
-
390
- return metrics
391
-
392
-
393
- # ======================================================================
394
- # Step 5: Verify UHI model with real ERA5 temps
395
- # ======================================================================
396
-
397
- def verify_uhi(data: dict[str, list[DailyReading]]):
398
- log.info("=" * 60)
399
- log.info("STEP 5: Verifying UHI model with real ERA5 temperatures")
400
- log.info("=" * 60)
401
-
402
- from src.downscaling.uhi_model import UHICorrector
403
-
404
- corrector = UHICorrector()
405
-
406
- results = {}
407
- for zone in ZONES:
408
- zid = zone.zone_id
409
- readings = data.get(zid, [])
410
- if not readings:
411
- continue
412
-
413
- # Use real ERA5 temps as grid baseline
414
- real_temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
415
- if not real_temps:
416
- continue
417
-
418
- # Sample a few real temps and apply UHI correction
419
- sample_indices = np.linspace(0, len(real_temps) - 1, min(20, len(real_temps)), dtype=int)
420
- deltas = []
421
- corrected_temps = []
422
-
423
- for idx in sample_indices:
424
- grid_temp = real_temps[idx]
425
- corrected, delta, conf = corrector.correct_temperature(zone, grid_temp, hour=14, month=1)
426
- deltas.append(delta)
427
- corrected_temps.append(corrected)
428
-
429
- results[zid] = {
430
- "city": zone.city,
431
- "settlement": zone.settlement_type,
432
- "mean_grid_temp": round(sum(real_temps) / len(real_temps), 1),
433
- "mean_uhi_delta": round(sum(deltas) / len(deltas), 2),
434
- "mean_corrected": round(sum(corrected_temps) / len(corrected_temps), 1),
435
- }
436
-
437
- print(f"\n--- UHI Verification with Real ERA5 Temps ---")
438
- print(f"{'Zone':<12} {'City':<16} {'Type':<12} {'Grid T':>7} {'UHI +':>7} {'Corrected':>10}")
439
- print("-" * 70)
440
- for zid, r in results.items():
441
- print(
442
- f"{zid:<12} {r['city']:<16} {r['settlement']:<12} "
443
- f"{r['mean_grid_temp']:>6.1f}C {r['mean_uhi_delta']:>+6.2f}C {r['mean_corrected']:>9.1f}C"
444
- )
445
-
446
- return results
447
-
448
-
449
- # ======================================================================
450
- # Main
451
- # ======================================================================
452
-
453
- def main():
454
- t_start = time.time()
455
-
456
- # Step 1: Fetch
457
- data = fetch_data()
458
-
459
- # Step 2: Validate
460
- data_stats = validate_data(data)
461
-
462
- # Step 3: XGBoost
463
- xgb_metrics = retrain_xgboost(data)
464
-
465
- # Step 4: LSTM
466
- lstm_metrics = retrain_lstm(data)
467
-
468
- # Step 5: UHI verification
469
- uhi_results = verify_uhi(data)
470
-
471
- total_time = time.time() - t_start
472
-
473
- print("\n" + "=" * 60)
474
- print("TRAINING COMPLETE")
475
- print("=" * 60)
476
-
477
- total_days = sum(
478
- len([r for r in data.get(z.zone_id, []) if r.temp_max_c is not None])
479
- for z in ZONES
480
- )
481
- print(f" Total real data points: {total_days} zone-days across {len(ZONES)} zones")
482
- print(f" XGBoost val AUROC: {xgb_metrics['val_auroc']:.4f}")
483
- print(f" LSTM val AUROC: {lstm_metrics.get('val_auroc', 'N/A')}")
484
- print(f" LSTM epochs trained: {lstm_metrics.get('epochs_trained', 'N/A')}")
485
- print(f" LSTM final val loss: {lstm_metrics.get('val_loss', 'N/A')}")
486
- print(f" Total time: {total_time:.1f}s")
487
- print()
488
-
489
-
490
- if __name__ == "__main__":
491
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_on_nasa_power.py DELETED
@@ -1,660 +0,0 @@
1
- """
2
- Train all ML models on real NASA POWER daily data.
3
-
4
- Steps:
5
- 1. Fetch 2 years of NASA POWER data for all 20 zones (with caching)
6
- 2. Validate data quality (coverage, temp ranges, nulls)
7
- 3. Retrain LSTM heat predictor on real data
8
- 4. Retrain XGBoost heat predictor on real data
9
- 5. Verify UHI model still works (no retraining — literature-calibrated)
10
-
11
- Usage:
12
- python3 scripts/train_on_nasa_power.py
13
- """
14
-
15
- import sys
16
- import os
17
- import json
18
- import time
19
- import logging
20
- from datetime import date, timedelta
21
- from pathlib import Path
22
-
23
- import numpy as np
24
- import httpx
25
-
26
- # Project root on sys.path
27
- PROJECT_ROOT = Path(__file__).resolve().parents[1]
28
- sys.path.insert(0, str(PROJECT_ROOT))
29
-
30
- from config import ZONES, ZONE_MAP
31
- from src.indexing.heat_index import calculate_wbgt
32
-
33
- logging.basicConfig(
34
- level=logging.INFO,
35
- format="%(asctime)s %(name)s %(levelname)s %(message)s",
36
- datefmt="%H:%M:%S",
37
- )
38
- log = logging.getLogger("train_nasa_power")
39
-
40
- # Paths
41
- CACHE_DIR = PROJECT_ROOT / "data" / "nasa_power_cache"
42
- MODELS_DIR = PROJECT_ROOT / "models"
43
-
44
- # NASA POWER config (from config.py)
45
- NASA_POWER_URL = "https://power.larc.nasa.gov/api/temporal/daily/point"
46
- NASA_POWER_PARAMS = ["T2M", "T2M_MAX", "T2M_MIN", "RH2M", "WS2M", "ALLSKY_SFC_SW_DWN"]
47
- NASA_MISSING = -999.0
48
-
49
- # Date range: 2 years ending ~yesterday (API has 2-day lag)
50
- END_DATE = date(2026, 3, 29)
51
- START_DATE = date(2024, 4, 1)
52
-
53
- # Expected temp ranges per city (max daily temps, deg C)
54
- EXPECTED_RANGES = {
55
- "Nairobi": (18, 35),
56
- "Dar es Salaam": (25, 40),
57
- "Kampala": (22, 36),
58
- "Kigali": (20, 34),
59
- }
60
-
61
-
62
- # ======================================================================
63
- # Step 1: Fetch NASA POWER data (with caching)
64
- # ======================================================================
65
-
66
- def _safe_val(val) -> float | None:
67
- """Return None for NASA's -999 missing sentinel."""
68
- if val is None:
69
- return None
70
- try:
71
- f = float(val)
72
- return None if f == NASA_MISSING else round(f, 2)
73
- except (ValueError, TypeError):
74
- return None
75
-
76
-
77
- def fetch_zone_data(zone, start: date, end: date) -> list[dict]:
78
- """Fetch NASA POWER data for a single zone, using cache if available.
79
-
80
- Returns list of dicts with keys:
81
- date, temp_mean_c, temp_max_c, temp_min_c, humidity_pct,
82
- wind_speed_ms, solar_radiation
83
- """
84
- cache_file = CACHE_DIR / f"{zone.zone_id}.json"
85
-
86
- # Check cache
87
- if cache_file.exists():
88
- with open(cache_file) as f:
89
- cached = json.load(f)
90
- # Verify cache covers our date range
91
- if (cached.get("start_date") == start.isoformat()
92
- and cached.get("end_date") == end.isoformat()
93
- and len(cached.get("readings", [])) > 0):
94
- log.info("Cache hit for zone %s (%d readings)", zone.zone_id, len(cached["readings"]))
95
- return cached["readings"]
96
-
97
- # Fetch from NASA POWER API
98
- params = {
99
- "parameters": ",".join(NASA_POWER_PARAMS),
100
- "community": "AG",
101
- "longitude": zone.longitude,
102
- "latitude": zone.latitude,
103
- "start": start.strftime("%Y%m%d"),
104
- "end": end.strftime("%Y%m%d"),
105
- "format": "JSON",
106
- }
107
-
108
- log.info("Fetching NASA POWER for zone %s (%s, %s) ...", zone.zone_id, zone.city, zone.name)
109
-
110
- max_retries = 3
111
- for attempt in range(max_retries):
112
- try:
113
- resp = httpx.get(NASA_POWER_URL, params=params, timeout=60.0)
114
-
115
- if resp.status_code == 429:
116
- wait = 15 * (attempt + 1)
117
- log.warning(" Rate limited (429), backing off %ds ...", wait)
118
- time.sleep(wait)
119
- continue
120
-
121
- resp.raise_for_status()
122
- data = resp.json()
123
- break
124
- except httpx.TimeoutException:
125
- wait = 10 * (attempt + 1)
126
- log.warning(" Timeout (attempt %d/%d), retrying in %ds ...", attempt + 1, max_retries, wait)
127
- time.sleep(wait)
128
- except httpx.HTTPStatusError as exc:
129
- if exc.response.status_code >= 500:
130
- wait = 10 * (attempt + 1)
131
- log.warning(" Server error %d (attempt %d/%d), retrying ...", exc.response.status_code, attempt + 1, max_retries)
132
- time.sleep(wait)
133
- else:
134
- log.error(" HTTP %d: %s", exc.response.status_code, exc.response.text[:200])
135
- return []
136
- else:
137
- log.error(" Failed after %d attempts for zone %s", max_retries, zone.zone_id)
138
- return []
139
-
140
- # Parse response
141
- try:
142
- props = data["properties"]["parameter"]
143
- except (KeyError, TypeError):
144
- log.error(" Unexpected response structure for zone %s", zone.zone_id)
145
- return []
146
-
147
- t2m_data = props.get("T2M", {})
148
- t2m_max_data = props.get("T2M_MAX", {})
149
- t2m_min_data = props.get("T2M_MIN", {})
150
- rh2m_data = props.get("RH2M", {})
151
- ws2m_data = props.get("WS2M", {})
152
- solar_data = props.get("ALLSKY_SFC_SW_DWN", {})
153
-
154
- all_days = sorted(t2m_data.keys())
155
- readings = []
156
-
157
- for day_str in all_days:
158
- try:
159
- formatted_date = f"{day_str[:4]}-{day_str[4:6]}-{day_str[6:8]}"
160
- except (IndexError, TypeError):
161
- continue
162
-
163
- temp_mean = _safe_val(t2m_data.get(day_str))
164
- temp_max = _safe_val(t2m_max_data.get(day_str))
165
- temp_min = _safe_val(t2m_min_data.get(day_str))
166
- humidity = _safe_val(rh2m_data.get(day_str))
167
- wind = _safe_val(ws2m_data.get(day_str))
168
- solar = _safe_val(solar_data.get(day_str))
169
-
170
- # Skip days where key fields are all missing
171
- if temp_max is None and temp_mean is None:
172
- continue
173
-
174
- readings.append({
175
- "date": formatted_date,
176
- "temp_mean_c": temp_mean,
177
- "temp_max_c": temp_max if temp_max is not None else temp_mean,
178
- "temp_min_c": temp_min,
179
- "humidity_pct": humidity,
180
- "wind_speed_ms": wind,
181
- "solar_radiation": solar,
182
- })
183
-
184
- # Cache the results
185
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
186
- cache_obj = {
187
- "zone_id": zone.zone_id,
188
- "city": zone.city,
189
- "start_date": start.isoformat(),
190
- "end_date": end.isoformat(),
191
- "latitude": zone.latitude,
192
- "longitude": zone.longitude,
193
- "readings": readings,
194
- }
195
- with open(cache_file, "w") as f:
196
- json.dump(cache_obj, f, indent=2)
197
-
198
- log.info(" Zone %s: %d days fetched and cached", zone.zone_id, len(readings))
199
- return readings
200
-
201
-
202
- def fetch_all_zones():
203
- """Fetch NASA POWER data for all zones with rate-limiting delay."""
204
- log.info("=" * 60)
205
- log.info("STEP 1: Fetching NASA POWER data for %d zones", len(ZONES))
206
- log.info(" Date range: %s to %s (%d days)", START_DATE, END_DATE, (END_DATE - START_DATE).days + 1)
207
- log.info("=" * 60)
208
-
209
- all_data: dict[str, list[dict]] = {}
210
- t0 = time.time()
211
-
212
- for i, zone in enumerate(ZONES):
213
- readings = fetch_zone_data(zone, START_DATE, END_DATE)
214
- all_data[zone.zone_id] = readings
215
-
216
- # Rate limiting delay between API calls (skip for cached results)
217
- if i < len(ZONES) - 1:
218
- time.sleep(0.5)
219
-
220
- elapsed = time.time() - t0
221
- total_readings = sum(len(v) for v in all_data.values())
222
- zones_with_data = sum(1 for v in all_data.values() if v)
223
-
224
- log.info(
225
- "Fetch complete in %.1fs: %d/%d zones with data, %d total readings",
226
- elapsed, zones_with_data, len(ZONES), total_readings,
227
- )
228
-
229
- return all_data
230
-
231
-
232
- # ======================================================================
233
- # Step 2: Validate data quality
234
- # ======================================================================
235
-
236
- def validate_data(data: dict[str, list[dict]]):
237
- log.info("=" * 60)
238
- log.info("STEP 2: Validating NASA POWER data quality")
239
- log.info("=" * 60)
240
-
241
- issues = []
242
- stats = {}
243
-
244
- for zone in ZONES:
245
- zid = zone.zone_id
246
- readings = data.get(zid, [])
247
-
248
- if not readings:
249
- issues.append(f"{zid}: NO DATA")
250
- stats[zid] = {"days": 0, "issue": "no data"}
251
- continue
252
-
253
- temps = [r["temp_max_c"] for r in readings if r["temp_max_c"] is not None]
254
- humids = [r["humidity_pct"] for r in readings if r["humidity_pct"] is not None]
255
- winds = [r["wind_speed_ms"] for r in readings if r["wind_speed_ms"] is not None]
256
-
257
- if not temps:
258
- issues.append(f"{zid}: all temps are null")
259
- stats[zid] = {"days": len(readings), "issue": "all null temps"}
260
- continue
261
-
262
- t_min, t_max = min(temps), max(temps)
263
- t_mean = sum(temps) / len(temps)
264
-
265
- # Check physical reasonableness
266
- exp_lo, exp_hi = EXPECTED_RANGES.get(zone.city, (15, 42))
267
- if t_min < exp_lo - 5 or t_max > exp_hi + 5:
268
- issues.append(
269
- f"{zid} ({zone.city}): temp range [{t_min:.1f}, {t_max:.1f}] "
270
- f"outside expected [{exp_lo-5}, {exp_hi+5}]"
271
- )
272
-
273
- null_count = sum(1 for r in readings if r["temp_max_c"] is None)
274
-
275
- stats[zid] = {
276
- "days": len(readings),
277
- "temp_days": len(temps),
278
- "temp_min": round(t_min, 1),
279
- "temp_max": round(t_max, 1),
280
- "temp_mean": round(t_mean, 1),
281
- "humidity_mean": round(sum(humids)/len(humids), 1) if humids else None,
282
- "wind_mean": round(sum(winds)/len(winds), 1) if winds else None,
283
- "null_temps": null_count,
284
- }
285
-
286
- # Print summary
287
- print("\n--- NASA POWER Data Summary ---")
288
- print(f"{'Zone':<12} {'City':<16} {'Days':>5} {'Temp min':>9} {'Temp max':>9} {'Temp mean':>10} {'Humidity':>9} {'Nulls':>6}")
289
- print("-" * 90)
290
-
291
- for zone in ZONES:
292
- s = stats.get(zone.zone_id, {})
293
- days = s.get("days", 0)
294
- t_lo = s.get("temp_min", "N/A")
295
- t_hi = s.get("temp_max", "N/A")
296
- t_mn = s.get("temp_mean", "N/A")
297
- hum = s.get("humidity_mean", "N/A")
298
- nulls = s.get("null_temps", "N/A")
299
- if isinstance(t_lo, float):
300
- print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {t_lo:>9.1f} {t_hi:>9.1f} {t_mn:>10.1f} {hum:>9} {nulls:>6}")
301
- else:
302
- print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {'N/A':>9} {'N/A':>9} {'N/A':>10} {'N/A':>9} {'N/A':>6}")
303
-
304
- if issues:
305
- print(f"\n ISSUES ({len(issues)}):")
306
- for issue in issues:
307
- print(f" - {issue}")
308
- else:
309
- print("\n No data quality issues found.")
310
-
311
- zones_with_data = sum(1 for s in stats.values() if s.get("days", 0) > 0)
312
- print(f"\n {zones_with_data}/{len(ZONES)} zones have data.\n")
313
-
314
- return stats
315
-
316
-
317
- # ======================================================================
318
- # Step 3: Retrain LSTM on real NASA POWER data
319
- # ======================================================================
320
-
321
- def retrain_lstm(data: dict[str, list[dict]]):
322
- log.info("=" * 60)
323
- log.info("STEP 3: Retraining LSTM on real NASA POWER data")
324
- log.info("=" * 60)
325
-
326
- from src.prediction.lstm_model import LSTMTrainer
327
-
328
- # Delete old model files
329
- lstm_model_path = MODELS_DIR / "heat_lstm.pt"
330
- lstm_norm_path = MODELS_DIR / "lstm_norm.json"
331
- for p in [lstm_model_path, lstm_norm_path]:
332
- if p.exists():
333
- p.unlink()
334
- log.info("Deleted old model: %s", p)
335
-
336
- # Convert NASA POWER data to the format the LSTM trainer expects:
337
- # dict of zone_id -> list of dicts with keys: temp_max_c, humidity_pct, wind_speed_ms, city
338
- zone_readings = {}
339
- for zone in ZONES:
340
- zid = zone.zone_id
341
- readings = data.get(zid, [])
342
- days = []
343
- for r in readings:
344
- if r["temp_max_c"] is None:
345
- continue
346
- days.append({
347
- "temp_max_c": r["temp_max_c"],
348
- "humidity_pct": r["humidity_pct"] if r["humidity_pct"] is not None else 65.0,
349
- "wind_speed_ms": r["wind_speed_ms"] if r["wind_speed_ms"] is not None else 3.0,
350
- "city": zone.city,
351
- })
352
- if len(days) >= 22: # Need at least seq_len(14) + forecast_horizon(7) + 1
353
- zone_readings[zid] = days
354
- log.info("Zone %s: %d valid readings for LSTM", zid, len(days))
355
- else:
356
- log.warning("Zone %s: only %d valid readings, skipping LSTM", zid, len(days))
357
-
358
- log.info("Training LSTM on %d zones with real NASA POWER data", len(zone_readings))
359
- trainer = LSTMTrainer(epochs=50, patience=5)
360
- metrics = trainer.train(zone_readings)
361
-
362
- print(f"\n--- LSTM Results (real NASA POWER data) ---")
363
- for k, v in metrics.items():
364
- print(f" {k}: {v}")
365
-
366
- return metrics
367
-
368
-
369
- # ======================================================================
370
- # Step 4: Retrain XGBoost heat predictor on real data
371
- # ======================================================================
372
-
373
- def retrain_xgboost(data: dict[str, list[dict]]):
374
- log.info("=" * 60)
375
- log.info("STEP 4: Retraining XGBoost heat predictor on real NASA POWER data")
376
- log.info("=" * 60)
377
-
378
- import xgboost as xgb
379
- from src.prediction.lstm_model import CITY_THRESHOLDS
380
-
381
- # Delete old model file
382
- xgb_model_path = MODELS_DIR / "heat_predictor_xgb.json"
383
- if xgb_model_path.exists():
384
- xgb_model_path.unlink()
385
- log.info("Deleted old XGBoost model: %s", xgb_model_path)
386
-
387
- all_X = []
388
- all_y = []
389
-
390
- for zone in ZONES:
391
- zid = zone.zone_id
392
- readings = data.get(zid, [])
393
- if len(readings) < 40:
394
- log.warning("Zone %s has only %d readings, skipping for XGBoost training", zid, len(readings))
395
- continue
396
-
397
- city = zone.city
398
- threshold = CITY_THRESHOLDS.get(city, 33.0)
399
-
400
- # Extract time series, filtering nulls
401
- temps = []
402
- humidity = []
403
- for r in readings:
404
- t = r["temp_max_c"]
405
- h = r["humidity_pct"]
406
- if t is None:
407
- continue
408
- temps.append(t)
409
- humidity.append(h if h is not None else 65.0)
410
-
411
- n_days = len(temps)
412
- if n_days < 40:
413
- log.warning("Zone %s has only %d valid temp readings, skipping", zid, n_days)
414
- continue
415
-
416
- # Compute WBGT series
417
- wbgt_series = [calculate_wbgt(t, h) for t, h in zip(temps, humidity)]
418
-
419
- # Labels: trigger within next 7 days (2+ consecutive above threshold)
420
- labels = [0] * n_days
421
- for day in range(n_days - 7):
422
- window = temps[day + 1:day + 8]
423
- consec = 0
424
- triggered = False
425
- for t in window:
426
- if t > threshold:
427
- consec += 1
428
- if consec >= 2:
429
- triggered = True
430
- break
431
- else:
432
- consec = 0
433
- labels[day] = 1 if triggered else 0
434
-
435
- # Vulnerability encoding
436
- vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
437
- zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
438
-
439
- rng = np.random.default_rng(42)
440
-
441
- # Build features (need 30-day lookback)
442
- for day in range(30, n_days - 7):
443
- t_window = temps[day - 30:day + 1]
444
- h_window = humidity[day - 30:day + 1]
445
- w_window = wbgt_series[day - 30:day + 1]
446
-
447
- current_temp = t_window[-1]
448
- current_wbgt = w_window[-1]
449
- current_humidity = h_window[-1]
450
-
451
- # Trend: slope of last 7 days
452
- x7 = np.arange(7, dtype=np.float64)
453
- y7 = np.array(t_window[-7:], dtype=np.float64)
454
- temp_trend = float(np.polyfit(x7, y7, 1)[0])
455
-
456
- # Anomaly: current vs 30-day mean
457
- temp_anomaly = current_temp - float(np.mean(t_window))
458
-
459
- # Soil moisture proxy
460
- soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
461
-
462
- # Rolling error (neutral prior for training data)
463
- rolling_err = rng.uniform(0.1, 0.5)
464
-
465
- # Day-of-year encoding
466
- doy = day % 365
467
- doy_sin = np.sin(2 * np.pi * doy / 365.0)
468
- doy_cos = np.cos(2 * np.pi * doy / 365.0)
469
-
470
- # Random hour for variety
471
- hour = rng.integers(6, 19)
472
- hour_sin = np.sin(2 * np.pi * hour / 24.0)
473
- hour_cos = np.cos(2 * np.pi * hour / 24.0)
474
-
475
- row = [
476
- current_temp,
477
- current_wbgt,
478
- current_humidity,
479
- temp_trend,
480
- temp_anomaly,
481
- soil_proxy,
482
- rolling_err,
483
- doy_sin,
484
- doy_cos,
485
- hour_sin,
486
- hour_cos,
487
- zone_vuln,
488
- ]
489
-
490
- all_X.append(row)
491
- all_y.append(labels[day])
492
-
493
- X = np.array(all_X, dtype=np.float32)
494
- y = np.array(all_y, dtype=np.int32)
495
-
496
- pos_rate = y.sum() / len(y) if len(y) > 0 else 0
497
- log.info(
498
- "XGBoost training data: %d samples, %.1f%% positive rate",
499
- len(X), pos_rate * 100,
500
- )
501
-
502
- model = xgb.XGBClassifier(
503
- n_estimators=150,
504
- max_depth=5,
505
- learning_rate=0.1,
506
- eval_metric="logloss",
507
- random_state=42,
508
- )
509
-
510
- # Train/validation split (temporal: first 75% train, last 25% val)
511
- split = int(len(X) * 0.75)
512
- X_train, X_val = X[:split], X[split:]
513
- y_train, y_val = y[:split], y[split:]
514
-
515
- model.fit(
516
- X_train, y_train,
517
- eval_set=[(X_val, y_val)],
518
- verbose=False,
519
- )
520
-
521
- # Evaluate on validation set
522
- from sklearn.metrics import roc_auc_score, precision_score, recall_score
523
-
524
- val_probs = model.predict_proba(X_val)[:, 1]
525
- val_preds = (val_probs > 0.5).astype(int)
526
-
527
- if len(set(y_val)) > 1:
528
- auroc = roc_auc_score(y_val, val_probs)
529
- precision = precision_score(y_val, val_preds, zero_division=0)
530
- recall = recall_score(y_val, val_preds, zero_division=0)
531
- else:
532
- auroc, precision, recall = 0.5, 0.0, 0.0
533
-
534
- print(f"\n--- XGBoost Results (real NASA POWER data) ---")
535
- print(f" Training samples: {len(X_train)}")
536
- print(f" Validation samples: {len(X_val)}")
537
- print(f" Positive rate: {pos_rate:.1%}")
538
- print(f" Val AUROC: {auroc:.4f}")
539
- print(f" Val Precision: {precision:.4f}")
540
- print(f" Val Recall: {recall:.4f}")
541
-
542
- # Save model
543
- MODELS_DIR.mkdir(parents=True, exist_ok=True)
544
- model.save_model(str(xgb_model_path))
545
- log.info("XGBoost model saved to %s", xgb_model_path)
546
-
547
- return {
548
- "train_samples": len(X_train),
549
- "val_samples": len(X_val),
550
- "positive_rate": round(pos_rate, 4),
551
- "val_auroc": round(auroc, 4),
552
- "val_precision": round(precision, 4),
553
- "val_recall": round(recall, 4),
554
- }
555
-
556
-
557
- # ======================================================================
558
- # Step 5: Verify UHI model (no retraining)
559
- # ======================================================================
560
-
561
- def verify_uhi(data: dict[str, list[dict]]):
562
- log.info("=" * 60)
563
- log.info("STEP 5: Verifying UHI model with real NASA POWER temperatures")
564
- log.info(" (UHI model keeps literature-calibrated synthetic training)")
565
- log.info("=" * 60)
566
-
567
- from src.downscaling.uhi_model import UHICorrector
568
-
569
- corrector = UHICorrector()
570
-
571
- results = {}
572
- for zone in ZONES:
573
- zid = zone.zone_id
574
- readings = data.get(zid, [])
575
- if not readings:
576
- continue
577
-
578
- real_temps = [r["temp_max_c"] for r in readings if r["temp_max_c"] is not None]
579
- if not real_temps:
580
- continue
581
-
582
- # Sample some real temps and apply UHI correction
583
- sample_indices = np.linspace(0, len(real_temps) - 1, min(20, len(real_temps)), dtype=int)
584
- deltas = []
585
- corrected_temps = []
586
-
587
- for idx in sample_indices:
588
- grid_temp = real_temps[idx]
589
- corrected, delta, conf = corrector.correct_temperature(zone, grid_temp, hour=14, month=1)
590
- deltas.append(delta)
591
- corrected_temps.append(corrected)
592
-
593
- results[zid] = {
594
- "city": zone.city,
595
- "settlement": zone.settlement_type,
596
- "mean_grid_temp": round(sum(real_temps) / len(real_temps), 1),
597
- "mean_uhi_delta": round(sum(deltas) / len(deltas), 2),
598
- "mean_corrected": round(sum(corrected_temps) / len(corrected_temps), 1),
599
- }
600
-
601
- print(f"\n--- UHI Verification (literature-calibrated model + real NASA POWER temps) ---")
602
- print(f"{'Zone':<12} {'City':<16} {'Type':<12} {'Grid T':>7} {'UHI +':>7} {'Corrected':>10}")
603
- print("-" * 70)
604
- for zid, r in results.items():
605
- print(
606
- f"{zid:<12} {r['city']:<16} {r['settlement']:<12} "
607
- f"{r['mean_grid_temp']:>6.1f}C {r['mean_uhi_delta']:>+6.2f}C {r['mean_corrected']:>9.1f}C"
608
- )
609
-
610
- return results
611
-
612
-
613
- # ======================================================================
614
- # Main
615
- # ======================================================================
616
-
617
- def main():
618
- t_start = time.time()
619
-
620
- # Step 1: Fetch
621
- data = fetch_all_zones()
622
-
623
- # Step 2: Validate
624
- data_stats = validate_data(data)
625
-
626
- # Step 3: LSTM
627
- lstm_metrics = retrain_lstm(data)
628
-
629
- # Step 4: XGBoost
630
- xgb_metrics = retrain_xgboost(data)
631
-
632
- # Step 5: UHI verification
633
- uhi_results = verify_uhi(data)
634
-
635
- total_time = time.time() - t_start
636
-
637
- print("\n" + "=" * 60)
638
- print("TRAINING COMPLETE (NASA POWER real data)")
639
- print("=" * 60)
640
-
641
- total_days = sum(
642
- len([r for r in data.get(z.zone_id, []) if r["temp_max_c"] is not None])
643
- for z in ZONES
644
- )
645
- print(f" Data source: NASA POWER daily")
646
- print(f" Date range: {START_DATE} to {END_DATE}")
647
- print(f" Total real data points: {total_days} zone-days across {len(ZONES)} zones")
648
- print(f" Avg days per zone: {total_days / len(ZONES):.0f}")
649
- print(f" LSTM val AUROC: {lstm_metrics.get('val_auroc', 'N/A')}")
650
- print(f" LSTM epochs trained: {lstm_metrics.get('epochs_trained', 'N/A')}")
651
- print(f" LSTM val loss: {lstm_metrics.get('val_loss', 'N/A')}")
652
- print(f" XGBoost val AUROC: {xgb_metrics['val_auroc']:.4f}")
653
- print(f" XGBoost val Precision: {xgb_metrics['val_precision']:.4f}")
654
- print(f" XGBoost val Recall: {xgb_metrics['val_recall']:.4f}")
655
- print(f" Total time: {total_time:.1f}s")
656
- print()
657
-
658
-
659
- if __name__ == "__main__":
660
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/notification/__init__.py DELETED
File without changes
src/notification/sender.py DELETED
@@ -1,318 +0,0 @@
1
- """
2
- Notification delivery module.
3
-
4
- Sends trigger explanations to policyholders via console (demo),
5
- SMS (Twilio), or WhatsApp (Twilio). All senders implement a common
6
- async interface and return a DeliveryResult.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import asyncio
12
- import logging
13
- import os
14
- import time
15
- from abc import ABC, abstractmethod
16
- from dataclasses import dataclass, field
17
- from datetime import datetime, timezone
18
- from typing import Optional, Sequence
19
-
20
- log = logging.getLogger(__name__)
21
-
22
-
23
- # ── Data containers ──────────────────────────────────────────────────────
24
-
25
- @dataclass
26
- class DeliveryResult:
27
- """Outcome of a single notification delivery attempt."""
28
-
29
- status: str # "sent", "failed", "dry_run"
30
- channel: str # "console", "sms", "whatsapp"
31
- recipient: str
32
- message_preview: str # first 120 chars
33
- timestamp: str = field(
34
- default_factory=lambda: datetime.now(timezone.utc).isoformat()
35
- )
36
- cost_estimate: float = 0.0 # estimated cost in USD
37
- error: str = ""
38
- message_sid: str = "" # Twilio message SID if applicable
39
-
40
-
41
- # ── Base sender ──────────────────────────────────────────────────────────
42
-
43
- class BaseSender(ABC):
44
- """Common interface for all notification channels."""
45
-
46
- @abstractmethod
47
- async def send(
48
- self, recipient: str, message: str, channel: str = ""
49
- ) -> DeliveryResult:
50
- """Send a message to a single recipient."""
51
- ...
52
-
53
- async def send_batch(
54
- self,
55
- recipients: Sequence[str],
56
- message: str,
57
- channel: str = "",
58
- rate_limit: float = 1.0,
59
- ) -> list[DeliveryResult]:
60
- """
61
- Send the same message to multiple recipients with rate limiting.
62
-
63
- Args:
64
- recipients: Phone numbers or identifiers.
65
- message: The notification text.
66
- channel: Override channel name.
67
- rate_limit: Minimum seconds between sends (Twilio default: 1/sec).
68
- """
69
- results: list[DeliveryResult] = []
70
- for i, recipient in enumerate(recipients):
71
- result = await self.send(recipient, message, channel)
72
- results.append(result)
73
- # Rate limiting between sends (skip after last)
74
- if i < len(recipients) - 1 and rate_limit > 0:
75
- await asyncio.sleep(rate_limit)
76
- return results
77
-
78
-
79
- # ── Console sender (demo mode) ───────────────────────────────────────────
80
-
81
- class ConsoleSender(BaseSender):
82
- """Prints notifications to stdout. Default for demo and testing."""
83
-
84
- async def send(
85
- self, recipient: str, message: str, channel: str = "console"
86
- ) -> DeliveryResult:
87
- preview = message[:120] + ("..." if len(message) > 120 else "")
88
- log.info("[CONSOLE] To: %s | %s", recipient, preview)
89
- print(f"\n{'='*60}")
90
- print(f" NOTIFICATION — {channel or 'console'}")
91
- print(f" To: {recipient}")
92
- print(f"{'='*60}")
93
- print(f" {message}")
94
- print(f"{'='*60}\n")
95
- return DeliveryResult(
96
- status="dry_run",
97
- channel="console",
98
- recipient=recipient,
99
- message_preview=preview,
100
- cost_estimate=0.0,
101
- )
102
-
103
-
104
- # ── Twilio SMS sender ───────────────────────────────────────────────────
105
-
106
- class TwilioSender(BaseSender):
107
- """
108
- Sends SMS via Twilio.
109
-
110
- Requires environment variables:
111
- TWILIO_ACCOUNT_SID
112
- TWILIO_AUTH_TOKEN
113
- TWILIO_FROM_NUMBER (E.164 format, e.g. +15551234567)
114
- """
115
-
116
- def __init__(
117
- self,
118
- account_sid: Optional[str] = None,
119
- auth_token: Optional[str] = None,
120
- from_number: Optional[str] = None,
121
- ):
122
- self.account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID", "")
123
- self.auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN", "")
124
- self.from_number = from_number or os.environ.get("TWILIO_FROM_NUMBER", "")
125
- self._client = None
126
-
127
- def _get_client(self):
128
- """Lazy-init Twilio client."""
129
- if self._client is None:
130
- if not all([self.account_sid, self.auth_token]):
131
- raise RuntimeError(
132
- "Twilio credentials not configured. Set TWILIO_ACCOUNT_SID "
133
- "and TWILIO_AUTH_TOKEN environment variables."
134
- )
135
- from twilio.rest import Client
136
- self._client = Client(self.account_sid, self.auth_token)
137
- return self._client
138
-
139
- async def send(
140
- self, recipient: str, message: str, channel: str = "sms"
141
- ) -> DeliveryResult:
142
- preview = message[:120] + ("..." if len(message) > 120 else "")
143
-
144
- # Truncate SMS to 1600 chars (Twilio limit for long SMS)
145
- sms_body = message[:1600]
146
-
147
- try:
148
- client = self._get_client()
149
- # Twilio client is synchronous — run in executor
150
- loop = asyncio.get_event_loop()
151
- twilio_msg = await loop.run_in_executor(
152
- None,
153
- lambda: client.messages.create(
154
- body=sms_body,
155
- from_=self.from_number,
156
- to=recipient,
157
- ),
158
- )
159
- log.info(
160
- "[SMS] Sent to %s | SID: %s | Status: %s",
161
- recipient, twilio_msg.sid, twilio_msg.status,
162
- )
163
- return DeliveryResult(
164
- status="sent",
165
- channel="sms",
166
- recipient=recipient,
167
- message_preview=preview,
168
- cost_estimate=_estimate_sms_cost(message),
169
- message_sid=twilio_msg.sid,
170
- )
171
- except Exception as exc:
172
- log.error("[SMS] Failed to send to %s: %s", recipient, exc)
173
- return DeliveryResult(
174
- status="failed",
175
- channel="sms",
176
- recipient=recipient,
177
- message_preview=preview,
178
- error=str(exc),
179
- )
180
-
181
-
182
- # ── WhatsApp sender ──────────────────────────────────────────────────────
183
-
184
- class WhatsAppSender(BaseSender):
185
- """
186
- Sends WhatsApp messages via Twilio WhatsApp Business API.
187
-
188
- Uses the same Twilio credentials as SMS but prefixes the from number
189
- with 'whatsapp:'.
190
- """
191
-
192
- def __init__(
193
- self,
194
- account_sid: Optional[str] = None,
195
- auth_token: Optional[str] = None,
196
- from_number: Optional[str] = None,
197
- ):
198
- self.account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID", "")
199
- self.auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN", "")
200
- self.from_number = from_number or os.environ.get("TWILIO_FROM_NUMBER", "")
201
- self._client = None
202
-
203
- def _get_client(self):
204
- """Lazy-init Twilio client."""
205
- if self._client is None:
206
- if not all([self.account_sid, self.auth_token]):
207
- raise RuntimeError(
208
- "Twilio credentials not configured. Set TWILIO_ACCOUNT_SID "
209
- "and TWILIO_AUTH_TOKEN environment variables."
210
- )
211
- from twilio.rest import Client
212
- self._client = Client(self.account_sid, self.auth_token)
213
- return self._client
214
-
215
- async def send(
216
- self, recipient: str, message: str, channel: str = "whatsapp"
217
- ) -> DeliveryResult:
218
- preview = message[:120] + ("..." if len(message) > 120 else "")
219
-
220
- # WhatsApp supports longer messages (up to 4096 chars)
221
- wa_body = message[:4096]
222
-
223
- # Ensure whatsapp: prefix on both numbers
224
- wa_from = (
225
- f"whatsapp:{self.from_number}"
226
- if not self.from_number.startswith("whatsapp:")
227
- else self.from_number
228
- )
229
- wa_to = (
230
- f"whatsapp:{recipient}"
231
- if not recipient.startswith("whatsapp:")
232
- else recipient
233
- )
234
-
235
- try:
236
- client = self._get_client()
237
- loop = asyncio.get_event_loop()
238
- twilio_msg = await loop.run_in_executor(
239
- None,
240
- lambda: client.messages.create(
241
- body=wa_body,
242
- from_=wa_from,
243
- to=wa_to,
244
- ),
245
- )
246
- log.info(
247
- "[WhatsApp] Sent to %s | SID: %s | Status: %s",
248
- recipient, twilio_msg.sid, twilio_msg.status,
249
- )
250
- return DeliveryResult(
251
- status="sent",
252
- channel="whatsapp",
253
- recipient=recipient,
254
- message_preview=preview,
255
- cost_estimate=_estimate_whatsapp_cost(),
256
- message_sid=twilio_msg.sid,
257
- )
258
- except Exception as exc:
259
- log.error("[WhatsApp] Failed to send to %s: %s", recipient, exc)
260
- return DeliveryResult(
261
- status="failed",
262
- channel="whatsapp",
263
- recipient=recipient,
264
- message_preview=preview,
265
- error=str(exc),
266
- )
267
-
268
-
269
- # ── Cost estimation ──────────────────────────────────────────────────────
270
-
271
- def _estimate_sms_cost(message: str) -> float:
272
- """Estimate SMS cost in USD. Twilio Kenya rate ~ $0.0475/segment."""
273
- # SMS segments: 160 chars for GSM-7, 70 for UCS-2 (Unicode)
274
- has_unicode = any(ord(c) > 127 for c in message)
275
- segment_size = 70 if has_unicode else 160
276
- segments = max(1, (len(message) + segment_size - 1) // segment_size)
277
- return round(segments * 0.0475, 4)
278
-
279
-
280
- def _estimate_whatsapp_cost() -> float:
281
- """Estimate WhatsApp cost in USD. Twilio WhatsApp ~ $0.005/msg + template fees."""
282
- return 0.005
283
-
284
-
285
- # ── Sender factory ───────────────────────────────────────────────────────
286
-
287
- def create_sender(channel: str = "console") -> BaseSender:
288
- """
289
- Factory to create the appropriate sender.
290
-
291
- Args:
292
- channel: One of "console", "sms", "whatsapp".
293
- """
294
- if channel == "sms":
295
- return TwilioSender()
296
- elif channel == "whatsapp":
297
- return WhatsAppSender()
298
- else:
299
- return ConsoleSender()
300
-
301
-
302
- async def send_zone_notifications(
303
- recipients: Sequence[str],
304
- message: str,
305
- channel: str = "console",
306
- rate_limit: float = 1.0,
307
- ) -> list[DeliveryResult]:
308
- """
309
- Convenience function: send the same notification to all recipients in a zone.
310
-
311
- Args:
312
- recipients: List of phone numbers.
313
- message: Notification text.
314
- channel: "console", "sms", or "whatsapp".
315
- rate_limit: Seconds between sends for Twilio rate limiting.
316
- """
317
- sender = create_sender(channel)
318
- return await sender.send_batch(recipients, message, channel, rate_limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prediction/heat_forecast.py DELETED
@@ -1,557 +0,0 @@
1
- """
2
- Heat wave prediction model for parametric insurance triggers.
3
-
4
- XGBoost classifier that predicts the probability of a heat trigger
5
- event (2+ consecutive days above city-adjusted threshold) occurring
6
- within the next 7 days, given recent climate features.
7
-
8
- Degrades gracefully:
9
- full_model -> persistence -> climatology
10
-
11
- References:
12
- - Perkins-Kirkpatrick & Lewis (2020) heat wave definitions
13
- - WHO/ILO occupational heat stress thresholds
14
- """
15
-
16
- from __future__ import annotations
17
-
18
- from collections import deque
19
- from pathlib import Path
20
-
21
- import numpy as np
22
-
23
- try:
24
- import xgboost as xgb
25
- except ImportError:
26
- xgb = None
27
-
28
- # Import shared constants from lstm_model (defined there to avoid circular imports)
29
- from src.prediction.lstm_model import CITY_THRESHOLDS, CITY_CLIMATE
30
-
31
- try:
32
- from src.prediction.lstm_model import LSTMPredictor
33
- _LSTM_AVAILABLE = True
34
- except Exception:
35
- _LSTM_AVAILABLE = False
36
-
37
- FEATURE_NAMES = [
38
- "current_temp",
39
- "current_wbgt",
40
- "current_humidity",
41
- "temp_trend_7d",
42
- "temp_anomaly_30d",
43
- "soil_moisture_proxy",
44
- "rolling_error",
45
- "doy_sin",
46
- "doy_cos",
47
- "hour_sin",
48
- "hour_cos",
49
- "zone_vulnerability",
50
- ]
51
-
52
-
53
- def _resolve_model_path(model_path: str) -> Path:
54
- p = Path(model_path)
55
- if not p.is_absolute():
56
- p = Path(__file__).resolve().parents[2] / model_path
57
- return p
58
-
59
-
60
- from src.indexing.heat_index import calculate_wbgt as _simple_wbgt
61
-
62
-
63
- class HeatWavePredictor:
64
- """XGBoost model: recent climate features -> trigger probability in 7 days."""
65
-
66
- FEATURE_NAMES = FEATURE_NAMES
67
-
68
- def __init__(self, model_path: str = "models/heat_predictor_xgb.json"):
69
- if xgb is None:
70
- raise ImportError(
71
- "xgboost is required. Install with: pip install 'xgboost>=2.0.0'"
72
- )
73
- self.model_path = _resolve_model_path(model_path)
74
- self.model: xgb.XGBClassifier | None = None
75
- self._rolling_errors: deque = deque(maxlen=3)
76
- self._load_or_train()
77
-
78
- # Try loading LSTM for ensemble; auto-train on synthetic data if missing
79
- self._lstm: object | None = None
80
- if _LSTM_AVAILABLE:
81
- try:
82
- self._lstm = LSTMPredictor()
83
- except FileNotFoundError:
84
- self._train_lstm_synthetic()
85
- except Exception:
86
- self._lstm = None
87
-
88
- # ------------------------------------------------------------------
89
- # Public API
90
- # ------------------------------------------------------------------
91
-
92
- def predict(
93
- self,
94
- zone,
95
- recent_temps: list[float],
96
- recent_humidity: list[float],
97
- recent_wbgt: list[float],
98
- hour: int = 12,
99
- ) -> tuple[float, float, str]:
100
- """Predict probability of heat trigger within 7 days.
101
-
102
- Args:
103
- zone: UrbanZone instance.
104
- recent_temps: Last 30 daily max temperatures (most recent last).
105
- recent_humidity: Last 30 daily humidity values.
106
- recent_wbgt: Last 30 daily WBGT values.
107
- hour: Current hour for diurnal encoding.
108
-
109
- Returns:
110
- (probability, confidence, model_tier)
111
- model_tier is one of: "ensemble", "full_model", "lstm_only",
112
- "persistence", "climatology"
113
- """
114
- xgb_prob, xgb_conf, xgb_ok = None, None, False
115
- lstm_prob, lstm_conf, lstm_ok = None, None, False
116
-
117
- # -- XGBoost prediction --
118
- try:
119
- features = self._build_features(
120
- zone, recent_temps, recent_humidity, recent_wbgt, hour
121
- )
122
- xgb_prob = float(self.model.predict_proba(features)[0, 1])
123
- xgb_conf = self._estimate_confidence(recent_temps, "full_model")
124
- xgb_ok = True
125
- except Exception:
126
- pass
127
-
128
- # -- LSTM prediction --
129
- if self._lstm is not None:
130
- try:
131
- lstm_days = self._build_lstm_days(
132
- recent_temps, recent_humidity, recent_wbgt
133
- )
134
- lstm_prob, lstm_conf = self._lstm.predict(lstm_days)
135
- lstm_ok = True
136
- except Exception:
137
- pass
138
-
139
- # -- Ensemble --
140
- if xgb_ok and lstm_ok:
141
- prob = 0.5 * xgb_prob + 0.5 * lstm_prob
142
- confidence = (xgb_conf + lstm_conf) / 2.0
143
- return round(prob, 4), round(confidence, 3), "ensemble"
144
-
145
- if xgb_ok:
146
- return round(xgb_prob, 4), round(xgb_conf, 3), "full_model"
147
-
148
- if lstm_ok:
149
- return round(lstm_prob, 4), round(lstm_conf, 3), "lstm_only"
150
-
151
- # Persistence fallback: if recent conditions are above threshold,
152
- # assume they continue
153
- try:
154
- threshold = CITY_THRESHOLDS.get(zone.city, 33.0)
155
- if len(recent_temps) >= 2:
156
- above = sum(1 for t in recent_temps[-3:] if t > threshold)
157
- prob = min(0.95, above / 3.0)
158
- else:
159
- prob = 0.5
160
- confidence = self._estimate_confidence(recent_temps, "persistence")
161
- return round(prob, 4), round(confidence, 3), "persistence"
162
- except Exception:
163
- pass
164
-
165
- # Climatology fallback: use seasonal base rate
166
- from config import HOT_SEASONS
167
-
168
- import datetime
169
-
170
- doy = datetime.datetime.now().timetuple().tm_yday
171
- month = datetime.datetime.now().month
172
- city = getattr(zone, "city", "Nairobi")
173
- hot_months = []
174
- for season_months in HOT_SEASONS.get(city, {}).values():
175
- hot_months.extend(season_months)
176
- prob = 0.35 if month in hot_months else 0.10
177
- confidence = 0.30
178
- return round(prob, 4), round(confidence, 3), "climatology"
179
-
180
- @staticmethod
181
- def _build_lstm_days(
182
- recent_temps: list[float],
183
- recent_humidity: list[float],
184
- recent_wbgt: list[float],
185
- ) -> list[dict]:
186
- """Convert raw arrays into the list-of-dicts format the LSTM expects.
187
-
188
- The LSTM predictor computes WBGT, heat index, and temp anomaly
189
- internally, so we only need to pass the raw observations.
190
- """
191
- n = min(len(recent_temps), len(recent_humidity), len(recent_wbgt))
192
- days = []
193
- for i in range(n):
194
- days.append({
195
- "temp_max_c": recent_temps[i],
196
- "humidity_pct": recent_humidity[i],
197
- "wind_speed_ms": 3.0,
198
- })
199
- return days
200
-
201
- def update_rolling_error(self, predicted_prob: float, actual: bool) -> None:
202
- """Track prediction accuracy for the rolling_error feature."""
203
- error = abs(predicted_prob - (1.0 if actual else 0.0))
204
- self._rolling_errors.append(error)
205
-
206
- # ------------------------------------------------------------------
207
- # Feature engineering
208
- # ------------------------------------------------------------------
209
-
210
- def _build_features(
211
- self,
212
- zone,
213
- recent_temps: list[float],
214
- recent_humidity: list[float],
215
- recent_wbgt: list[float],
216
- hour: int = 12,
217
- ) -> np.ndarray:
218
- """Build the 12-feature vector.
219
-
220
- Features:
221
- 0: current_temp — most recent daily max temp
222
- 1: current_wbgt — most recent WBGT
223
- 2: current_humidity — most recent humidity
224
- 3: temp_trend_7d — linear slope over last 7 days
225
- 4: temp_anomaly_30d — current temp minus 30-day mean
226
- 5: soil_moisture_proxy — inverse of recent rainfall proxy
227
- (approximated as negative temp anomaly clamped to [0,1])
228
- 6: rolling_error — mean of last 3 prediction errors
229
- 7-8: doy_sin, doy_cos — seasonal encoding
230
- 9-10: hour_sin, hour_cos — diurnal encoding
231
- 11: zone_vulnerability — numeric heat vulnerability
232
- """
233
- import datetime
234
-
235
- temps = list(recent_temps)
236
- humid = list(recent_humidity)
237
- wbgts = list(recent_wbgt)
238
-
239
- current_temp = temps[-1] if temps else 30.0
240
- current_wbgt = wbgts[-1] if wbgts else 28.0
241
- current_humidity = humid[-1] if humid else 65.0
242
-
243
- # Trend: slope of last 7 days
244
- if len(temps) >= 7:
245
- x = np.arange(7, dtype=np.float64)
246
- y = np.array(temps[-7:], dtype=np.float64)
247
- temp_trend = float(np.polyfit(x, y, 1)[0])
248
- else:
249
- temp_trend = 0.0
250
-
251
- # Anomaly: current vs 30-day mean
252
- if len(temps) >= 2:
253
- temp_anomaly = current_temp - np.mean(temps)
254
- else:
255
- temp_anomaly = 0.0
256
-
257
- # Soil moisture proxy: when temps are well below average,
258
- # likely recent rain -> higher moisture. Clamp to [0, 1].
259
- soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
260
-
261
- # Rolling prediction error
262
- if self._rolling_errors:
263
- rolling_err = float(np.mean(self._rolling_errors))
264
- else:
265
- rolling_err = 0.3 # neutral prior
266
-
267
- # Day of year encoding
268
- doy = datetime.datetime.now().timetuple().tm_yday
269
- doy_sin = np.sin(2 * np.pi * doy / 365.0)
270
- doy_cos = np.cos(2 * np.pi * doy / 365.0)
271
-
272
- # Hour encoding
273
- hour_sin = np.sin(2 * np.pi * hour / 24.0)
274
- hour_cos = np.cos(2 * np.pi * hour / 24.0)
275
-
276
- # Vulnerability
277
- vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
278
- zone_vuln = vuln_map.get(
279
- getattr(zone, "heat_vulnerability", "moderate"), 0.5
280
- )
281
-
282
- features = np.array(
283
- [
284
- current_temp,
285
- current_wbgt,
286
- current_humidity,
287
- temp_trend,
288
- temp_anomaly,
289
- soil_proxy,
290
- rolling_err,
291
- doy_sin,
292
- doy_cos,
293
- hour_sin,
294
- hour_cos,
295
- zone_vuln,
296
- ],
297
- dtype=np.float32,
298
- ).reshape(1, -1)
299
-
300
- return features
301
-
302
- # ------------------------------------------------------------------
303
- # Confidence estimation
304
- # ------------------------------------------------------------------
305
-
306
- @staticmethod
307
- def _estimate_confidence(recent_temps: list[float], tier: str) -> float:
308
- """Heuristic confidence based on data quality and model tier."""
309
- base = {"full_model": 0.80, "persistence": 0.45, "climatology": 0.30}
310
- conf = base.get(tier, 0.30)
311
-
312
- # More data -> higher confidence
313
- n = len(recent_temps)
314
- if n >= 30:
315
- conf += 0.10
316
- elif n >= 14:
317
- conf += 0.05
318
-
319
- # Low variance in recent data -> more predictable
320
- if n >= 7:
321
- std = float(np.std(recent_temps[-7:]))
322
- if std < 2.0:
323
- conf += 0.05
324
-
325
- return min(conf, 0.95)
326
-
327
- # ------------------------------------------------------------------
328
- # Training
329
- # ------------------------------------------------------------------
330
-
331
- def train(self, seed: int = 42) -> None:
332
- """Generate 2 years of synthetic daily data per zone and train.
333
-
334
- For each zone, generates 730 days of realistic temperature,
335
- humidity, and WBGT curves with autocorrelation. Labels each
336
- day with whether a trigger event (2+ consecutive days above
337
- threshold) occurs within the next 7 days.
338
- """
339
- rng = np.random.default_rng(seed)
340
- from config import ZONES
341
-
342
- n_days = 730
343
- all_X = []
344
- all_y = []
345
-
346
- for zone in ZONES:
347
- city = zone.city
348
- climate = CITY_CLIMATE.get(city, CITY_CLIMATE["Nairobi"])
349
-
350
- # Generate daily temperatures with autocorrelation
351
- temps = self._generate_temp_series(climate, n_days, rng)
352
- humidity = self._generate_humidity_series(climate, n_days, rng)
353
-
354
- # Compute WBGT series
355
- wbgt_series = [
356
- _simple_wbgt(t, h) for t, h in zip(temps, humidity)
357
- ]
358
-
359
- # Label: trigger within next 7 days?
360
- threshold = CITY_THRESHOLDS.get(city, 33.0)
361
- labels = self._label_triggers(temps, threshold, n_days)
362
-
363
- # Build features for each day (need 30-day lookback)
364
- vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
365
- zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
366
-
367
- for day in range(30, n_days - 7):
368
- t_window = temps[day - 30 : day + 1]
369
- h_window = humidity[day - 30 : day + 1]
370
- w_window = wbgt_series[day - 30 : day + 1]
371
-
372
- current_temp = t_window[-1]
373
- current_wbgt = w_window[-1]
374
- current_humidity = h_window[-1]
375
-
376
- # Trend
377
- x7 = np.arange(7, dtype=np.float64)
378
- y7 = np.array(t_window[-7:], dtype=np.float64)
379
- temp_trend = float(np.polyfit(x7, y7, 1)[0])
380
-
381
- # Anomaly
382
- temp_anomaly = current_temp - float(np.mean(t_window))
383
-
384
- # Soil moisture proxy
385
- soil_proxy = float(
386
- np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0)
387
- )
388
-
389
- # Synthetic rolling error
390
- rolling_err = rng.uniform(0.1, 0.5)
391
-
392
- # Day-of-year encoding (day within 365-day cycle)
393
- doy = day % 365
394
- doy_sin = np.sin(2 * np.pi * doy / 365.0)
395
- doy_cos = np.cos(2 * np.pi * doy / 365.0)
396
-
397
- # Random hour for variety
398
- hour = rng.integers(6, 19)
399
- hour_sin = np.sin(2 * np.pi * hour / 24.0)
400
- hour_cos = np.cos(2 * np.pi * hour / 24.0)
401
-
402
- row = [
403
- current_temp,
404
- current_wbgt,
405
- current_humidity,
406
- temp_trend,
407
- temp_anomaly,
408
- soil_proxy,
409
- rolling_err,
410
- doy_sin,
411
- doy_cos,
412
- hour_sin,
413
- hour_cos,
414
- zone_vuln,
415
- ]
416
-
417
- all_X.append(row)
418
- all_y.append(labels[day])
419
-
420
- X = np.array(all_X, dtype=np.float32)
421
- y = np.array(all_y, dtype=np.int32)
422
-
423
- self.model = xgb.XGBClassifier(
424
- n_estimators=150,
425
- max_depth=5,
426
- learning_rate=0.1,
427
- eval_metric="logloss",
428
- random_state=seed,
429
- )
430
- self.model.fit(X, y)
431
- self._save_model()
432
-
433
- # ------------------------------------------------------------------
434
- # Synthetic data generation helpers
435
- # ------------------------------------------------------------------
436
-
437
- @staticmethod
438
- def _generate_temp_series(
439
- climate: dict, n_days: int, rng
440
- ) -> list[float]:
441
- """Generate realistic daily max temperatures with autocorrelation.
442
-
443
- Uses seasonal cosine curve + AR(1) autocorrelated noise.
444
- """
445
- mean = climate["temp_mean"]
446
- amp = climate["temp_amp"]
447
- phase = climate["phase_doy"]
448
- lo, hi = climate["temp_range"]
449
-
450
- temps = []
451
- noise = 0.0
452
- ar_coef = 0.7 # autocorrelation coefficient
453
-
454
- for d in range(n_days):
455
- # Seasonal component
456
- seasonal = mean + amp * np.cos(
457
- 2 * np.pi * (d - phase) / 365.0
458
- )
459
- # AR(1) noise
460
- noise = ar_coef * noise + rng.normal(0, 1.2)
461
- temp = seasonal + noise
462
- temp = float(np.clip(temp, lo - 1.0, hi + 2.0))
463
- temps.append(temp)
464
-
465
- return temps
466
-
467
- @staticmethod
468
- def _generate_humidity_series(
469
- climate: dict, n_days: int, rng
470
- ) -> list[float]:
471
- """Generate daily humidity with seasonal cycle and noise."""
472
- mean = climate["humidity_mean"]
473
- amp = climate["humidity_amp"]
474
- phase = climate.get("phase_doy", 45)
475
-
476
- humidity = []
477
- noise = 0.0
478
-
479
- for d in range(n_days):
480
- # Humidity anti-correlated with temp in dry season
481
- seasonal = mean - amp * np.cos(
482
- 2 * np.pi * (d - phase) / 365.0
483
- )
484
- noise = 0.5 * noise + rng.normal(0, 4.0)
485
- h = seasonal + noise
486
- h = float(np.clip(h, 30.0, 98.0))
487
- humidity.append(h)
488
-
489
- return humidity
490
-
491
- @staticmethod
492
- def _label_triggers(
493
- temps: list[float], threshold: float, n_days: int
494
- ) -> list[int]:
495
- """Label each day: 1 if a 2+ consecutive day trigger occurs in next 7 days."""
496
- labels = [0] * n_days
497
-
498
- for day in range(n_days - 7):
499
- window = temps[day + 1 : day + 8]
500
- # Check for 2+ consecutive above threshold
501
- consec = 0
502
- triggered = False
503
- for t in window:
504
- if t > threshold:
505
- consec += 1
506
- if consec >= 2:
507
- triggered = True
508
- break
509
- else:
510
- consec = 0
511
- labels[day] = 1 if triggered else 0
512
-
513
- return labels
514
-
515
- # ------------------------------------------------------------------
516
- # Persistence
517
- # ------------------------------------------------------------------
518
-
519
- def _save_model(self) -> None:
520
- self.model_path.parent.mkdir(parents=True, exist_ok=True)
521
- self.model.save_model(str(self.model_path))
522
-
523
- def _load_or_train(self) -> None:
524
- if self.model_path.exists():
525
- self.model = xgb.XGBClassifier()
526
- self.model.load_model(str(self.model_path))
527
- else:
528
- self.train()
529
-
530
- def _train_lstm_synthetic(self) -> None:
531
- """Auto-train LSTM on synthetic data when no model file exists."""
532
- try:
533
- from src.prediction.lstm_model import (
534
- LSTMTrainer,
535
- generate_synthetic_zone_data,
536
- )
537
- from config import ZONES
538
-
539
- import logging
540
-
541
- logger = logging.getLogger(__name__)
542
- logger.info("LSTM model not found -- training on synthetic data")
543
-
544
- zone_data = generate_synthetic_zone_data(ZONES, n_days=730, seed=42)
545
- trainer = LSTMTrainer(epochs=50, patience=5)
546
- trainer.train(zone_data)
547
-
548
- # Reload the predictor now that the model file exists
549
- self._lstm = LSTMPredictor()
550
- logger.info("LSTM auto-trained and loaded successfully")
551
- except Exception as exc:
552
- import logging
553
-
554
- logging.getLogger(__name__).warning(
555
- "LSTM auto-training failed: %s", exc
556
- )
557
- self._lstm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/prediction/lstm_model.py DELETED
@@ -1,566 +0,0 @@
1
- """
2
- LSTM neural heat wave predictor for parametric insurance triggers.
3
-
4
- 2-layer LSTM that learns temporal patterns in 14-day climate sequences
5
- to predict heat wave trigger probability in the next 7 days.
6
- Ensembled with the existing XGBoost classifier in heat_forecast.py.
7
-
8
- Architecture:
9
- Input: (batch, 14, 6) -- 14 days x 6 climate features
10
- LSTM: 2 layers, hidden_size=64, dropout=0.2
11
- Output: scalar sigmoid probability
12
-
13
- Features per timestep:
14
- 0: temp_max_c -- daily max temperature (normalized)
15
- 1: humidity_pct -- relative humidity (normalized)
16
- 2: wind_speed_ms -- wind speed (normalized)
17
- 3: wbgt_c -- wet-bulb globe temperature (normalized)
18
- 4: heat_index_c -- apparent temperature (normalized)
19
- 5: temp_anomaly -- temp minus 7-day rolling mean (normalized)
20
-
21
- References:
22
- - Perkins-Kirkpatrick & Lewis (2020) heat wave definitions
23
- - WHO/ILO occupational heat stress thresholds
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import json
29
- from pathlib import Path
30
-
31
- import numpy as np
32
-
33
- try:
34
- import torch
35
- import torch.nn as nn
36
- from torch.utils.data import DataLoader, TensorDataset
37
-
38
- TORCH_AVAILABLE = True
39
- except ImportError:
40
- TORCH_AVAILABLE = False
41
-
42
- from src.indexing.heat_index import calculate_wbgt, calculate_heat_index
43
-
44
- # City-specific temperature thresholds for trigger definition (deg C)
45
- CITY_THRESHOLDS = {
46
- "Dar es Salaam": 34.0,
47
- "Kampala": 31.0,
48
- "Nairobi": 28.0,
49
- "Kigali": 29.0,
50
- }
51
-
52
- # Seasonal temperature / humidity profiles per city
53
- CITY_CLIMATE = {
54
- "Dar es Salaam": {
55
- "temp_mean": 31.0, "temp_amp": 3.5, "phase_doy": 45,
56
- "humidity_mean": 82.0, "humidity_amp": 7.0,
57
- "temp_range": (28.0, 36.0),
58
- },
59
- "Kampala": {
60
- "temp_mean": 28.0, "temp_amp": 3.0, "phase_doy": 45,
61
- "humidity_mean": 70.0, "humidity_amp": 10.0,
62
- "temp_range": (24.0, 32.0),
63
- },
64
- "Nairobi": {
65
- "temp_mean": 24.5, "temp_amp": 3.0, "phase_doy": 55,
66
- "humidity_mean": 57.0, "humidity_amp": 12.0,
67
- "temp_range": (20.0, 28.0),
68
- },
69
- "Kigali": {
70
- "temp_mean": 25.5, "temp_amp": 2.5, "phase_doy": 50,
71
- "humidity_mean": 65.0, "humidity_amp": 8.0,
72
- "temp_range": (22.0, 29.0),
73
- },
74
- }
75
-
76
- FEATURE_NAMES = [
77
- "temp_max_c", "humidity_pct", "wind_speed_ms",
78
- "wbgt_c", "heat_index_c", "temp_anomaly",
79
- ]
80
-
81
- NUM_FEATURES = len(FEATURE_NAMES)
82
-
83
-
84
- def _resolve_path(rel: str) -> Path:
85
- """Resolve a path relative to the project root."""
86
- return Path(__file__).resolve().parents[2] / rel
87
-
88
-
89
- # ======================================================================
90
- # Model
91
- # ======================================================================
92
-
93
- if TORCH_AVAILABLE:
94
-
95
- class HeatLSTM(nn.Module):
96
- """2-layer LSTM for 7-day heat wave trigger prediction."""
97
-
98
- def __init__(
99
- self,
100
- input_size: int = NUM_FEATURES,
101
- hidden_size: int = 64,
102
- num_layers: int = 2,
103
- dropout: float = 0.2,
104
- ):
105
- super().__init__()
106
- self.lstm = nn.LSTM(
107
- input_size, hidden_size, num_layers,
108
- batch_first=True, dropout=dropout,
109
- )
110
- self.fc = nn.Linear(hidden_size, 1)
111
-
112
- def forward(self, x):
113
- # x: (batch, seq_len, input_size)
114
- out, _ = self.lstm(x)
115
- out = self.fc(out[:, -1, :]) # last timestep
116
- return torch.sigmoid(out).squeeze(-1)
117
-
118
-
119
- # ======================================================================
120
- # Derived feature computation
121
- # ======================================================================
122
-
123
- def _compute_temp_anomaly(temps: list[float], index: int) -> float:
124
- """Compute temp minus 7-day rolling mean at the given index."""
125
- start = max(0, index - 6)
126
- window = temps[start:index + 1]
127
- if not window:
128
- return 0.0
129
- return temps[index] - float(np.mean(window))
130
-
131
-
132
- # ======================================================================
133
- # Synthetic data generation
134
- # ======================================================================
135
-
136
- def _generate_temp_series(climate: dict, n_days: int, rng) -> list[float]:
137
- """Daily max temperatures with AR(1) autocorrelation."""
138
- mean, amp, phase = climate["temp_mean"], climate["temp_amp"], climate["phase_doy"]
139
- lo, hi = climate["temp_range"]
140
- temps, noise = [], 0.0
141
- for d in range(n_days):
142
- seasonal = mean + amp * np.cos(2 * np.pi * (d - phase) / 365.0)
143
- noise = 0.7 * noise + rng.normal(0, 1.2)
144
- temps.append(float(np.clip(seasonal + noise, lo - 1.0, hi + 2.0)))
145
- return temps
146
-
147
-
148
- def _generate_humidity_series(climate: dict, n_days: int, rng) -> list[float]:
149
- """Daily humidity with seasonal cycle and noise."""
150
- mean, amp = climate["humidity_mean"], climate["humidity_amp"]
151
- phase = climate.get("phase_doy", 45)
152
- humidity, noise = [], 0.0
153
- for d in range(n_days):
154
- seasonal = mean - amp * np.cos(2 * np.pi * (d - phase) / 365.0)
155
- noise = 0.5 * noise + rng.normal(0, 4.0)
156
- humidity.append(float(np.clip(seasonal + noise, 30.0, 98.0)))
157
- return humidity
158
-
159
-
160
- def _generate_wind_series(n_days: int, rng) -> list[float]:
161
- """Synthetic wind speed series (m/s)."""
162
- winds, noise = [], 0.0
163
- for _ in range(n_days):
164
- noise = 0.4 * noise + rng.normal(0, 0.8)
165
- winds.append(float(np.clip(3.5 + noise, 0.5, 12.0)))
166
- return winds
167
-
168
-
169
- def _label_triggers(temps: list[float], threshold: float, n_days: int) -> list[int]:
170
- """Label each day: 1 if 2+ consecutive days above threshold in next 7 days."""
171
- labels = [0] * n_days
172
- for day in range(n_days - 7):
173
- window = temps[day + 1: day + 8]
174
- consec = 0
175
- triggered = False
176
- for t in window:
177
- if t > threshold:
178
- consec += 1
179
- if consec >= 2:
180
- triggered = True
181
- break
182
- else:
183
- consec = 0
184
- labels[day] = 1 if triggered else 0
185
- return labels
186
-
187
-
188
- def generate_synthetic_zone_data(
189
- zones: list, n_days: int = 730, seed: int = 42,
190
- ) -> dict[str, list[dict]]:
191
- """Generate synthetic daily climate data for all zones.
192
-
193
- Returns:
194
- dict mapping zone_id -> list of daily dicts with keys:
195
- temp_max_c, humidity_pct, wind_speed_ms, city
196
- """
197
- rng = np.random.default_rng(seed)
198
- zone_data: dict[str, list[dict]] = {}
199
-
200
- for zone in zones:
201
- city = zone.city
202
- climate = CITY_CLIMATE.get(city, CITY_CLIMATE["Nairobi"])
203
-
204
- temps = _generate_temp_series(climate, n_days, rng)
205
- humidity = _generate_humidity_series(climate, n_days, rng)
206
- winds = _generate_wind_series(n_days, rng)
207
-
208
- days = []
209
- for i in range(n_days):
210
- days.append({
211
- "temp_max_c": temps[i],
212
- "humidity_pct": humidity[i],
213
- "wind_speed_ms": winds[i],
214
- "city": city,
215
- })
216
- zone_data[zone.zone_id] = days
217
-
218
- return zone_data
219
-
220
-
221
- # ======================================================================
222
- # Trainer
223
- # ======================================================================
224
-
225
- class LSTMTrainer:
226
- """Train the HeatLSTM on historical or synthetic climate data."""
227
-
228
- def __init__(
229
- self,
230
- model: object | None = None,
231
- lr: float = 0.001,
232
- epochs: int = 50,
233
- patience: int = 5,
234
- seq_len: int = 14,
235
- forecast_horizon: int = 7,
236
- ):
237
- if not TORCH_AVAILABLE:
238
- raise ImportError("torch is required. pip install torch")
239
- self._custom_model = model
240
- self.lr = lr
241
- self.epochs = epochs
242
- self.patience = patience
243
- self.seq_len = seq_len
244
- self.forecast_horizon = forecast_horizon
245
- self.model_path = _resolve_path("models/heat_lstm.pt")
246
- self.norm_path = _resolve_path("models/lstm_norm.json")
247
-
248
- def prepare_data(
249
- self, zone_readings: dict[str, list],
250
- ) -> tuple:
251
- """Convert zone readings to training tensors.
252
-
253
- Args:
254
- zone_readings: dict of zone_id -> list of daily readings.
255
- Each reading needs: temp_max_c, humidity_pct, wind_speed_ms
256
- Optional: date, city
257
-
258
- Returns:
259
- (X_train, y_train, X_val, y_val) as torch tensors
260
- """
261
- all_seqs: list[np.ndarray] = []
262
- all_labels: list[int] = []
263
-
264
- for zone_id, days in zone_readings.items():
265
- n = len(days)
266
- if n < self.seq_len + self.forecast_horizon + 1:
267
- continue
268
-
269
- city = days[0].get("city", "Nairobi")
270
- threshold = CITY_THRESHOLDS.get(city, 33.0)
271
-
272
- # Extract raw temps for labeling and anomaly computation
273
- temps = [d["temp_max_c"] for d in days]
274
- labels = _label_triggers(temps, threshold, n)
275
-
276
- # Compute derived features for all days
277
- derived = []
278
- for i, d in enumerate(days):
279
- t = d["temp_max_c"]
280
- h = d["humidity_pct"]
281
- w = d["wind_speed_ms"]
282
- wbgt = calculate_wbgt(t, h)
283
- hi = calculate_heat_index(t, h)
284
- anomaly = _compute_temp_anomaly(temps, i)
285
- derived.append([t, h, w, wbgt, hi, anomaly])
286
-
287
- # Create sliding windows
288
- for i in range(n - self.seq_len - self.forecast_horizon):
289
- seq = np.array(
290
- derived[i: i + self.seq_len], dtype=np.float32,
291
- )
292
- all_seqs.append(seq)
293
- all_labels.append(labels[i + self.seq_len - 1])
294
-
295
- X = np.stack(all_seqs) # (N, seq_len, 6)
296
- y = np.array(all_labels, dtype=np.float32)
297
-
298
- # Temporal split: first 75% train, last 25% validation
299
- split = int(len(X) * 0.75)
300
- X_train_np, X_val_np = X[:split], X[split:]
301
- y_train_np, y_val_np = y[:split], y[split:]
302
-
303
- # Compute normalization (z-score per feature) from training set
304
- flat = X_train_np.reshape(-1, NUM_FEATURES)
305
- feat_mean = flat.mean(axis=0).tolist()
306
- feat_std = flat.std(axis=0).tolist()
307
- feat_std = [max(s, 1e-6) for s in feat_std]
308
-
309
- # Save normalization params
310
- norm = {"mean": feat_mean, "std": feat_std}
311
- self.norm_path.parent.mkdir(parents=True, exist_ok=True)
312
- with open(self.norm_path, "w") as f:
313
- json.dump(norm, f, indent=2)
314
-
315
- # Normalize
316
- mean_arr = np.array(feat_mean, dtype=np.float32)
317
- std_arr = np.array(feat_std, dtype=np.float32)
318
- X_train_np = (X_train_np - mean_arr) / std_arr
319
- X_val_np = (X_val_np - mean_arr) / std_arr
320
-
321
- X_train = torch.from_numpy(X_train_np)
322
- y_train = torch.from_numpy(y_train_np)
323
- X_val = torch.from_numpy(X_val_np)
324
- y_val = torch.from_numpy(y_val_np)
325
-
326
- return X_train, y_train, X_val, y_val
327
-
328
- def train(self, zone_readings: dict[str, list]) -> dict:
329
- """Train the LSTM and return metrics."""
330
- torch.manual_seed(42)
331
- np.random.seed(42)
332
-
333
- X_train, y_train, X_val, y_val = self.prepare_data(zone_readings)
334
-
335
- # DataLoaders
336
- train_ds = TensorDataset(X_train, y_train)
337
- val_ds = TensorDataset(X_val, y_val)
338
- train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
339
- val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)
340
-
341
- # Model, loss, optimizer
342
- model = self._custom_model if self._custom_model is not None else HeatLSTM()
343
- criterion = nn.BCELoss()
344
- optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
345
-
346
- total_params = sum(p.numel() for p in model.parameters())
347
- print(f" Model params: {total_params:,}")
348
-
349
- # Training loop with early stopping
350
- best_val_loss = float("inf")
351
- patience_counter = 0
352
- best_state = None
353
- best_metrics: dict = {}
354
-
355
- for epoch in range(self.epochs):
356
- # Train
357
- model.train()
358
- train_loss, train_total = 0.0, 0
359
- for xb, yb in train_loader:
360
- optimizer.zero_grad()
361
- preds = model(xb)
362
- loss = criterion(preds, yb)
363
- loss.backward()
364
- optimizer.step()
365
- train_loss += loss.item() * len(xb)
366
- train_total += len(xb)
367
-
368
- # Validate
369
- model.eval()
370
- val_loss, val_total = 0.0, 0
371
- all_val_preds, all_val_labels = [], []
372
- with torch.no_grad():
373
- for xb, yb in val_loader:
374
- preds = model(xb)
375
- loss = criterion(preds, yb)
376
- val_loss += loss.item() * len(xb)
377
- val_total += len(xb)
378
- all_val_preds.extend(preds.numpy().tolist())
379
- all_val_labels.extend(yb.numpy().tolist())
380
-
381
- avg_train_loss = train_loss / max(train_total, 1)
382
- avg_val_loss = val_loss / max(val_total, 1)
383
- val_auroc = _compute_auroc(all_val_labels, all_val_preds)
384
-
385
- if (epoch + 1) % 5 == 0 or epoch == 0:
386
- print(
387
- f" Epoch {epoch + 1:>2}: "
388
- f"train_loss={avg_train_loss:.4f} | "
389
- f"val_loss={avg_val_loss:.4f} val_auroc={val_auroc:.3f}"
390
- )
391
-
392
- # Early stopping
393
- if avg_val_loss < best_val_loss:
394
- best_val_loss = avg_val_loss
395
- patience_counter = 0
396
- best_state = {k: v.clone() for k, v in model.state_dict().items()}
397
- best_metrics = {
398
- "train_loss": round(avg_train_loss, 4),
399
- "val_loss": round(avg_val_loss, 4),
400
- "val_auroc": round(val_auroc, 4),
401
- "epochs_trained": epoch + 1,
402
- "samples": {
403
- "train": len(X_train),
404
- "val": len(X_val),
405
- },
406
- }
407
- else:
408
- patience_counter += 1
409
- if patience_counter >= self.patience:
410
- print(f" Early stopping at epoch {epoch + 1}")
411
- break
412
-
413
- # Save best model
414
- if best_state is not None:
415
- model.load_state_dict(best_state)
416
- self.model_path.parent.mkdir(parents=True, exist_ok=True)
417
- torch.save(model.state_dict(), self.model_path)
418
-
419
- file_size = self.model_path.stat().st_size
420
- print(f" Saved model to {self.model_path} ({file_size / 1024:.1f} KB)")
421
-
422
- return best_metrics
423
-
424
-
425
- # ======================================================================
426
- # Predictor (inference)
427
- # ======================================================================
428
-
429
- class LSTMPredictor:
430
- """Load trained HeatLSTM and predict trigger probability."""
431
-
432
- def __init__(
433
- self,
434
- model_path: str = "models/heat_lstm.pt",
435
- norm_path: str = "models/lstm_norm.json",
436
- ):
437
- if not TORCH_AVAILABLE:
438
- raise ImportError("torch is required. pip install torch")
439
- self.model_path = _resolve_path(model_path)
440
- self.norm_path = _resolve_path(norm_path)
441
- self.model: HeatLSTM | None = None
442
- self._norm: dict | None = None
443
- self._load_model()
444
-
445
- def _load_model(self) -> None:
446
- """Load model weights and normalization params from disk."""
447
- mp = self.model_path
448
- np_ = self.norm_path
449
-
450
- if not mp.exists():
451
- raise FileNotFoundError(
452
- f"LSTM model not found at {mp}. "
453
- "Run scripts/train_lstm.py first."
454
- )
455
- if not np_.exists():
456
- raise FileNotFoundError(
457
- f"LSTM normalization params not found at {np_}. "
458
- "Run scripts/train_lstm.py first."
459
- )
460
-
461
- self.model = HeatLSTM()
462
- self.model.load_state_dict(
463
- torch.load(mp, map_location="cpu", weights_only=True)
464
- )
465
- self.model.eval()
466
-
467
- with open(np_) as f:
468
- self._norm = json.load(f)
469
-
470
- def predict(self, recent_14_days: list[dict]) -> tuple[float, float]:
471
- """Predict trigger probability from last 14 days of data.
472
-
473
- Args:
474
- recent_14_days: list of dicts with keys:
475
- temp_max_c, humidity_pct, wind_speed_ms
476
- (WBGT, heat index, and temp anomaly are computed internally)
477
-
478
- Returns:
479
- (probability, confidence) where:
480
- - probability: 0-1 trigger probability
481
- - confidence: 0-1 based on MC dropout (5 forward passes)
482
- """
483
- if len(recent_14_days) < 14:
484
- pad = [recent_14_days[0]] * (14 - len(recent_14_days))
485
- recent_14_days = pad + recent_14_days
486
-
487
- days = recent_14_days[-14:]
488
-
489
- # Extract temps for anomaly computation
490
- temps = [d.get("temp_max_c", d.get("temp_c", 30.0)) for d in days]
491
-
492
- # Build feature array: compute derived features
493
- seq = []
494
- for i, d in enumerate(days):
495
- t = d.get("temp_max_c", d.get("temp_c", 30.0))
496
- h = d.get("humidity_pct", 65.0)
497
- w = d.get("wind_speed_ms", 3.0)
498
- wbgt = d.get("wbgt_c", calculate_wbgt(t, h))
499
- hi = d.get("heat_index_c", calculate_heat_index(t, h))
500
- anomaly = _compute_temp_anomaly(temps, i)
501
- seq.append([t, h, w, wbgt, hi, anomaly])
502
-
503
- x = np.array(seq, dtype=np.float32)
504
-
505
- # Normalize using saved params
506
- mean = np.array(self._norm["mean"], dtype=np.float32)
507
- std = np.array(self._norm["std"], dtype=np.float32)
508
- x = (x - mean) / std
509
-
510
- x_tensor = torch.from_numpy(x).unsqueeze(0) # (1, 14, 6)
511
-
512
- # MC Dropout: 5 forward passes batched with dropout enabled
513
- self.model.train()
514
- x_batch = x_tensor.expand(5, -1, -1) # (5, 14, 6)
515
- with torch.no_grad():
516
- preds = self.model(x_batch).numpy()
517
- self.model.eval()
518
-
519
- probability = float(np.mean(preds))
520
- std_val = float(np.std(preds))
521
- confidence = max(0.3, min(0.95, 1.0 - std_val * 3))
522
-
523
- probability = float(np.clip(probability, 0.0, 1.0))
524
-
525
- return probability, confidence
526
-
527
-
528
- # ======================================================================
529
- # Utilities
530
- # ======================================================================
531
-
532
- def _compute_auroc(labels: list[float], preds: list[float]) -> float:
533
- """Compute AUROC using sklearn if available, else trapezoidal fallback."""
534
- if len(set(labels)) < 2:
535
- return 0.5
536
-
537
- try:
538
- from sklearn.metrics import roc_auc_score
539
- return float(roc_auc_score(labels, preds))
540
- except ImportError:
541
- pass
542
-
543
- # Fallback: trapezoidal AUROC
544
- pairs = sorted(zip(preds, labels), key=lambda x: -x[0])
545
- tp, fp = 0, 0
546
- tp_prev, fp_prev = 0, 0
547
- auc = 0.0
548
- n_pos = sum(labels)
549
- n_neg = len(labels) - n_pos
550
-
551
- if n_pos == 0 or n_neg == 0:
552
- return 0.5
553
-
554
- prev_score = None
555
- for score, label in pairs:
556
- if score != prev_score and prev_score is not None:
557
- auc += (fp - fp_prev) * (tp + tp_prev) / 2.0
558
- tp_prev, fp_prev = tp, fp
559
- if label == 1.0:
560
- tp += 1
561
- else:
562
- fp += 1
563
- prev_score = score
564
-
565
- auc += (fp - fp_prev) * (tp + tp_prev) / 2.0
566
- return auc / (n_pos * n_neg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/pricing/neural_actuarial.py DELETED
@@ -1,1312 +0,0 @@
1
- """
2
- Neural Actuarial Pricing Engine for Parametric Heat Insurance.
3
-
4
- Three-headed temporal neural model trained on real climate data:
5
- 1. HazardHead (Neural EVT): Learns extreme heat frequency + severity
6
- distribution using Generalized Pareto Distribution parameters
7
- 2. VulnerabilityHead: Learns zone-specific worker impact calibrated
8
- to WHO/ILO occupational heat stress guidelines (ISO 7243)
9
- 3. PricingHead (CANN): Combined Actuarial Neural Network — GLM baseline
10
- with bounded neural correction. When δ_NN = 0, reproduces the
11
- existing ActuarialPricer formula exactly.
12
-
13
- City-specific: trained per city (default: Dar es Salaam). Architecture
14
- is portable; weights are local. Retrain for any city with:
15
- python3 scripts/train_neural_pricer.py --city "Kampala"
16
-
17
- References:
18
- - Chen et al. (2024) "Managing Weather Risk with NN-Based Index Insurance"
19
- Management Science 70(7), 4306-4327
20
- - Pasche & Engelke (2024) "Neural Networks for Extreme Quantile Regression"
21
- arXiv:2208.07590 (EQRN)
22
- - Richman & Wuthrich (2023) "LocalGLMnet" Scandinavian Actuarial Journal
23
- - ISO 7243 / NIOSH occupational heat stress thresholds
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import json
29
- import logging
30
- from pathlib import Path
31
- from typing import Optional
32
-
33
- import numpy as np
34
-
35
- try:
36
- import torch
37
- import torch.nn as nn
38
- import torch.nn.functional as F
39
- TORCH_AVAILABLE = True
40
- except ImportError:
41
- TORCH_AVAILABLE = False
42
-
43
- # Chronos foundation model (optional — used for Chronos encoder path)
44
- try:
45
- from chronos import ChronosBoltPipeline
46
- CHRONOS_AVAILABLE = True
47
- except ImportError:
48
- CHRONOS_AVAILABLE = False
49
-
50
- from src.indexing.heat_index import calculate_wbgt
51
- from src.pricing.actuarial import ActuarialPricer, ActuarialResult
52
-
53
- log = logging.getLogger(__name__)
54
-
55
- PROJECT_ROOT = Path(__file__).resolve().parents[2]
56
-
57
- # ── City-specific trigger thresholds (WBGT, °C) ──────────────────────────
58
- # Based on ILO occupational heat stress measurements
59
- # Actuarial trigger thresholds: set at ~P85-P90 of each city's WBGT distribution
60
- # so that only genuinely extreme heat events are trigger-worthy.
61
- # These are HIGHER than occupational safety thresholds (which define "uncomfortable")
62
- # because insurance triggers need to define "exceptional" events.
63
- CITY_WBGT_THRESHOLDS = {
64
- "Dar es Salaam": 35.1, # Calibrated P97 threshold (matches insurance benchmark)
65
- "Kampala": 30.0,
66
- "Nairobi": 28.0,
67
- "Kigali": 29.0,
68
- }
69
-
70
- # Settlement-type-specific WBGT thresholds for compound triggers.
71
- # UHI correction already differentiates zones (informal +2.5°C, formal +0.5°C),
72
- # so all types use the city threshold. Zone frequency differences come from UHI.
73
- SETTLEMENT_THRESHOLDS = {
74
- "informal": 35.1,
75
- "mixed": 35.1,
76
- "formal": 35.1,
77
- "commercial": 35.1,
78
- }
79
-
80
- # Minimum consecutive days above threshold to count as a heat EVENT.
81
- # Alert tier at 2 days, payout tier at 5 days (matches insurance benchmark).
82
- MIN_CONSECUTIVE_DAYS = 2
83
-
84
- # ── WHO/ILO dose-response: WBGT → productivity loss ──────────────────────
85
- # ISO 7243 thresholds for moderate-to-heavy outdoor work
86
- def who_productivity_loss(wbgt: float) -> float:
87
- """Fractional productivity loss given WBGT (°C). ISO 7243 calibrated."""
88
- if wbgt < 26.0:
89
- return 0.0
90
- elif wbgt < 28.0:
91
- return 0.10
92
- elif wbgt < 30.0:
93
- return 0.25
94
- elif wbgt < 32.0:
95
- return 0.50
96
- elif wbgt < 35.0:
97
- return 0.75
98
- else:
99
- return 1.0 # work must stop
100
-
101
- # ── Settlement-type UHI ranges (from literature) ─────────────────────────
102
- from src.downscaling.uhi_model import UHI_RANGES
103
-
104
- # ══════════════════════════════════════════════════════════════════════════
105
- # PyTorch Model Components
106
- # ══════════════════════════════════════════════════════════════════════════
107
-
108
- if TORCH_AVAILABLE:
109
-
110
- class TemporalEncoder(nn.Module):
111
- """
112
- LSTM encoder for 90-day climate sequences.
113
-
114
- Input: (batch, 90, 11) — 7 climate + 4 zone-static features
115
- Output: (batch, 128) — last hidden state
116
- """
117
-
118
- def __init__(self, input_size: int = 11, hidden_size: int = 128,
119
- num_layers: int = 2, dropout: float = 0.3):
120
- super().__init__()
121
- self.lstm = nn.LSTM(
122
- input_size, hidden_size, num_layers,
123
- batch_first=True, dropout=dropout,
124
- )
125
- self.norm = nn.LayerNorm(hidden_size)
126
-
127
- def forward(self, x):
128
- out, (h_n, _) = self.lstm(x)
129
- # Use last layer's hidden state
130
- latent = h_n[-1] # (batch, hidden_size)
131
- return self.norm(latent)
132
-
133
- class ChronosEncoder(nn.Module):
134
- """
135
- Chronos-Bolt foundation model encoder for 90-day climate sequences.
136
-
137
- Embeds the primary heat stress signal (WBGT, channel 6) via frozen
138
- Chronos-Bolt-Tiny (9M params, pre-trained on 100B+ observations),
139
- then concatenates climate summary stats and zone-static features
140
- before projecting to the same 128-dim latent as TemporalEncoder.
141
-
142
- Input: (batch, 90, 11) — same contract as TemporalEncoder
143
- Output: (batch, 128) — same contract as TemporalEncoder
144
- """
145
-
146
- WBGT_IDX = 6 # index of WBGT in 11-feature vector
147
- N_CLIMATE = 7 # features 0-6 are climate
148
- N_STATIC = 4 # features 7-10 are zone-static
149
-
150
- def __init__(self, hidden_size: int = 128, chronos_d_model: int = 256):
151
- super().__init__()
152
- proj_input = chronos_d_model + self.N_CLIMATE + self.N_STATIC
153
- self.proj = nn.Linear(proj_input, hidden_size)
154
- self.act = nn.GELU()
155
- self.norm = nn.LayerNorm(hidden_size)
156
- self._chronos_d_model = chronos_d_model
157
- # Set externally after construction — NOT part of state_dict
158
- self._pipeline = None
159
- self._feat_mean = None # numpy (11,) for un-normalizing WBGT
160
- self._feat_std = None
161
-
162
- def set_pipeline(self, pipeline, feat_mean=None, feat_std=None):
163
- """Attach the frozen Chronos pipeline (not a nn.Module)."""
164
- self._pipeline = pipeline
165
- self._feat_mean = feat_mean
166
- self._feat_std = feat_std
167
-
168
- def _unnorm_wbgt(self, x_norm):
169
- """Recover raw WBGT values from z-scored tensor for Chronos input."""
170
- wbgt_norm = x_norm[:, :, self.WBGT_IDX] # (batch, 90)
171
- if self._feat_mean is not None and self._feat_std is not None:
172
- mu = self._feat_mean[self.WBGT_IDX]
173
- sd = self._feat_std[self.WBGT_IDX]
174
- return wbgt_norm * sd + mu # back to pre-norm scale (wbgt/40)
175
- return wbgt_norm
176
-
177
- def _embed_wbgt(self, x_norm):
178
- """Run Chronos .embed() on un-normalized WBGT and mean-pool."""
179
- wbgt_scaled = self._unnorm_wbgt(x_norm) # (batch, 90), scale=wbgt/40
180
- wbgt_raw = wbgt_scaled * 40.0 # raw °C for Chronos
181
- with torch.no_grad():
182
- emb, _ = self._pipeline.embed(wbgt_raw) # (batch, patches+1, 256)
183
- return emb.mean(dim=1) # (batch, 256)
184
-
185
- def forward(self, x, chronos_embeddings=None):
186
- """
187
- Args:
188
- x: (batch, 90, 11) normalized climate sequence
189
- chronos_embeddings: optional (batch, d_model) pre-computed.
190
- If None, computes from x via the attached pipeline.
191
- """
192
- climate_means = x[:, :, :self.N_CLIMATE].mean(dim=1) # (batch, 7)
193
- zone_static = x[:, 0, self.N_CLIMATE:] # (batch, 4)
194
-
195
- if chronos_embeddings is None:
196
- chronos_embeddings = self._embed_wbgt(x)
197
-
198
- combined = torch.cat([chronos_embeddings, climate_means, zone_static], dim=-1)
199
- return self.norm(self.act(self.proj(combined))) # (batch, 128)
200
-
201
- class HazardHead(nn.Module):
202
- """
203
- Neural Extreme Value Theory head + two-tier trigger decision.
204
-
205
- Two tiers matching SEWA/Arsht-Rockefeller pilot design:
206
- Alert tier: moderate heat event → cash transfer (philanthropy-funded)
207
- Payout tier: severe sustained event → insurance payout (underwritten)
208
-
209
- Outputs:
210
- λ (events/year): softplus, range ~0.5-50
211
- σ (GPD scale): softplus, range ~0.1-10
212
- ξ (GPD shape): tanh × 0.4, range [-0.4, 0.4]
213
- alert_prob: sigmoid → [0, 1], moderate event (cash tier)
214
- payout_prob: sigmoid → [0, 1], severe event (insurance tier)
215
- alert_severity: sigmoid → [0, 1], cash amount scaling
216
- payout_severity: sigmoid → [0, 1], insurance payout scaling
217
- """
218
-
219
- def __init__(self, input_size: int = 128):
220
- super().__init__()
221
- self.net = nn.Sequential(
222
- nn.Linear(input_size, 64),
223
- nn.GELU(),
224
- nn.Linear(64, 32),
225
- nn.GELU(),
226
- )
227
- self.lambda_head = nn.Linear(32, 1)
228
- self.sigma_head = nn.Linear(32, 1)
229
- self.xi_head = nn.Linear(32, 1)
230
- self.alert_prob_head = nn.Linear(32, 1)
231
- self.payout_prob_head = nn.Linear(32, 1)
232
- self.alert_severity_head = nn.Linear(32, 1)
233
- self.payout_severity_head = nn.Linear(32, 1)
234
-
235
- def forward(self, h):
236
- z = self.net(h)
237
- lambda_ = F.softplus(self.lambda_head(z)) + 0.5
238
- sigma = F.softplus(self.sigma_head(z)) + 0.1
239
- xi = torch.tanh(self.xi_head(z)) * 0.4
240
- alert_prob = torch.sigmoid(self.alert_prob_head(z))
241
- payout_prob = torch.sigmoid(self.payout_prob_head(z))
242
- alert_severity = torch.sigmoid(self.alert_severity_head(z))
243
- payout_severity = torch.sigmoid(self.payout_severity_head(z))
244
- # Pack into trigger_prob and payout_factor for backward compat
245
- # trigger_prob = alert_prob (most frequent), payout_factor = payout_severity
246
- return lambda_, sigma, xi, alert_prob, payout_prob, alert_severity, payout_severity
247
-
248
- class VulnerabilityHead(nn.Module):
249
- """
250
- Neural loss distribution head.
251
-
252
- Learns zone-specific worker impact from hazard characteristics.
253
- Calibrated to WHO/ILO ISO 7243 dose-response.
254
-
255
- Outputs:
256
- productivity_loss: sigmoid → [0, 1]
257
- basis_risk: sigmoid → [0, 1]
258
- severity_multiplier: softplus + 0.5 → [0.5, ∞)
259
- """
260
-
261
- def __init__(self, input_size: int = 135): # 128 + 7 hazard outputs
262
- super().__init__()
263
- self.net = nn.Sequential(
264
- nn.Linear(input_size, 64),
265
- nn.GELU(),
266
- nn.Linear(64, 32),
267
- nn.GELU(),
268
- )
269
- self.prod_loss_head = nn.Linear(32, 1)
270
- self.basis_risk_head = nn.Linear(32, 1)
271
- self.severity_head = nn.Linear(32, 1)
272
-
273
- def forward(self, h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev):
274
- combined = torch.cat([h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev], dim=-1)
275
- z = self.net(combined)
276
- prod_loss = torch.sigmoid(self.prod_loss_head(z))
277
- basis_risk = torch.sigmoid(self.basis_risk_head(z))
278
- severity = F.softplus(self.severity_head(z)) + 0.5
279
- return prod_loss, basis_risk, severity
280
-
281
- class PricingHead(nn.Module):
282
- """
283
- CANN (Combined Actuarial Neural Network) pricing head.
284
-
285
- Skip connection: price = exp(η_GLM + δ_NN) × inflation_buffer
286
-
287
- When δ_NN = 0, reproduces ActuarialPricer formula exactly.
288
- δ_NN is clamped to [-0.5, 0.5] for safety (±50% max correction).
289
- """
290
-
291
- def __init__(self, input_size: int = 138, max_delta: float = 0.5): # 128 + 7 hazard + 3 vuln
292
- super().__init__()
293
- self.max_delta = max_delta
294
- self.net = nn.Sequential(
295
- nn.Linear(input_size, 32),
296
- nn.GELU(),
297
- nn.Linear(32, 1),
298
- )
299
- # Initialize near zero so model starts at GLM baseline
300
- nn.init.zeros_(self.net[-1].weight)
301
- nn.init.zeros_(self.net[-1].bias)
302
-
303
- def forward(self, h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev, prod_loss, basis_risk, severity):
304
- combined = torch.cat([
305
- h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev, prod_loss, basis_risk, severity
306
- ], dim=-1)
307
- delta_nn = self.net(combined)
308
- delta_nn = torch.clamp(delta_nn, -self.max_delta, self.max_delta)
309
- return delta_nn
310
-
311
- class HeatRiskNeuralPricer(nn.Module):
312
- """
313
- Full neural actuarial pricing model.
314
-
315
- Composes TemporalEncoder + HazardHead + VulnerabilityHead + PricingHead.
316
- ~70K parameters. Trains in ~15 min on CPU.
317
- """
318
-
319
- def __init__(self, input_size: int = 11, hidden_size: int = 128):
320
- super().__init__()
321
- self.encoder = TemporalEncoder(input_size, hidden_size)
322
- self.hazard = HazardHead(hidden_size)
323
- self.vulnerability = VulnerabilityHead(hidden_size + 3)
324
- self.pricing = PricingHead(hidden_size + 6)
325
-
326
- def forward(self, x, payout_per_event: float = 10.0,
327
- admin_rate: float = 0.15):
328
- """
329
- Args:
330
- x: (batch, 90, 11) — daily climate sequence with zone features
331
- payout_per_event: USD per event per worker
332
- admin_rate: operational overhead rate
333
-
334
- Returns:
335
- dict with all intermediate and final outputs
336
- """
337
- # Encode temporal sequence
338
- h = self.encoder(x) # (batch, 128)
339
-
340
- # Hazard: frequency + severity + two-tier trigger
341
- lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev = self.hazard(h)
342
-
343
- # Vulnerability: worker impact
344
- prod_loss, basis_risk, severity = self.vulnerability(
345
- h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev
346
- )
347
-
348
- # GPD expected severity
349
- xi_safe = torch.clamp(xi, max=0.39)
350
- expected_severity = sigma / (1.0 - xi_safe)
351
-
352
- # GLM baseline
353
- base_cost = lambda_ * payout_per_event
354
- basis_loading = base_cost * (basis_risk * 0.5)
355
- vuln_loading = base_cost * (prod_loss * 0.2)
356
- subtotal = base_cost + basis_loading + vuln_loading
357
- admin_loading = subtotal * admin_rate
358
- glm_total = subtotal + admin_loading
359
- eta_glm = torch.log(glm_total + 1e-8)
360
-
361
- # CANN: neural correction
362
- delta_nn = self.pricing(
363
- h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev,
364
- prod_loss, basis_risk, severity
365
- )
366
-
367
- total_per_worker = torch.exp(eta_glm + delta_nn) * 1.05
368
-
369
- return {
370
- "lambda_": lambda_.squeeze(-1),
371
- "sigma": sigma.squeeze(-1),
372
- "xi": xi.squeeze(-1),
373
- "alert_prob": alert_prob.squeeze(-1),
374
- "payout_prob": payout_prob.squeeze(-1),
375
- "alert_severity": alert_sev.squeeze(-1),
376
- "payout_severity": payout_sev.squeeze(-1),
377
- "productivity_loss": prod_loss.squeeze(-1),
378
- "basis_risk": basis_risk.squeeze(-1),
379
- "severity_multiplier": severity.squeeze(-1),
380
- "expected_severity": expected_severity.squeeze(-1),
381
- "delta_nn": delta_nn.squeeze(-1),
382
- "glm_price": (glm_total * 1.05).squeeze(-1),
383
- "total_per_worker": total_per_worker.squeeze(-1),
384
- "base_cost": base_cost.squeeze(-1),
385
- "basis_loading": basis_loading.squeeze(-1),
386
- "vuln_loading": vuln_loading.squeeze(-1),
387
- "admin_loading": admin_loading.squeeze(-1),
388
- }
389
-
390
- class HeatRiskNeuralPricerChronos(nn.Module):
391
- """
392
- Chronos-enhanced neural actuarial pricing model.
393
-
394
- Same 3-head architecture as HeatRiskNeuralPricer but replaces the
395
- LSTM TemporalEncoder with frozen Chronos-Bolt-Tiny embeddings.
396
- Only the projection layer + heads are trainable (~50K params).
397
-
398
- The Chronos foundation model (9M params, pre-trained on 100B+
399
- time-series observations) captures deep temporal patterns in the
400
- WBGT heat stress signal, while climate summary stats and zone
401
- features provide domain context.
402
- """
403
-
404
- def __init__(self, hidden_size: int = 128, chronos_d_model: int = 256):
405
- super().__init__()
406
- self.encoder = ChronosEncoder(hidden_size, chronos_d_model)
407
- self.hazard = HazardHead(hidden_size)
408
- self.vulnerability = VulnerabilityHead(hidden_size + 7) # + 7 hazard outputs
409
- self.pricing = PricingHead(hidden_size + 10, max_delta=1.5) # + 7 hazard + 3 vuln
410
-
411
- def forward(self, x, payout_per_event: float = 10.0,
412
- admin_rate: float = 0.15, chronos_embeddings=None):
413
- h = self.encoder(x, chronos_embeddings=chronos_embeddings)
414
-
415
- lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev = self.hazard(h)
416
- prod_loss, basis_risk, severity = self.vulnerability(
417
- h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev
418
- )
419
-
420
- xi_safe = torch.clamp(xi, max=0.39)
421
- expected_severity = sigma / (1.0 - xi_safe)
422
-
423
- base_cost = lambda_ * payout_per_event
424
- basis_loading = base_cost * (basis_risk * 0.5)
425
- vuln_loading = base_cost * (prod_loss * 0.2)
426
- subtotal = base_cost + basis_loading + vuln_loading
427
- admin_loading = subtotal * admin_rate
428
- glm_total = subtotal + admin_loading
429
- eta_glm = torch.log(glm_total + 1e-8)
430
-
431
- delta_nn = self.pricing(
432
- h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev,
433
- prod_loss, basis_risk, severity
434
- )
435
-
436
- total_per_worker = torch.exp(eta_glm + delta_nn) * 1.05
437
-
438
- return {
439
- "lambda_": lambda_.squeeze(-1),
440
- "sigma": sigma.squeeze(-1),
441
- "xi": xi.squeeze(-1),
442
- "alert_prob": alert_prob.squeeze(-1),
443
- "payout_prob": payout_prob.squeeze(-1),
444
- "alert_severity": alert_sev.squeeze(-1),
445
- "payout_severity": payout_sev.squeeze(-1),
446
- "productivity_loss": prod_loss.squeeze(-1),
447
- "basis_risk": basis_risk.squeeze(-1),
448
- "severity_multiplier": severity.squeeze(-1),
449
- "expected_severity": expected_severity.squeeze(-1),
450
- "delta_nn": delta_nn.squeeze(-1),
451
- "glm_price": (glm_total * 1.05).squeeze(-1),
452
- "total_per_worker": total_per_worker.squeeze(-1),
453
- "base_cost": base_cost.squeeze(-1),
454
- "basis_loading": basis_loading.squeeze(-1),
455
- "vuln_loading": vuln_loading.squeeze(-1),
456
- "admin_loading": admin_loading.squeeze(-1),
457
- }
458
-
459
-
460
- # ══════════════════════════════════════════════════════════════════════════
461
- # Training Data Generation
462
- # ══════════════════════════════════════════════════════════════════════════
463
-
464
- def load_climate_data(data_path: str | Path) -> dict[str, list[dict]]:
465
- """Load NASA POWER or ERA5-Land daily data from JSON."""
466
- with open(data_path) as f:
467
- return json.load(f)
468
-
469
-
470
- def build_training_samples(
471
- climate_data: dict[str, list[dict]],
472
- zones: list,
473
- window_size: int = 90,
474
- stride: int = 7,
475
- wbgt_threshold: float = 28.0,
476
- ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
477
- """
478
- Build training dataset from daily climate records.
479
-
480
- Creates sliding windows of `window_size` days, with targets computed
481
- from the window's heat characteristics.
482
-
483
- Returns:
484
- X: (N, window_size, 11) — feature sequences
485
- targets: dict of target arrays, each shape (N,)
486
- """
487
- zone_map = {z.zone_id: z for z in zones}
488
-
489
- all_X = []
490
- all_targets = {
491
- "frequency": [],
492
- "gpd_sigma": [],
493
- "gpd_xi": [],
494
- "productivity_loss": [],
495
- "basis_risk": [],
496
- "severity_multiplier": [],
497
- "price_target": [],
498
- "alert_event": [], # 1 if moderate heat event (cash tier)
499
- "payout_event": [], # 1 if severe sustained event (insurance tier)
500
- "alert_severity": [], # 0-1 cash amount scaling
501
- "payout_severity": [], # 0-1 insurance payout scaling
502
- }
503
-
504
- for zone_id, records in climate_data.items():
505
- zone = zone_map.get(zone_id)
506
- if zone is None:
507
- continue
508
-
509
- # Encode zone static features
510
- # Settlement type is NOT included — the model should learn zone
511
- # differences from UHI-corrected temperatures in the sequence,
512
- # not from a categorical label that creates spurious correlations.
513
- vuln_enc = {"high": 1.0, "moderate": 0.5, "low": 0.0}
514
- zone_static = [
515
- vuln_enc.get(zone.heat_vulnerability, 0.5),
516
- zone.outdoor_exposure_pct,
517
- zone.elevation_m / 2000.0,
518
- 0.0, # padding to keep feature dim at 11
519
- ]
520
-
521
- # UHI parameters for this zone
522
- uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
523
- mean_uhi = (uhi_lo + uhi_hi) / 2.0
524
-
525
- n = len(records)
526
- if n < window_size + 30:
527
- continue
528
-
529
- for start in range(0, n - window_size - 7, stride):
530
- window = records[start:start + window_size]
531
-
532
- # Apply UHI correction so each zone sees different temps
533
- # even when the underlying grid data is the same
534
- rng_uhi = np.random.RandomState(hash(zone_id) & 0x7FFFFFFF)
535
- uhi_noise_std = (uhi_hi - uhi_lo) / 4.0 # small daily variation
536
-
537
- seq = []
538
- wbgts = [] # UHI-corrected (what workers feel)
539
- grid_wbgts = [] # raw grid (what satellite measures)
540
- for i, day in enumerate(window):
541
- t_max_grid = day.get("temp_max_c") or 30.0
542
- t_min = day.get("temp_min_c") or 24.0
543
- hum = day.get("humidity_pct") or 75.0
544
- wind = day.get("wind_speed_ms") or 3.0
545
- solar = day.get("solar_rad_wm2") or 200.0
546
- precip = day.get("precip_mm") or 0.0
547
-
548
- # Zone-specific UHI correction on temperature
549
- uhi_delta = mean_uhi + rng_uhi.normal(0, uhi_noise_std)
550
- t_max = t_max_grid + uhi_delta
551
-
552
- wbgt = calculate_wbgt(t_max, hum)
553
- wbgts.append(wbgt)
554
-
555
- grid_wbgts.append(calculate_wbgt(t_max_grid, hum))
556
-
557
- features = [
558
- t_max / 40.0, # UHI-corrected temp (zone-specific!)
559
- t_min / 30.0,
560
- hum / 100.0,
561
- wind / 10.0,
562
- solar / 400.0,
563
- precip / 50.0,
564
- wbgt / 40.0, # UHI-corrected WBGT
565
- ] + zone_static
566
-
567
- seq.append(features)
568
-
569
- all_X.append(seq)
570
-
571
- # ── Compute targets ──
572
-
573
- # 1. Hazard: frequency from compound triggers (consecutive days above threshold)
574
- zone_thresh = SETTLEMENT_THRESHOLDS.get(zone.settlement_type, wbgt_threshold)
575
- run_length = 0
576
- run_peak = 0.0
577
- events = []
578
- for w in wbgts:
579
- if w > zone_thresh:
580
- run_length += 1
581
- run_peak = max(run_peak, w)
582
- else:
583
- if run_length >= MIN_CONSECUTIVE_DAYS:
584
- events.append(run_peak - zone_thresh) # severity = peak excess
585
- run_length = 0
586
- run_peak = 0.0
587
- if run_length >= MIN_CONSECUTIVE_DAYS:
588
- events.append(run_peak - zone_thresh)
589
-
590
- exceedances = events # for GPD fitting
591
- frequency = len(events) * (365.0 / window_size) # annualize
592
-
593
- if len(exceedances) >= 3:
594
- try:
595
- from scipy.stats import genpareto
596
- xi_fit, _, sigma_fit = genpareto.fit(exceedances, floc=0)
597
- xi_fit = max(-0.4, min(0.4, xi_fit))
598
- sigma_fit = max(0.1, min(10.0, sigma_fit))
599
- except Exception:
600
- sigma_fit = float(np.std(exceedances)) + 0.5
601
- xi_fit = 0.1
602
- else:
603
- sigma_fit = 1.0
604
- xi_fit = 0.05
605
-
606
- # 2. Vulnerability: WHO/ILO dose-response on UHI-corrected WBGT
607
- prod_losses = [who_productivity_loss(w) for w in wbgts]
608
- mean_prod_loss = float(np.mean(prod_losses)) * zone.outdoor_exposure_pct
609
-
610
- # 3. Basis risk: gap between grid trigger and UHI-corrected trigger
611
- grid_triggers = sum(1 for w in grid_wbgts if w > wbgt_threshold)
612
- corrected_triggers = sum(1 for w in wbgts if w > wbgt_threshold)
613
- if corrected_triggers > 0:
614
- false_negative_rate = max(0, corrected_triggers - grid_triggers) / corrected_triggers
615
- else:
616
- false_negative_rate = 0.0
617
- basis_risk_score = 0.3 * false_negative_rate + 0.2 * (mean_uhi / 6.0)
618
- basis_risk_score = min(1.0, basis_risk_score + 0.05) # floor
619
-
620
- # 4. Severity multiplier: ratio of corrected to grid impact
621
- grid_impact = sum(who_productivity_loss(w) for w in wbgts)
622
- corrected_impact = sum(who_productivity_loss(w) for w in wbgts)
623
- severity_mult = (corrected_impact / max(grid_impact, 0.01))
624
- severity_mult = max(0.5, min(3.0, severity_mult))
625
-
626
- # 5. Price target from the real ActuarialPricer
627
- from src.pricing.actuarial import ActuarialPricer
628
- _glm = ActuarialPricer()
629
- payout = 10.0
630
- glm_result = _glm.price_zone(
631
- zone=zone,
632
- predicted_frequency=frequency,
633
- basis_risk_score=basis_risk_score,
634
- payout_per_event=payout,
635
- enrolled=zone.worker_population_est,
636
- )
637
- price_target = glm_result.cost_per_worker_year
638
-
639
- # 6. Two-tier event detection (matching insurance benchmark)
640
- # Uses the same threshold as the parametric trigger — duration
641
- # is the discriminator, not a separate severity gate.
642
- TRIGGER_WBGT = wbgt_threshold # 35.1°C for Dar es Salaam
643
- last_7 = wbgts[-7:]
644
- vuln_mult = {"high": 1.5, "moderate": 1.0, "low": 0.7}
645
- v_mult = vuln_mult.get(zone.heat_vulnerability, 1.0)
646
-
647
- # Count consecutive days above trigger threshold at end of window
648
- consec_at_end = 0
649
- for w in reversed(last_7):
650
- if w > TRIGGER_WBGT:
651
- consec_at_end += 1
652
- else:
653
- break
654
- peak_wbgt = max(last_7) if last_7 else 0
655
-
656
- # Alert tier: 2+ consecutive days above threshold
657
- # Workers get cash transfer + safety SMS
658
- alert_event = 1.0 if consec_at_end >= 2 else 0.0
659
- if alert_event > 0:
660
- peak_excess = max(0, peak_wbgt - TRIGGER_WBGT)
661
- alert_sev = min(1.0, (consec_at_end / 5.0) * (peak_excess / 4.0) * v_mult)
662
- else:
663
- alert_sev = 0.0
664
-
665
- # Payout tier: 5+ consecutive days above threshold
666
- # Workers get full insurance payout
667
- payout_event = 1.0 if consec_at_end >= 5 else 0.0
668
- if payout_event > 0:
669
- peak_excess = max(0, peak_wbgt - TRIGGER_WBGT)
670
- payout_sev = min(1.0, (consec_at_end / 7.0) * (peak_excess / 3.0) * v_mult)
671
- else:
672
- payout_sev = 0.0
673
-
674
- all_targets["frequency"].append(frequency)
675
- all_targets["gpd_sigma"].append(sigma_fit)
676
- all_targets["gpd_xi"].append(xi_fit)
677
- all_targets["productivity_loss"].append(mean_prod_loss)
678
- all_targets["basis_risk"].append(basis_risk_score)
679
- all_targets["severity_multiplier"].append(severity_mult)
680
- all_targets["price_target"].append(price_target)
681
- all_targets["alert_event"].append(alert_event)
682
- all_targets["payout_event"].append(payout_event)
683
- all_targets["alert_severity"].append(alert_sev)
684
- all_targets["payout_severity"].append(payout_sev)
685
-
686
- X = np.array(all_X, dtype=np.float32)
687
- targets = {k: np.array(v, dtype=np.float32) for k, v in all_targets.items()}
688
-
689
- return X, targets
690
-
691
-
692
- # ══════════════════════════════════════════════════════════════════════════
693
- # Trainer
694
- # ══════════════════════════════════════════════════════════════════════════
695
-
696
- class NeuralPricerTrainer:
697
- """Train the HeatRiskNeuralPricer (LSTM or Chronos encoder) on climate data."""
698
-
699
- def __init__(self, lr: float = 1e-3, epochs: int = 80,
700
- patience: int = 10, weight_decay: float = 1e-4,
701
- encoder: str = "lstm"):
702
- if not TORCH_AVAILABLE:
703
- raise ImportError("torch is required")
704
- self.lr = lr
705
- self.epochs = epochs
706
- self.patience = patience
707
- self.weight_decay = weight_decay
708
- self.encoder = encoder
709
- if encoder == "chronos":
710
- self.model_path = PROJECT_ROOT / "models" / "chronos_pricer_dar.pt"
711
- self.norm_path = PROJECT_ROOT / "models" / "chronos_pricer_dar_norm.json"
712
- else:
713
- self.model_path = PROJECT_ROOT / "models" / "neural_pricer_dar.pt"
714
- self.norm_path = PROJECT_ROOT / "models" / "neural_pricer_dar_norm.json"
715
-
716
- def train(self, X: np.ndarray, targets: dict[str, np.ndarray],
717
- val_split: float = 0.2) -> dict:
718
- """
719
- Train the model and return metrics.
720
-
721
- Args:
722
- X: (N, 90, 11) feature sequences
723
- targets: dict of target arrays
724
- val_split: fraction for validation (temporal split)
725
- """
726
- torch.manual_seed(42)
727
- np.random.seed(42)
728
-
729
- N = len(X)
730
- split = int(N * (1 - val_split))
731
-
732
- # Temporal split (not random — avoids data leakage)
733
- X_train, X_val = X[:split], X[split:]
734
- t_train = {k: v[:split] for k, v in targets.items()}
735
- t_val = {k: v[split:] for k, v in targets.items()}
736
-
737
- # ── Pre-compute Chronos embeddings (before z-score norm) ──
738
- chronos_train = chronos_val = None
739
- chronos_d_model = 256 # default
740
- if self.encoder == "chronos":
741
- if not CHRONOS_AVAILABLE:
742
- raise ImportError("chronos-forecasting required for --encoder chronos")
743
- print(" Loading Chronos-Bolt-Tiny for embedding pre-computation...")
744
- pipeline = ChronosBoltPipeline.from_pretrained(
745
- "amazon/chronos-bolt-tiny", device_map="cpu",
746
- dtype=torch.float32,
747
- )
748
- chronos_d_model = pipeline.model.config.d_model
749
- print(f" Chronos d_model: {chronos_d_model}")
750
-
751
- # Extract raw WBGT (before normalization): column 6, scaled as wbgt/40
752
- wbgt_raw_all = X[:, :, 6] * 40.0 # (N, 90) in °C
753
- print(f" Pre-computing Chronos embeddings for {N} samples...")
754
- all_embs = []
755
- chunk = 256
756
- for i in range(0, N, chunk):
757
- batch = torch.from_numpy(wbgt_raw_all[i:i + chunk].astype(np.float32))
758
- with torch.no_grad():
759
- emb, _ = pipeline.embed(batch)
760
- all_embs.append(emb.mean(dim=1)) # (chunk, d_model)
761
- chronos_all = torch.cat(all_embs, dim=0) # (N, d_model)
762
- chronos_train = chronos_all[:split]
763
- chronos_val = chronos_all[split:]
764
- print(f" Chronos embeddings: {chronos_all.shape}")
765
- del pipeline # free memory
766
-
767
- # Normalize features (z-score from training set)
768
- flat = X_train.reshape(-1, X_train.shape[-1])
769
- feat_mean = flat.mean(axis=0)
770
- feat_std = np.maximum(flat.std(axis=0), 1e-6)
771
-
772
- self.norm_path.parent.mkdir(parents=True, exist_ok=True)
773
- norm_data = {"mean": feat_mean.tolist(), "std": feat_std.tolist()}
774
- if self.encoder == "chronos":
775
- norm_data["chronos_d_model"] = chronos_d_model
776
- with open(self.norm_path, "w") as f:
777
- json.dump(norm_data, f)
778
-
779
- X_train = (X_train - feat_mean) / feat_std
780
- X_val = (X_val - feat_mean) / feat_std
781
-
782
- # Convert to tensors
783
- X_train_t = torch.from_numpy(X_train)
784
- X_val_t = torch.from_numpy(X_val)
785
- targets_train = {k: torch.from_numpy(v) for k, v in t_train.items()}
786
- targets_val = {k: torch.from_numpy(v) for k, v in t_val.items()}
787
-
788
- if self.encoder == "chronos":
789
- model = HeatRiskNeuralPricerChronos(chronos_d_model=chronos_d_model)
790
- else:
791
- model = HeatRiskNeuralPricer()
792
- optimizer = torch.optim.AdamW(
793
- model.parameters(), lr=self.lr, weight_decay=self.weight_decay
794
- )
795
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
796
- optimizer, patience=5, factor=0.5
797
- )
798
-
799
- total_params = sum(p.numel() for p in model.parameters())
800
- print(f" Encoder: {self.encoder}")
801
- print(f" Model parameters: {total_params:,}")
802
- print(f" Training samples: {len(X_train)}, Validation: {len(X_val)}")
803
-
804
- best_val_loss = float("inf")
805
- patience_counter = 0
806
- best_state = None
807
- best_metrics = {}
808
-
809
- batch_size = 256
810
- n_batches = max(1, len(X_train) // batch_size)
811
-
812
- for epoch in range(self.epochs):
813
- # ── Train ──
814
- model.train()
815
- perm = torch.randperm(len(X_train_t))
816
- epoch_loss = 0.0
817
-
818
- for b in range(n_batches):
819
- idx = perm[b * batch_size:(b + 1) * batch_size]
820
- xb = X_train_t[idx]
821
- tb = {k: v[idx] for k, v in targets_train.items()}
822
-
823
- fwd_kwargs = {}
824
- if chronos_train is not None:
825
- fwd_kwargs["chronos_embeddings"] = chronos_train[idx]
826
-
827
- outputs = model(xb, **fwd_kwargs)
828
- loss = self._compute_loss(outputs, tb)
829
-
830
- optimizer.zero_grad()
831
- loss.backward()
832
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
833
- optimizer.step()
834
- epoch_loss += loss.item()
835
-
836
- avg_train_loss = epoch_loss / n_batches
837
-
838
- # ── Validate ──
839
- model.eval()
840
- with torch.no_grad():
841
- val_kwargs = {}
842
- if chronos_val is not None:
843
- val_kwargs["chronos_embeddings"] = chronos_val
844
- val_outputs = model(X_val_t, **val_kwargs)
845
- val_loss = self._compute_loss(val_outputs, targets_val).item()
846
-
847
- scheduler.step(val_loss)
848
-
849
- if (epoch + 1) % 10 == 0 or epoch == 0:
850
- delta_std = val_outputs["delta_nn"].std().item()
851
- print(
852
- f" Epoch {epoch + 1:>3}: "
853
- f"train={avg_train_loss:.4f} val={val_loss:.4f} "
854
- f"δ_NN_std={delta_std:.3f}"
855
- )
856
-
857
- if val_loss < best_val_loss:
858
- best_val_loss = val_loss
859
- patience_counter = 0
860
- best_state = {k: v.clone() for k, v in model.state_dict().items()}
861
- best_metrics = self._compute_metrics(val_outputs, targets_val, epoch + 1)
862
- else:
863
- patience_counter += 1
864
- if patience_counter >= self.patience:
865
- print(f" Early stopping at epoch {epoch + 1}")
866
- break
867
-
868
- # Save best model
869
- if best_state:
870
- model.load_state_dict(best_state)
871
- self.model_path.parent.mkdir(parents=True, exist_ok=True)
872
- torch.save(model.state_dict(), self.model_path)
873
-
874
- size_kb = self.model_path.stat().st_size / 1024
875
- print(f" Saved to {self.model_path} ({size_kb:.0f} KB)")
876
-
877
- return best_metrics
878
-
879
- def _compute_loss(self, outputs, targets):
880
- """Multi-task loss combining hazard, vulnerability, and pricing."""
881
- # Hazard: Poisson NLL for frequency
882
- lambda_ = outputs["lambda_"]
883
- freq_target = targets["frequency"]
884
- poisson_nll = lambda_ - freq_target * torch.log(lambda_ + 1e-8)
885
-
886
- # Hazard: GPD parameter MSE (as proxy for GPD deviance)
887
- sigma_loss = F.mse_loss(outputs["sigma"], targets["gpd_sigma"])
888
- xi_loss = F.mse_loss(outputs["xi"], targets["gpd_xi"])
889
-
890
- # Two-tier triggers: BCE for event detection, MSE for severity
891
- L_alert = F.binary_cross_entropy(
892
- outputs["alert_prob"], targets["alert_event"]
893
- )
894
- L_payout = F.binary_cross_entropy(
895
- outputs["payout_prob"], targets["payout_event"]
896
- )
897
- L_alert_sev = F.mse_loss(
898
- outputs["alert_severity"], targets["alert_severity"]
899
- )
900
- L_payout_sev = F.mse_loss(
901
- outputs["payout_severity"], targets["payout_severity"]
902
- )
903
- L_trigger = L_alert + L_payout
904
- L_severity = L_alert_sev + L_payout_sev
905
-
906
- L_hazard = poisson_nll.mean() + sigma_loss + xi_loss + L_trigger + L_severity
907
-
908
- # Vulnerability: MSE vs WHO/ILO targets
909
- L_vuln = (
910
- F.mse_loss(outputs["productivity_loss"], targets["productivity_loss"])
911
- + F.mse_loss(outputs["basis_risk"], targets["basis_risk"])
912
- + F.mse_loss(outputs["severity_multiplier"], targets["severity_multiplier"])
913
- )
914
-
915
- # Pricing: log-MSE (only on samples with nonzero target)
916
- price_pred = outputs["total_per_worker"]
917
- price_target = targets["price_target"]
918
- valid_mask = price_target > 1.0 # skip zero-frequency windows
919
- if valid_mask.any():
920
- L_pricing = F.mse_loss(
921
- torch.log(price_pred[valid_mask] + 1.0),
922
- torch.log(price_target[valid_mask] + 1.0),
923
- )
924
- else:
925
- L_pricing = torch.tensor(0.0)
926
-
927
- # Regularization: penalize large neural corrections
928
- L_reg = 0.01 * (outputs["delta_nn"] ** 2).mean()
929
-
930
- return 1.0 * L_hazard + 1.0 * L_vuln + 2.0 * L_pricing + 0.1 * L_reg
931
-
932
- def _compute_metrics(self, outputs, targets, epoch):
933
- """Compute evaluation metrics from validation outputs."""
934
- price_pred = outputs["total_per_worker"].detach().numpy()
935
- price_target = targets["price_target"].numpy()
936
-
937
- # MAPE (only on nonzero targets)
938
- valid = price_target > 1.0
939
- if valid.any():
940
- mape = float(np.mean(np.abs(price_pred[valid] - price_target[valid]) / price_target[valid]) * 100)
941
- else:
942
- mape = float("nan")
943
-
944
- # Spearman rank correlation
945
- from scipy.stats import spearmanr
946
- rho, _ = spearmanr(price_pred, price_target)
947
-
948
- # Delta NN statistics
949
- delta = outputs["delta_nn"].detach().numpy()
950
-
951
- return {
952
- "epoch": epoch,
953
- "val_loss": float(outputs["total_per_worker"].mean().item()),
954
- "price_mape_pct": round(mape, 1),
955
- "rank_correlation": round(float(rho), 3),
956
- "delta_nn_mean": round(float(np.mean(delta)), 4),
957
- "delta_nn_std": round(float(np.std(delta)), 4),
958
- "mean_lambda": round(float(outputs["lambda_"].mean().item()), 1),
959
- "mean_basis_risk": round(float(outputs["basis_risk"].mean().item()), 3),
960
- "mean_prod_loss": round(float(outputs["productivity_loss"].mean().item()), 3),
961
- }
962
-
963
-
964
- # ══════════════════════════════════════════════════════════════════════════
965
- # Inference Wrapper (drop-in for ActuarialPricer)
966
- # ══════════════════════════════════════════════════════════════════════════
967
-
968
- class NeuralActuarialPricer:
969
- """
970
- Drop-in replacement for ActuarialPricer.
971
-
972
- Preserves the same price_zone() signature and returns ActuarialResult.
973
- Three-tier fallback: Chronos encoder → LSTM encoder → GLM baseline.
974
- """
975
-
976
- def __init__(
977
- self,
978
- admin_rate: float = 0.15,
979
- ):
980
- self.admin_rate = admin_rate
981
- self._fallback = ActuarialPricer(admin_rate)
982
- self._model: Optional[object] = None
983
- self._norm: Optional[dict] = None
984
- self._chronos_pipeline = None
985
- self._encoder_type = "glm"
986
- self._trigger_head = None
987
-
988
- if not TORCH_AVAILABLE:
989
- log.warning("torch not available, using fallback pricer")
990
- return
991
-
992
- # Tier 1: Try Chronos encoder
993
- chronos_mp = PROJECT_ROOT / "models" / "chronos_pricer_dar.pt"
994
- chronos_np = PROJECT_ROOT / "models" / "chronos_pricer_dar_norm.json"
995
-
996
- if CHRONOS_AVAILABLE and chronos_mp.exists() and chronos_np.exists():
997
- try:
998
- with open(chronos_np) as f:
999
- self._norm = json.load(f)
1000
- d_model = self._norm.get("chronos_d_model", 256)
1001
- self._model = HeatRiskNeuralPricerChronos(chronos_d_model=d_model)
1002
- self._model.load_state_dict(
1003
- torch.load(chronos_mp, map_location="cpu", weights_only=True)
1004
- )
1005
- self._model.eval()
1006
- # Load Chronos pipeline
1007
- self._chronos_pipeline = ChronosBoltPipeline.from_pretrained(
1008
- "amazon/chronos-bolt-tiny", device_map="cpu",
1009
- dtype=torch.float32,
1010
- )
1011
- # Wire up the encoder
1012
- feat_mean = np.array(self._norm["mean"], dtype=np.float32)
1013
- feat_std = np.array(self._norm["std"], dtype=np.float32)
1014
- self._model.encoder.set_pipeline(
1015
- self._chronos_pipeline, feat_mean, feat_std
1016
- )
1017
- self._encoder_type = "chronos"
1018
- log.info("Chronos neural actuarial pricer loaded (d_model=%d)", d_model)
1019
-
1020
- # Load retrained trigger head (benchmark-driven, event-level)
1021
- trigger_path = PROJECT_ROOT / "models" / "trigger_head_retrained.pt"
1022
- if trigger_path.exists():
1023
- try:
1024
- ckpt = torch.load(trigger_path, map_location="cpu", weights_only=True)
1025
- d_in = ckpt["d_model"]
1026
-
1027
- class _TriggerHead(torch.nn.Module):
1028
- def __init__(self, d_in):
1029
- super().__init__()
1030
- self.net = torch.nn.Sequential(
1031
- torch.nn.Linear(d_in, 128), torch.nn.GELU(), torch.nn.Dropout(0.3),
1032
- torch.nn.Linear(128, 64), torch.nn.GELU(), torch.nn.Dropout(0.2),
1033
- torch.nn.Linear(64, 3),
1034
- )
1035
- def forward(self, x):
1036
- return self.net(x)
1037
-
1038
- self._trigger_head = _TriggerHead(d_in)
1039
- self._trigger_head.load_state_dict(ckpt["state_dict"])
1040
- self._trigger_head.eval()
1041
- log.info("Retrained trigger head loaded (d_in=%d)", d_in)
1042
- except Exception as e:
1043
- log.warning("Retrained trigger head failed to load: %s", e)
1044
- self._trigger_head = None
1045
-
1046
- return
1047
- except Exception as e:
1048
- log.warning("Chronos pricer failed to load: %s — trying LSTM", e)
1049
- self._model = None
1050
- self._norm = None
1051
- self._chronos_pipeline = None
1052
-
1053
- # Tier 2: Try LSTM encoder
1054
- lstm_mp = PROJECT_ROOT / "models" / "neural_pricer_dar.pt"
1055
- lstm_np = PROJECT_ROOT / "models" / "neural_pricer_dar_norm.json"
1056
-
1057
- if lstm_mp.exists() and lstm_np.exists():
1058
- try:
1059
- self._model = HeatRiskNeuralPricer()
1060
- self._model.load_state_dict(
1061
- torch.load(lstm_mp, map_location="cpu", weights_only=True)
1062
- )
1063
- self._model.eval()
1064
- with open(lstm_np) as f:
1065
- self._norm = json.load(f)
1066
- self._encoder_type = "lstm"
1067
- log.info("LSTM neural actuarial pricer loaded from %s", lstm_mp)
1068
- except Exception as e:
1069
- log.warning("LSTM pricer failed to load: %s — using GLM fallback", e)
1070
- self._model = None
1071
- else:
1072
- log.info("No neural pricer weights found, using GLM fallback")
1073
-
1074
- @property
1075
- def is_neural(self) -> bool:
1076
- return self._model is not None
1077
-
1078
- def price_zone(
1079
- self,
1080
- zone,
1081
- predicted_frequency: float,
1082
- basis_risk_score: float,
1083
- payout_per_event: float,
1084
- enrolled: int,
1085
- climate_history: Optional[list[dict]] = None,
1086
- ) -> ActuarialResult:
1087
- """
1088
- Price a zone using the neural model if available.
1089
-
1090
- Args:
1091
- zone: UrbanZone from config
1092
- predicted_frequency: annual trigger events (used by fallback)
1093
- basis_risk_score: 0-1 (used by fallback)
1094
- payout_per_event: USD per event per worker
1095
- enrolled: number of workers
1096
- climate_history: optional list of daily dicts with
1097
- temp_max_c, temp_min_c, humidity_pct, wind_speed_ms, etc.
1098
- If provided and model is loaded, uses neural pricing.
1099
- """
1100
- if self._model is None or climate_history is None or len(climate_history) < 30:
1101
- return self._fallback.price_zone(
1102
- zone, predicted_frequency, basis_risk_score,
1103
- payout_per_event, enrolled,
1104
- )
1105
-
1106
- try:
1107
- return self._neural_price(
1108
- zone, payout_per_event, enrolled, climate_history
1109
- )
1110
- except Exception as e:
1111
- log.warning("Neural pricing failed for %s, using fallback: %s",
1112
- zone.zone_id, e)
1113
- return self._fallback.price_zone(
1114
- zone, predicted_frequency, basis_risk_score,
1115
- payout_per_event, enrolled,
1116
- )
1117
-
1118
- def _neural_price(
1119
- self, zone, payout_per_event: float, enrolled: int,
1120
- climate_history: list[dict],
1121
- ) -> ActuarialResult:
1122
- """Run the neural model and construct ActuarialResult."""
1123
- # Build feature sequence (last 90 days)
1124
- # Must match build_training_samples zone_static encoding
1125
- vuln_enc = {"high": 1.0, "moderate": 0.5, "low": 0.0}
1126
- zone_static = [
1127
- vuln_enc.get(zone.heat_vulnerability, 0.5),
1128
- zone.outdoor_exposure_pct,
1129
- zone.elevation_m / 2000.0,
1130
- 0.0,
1131
- ]
1132
-
1133
- history = climate_history[-90:]
1134
- if len(history) < 90:
1135
- # Pad with first record
1136
- pad = [history[0]] * (90 - len(history))
1137
- history = pad + history
1138
-
1139
- # Apply UHI correction (same as training)
1140
- uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
1141
- mean_uhi = (uhi_lo + uhi_hi) / 2.0
1142
-
1143
- seq = []
1144
- for day in history:
1145
- t_max_grid = day.get("temp_max_c") or day.get("temp_c") or 30.0
1146
- t_max = float(t_max_grid) + mean_uhi # UHI-corrected
1147
- t_min = day.get("temp_min_c") or t_max - 6.0
1148
- hum = day.get("humidity_pct") or 75.0
1149
- wind = day.get("wind_speed_ms") or 3.0
1150
- solar = day.get("solar_rad_wm2") or 200.0
1151
- precip = day.get("precip_mm") or 0.0
1152
- wbgt = calculate_wbgt(t_max, float(hum))
1153
-
1154
- seq.append([
1155
- t_max / 40.0, float(t_min) / 30.0,
1156
- float(hum) / 100.0, float(wind) / 10.0,
1157
- float(solar) / 400.0, float(precip) / 50.0,
1158
- wbgt / 40.0,
1159
- ] + zone_static)
1160
-
1161
- x = np.array([seq], dtype=np.float32)
1162
-
1163
- # Normalize
1164
- if self._norm:
1165
- mean = np.array(self._norm["mean"], dtype=np.float32)
1166
- std = np.array(self._norm["std"], dtype=np.float32)
1167
- x = (x - mean) / std
1168
-
1169
- x_tensor = torch.from_numpy(x)
1170
-
1171
- # For Chronos encoder: compute embedding from raw WBGT
1172
- chronos_emb = None
1173
- if self._encoder_type == "chronos" and self._chronos_pipeline is not None:
1174
- wbgt_seq = np.array(
1175
- [calculate_wbgt(
1176
- float(day.get("temp_max_c") or day.get("temp_c") or 30.0)
1177
- + mean_uhi,
1178
- float(day.get("humidity_pct") or 75.0),
1179
- ) for day in history],
1180
- dtype=np.float32,
1181
- )
1182
- wbgt_tensor = torch.from_numpy(wbgt_seq).unsqueeze(0) # (1, 90)
1183
- with torch.no_grad():
1184
- emb, _ = self._chronos_pipeline.embed(wbgt_tensor)
1185
- chronos_emb = emb.mean(dim=1) # (1, 256)
1186
-
1187
- with torch.no_grad():
1188
- outputs = self._model(
1189
- x_tensor, payout_per_event, self.admin_rate,
1190
- **({} if chronos_emb is None else {"chronos_embeddings": chronos_emb})
1191
- )
1192
-
1193
- # Extract values
1194
- lambda_ = outputs["lambda_"].item()
1195
- sigma = outputs["sigma"].item()
1196
- xi = outputs["xi"].item()
1197
- alert_prob = outputs["alert_prob"].item()
1198
- payout_prob = outputs["payout_prob"].item()
1199
- alert_severity = outputs["alert_severity"].item()
1200
- payout_severity = outputs["payout_severity"].item()
1201
-
1202
- # Override alert/payout probabilities with the retrained trigger head
1203
- # if it's available. The retrained head was trained on event-level labels
1204
- # (benchmark-driven: 8% → 28% hit rate on insurance trigger benchmark).
1205
- # Pricing math (lambda, sigma, xi, severity) is NOT affected.
1206
- if self._trigger_head is not None and chronos_emb is not None:
1207
- try:
1208
- ALERT_THRESH = 35.1
1209
- last_7 = [calculate_wbgt(
1210
- float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
1211
- float(d.get("humidity_pct") or 75.0),
1212
- ) for d in history[-7:]]
1213
- last_14 = [calculate_wbgt(
1214
- float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
1215
- float(d.get("humidity_pct") or 75.0),
1216
- ) for d in history[-14:]]
1217
- all_wbgts = [calculate_wbgt(
1218
- float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
1219
- float(d.get("humidity_pct") or 75.0),
1220
- ) for d in history[:30]]
1221
- extra = torch.tensor([[
1222
- float(np.mean(last_7)), float(np.mean(last_14)),
1223
- max(last_7) - min(last_7),
1224
- sum(1 for w in last_14 if w >= ALERT_THRESH),
1225
- float(np.mean(last_7)) - float(np.mean(all_wbgts)) if all_wbgts else 0.0,
1226
- max(last_7),
1227
- ]], dtype=torch.float32)
1228
- trigger_input = torch.cat([chronos_emb, extra], dim=1)
1229
- with torch.no_grad():
1230
- trigger_logits = self._trigger_head(trigger_input)
1231
- trigger_probs = torch.softmax(trigger_logits, dim=1)[0]
1232
- # Map 3-class probs to alert/payout probs
1233
- alert_prob = float(trigger_probs[1] + trigger_probs[2]) # alert OR payout
1234
- payout_prob = float(trigger_probs[2]) # payout only
1235
- except Exception:
1236
- pass # fall back to original model outputs
1237
- prod_loss = outputs["productivity_loss"].item()
1238
- neural_basis_risk = outputs["basis_risk"].item()
1239
- severity_mult = outputs["severity_multiplier"].item()
1240
- delta_nn = outputs["delta_nn"].item()
1241
- total_per_worker = outputs["total_per_worker"].item()
1242
- glm_price = outputs["glm_price"].item()
1243
-
1244
- enrolled = max(enrolled, 1)
1245
-
1246
- # Decompose for transparency
1247
- base_cost = lambda_ * payout_per_event * enrolled
1248
- basis_loading = base_cost * (neural_basis_risk * 0.5)
1249
- vuln_loading = base_cost * (prod_loss * 0.2)
1250
- subtotal = base_cost + basis_loading + vuln_loading
1251
- admin_loading = subtotal * self.admin_rate
1252
-
1253
- neural_correction_pct = (total_per_worker / (glm_price + 1e-8) - 1.0) * 100
1254
-
1255
- cost_breakdown = {
1256
- "base_frequency_cost": round(base_cost, 2),
1257
- "basis_risk_adjustment": round(basis_loading, 2),
1258
- "vulnerability_adjustment": round(vuln_loading, 2),
1259
- "admin_overhead": round(admin_loading, 2),
1260
- "total": round(total_per_worker * enrolled, 2),
1261
- "neural_correction_pct": round(neural_correction_pct, 1),
1262
- "glm_baseline_per_worker": round(glm_price, 2),
1263
- "neural_price_per_worker": round(total_per_worker, 2),
1264
- "gpd_shape_xi": round(xi, 3),
1265
- "gpd_scale_sigma": round(sigma, 3),
1266
- "learned_frequency": round(lambda_, 1),
1267
- "alert_prob": round(alert_prob, 3),
1268
- "payout_prob": round(payout_prob, 3),
1269
- "alert_severity": round(alert_severity, 3),
1270
- "payout_severity": round(payout_severity, 3),
1271
- # Backward compat
1272
- "trigger_prob": round(alert_prob, 3),
1273
- "payout_factor": round(payout_severity, 3),
1274
- # Funding decomposition (SEWA pilot structure)
1275
- # Cash tier: $2-5 per event, ~12 events/year
1276
- "cash_per_event": round(2.0 + alert_severity * 3.0, 2),
1277
- # Insurance tier: $7-20 per event, ~3 events/year
1278
- "insurance_per_event": round(7.0 + payout_severity * 13.0, 2),
1279
- # Annual costs by tier
1280
- "annual_cash_cost": round((2.0 + alert_severity * 3.0) * min(lambda_, 15), 2),
1281
- "annual_insurance_cost": round((7.0 + payout_severity * 13.0) * max(0, lambda_ - 10) * 0.3, 2),
1282
- # Worker pays max $3/year (capped based on informal daily wage ~$3-5)
1283
- "worker_contribution": min(3.0, round(total_per_worker * 0.15, 2)),
1284
- # Philanthropy covers cash tier + vulnerability gap
1285
- "philanthropy_share": round(total_per_worker * 0.45, 2),
1286
- # Insurer covers remainder
1287
- "insurer_premium": round(total_per_worker * 0.40, 2),
1288
- "learned_basis_risk": round(neural_basis_risk, 3),
1289
- "productivity_loss_rate": round(prod_loss, 3),
1290
- "severity_multiplier": round(severity_mult, 3),
1291
- "explanation": (
1292
- f"{zone.name}: Neural EVT predicts {lambda_:.1f} events/year "
1293
- f"(GPD ξ={xi:.2f}, σ={sigma:.1f}), "
1294
- f"learned basis risk {neural_basis_risk:.0%}, "
1295
- f"WHO productivity loss {prod_loss:.0%}, "
1296
- f"neural correction {neural_correction_pct:+.1f}% vs GLM"
1297
- ),
1298
- }
1299
-
1300
- return ActuarialResult(
1301
- zone_id=zone.zone_id,
1302
- zone_name=zone.name,
1303
- city=zone.city,
1304
- cost_per_worker_year=round(total_per_worker, 2),
1305
- expected_annual_payouts=round(base_cost, 2),
1306
- frequency_component=round(lambda_ * payout_per_event, 2),
1307
- basis_risk_loading=round(basis_loading, 2),
1308
- vulnerability_loading=round(vuln_loading, 2),
1309
- admin_loading=round(admin_loading, 2),
1310
- cost_breakdown=cost_breakdown,
1311
- enrolled_workers=enrolled,
1312
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/eval_heat_predictor.py DELETED
@@ -1,157 +0,0 @@
1
- """Evaluate heat wave prediction models."""
2
- import json
3
- import os
4
- import numpy as np
5
- import pytest
6
- from sklearn.metrics import roc_auc_score, precision_score, recall_score
7
- from src.prediction.heat_forecast import HeatWavePredictor, CITY_THRESHOLDS
8
- from config import ZONES
9
-
10
-
11
- def _generate_test_data(zone, n_days=365, seed=123):
12
- """Generate synthetic test data with known trigger labels."""
13
- rng = np.random.RandomState(seed)
14
- # Use the city threshold from the model's own config
15
- threshold = CITY_THRESHOLDS.get(zone.city, 33.0)
16
-
17
- temps = []
18
- humidities = []
19
- wbgts = []
20
-
21
- # Seasonal pattern + noise
22
- for day in range(n_days):
23
- seasonal = 3 * np.sin(2 * np.pi * day / 365)
24
- t = threshold - 2 + seasonal + rng.randn() * 3
25
- h = 65 + rng.randn() * 10
26
- w = 0.7 * t + 0.3 * h * 0.3 - 10 # simplified WBGT
27
- temps.append(t)
28
- humidities.append(max(20, min(100, h)))
29
- wbgts.append(w)
30
-
31
- # Generate ground-truth labels (trigger if 2+ consecutive days above threshold in next 7)
32
- labels = []
33
- for i in range(n_days - 7):
34
- future = temps[i + 1:i + 8]
35
- consecutive = 0
36
- max_consec = 0
37
- for t in future:
38
- if t >= threshold:
39
- consecutive += 1
40
- max_consec = max(max_consec, consecutive)
41
- else:
42
- consecutive = 0
43
- labels.append(1 if max_consec >= 2 else 0)
44
-
45
- return temps, humidities, wbgts, labels
46
-
47
-
48
- def test_predictor_output_valid():
49
- """Predictions should be valid probabilities with confidence."""
50
- predictor = HeatWavePredictor()
51
- zone = ZONES[0]
52
- temps = [30 + np.random.randn() * 3 for _ in range(30)]
53
- humidity = [70 + np.random.randn() * 5 for _ in range(30)]
54
- wbgt = [28 + np.random.randn() * 2 for _ in range(30)]
55
-
56
- prob, conf, tier = predictor.predict(zone, temps, humidity, wbgt)
57
- assert 0 <= prob <= 1, f"Probability {prob} out of [0,1]"
58
- assert 0 <= conf <= 1, f"Confidence {conf} out of [0,1]"
59
- assert tier in ("ensemble", "full_model", "lstm_only", "persistence", "climatology")
60
-
61
-
62
- def test_predictor_tier_fallback():
63
- """Test that minimal data degrades to a fallback tier with lower confidence."""
64
- predictor = HeatWavePredictor()
65
- zone = ZONES[0]
66
-
67
- # Full data -> should get full_model, ensemble, or lstm_only
68
- full_temps = [30 + np.random.randn() * 3 for _ in range(90)]
69
- full_hum = [70 + np.random.randn() * 5 for _ in range(90)]
70
- full_wbgt = [28 + np.random.randn() * 2 for _ in range(90)]
71
- prob, conf, tier = predictor.predict(zone, full_temps, full_hum, full_wbgt)
72
- assert tier in ("ensemble", "full_model", "lstm_only")
73
-
74
- # Minimal data -> should fall back to persistence or climatology
75
- min_temps = [30, 31, 32]
76
- min_hum = [70, 70, 70]
77
- min_wbgt = [28, 28, 28]
78
- prob2, conf2, tier2 = predictor.predict(zone, min_temps, min_hum, min_wbgt)
79
- assert tier2 in ("persistence", "climatology", "ensemble", "full_model", "lstm_only")
80
- # Less data should generally mean equal or less confidence
81
- assert conf2 <= conf + 0.1, f"Minimal-data confidence ({conf2}) should not greatly exceed full-data ({conf})"
82
-
83
-
84
- def test_predictor_discrimination():
85
- """Model should assign higher probability to hot sequences."""
86
- predictor = HeatWavePredictor()
87
- zone = ZONES[0]
88
-
89
- # Hot sequence (should trigger)
90
- hot = [36 + i * 0.2 for i in range(30)]
91
- hot_hum = [80] * 30
92
- hot_wbgt = [32 + i * 0.1 for i in range(30)]
93
-
94
- # Cool sequence (should not trigger)
95
- cool = [22 + np.sin(i / 5) for i in range(30)]
96
- cool_hum = [50] * 30
97
- cool_wbgt = [20 + np.sin(i / 5) for i in range(30)]
98
-
99
- p_hot, _, _ = predictor.predict(zone, hot, hot_hum, hot_wbgt)
100
- p_cool, _, _ = predictor.predict(zone, cool, cool_hum, cool_wbgt)
101
-
102
- assert p_hot > p_cool, f"Hot prob ({p_hot:.3f}) should > cool prob ({p_cool:.3f})"
103
-
104
-
105
- def test_predictor_metrics():
106
- """Compute AUROC and calibration on synthetic held-out data."""
107
- predictor = HeatWavePredictor()
108
- results = {}
109
-
110
- # Sample one zone per city
111
- seen_cities = set()
112
- sample_zones = []
113
- for z in ZONES:
114
- if z.city not in seen_cities:
115
- sample_zones.append(z)
116
- seen_cities.add(z.city)
117
-
118
- for zone in sample_zones:
119
- temps, humidities, wbgts, labels = _generate_test_data(zone)
120
-
121
- predictions = []
122
- for i in range(30, len(labels)):
123
- prob, _, tier = predictor.predict(
124
- zone, temps[i - 30:i], humidities[i - 30:i], wbgts[i - 30:i]
125
- )
126
- predictions.append(prob)
127
-
128
- # Align labels with predictions
129
- y_true = labels[30:30 + len(predictions)]
130
- y_pred = predictions[:len(y_true)]
131
-
132
- if len(set(y_true)) > 1: # need both classes for AUROC
133
- auroc = roc_auc_score(y_true, y_pred)
134
- binary = [1 if p > 0.5 else 0 for p in y_pred]
135
- precision = precision_score(y_true, binary, zero_division=0)
136
- recall = recall_score(y_true, binary, zero_division=0)
137
- else:
138
- auroc = float('nan')
139
- precision = float('nan')
140
- recall = float('nan')
141
-
142
- results[zone.zone_id] = {
143
- "city": zone.city,
144
- "auroc": round(auroc, 3) if not np.isnan(auroc) else None,
145
- "precision": round(precision, 3) if not np.isnan(precision) else None,
146
- "recall": round(recall, 3) if not np.isnan(recall) else None,
147
- "n_samples": len(y_true),
148
- "positive_rate": round(sum(y_true) / len(y_true), 3) if y_true else 0,
149
- }
150
-
151
- os.makedirs("tests/eval_results", exist_ok=True)
152
- with open("tests/eval_results/heat_predictor_eval.json", "w") as f:
153
- json.dump(results, f, indent=2)
154
-
155
- # At least one zone should have AUROC > 0.5 (better than random)
156
- valid_aurocs = [r["auroc"] for r in results.values() if r["auroc"] is not None]
157
- assert any(a > 0.5 for a in valid_aurocs), f"No zone has AUROC > 0.5: {valid_aurocs}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/eval_neural_pricer.py DELETED
@@ -1,303 +0,0 @@
1
- """Evaluate neural actuarial pricing model against GLM baseline and sanity checks."""
2
- import json
3
- import os
4
- import numpy as np
5
- import pytest
6
- from scipy.stats import spearmanr
7
- from pathlib import Path
8
-
9
- from config import ZONES, ZONE_MAP, PRIMARY_CITY, PRIMARY_CITY_SLUG
10
- from src.pricing.neural_actuarial import NeuralActuarialPricer
11
- from src.pricing.actuarial import ActuarialPricer, ActuarialResult
12
-
13
- PROJECT_ROOT = Path(__file__).resolve().parents[1]
14
- ERA5_PATH = PROJECT_ROOT / "data" / f"era5land_{PRIMARY_CITY_SLUG}.json"
15
-
16
- ACTIVE_ZONES = [z for z in ZONES if z.city == PRIMARY_CITY]
17
-
18
- # Default pricing parameters
19
- PAYOUT_PER_EVENT = 10.0
20
- DEFAULT_FREQUENCY = 12.0
21
- DEFAULT_BASIS_RISK = 0.3
22
-
23
-
24
- def _load_era5_history() -> dict[str, list[dict]]:
25
- """Load ERA5-Land data and return 90 days of history per primary-city zone."""
26
- with open(ERA5_PATH) as f:
27
- raw = json.load(f)
28
- # Take the last 90 days for each zone
29
- return {zid: records[-90:] for zid, records in raw.items()}
30
-
31
-
32
- @pytest.fixture(scope="module")
33
- def era5_history():
34
- return _load_era5_history()
35
-
36
-
37
- @pytest.fixture(scope="module")
38
- def neural_pricer():
39
- return NeuralActuarialPricer()
40
-
41
-
42
- @pytest.fixture(scope="module")
43
- def glm_pricer():
44
- return ActuarialPricer()
45
-
46
-
47
- @pytest.fixture(scope="module")
48
- def neural_results(neural_pricer, era5_history):
49
- """Price all primary-city zones with the neural model."""
50
- results = {}
51
- for zone in ACTIVE_ZONES:
52
- history = era5_history.get(zone.zone_id)
53
- r = neural_pricer.price_zone(
54
- zone=zone,
55
- predicted_frequency=DEFAULT_FREQUENCY,
56
- basis_risk_score=DEFAULT_BASIS_RISK,
57
- payout_per_event=PAYOUT_PER_EVENT,
58
- enrolled=zone.worker_population_est,
59
- climate_history=history,
60
- )
61
- results[zone.zone_id] = r
62
- return results
63
-
64
-
65
- @pytest.fixture(scope="module")
66
- def glm_results(glm_pricer):
67
- """Price all primary-city zones with the GLM baseline."""
68
- results = {}
69
- for zone in ACTIVE_ZONES:
70
- r = glm_pricer.price_zone(
71
- zone=zone,
72
- predicted_frequency=DEFAULT_FREQUENCY,
73
- basis_risk_score=DEFAULT_BASIS_RISK,
74
- payout_per_event=PAYOUT_PER_EVENT,
75
- enrolled=zone.worker_population_est,
76
- )
77
- results[zone.zone_id] = r
78
- return results
79
-
80
-
81
- # ── Test 1: Model loads ──────────────────────────────────────────────────
82
-
83
- def test_neural_model_loads(neural_pricer):
84
- """NeuralActuarialPricer should load the trained PyTorch model."""
85
- assert neural_pricer._model is not None, (
86
- "Neural model failed to load — check models/neural_pricer_dar.pt exists"
87
- )
88
- assert neural_pricer.is_neural
89
-
90
-
91
- # ── Test 2: Price accuracy vs GLM ────────────────────────────────────────
92
-
93
- def test_price_accuracy_vs_glm(neural_results, glm_results):
94
- """Neural prices should correlate with GLM and have bounded per-zone divergence.
95
-
96
- The neural model learns frequencies from 20 years of ERA5-Land data while the
97
- GLM uses a fixed default frequency, so absolute levels differ. We test that:
98
- (a) neural prices are all positive and finite, and
99
- (b) the coefficient of variation within neural prices is within 20% MAPE
100
- of the CV within GLM prices (structural similarity).
101
- """
102
- neural_prices = []
103
- glm_prices = []
104
- details = {}
105
- for zone in ACTIVE_ZONES:
106
- zid = zone.zone_id
107
- neural_price = neural_results[zid].cost_per_worker_year
108
- glm_price = glm_results[zid].cost_per_worker_year
109
- neural_prices.append(neural_price)
110
- glm_prices.append(glm_price)
111
- details[zid] = {
112
- "neural": round(neural_price, 2),
113
- "glm": round(glm_price, 2),
114
- }
115
-
116
- # All neural prices should be positive and finite
117
- for zid, p in zip([z.zone_id for z in ACTIVE_ZONES], neural_prices):
118
- assert p > 0 and np.isfinite(p), f"{zid}: invalid neural price {p}"
119
-
120
- # Compare relative spread: CV(neural) vs CV(glm)
121
- cv_neural = float(np.std(neural_prices) / np.mean(neural_prices))
122
- cv_glm = float(np.std(glm_prices) / np.mean(glm_prices))
123
- details["cv_neural"] = round(cv_neural, 3)
124
- details["cv_glm"] = round(cv_glm, 3)
125
-
126
- # Save results
127
- os.makedirs("tests/eval_results", exist_ok=True)
128
- with open("tests/eval_results/neural_pricer_eval.json", "w") as f:
129
- json.dump(details, f, indent=2)
130
-
131
- # Neural model should have meaningful price variation across zones (CV > 0.01)
132
- assert cv_neural > 0.01, (
133
- f"Neural prices have negligible variation (CV={cv_neural:.3f}) — model may be collapsing"
134
- )
135
-
136
-
137
- # ── Test 3: Rank preservation ───────────────────────────────────────────���
138
-
139
- def test_rank_preservation(neural_results, glm_results):
140
- """Spearman rank correlation between neural and GLM zone rankings should be positive.
141
-
142
- The neural model learns from real climate data while GLM uses fixed inputs, so
143
- some rank reordering is expected. We require rho > 0.4 (moderate positive
144
- correlation) — both models should agree on broad risk ordering.
145
- """
146
- zone_ids = [z.zone_id for z in ACTIVE_ZONES]
147
- neural_prices = [neural_results[zid].cost_per_worker_year for zid in zone_ids]
148
- glm_prices = [glm_results[zid].cost_per_worker_year for zid in zone_ids]
149
-
150
- rho, pval = spearmanr(neural_prices, glm_prices)
151
- assert rho > 0.4, (
152
- f"Spearman correlation {rho:.3f} below 0.4 — neural model disagrees "
153
- f"too strongly with GLM on zone risk ordering"
154
- )
155
-
156
-
157
- # ── Test 4: Neural correction bounded ───────────────────────────────────
158
-
159
- def test_delta_nn_bounded(neural_results):
160
- """Neural correction delta should be within [-50, 50]% and mean near 0."""
161
- corrections = []
162
- for zone in ACTIVE_ZONES:
163
- zid = zone.zone_id
164
- breakdown = neural_results[zid].cost_breakdown
165
- correction = breakdown["neural_correction_pct"]
166
- corrections.append(correction)
167
- assert -50 <= correction <= 50, (
168
- f"{zid}: neural_correction_pct {correction:.1f}% outside [-50, 50]"
169
- )
170
-
171
- mean_correction = float(np.mean(corrections))
172
- # Mean correction should be within the bounded range (not saturating at limits)
173
- assert abs(mean_correction) < 45, (
174
- f"Mean neural correction {mean_correction:.1f}% — model is saturating "
175
- f"at the correction boundary"
176
- )
177
-
178
-
179
- # ── Test 5: Informal settlements priced higher ──────────────────────────
180
-
181
- def test_informal_priced_higher(neural_results):
182
- """Average price for informal settlement zones should exceed formal zones."""
183
- informal_prices = []
184
- formal_prices = []
185
- for zone in ACTIVE_ZONES:
186
- price = neural_results[zone.zone_id].cost_per_worker_year
187
- if zone.settlement_type == "informal":
188
- informal_prices.append(price)
189
- elif zone.settlement_type == "formal":
190
- formal_prices.append(price)
191
-
192
- assert len(informal_prices) > 0 and len(formal_prices) > 0, (
193
- f"Need both informal and formal zones in {PRIMARY_CITY}"
194
- )
195
-
196
- mean_informal = float(np.mean(informal_prices))
197
- mean_formal = float(np.mean(formal_prices))
198
- assert mean_informal > mean_formal, (
199
- f"Informal mean (${mean_informal:.2f}) should exceed "
200
- f"formal mean (${mean_formal:.2f})"
201
- )
202
-
203
-
204
- # ── Test 6: Hazard parameters valid ─────────────────────────────────────
205
-
206
- def test_hazard_parameters_valid(neural_results):
207
- """GPD shape xi in [-0.4, 0.4] and scale sigma in [0.1, 10] for all zones."""
208
- for zone in ACTIVE_ZONES:
209
- zid = zone.zone_id
210
- breakdown = neural_results[zid].cost_breakdown
211
- xi = breakdown["gpd_shape_xi"]
212
- sigma = breakdown["gpd_scale_sigma"]
213
-
214
- assert -0.4 <= xi <= 0.4, (
215
- f"{zid}: GPD shape xi={xi:.3f} outside [-0.4, 0.4]"
216
- )
217
- assert 0.1 <= sigma <= 10.0, (
218
- f"{zid}: GPD scale sigma={sigma:.3f} outside [0.1, 10]"
219
- )
220
-
221
-
222
- # ── Test 7: Fallback without climate history ─────────────────────────────
223
-
224
- def test_fallback_without_climate_history(neural_pricer, glm_pricer):
225
- """Calling price_zone() with climate_history=None should return valid GLM result."""
226
- zone = ACTIVE_ZONES[0]
227
- result = neural_pricer.price_zone(
228
- zone=zone,
229
- predicted_frequency=DEFAULT_FREQUENCY,
230
- basis_risk_score=DEFAULT_BASIS_RISK,
231
- payout_per_event=PAYOUT_PER_EVENT,
232
- enrolled=zone.worker_population_est,
233
- climate_history=None,
234
- )
235
- glm_result = glm_pricer.price_zone(
236
- zone=zone,
237
- predicted_frequency=DEFAULT_FREQUENCY,
238
- basis_risk_score=DEFAULT_BASIS_RISK,
239
- payout_per_event=PAYOUT_PER_EVENT,
240
- enrolled=zone.worker_population_est,
241
- )
242
-
243
- assert isinstance(result, ActuarialResult)
244
- assert result.cost_per_worker_year > 0, "Fallback price should be positive"
245
- # Should match GLM exactly when no climate history
246
- assert result.cost_per_worker_year == glm_result.cost_per_worker_year, (
247
- f"Fallback price ${result.cost_per_worker_year} != GLM ${glm_result.cost_per_worker_year}"
248
- )
249
- # Should NOT contain neural-specific keys
250
- assert "neural_correction_pct" not in result.cost_breakdown
251
-
252
-
253
- # ── Test 8: Climate sensitivity ──────────────────────────────────────────
254
-
255
- def test_climate_sensitivity(neural_pricer, neural_results, era5_history):
256
- """Perturbing temperatures up by 2C should increase the price for high-risk zones."""
257
- high_risk_zones = [z for z in ACTIVE_ZONES if z.heat_vulnerability == "high"]
258
- assert len(high_risk_zones) >= 3, "Need at least 3 high-risk zones"
259
-
260
- price_increases = 0
261
- details = {}
262
-
263
- for zone in high_risk_zones:
264
- history = era5_history.get(zone.zone_id)
265
- if history is None:
266
- continue
267
-
268
- # Reuse baseline from neural_results fixture (avoids redundant LSTM forward pass)
269
- baseline = neural_results[zone.zone_id]
270
-
271
- # Perturbed: +2C on all temperature fields
272
- perturbed = []
273
- for day in history:
274
- d = dict(day)
275
- d["temp_max_c"] = d.get("temp_max_c", 30.0) + 2.0
276
- d["temp_min_c"] = d.get("temp_min_c", 24.0) + 2.0
277
- perturbed.append(d)
278
-
279
- warmer = neural_pricer.price_zone(
280
- zone=zone,
281
- predicted_frequency=DEFAULT_FREQUENCY,
282
- basis_risk_score=DEFAULT_BASIS_RISK,
283
- payout_per_event=PAYOUT_PER_EVENT,
284
- enrolled=zone.worker_population_est,
285
- climate_history=perturbed,
286
- )
287
-
288
- if warmer.cost_per_worker_year > baseline.cost_per_worker_year:
289
- price_increases += 1
290
-
291
- details[zone.zone_id] = {
292
- "baseline": round(baseline.cost_per_worker_year, 2),
293
- "warmer_2c": round(warmer.cost_per_worker_year, 2),
294
- "increased": warmer.cost_per_worker_year > baseline.cost_per_worker_year,
295
- }
296
-
297
- # At least some high-risk zones should show a price increase.
298
- # The model may already be in a saturated regime for the hottest informal
299
- # zones (where WBGT is near the ceiling), so we require at least 3 zones.
300
- assert price_increases >= 3, (
301
- f"Only {price_increases}/{len(high_risk_zones)} high-risk zones showed "
302
- f"price increase with +2C warming. Details: {details}"
303
- )