Cut dead neural/LSTM pricing + predictor paths
Browse filesThese were left behind after the earlier XGBoost heat-wave predictor
cut (93f565b) and the move to empirical burn analysis for pricing.
Nothing in src/pipeline.py or src/api.py imports them; the eval tests
they ship with couldn't pass either (neural_pricer_dar.pt never
existed on disk, so test_neural_model_loads fails by construction).
Removed:
- src/prediction/heat_forecast.py (HeatWavePredictor)
- src/prediction/lstm_model.py (LSTM trainer + CITY_THRESHOLDS)
- src/pricing/neural_actuarial.py (neural pricer; replaced by burn analysis)
- src/notification/sender.py + __init__.py (notify step was cut in a
prior pipeline pass; no one imports this anymore)
- scripts/train_neural_pricer.py, train_on_era5.py, train_on_nasa_power.py,
train_lstm.py, backtest_pricing.py (all drove the dead paths)
- tests/eval_heat_predictor.py, eval_neural_pricer.py (evaluated dead code)
- models/heat_predictor_xgb.json, trigger_head_retrained.pt
- (untracked) models/heat_lstm.pt, lstm_norm.json,
scripts/retrain_trigger_heads.py
Kept: BurnAnalysisPricer, UHI XGBoost correction (still live),
GraphCast forecast_trigger, RAG index, basis_risk assessor.
Smoke-tested: src.pipeline + src.api import clean post-cut.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- models/heat_predictor_xgb.json +0 -0
- models/trigger_head_retrained.pt +0 -3
- scripts/backtest_pricing.py +0 -508
- scripts/train_lstm.py +0 -58
- scripts/train_neural_pricer.py +0 -142
- scripts/train_on_era5.py +0 -491
- scripts/train_on_nasa_power.py +0 -660
- src/notification/__init__.py +0 -0
- src/notification/sender.py +0 -318
- src/prediction/heat_forecast.py +0 -557
- src/prediction/lstm_model.py +0 -566
- src/pricing/neural_actuarial.py +0 -1312
- tests/eval_heat_predictor.py +0 -157
- tests/eval_neural_pricer.py +0 -303
|
The diff for this file is too large to render.
See raw diff
|
|
|
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:6fb225a2eabbb94697c1c4a3f772d0850c54f48c83855a318282c54d77a1f186
|
| 3 |
-
size 168777
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,508 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Actuarial backtesting: predicted frequency vs. actual trigger events.
|
| 4 |
-
|
| 5 |
-
Runs the NeuralActuarialPricer (or GLM fallback) against 19 historical
|
| 6 |
-
hot seasons (Dec-Mar 2005-06 through 2023-24) using ERA5-Land data for
|
| 7 |
-
all 15 Dar es Salaam zones. Counts actual alert and payout events
|
| 8 |
-
using the same thresholds as the benchmark panel, then compares to the
|
| 9 |
-
model's learned_frequency (lambda).
|
| 10 |
-
|
| 11 |
-
Usage:
|
| 12 |
-
python scripts/backtest_pricing.py
|
| 13 |
-
python scripts/backtest_pricing.py --zone DAR-JAN # single zone
|
| 14 |
-
python scripts/backtest_pricing.py --all-zones # all 15 zones
|
| 15 |
-
python scripts/backtest_pricing.py --no-uhi # skip UHI correction
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
from __future__ import annotations
|
| 19 |
-
|
| 20 |
-
import argparse
|
| 21 |
-
import json
|
| 22 |
-
import sys
|
| 23 |
-
from datetime import date, timedelta
|
| 24 |
-
from pathlib import Path
|
| 25 |
-
|
| 26 |
-
import numpy as np
|
| 27 |
-
|
| 28 |
-
# ── Project imports ──────────────────────────────────────────────────────
|
| 29 |
-
PROJECT_ROOT = Path.home() / "climate-risk-engine"
|
| 30 |
-
sys.path.insert(0, str(PROJECT_ROOT))
|
| 31 |
-
|
| 32 |
-
from config import ZONE_MAP, ZONES
|
| 33 |
-
from src.indexing.heat_index import calculate_wbgt
|
| 34 |
-
from src.downscaling.uhi_model import UHI_RANGES
|
| 35 |
-
|
| 36 |
-
# ── Constants ────────────────────────────────────────────────────────────
|
| 37 |
-
# Hot season: December through March (Dar es Salaam)
|
| 38 |
-
HOT_SEASON_MONTHS = {12, 1, 2, 3}
|
| 39 |
-
|
| 40 |
-
# Benchmark trigger thresholds (from neural_actuarial.py and CLAUDE.md)
|
| 41 |
-
WINDOW_ENTRY_WBGT = 35.1 # approximate window-entry threshold
|
| 42 |
-
ALERT_MIN_DAYS = 2 # consecutive days for alert tier
|
| 43 |
-
PAYOUT_MIN_DAYS = 5 # consecutive days for payout tier
|
| 44 |
-
PAYOUT_PEAK_WBGT = 30.7 # peak WBGT for full payout qualification
|
| 45 |
-
|
| 46 |
-
# Climate history window fed to the pricer
|
| 47 |
-
HISTORY_DAYS = 90
|
| 48 |
-
|
| 49 |
-
# Default payout per event per worker (USD)
|
| 50 |
-
PAYOUT_PER_EVENT = 10.0
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# ── Helpers ──────────────────────────────────────────────────────────────
|
| 54 |
-
|
| 55 |
-
def load_era5_data() -> dict[str, list[dict]]:
|
| 56 |
-
"""Load ERA5-Land daily records keyed by zone_id."""
|
| 57 |
-
path = PROJECT_ROOT / "data" / "era5land_dar_es_salaam.json"
|
| 58 |
-
with open(path) as f:
|
| 59 |
-
return json.load(f)
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def build_date_index(records: list[dict]) -> dict[str, int]:
|
| 63 |
-
"""Map date string -> index in records list."""
|
| 64 |
-
return {r["date"]: i for i, r in enumerate(records)}
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def season_label(year: int) -> str:
|
| 68 |
-
"""e.g. season_label(2005) -> '2005-06'"""
|
| 69 |
-
return f"{year}-{str(year + 1)[-2:]}"
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def get_season_dates(start_year: int) -> tuple[date, date]:
|
| 73 |
-
"""Return (first day of Dec, last day of Mar) for a hot season."""
|
| 74 |
-
season_start = date(start_year, 12, 1)
|
| 75 |
-
season_end = date(start_year + 1, 3, 31)
|
| 76 |
-
return season_start, season_end
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_history_window(
|
| 80 |
-
records: list[dict],
|
| 81 |
-
date_idx: dict[str, int],
|
| 82 |
-
season_start: date,
|
| 83 |
-
) -> list[dict]:
|
| 84 |
-
"""Get the 90-day climate history ending just before the season start."""
|
| 85 |
-
end = season_start - timedelta(days=1)
|
| 86 |
-
start = end - timedelta(days=HISTORY_DAYS - 1)
|
| 87 |
-
window = []
|
| 88 |
-
d = start
|
| 89 |
-
while d <= end:
|
| 90 |
-
ds = d.isoformat()
|
| 91 |
-
if ds in date_idx:
|
| 92 |
-
window.append(records[date_idx[ds]])
|
| 93 |
-
d += timedelta(days=1)
|
| 94 |
-
return window
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def get_season_records(
|
| 98 |
-
records: list[dict],
|
| 99 |
-
date_idx: dict[str, int],
|
| 100 |
-
season_start: date,
|
| 101 |
-
season_end: date,
|
| 102 |
-
) -> list[dict]:
|
| 103 |
-
"""Get all daily records within the hot season window."""
|
| 104 |
-
out = []
|
| 105 |
-
d = season_start
|
| 106 |
-
while d <= season_end:
|
| 107 |
-
ds = d.isoformat()
|
| 108 |
-
if ds in date_idx:
|
| 109 |
-
out.append(records[date_idx[ds]])
|
| 110 |
-
d += timedelta(days=1)
|
| 111 |
-
return out
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def count_trigger_events(
|
| 115 |
-
season_records: list[dict],
|
| 116 |
-
uhi_delta: float,
|
| 117 |
-
apply_uhi: bool = True,
|
| 118 |
-
) -> dict:
|
| 119 |
-
"""
|
| 120 |
-
Count alert and payout events during a season.
|
| 121 |
-
|
| 122 |
-
A "window" = consecutive days where WBGT >= WINDOW_ENTRY_WBGT.
|
| 123 |
-
- Alert event: window duration >= ALERT_MIN_DAYS
|
| 124 |
-
- Payout event: window duration >= PAYOUT_MIN_DAYS AND peak WBGT >= PAYOUT_PEAK_WBGT
|
| 125 |
-
|
| 126 |
-
Returns counts with and without UHI correction.
|
| 127 |
-
"""
|
| 128 |
-
results = {}
|
| 129 |
-
for label, use_uhi in [("uhi", True), ("grid", False)]:
|
| 130 |
-
if label == "uhi" and not apply_uhi:
|
| 131 |
-
continue
|
| 132 |
-
|
| 133 |
-
delta = uhi_delta if use_uhi else 0.0
|
| 134 |
-
wbgts = []
|
| 135 |
-
for rec in season_records:
|
| 136 |
-
t_max = (rec.get("temp_max_c") or 30.0) + delta
|
| 137 |
-
hum = rec.get("humidity_pct") or 75.0
|
| 138 |
-
wbgts.append(calculate_wbgt(t_max, hum))
|
| 139 |
-
|
| 140 |
-
# Find consecutive windows above threshold
|
| 141 |
-
alert_events = 0
|
| 142 |
-
payout_events = 0
|
| 143 |
-
run_length = 0
|
| 144 |
-
run_peak = 0.0
|
| 145 |
-
|
| 146 |
-
for w in wbgts:
|
| 147 |
-
if w >= WINDOW_ENTRY_WBGT:
|
| 148 |
-
run_length += 1
|
| 149 |
-
run_peak = max(run_peak, w)
|
| 150 |
-
else:
|
| 151 |
-
if run_length >= ALERT_MIN_DAYS:
|
| 152 |
-
alert_events += 1
|
| 153 |
-
if run_length >= PAYOUT_MIN_DAYS and run_peak >= PAYOUT_PEAK_WBGT:
|
| 154 |
-
payout_events += 1
|
| 155 |
-
run_length = 0
|
| 156 |
-
run_peak = 0.0
|
| 157 |
-
# Close trailing run
|
| 158 |
-
if run_length >= ALERT_MIN_DAYS:
|
| 159 |
-
alert_events += 1
|
| 160 |
-
if run_length >= PAYOUT_MIN_DAYS and run_peak >= PAYOUT_PEAK_WBGT:
|
| 161 |
-
payout_events += 1
|
| 162 |
-
|
| 163 |
-
results[label] = {
|
| 164 |
-
"alert_events": alert_events,
|
| 165 |
-
"payout_events": payout_events,
|
| 166 |
-
"mean_wbgt": round(float(np.mean(wbgts)), 2) if wbgts else 0.0,
|
| 167 |
-
"max_wbgt": round(float(np.max(wbgts)), 2) if wbgts else 0.0,
|
| 168 |
-
"days_above_threshold": sum(1 for w in wbgts if w >= WINDOW_ENTRY_WBGT),
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
return results
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
# ── Main backtest ────────────────────────────────────────────────────────
|
| 175 |
-
|
| 176 |
-
def run_backtest(
|
| 177 |
-
zone_ids: list[str] | None = None,
|
| 178 |
-
apply_uhi: bool = True,
|
| 179 |
-
verbose: bool = True,
|
| 180 |
-
) -> list[dict]:
|
| 181 |
-
"""
|
| 182 |
-
Run the backtest across seasons and zones.
|
| 183 |
-
|
| 184 |
-
Returns list of row dicts with season, zone, predicted, and actual counts.
|
| 185 |
-
"""
|
| 186 |
-
# Load pricer (neural if available, else GLM fallback)
|
| 187 |
-
try:
|
| 188 |
-
from src.pricing.neural_actuarial import NeuralActuarialPricer
|
| 189 |
-
pricer = NeuralActuarialPricer()
|
| 190 |
-
pricer_type = "chronos" if pricer._encoder_type == "chronos" else (
|
| 191 |
-
"lstm" if pricer._encoder_type == "lstm" else "glm"
|
| 192 |
-
)
|
| 193 |
-
except Exception as e:
|
| 194 |
-
print(f"[warn] Neural pricer unavailable ({e}), using GLM fallback")
|
| 195 |
-
from src.pricing.actuarial import ActuarialPricer as _AP
|
| 196 |
-
pricer = _AP()
|
| 197 |
-
pricer_type = "glm"
|
| 198 |
-
|
| 199 |
-
if verbose:
|
| 200 |
-
print(f"Pricer: {pricer_type}")
|
| 201 |
-
print(f"UHI correction: {'ON' if apply_uhi else 'OFF'}")
|
| 202 |
-
print()
|
| 203 |
-
|
| 204 |
-
# Load ERA5 data
|
| 205 |
-
era5 = load_era5_data()
|
| 206 |
-
|
| 207 |
-
# Resolve zone list
|
| 208 |
-
dar_zones = [z for z in ZONES if z.city == "Dar es Salaam"]
|
| 209 |
-
if zone_ids:
|
| 210 |
-
dar_zones = [z for z in dar_zones if z.zone_id in zone_ids]
|
| 211 |
-
if not dar_zones:
|
| 212 |
-
print("No matching zones found.")
|
| 213 |
-
return []
|
| 214 |
-
|
| 215 |
-
if verbose:
|
| 216 |
-
print(f"Zones: {[z.zone_id for z in dar_zones]}")
|
| 217 |
-
print()
|
| 218 |
-
|
| 219 |
-
# Seasons: Dec 2005 through Mar 2024 -> start years 2005..2023
|
| 220 |
-
season_years = list(range(2005, 2024))
|
| 221 |
-
rows = []
|
| 222 |
-
|
| 223 |
-
for zone in dar_zones:
|
| 224 |
-
records = era5.get(zone.zone_id, [])
|
| 225 |
-
if not records:
|
| 226 |
-
print(f"[warn] No ERA5 data for {zone.zone_id}, skipping")
|
| 227 |
-
continue
|
| 228 |
-
|
| 229 |
-
date_idx = build_date_index(records)
|
| 230 |
-
uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
|
| 231 |
-
mean_uhi = (uhi_lo + uhi_hi) / 2.0
|
| 232 |
-
|
| 233 |
-
for sy in season_years:
|
| 234 |
-
season_start, season_end = get_season_dates(sy)
|
| 235 |
-
|
| 236 |
-
# 1. Get 90-day history ending before season start
|
| 237 |
-
history = get_history_window(records, date_idx, season_start)
|
| 238 |
-
if len(history) < 30:
|
| 239 |
-
continue # not enough history
|
| 240 |
-
|
| 241 |
-
# 2. Run pricer to get predicted frequency (lambda)
|
| 242 |
-
predicted_freq = None
|
| 243 |
-
try:
|
| 244 |
-
if pricer_type == "glm":
|
| 245 |
-
# GLM needs a frequency estimate -- use historical rate
|
| 246 |
-
# from the previous season as a naive baseline
|
| 247 |
-
result = pricer.price_zone(
|
| 248 |
-
zone=zone,
|
| 249 |
-
predicted_frequency=10.0, # placeholder
|
| 250 |
-
basis_risk_score=0.3,
|
| 251 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 252 |
-
enrolled=zone.worker_population_est,
|
| 253 |
-
)
|
| 254 |
-
predicted_freq = 10.0 # GLM doesn't learn frequency
|
| 255 |
-
else:
|
| 256 |
-
result = pricer.price_zone(
|
| 257 |
-
zone=zone,
|
| 258 |
-
predicted_frequency=10.0,
|
| 259 |
-
basis_risk_score=0.3,
|
| 260 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 261 |
-
enrolled=zone.worker_population_est,
|
| 262 |
-
climate_history=history,
|
| 263 |
-
)
|
| 264 |
-
predicted_freq = result.cost_breakdown.get(
|
| 265 |
-
"learned_frequency", 10.0
|
| 266 |
-
)
|
| 267 |
-
except Exception as e:
|
| 268 |
-
if verbose:
|
| 269 |
-
print(f" [warn] Pricer failed for {zone.zone_id} "
|
| 270 |
-
f"season {season_label(sy)}: {e}")
|
| 271 |
-
continue
|
| 272 |
-
|
| 273 |
-
# 3. Count actual events during the season
|
| 274 |
-
season_recs = get_season_records(
|
| 275 |
-
records, date_idx, season_start, season_end
|
| 276 |
-
)
|
| 277 |
-
if not season_recs:
|
| 278 |
-
continue
|
| 279 |
-
|
| 280 |
-
actuals = count_trigger_events(
|
| 281 |
-
season_recs, mean_uhi, apply_uhi=apply_uhi
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
# Build row -- use UHI-corrected actuals for primary comparison
|
| 285 |
-
# if UHI is on, otherwise use grid actuals
|
| 286 |
-
primary = actuals.get("uhi", actuals.get("grid", {}))
|
| 287 |
-
grid = actuals.get("grid", {})
|
| 288 |
-
|
| 289 |
-
row = {
|
| 290 |
-
"season": season_label(sy),
|
| 291 |
-
"zone_id": zone.zone_id,
|
| 292 |
-
"zone_name": zone.name,
|
| 293 |
-
"settlement_type": zone.settlement_type,
|
| 294 |
-
"predicted_freq": round(predicted_freq, 1),
|
| 295 |
-
"actual_alert_uhi": primary.get("alert_events", 0),
|
| 296 |
-
"actual_payout_uhi": primary.get("payout_events", 0),
|
| 297 |
-
"actual_alert_grid": grid.get("alert_events", 0),
|
| 298 |
-
"actual_payout_grid": grid.get("payout_events", 0),
|
| 299 |
-
"mean_wbgt_uhi": primary.get("mean_wbgt", 0),
|
| 300 |
-
"max_wbgt_uhi": primary.get("max_wbgt", 0),
|
| 301 |
-
"days_above_uhi": primary.get("days_above_threshold", 0),
|
| 302 |
-
"mean_wbgt_grid": grid.get("mean_wbgt", 0),
|
| 303 |
-
"max_wbgt_grid": grid.get("max_wbgt", 0),
|
| 304 |
-
"days_above_grid": grid.get("days_above_threshold", 0),
|
| 305 |
-
"season_days": len(season_recs),
|
| 306 |
-
}
|
| 307 |
-
rows.append(row)
|
| 308 |
-
|
| 309 |
-
return rows
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
def print_summary(rows: list[dict], apply_uhi: bool = True) -> None:
|
| 313 |
-
"""Print formatted summary tables."""
|
| 314 |
-
if not rows:
|
| 315 |
-
print("No backtest results to display.")
|
| 316 |
-
return
|
| 317 |
-
|
| 318 |
-
# ── Per-season detail table ──────────────────────────────────────
|
| 319 |
-
zones_in_data = sorted(set(r["zone_id"] for r in rows))
|
| 320 |
-
single_zone = len(zones_in_data) == 1
|
| 321 |
-
|
| 322 |
-
print("=" * 95)
|
| 323 |
-
print("ACTUARIAL BACKTEST: Predicted Frequency vs. Actual Trigger Events")
|
| 324 |
-
print("=" * 95)
|
| 325 |
-
print()
|
| 326 |
-
|
| 327 |
-
if single_zone:
|
| 328 |
-
zone_id = zones_in_data[0]
|
| 329 |
-
zone_name = rows[0]["zone_name"]
|
| 330 |
-
print(f"Zone: {zone_id} ({zone_name})")
|
| 331 |
-
print(f"Threshold: WBGT >= {WINDOW_ENTRY_WBGT} C (window entry)")
|
| 332 |
-
print(f"Alert: >= {ALERT_MIN_DAYS} consecutive days | "
|
| 333 |
-
f"Payout: >= {PAYOUT_MIN_DAYS} days AND peak WBGT >= {PAYOUT_PEAK_WBGT}")
|
| 334 |
-
print()
|
| 335 |
-
|
| 336 |
-
header = (
|
| 337 |
-
f"{'Season':<10} {'Pred_Freq':>9} "
|
| 338 |
-
f"{'Alert_UHI':>9} {'Payout_UHI':>10} "
|
| 339 |
-
f"{'Alert_Grid':>10} {'Payout_Grid':>11} "
|
| 340 |
-
f"{'MeanWBGT':>8} {'MaxWBGT':>7} {'Days>Thr':>8}"
|
| 341 |
-
)
|
| 342 |
-
print(header)
|
| 343 |
-
print("-" * len(header))
|
| 344 |
-
|
| 345 |
-
for r in sorted(rows, key=lambda x: x["season"]):
|
| 346 |
-
print(
|
| 347 |
-
f"{r['season']:<10} {r['predicted_freq']:>9.1f} "
|
| 348 |
-
f"{r['actual_alert_uhi']:>9} {r['actual_payout_uhi']:>10} "
|
| 349 |
-
f"{r['actual_alert_grid']:>10} {r['actual_payout_grid']:>11} "
|
| 350 |
-
f"{r['mean_wbgt_uhi']:>8.1f} {r['max_wbgt_uhi']:>7.1f} "
|
| 351 |
-
f"{r['days_above_uhi']:>8}"
|
| 352 |
-
)
|
| 353 |
-
else:
|
| 354 |
-
# Multi-zone: aggregate by zone
|
| 355 |
-
print(f"Zones: {len(zones_in_data)} | Seasons: "
|
| 356 |
-
f"{sorted(set(r['season'] for r in rows))[0]} to "
|
| 357 |
-
f"{sorted(set(r['season'] for r in rows))[-1]}")
|
| 358 |
-
print(f"Threshold: WBGT >= {WINDOW_ENTRY_WBGT} C")
|
| 359 |
-
print()
|
| 360 |
-
|
| 361 |
-
header = (
|
| 362 |
-
f"{'Zone':<10} {'Type':<11} {'Seasons':>7} "
|
| 363 |
-
f"{'MeanPred':>8} {'MeanAlertU':>10} {'MeanPayU':>8} "
|
| 364 |
-
f"{'MeanAlertG':>10} {'MeanPayG':>8} {'MeanWBGT':>8}"
|
| 365 |
-
)
|
| 366 |
-
print(header)
|
| 367 |
-
print("-" * len(header))
|
| 368 |
-
|
| 369 |
-
for zid in zones_in_data:
|
| 370 |
-
zrows = [r for r in rows if r["zone_id"] == zid]
|
| 371 |
-
n = len(zrows)
|
| 372 |
-
print(
|
| 373 |
-
f"{zid:<10} {zrows[0]['settlement_type']:<11} {n:>7} "
|
| 374 |
-
f"{np.mean([r['predicted_freq'] for r in zrows]):>8.1f} "
|
| 375 |
-
f"{np.mean([r['actual_alert_uhi'] for r in zrows]):>10.1f} "
|
| 376 |
-
f"{np.mean([r['actual_payout_uhi'] for r in zrows]):>8.1f} "
|
| 377 |
-
f"{np.mean([r['actual_alert_grid'] for r in zrows]):>10.1f} "
|
| 378 |
-
f"{np.mean([r['actual_payout_grid'] for r in zrows]):>8.1f} "
|
| 379 |
-
f"{np.mean([r['mean_wbgt_uhi'] for r in zrows]):>8.1f}"
|
| 380 |
-
)
|
| 381 |
-
|
| 382 |
-
# ── Aggregate metrics ────────────────────────────────────────────
|
| 383 |
-
print()
|
| 384 |
-
print("=" * 95)
|
| 385 |
-
print("AGGREGATE METRICS")
|
| 386 |
-
print("=" * 95)
|
| 387 |
-
print()
|
| 388 |
-
|
| 389 |
-
pred = np.array([r["predicted_freq"] for r in rows])
|
| 390 |
-
# Annualize actuals: season is ~121 days (Dec-Mar), so scale up
|
| 391 |
-
# But predicted_freq is already annual. We compare predicted annual
|
| 392 |
-
# to actual season counts directly -- the model should predict events
|
| 393 |
-
# per year, but events in Dec-Mar ARE the bulk of the hot season.
|
| 394 |
-
# So the comparison is: does the model's annual lambda match the
|
| 395 |
-
# observed rate when we look at the actual hot season?
|
| 396 |
-
|
| 397 |
-
for label_suffix, alert_key, payout_key in [
|
| 398 |
-
("(UHI-corrected)", "actual_alert_uhi", "actual_payout_uhi"),
|
| 399 |
-
("(grid / no UHI)", "actual_alert_grid", "actual_payout_grid"),
|
| 400 |
-
]:
|
| 401 |
-
if not apply_uhi and label_suffix == "(UHI-corrected)":
|
| 402 |
-
continue
|
| 403 |
-
|
| 404 |
-
actual_alert = np.array([r[alert_key] for r in rows], dtype=float)
|
| 405 |
-
actual_payout = np.array([r[payout_key] for r in rows], dtype=float)
|
| 406 |
-
|
| 407 |
-
print(f"--- {label_suffix} ---")
|
| 408 |
-
print(f" Mean predicted frequency (annual lambda): {pred.mean():.2f}")
|
| 409 |
-
print(f" Mean actual alert events per season: {actual_alert.mean():.2f}")
|
| 410 |
-
print(f" Mean actual payout events per season: {actual_payout.mean():.2f}")
|
| 411 |
-
print()
|
| 412 |
-
print(f" Predicted/Actual alert ratio: "
|
| 413 |
-
f"{pred.mean() / max(actual_alert.mean(), 0.01):.2f}x")
|
| 414 |
-
print(f" Predicted/Actual payout ratio: "
|
| 415 |
-
f"{pred.mean() / max(actual_payout.mean(), 0.01):.2f}x")
|
| 416 |
-
|
| 417 |
-
# Correlation (only meaningful if there's variance)
|
| 418 |
-
if actual_alert.std() > 0 and pred.std() > 0:
|
| 419 |
-
corr_alert = float(np.corrcoef(pred, actual_alert)[0, 1])
|
| 420 |
-
print(f" Pearson correlation (pred vs alert): {corr_alert:.3f}")
|
| 421 |
-
else:
|
| 422 |
-
print(f" Pearson correlation (pred vs alert): N/A (no variance)")
|
| 423 |
-
|
| 424 |
-
if actual_payout.std() > 0 and pred.std() > 0:
|
| 425 |
-
corr_payout = float(np.corrcoef(pred, actual_payout)[0, 1])
|
| 426 |
-
print(f" Pearson correlation (pred vs payout): {corr_payout:.3f}")
|
| 427 |
-
else:
|
| 428 |
-
print(f" Pearson correlation (pred vs payout): N/A (no variance)")
|
| 429 |
-
|
| 430 |
-
# RMSE
|
| 431 |
-
rmse_alert = float(np.sqrt(np.mean((pred - actual_alert) ** 2)))
|
| 432 |
-
rmse_payout = float(np.sqrt(np.mean((pred - actual_payout) ** 2)))
|
| 433 |
-
print(f" RMSE (pred vs alert): {rmse_alert:.2f}")
|
| 434 |
-
print(f" RMSE (pred vs payout): {rmse_payout:.2f}")
|
| 435 |
-
|
| 436 |
-
# Per-season hit rate: how often does predicted > 0 match actual > 0
|
| 437 |
-
pred_positive = pred > 0.5
|
| 438 |
-
actual_alert_positive = actual_alert > 0
|
| 439 |
-
actual_payout_positive = actual_payout > 0
|
| 440 |
-
if len(pred) > 0:
|
| 441 |
-
alert_hit = float(
|
| 442 |
-
np.mean(pred_positive == actual_alert_positive) * 100
|
| 443 |
-
)
|
| 444 |
-
payout_hit = float(
|
| 445 |
-
np.mean(pred_positive == actual_payout_positive) * 100
|
| 446 |
-
)
|
| 447 |
-
print(f" Direction accuracy (alert): {alert_hit:.0f}%")
|
| 448 |
-
print(f" Direction accuracy (payout): {payout_hit:.0f}%")
|
| 449 |
-
print()
|
| 450 |
-
|
| 451 |
-
# ── By settlement type ───────────────────────────────────────────
|
| 452 |
-
stypes = sorted(set(r["settlement_type"] for r in rows))
|
| 453 |
-
if len(stypes) > 1:
|
| 454 |
-
print("--- By settlement type ---")
|
| 455 |
-
for st in stypes:
|
| 456 |
-
st_rows = [r for r in rows if r["settlement_type"] == st]
|
| 457 |
-
st_pred = np.mean([r["predicted_freq"] for r in st_rows])
|
| 458 |
-
st_alert = np.mean([r["actual_alert_uhi"] for r in st_rows])
|
| 459 |
-
st_payout = np.mean([r["actual_payout_uhi"] for r in st_rows])
|
| 460 |
-
print(f" {st:<12} pred={st_pred:.1f} alert={st_alert:.1f} "
|
| 461 |
-
f"payout={st_payout:.1f} "
|
| 462 |
-
f"ratio={st_pred / max(st_alert, 0.01):.1f}x")
|
| 463 |
-
print()
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
def main():
|
| 467 |
-
parser = argparse.ArgumentParser(
|
| 468 |
-
description="Backtest pricing model against historical ERA5-Land data"
|
| 469 |
-
)
|
| 470 |
-
parser.add_argument(
|
| 471 |
-
"--zone", type=str, default="DAR-JAN",
|
| 472 |
-
help="Zone ID to test (default: DAR-JAN)"
|
| 473 |
-
)
|
| 474 |
-
parser.add_argument(
|
| 475 |
-
"--all-zones", action="store_true",
|
| 476 |
-
help="Run backtest across all 15 Dar es Salaam zones"
|
| 477 |
-
)
|
| 478 |
-
parser.add_argument(
|
| 479 |
-
"--no-uhi", action="store_true",
|
| 480 |
-
help="Disable UHI correction (grid-only WBGT)"
|
| 481 |
-
)
|
| 482 |
-
parser.add_argument(
|
| 483 |
-
"--quiet", action="store_true",
|
| 484 |
-
help="Suppress per-season detail, show only aggregate metrics"
|
| 485 |
-
)
|
| 486 |
-
args = parser.parse_args()
|
| 487 |
-
|
| 488 |
-
apply_uhi = not args.no_uhi
|
| 489 |
-
|
| 490 |
-
if args.all_zones:
|
| 491 |
-
zone_ids = None # all Dar zones
|
| 492 |
-
else:
|
| 493 |
-
zone_ids = [args.zone]
|
| 494 |
-
|
| 495 |
-
rows = run_backtest(
|
| 496 |
-
zone_ids=zone_ids,
|
| 497 |
-
apply_uhi=apply_uhi,
|
| 498 |
-
verbose=not args.quiet,
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
if not args.quiet:
|
| 502 |
-
print()
|
| 503 |
-
|
| 504 |
-
print_summary(rows, apply_uhi=apply_uhi)
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
if __name__ == "__main__":
|
| 508 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
"""Train the LSTM heat wave predictor.
|
| 2 |
-
|
| 3 |
-
Tries ERA5 data first, falls back to synthetic data generation
|
| 4 |
-
(same seasonal + AR(1) approach as the existing XGBoost trainer).
|
| 5 |
-
|
| 6 |
-
Usage:
|
| 7 |
-
python3 scripts/train_lstm.py
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
import sys
|
| 11 |
-
import time
|
| 12 |
-
|
| 13 |
-
sys.path.insert(0, ".")
|
| 14 |
-
|
| 15 |
-
from src.prediction.lstm_model import LSTMTrainer, generate_synthetic_zone_data
|
| 16 |
-
from config import ZONES
|
| 17 |
-
|
| 18 |
-
print("=" * 60)
|
| 19 |
-
print("LSTM Heat Wave Predictor -- Training")
|
| 20 |
-
print("=" * 60)
|
| 21 |
-
|
| 22 |
-
# Try ERA5 data first, fall back to synthetic
|
| 23 |
-
zone_data = None
|
| 24 |
-
|
| 25 |
-
try:
|
| 26 |
-
from src.ingestion.era5_fetcher import fetch_era5_sync
|
| 27 |
-
|
| 28 |
-
print("\nFetching ERA5 data for training...")
|
| 29 |
-
raw = fetch_era5_sync(ZONES, days_back=365)
|
| 30 |
-
# Convert to training format if fetch succeeds
|
| 31 |
-
if raw and len(raw) > 0:
|
| 32 |
-
zone_data = raw
|
| 33 |
-
print(f" Loaded ERA5 data for {len(zone_data)} zones")
|
| 34 |
-
except Exception as e:
|
| 35 |
-
print(f"\nERA5 unavailable ({e}), using synthetic training data")
|
| 36 |
-
|
| 37 |
-
if zone_data is None:
|
| 38 |
-
print("\nGenerating synthetic training data (2 years x 20 zones)...")
|
| 39 |
-
t0 = time.time()
|
| 40 |
-
zone_data = generate_synthetic_zone_data(ZONES, n_days=730, seed=42)
|
| 41 |
-
elapsed = time.time() - t0
|
| 42 |
-
total_days = sum(len(v) for v in zone_data.values())
|
| 43 |
-
print(f" Generated {total_days:,} zone-days in {elapsed:.1f}s")
|
| 44 |
-
|
| 45 |
-
# Train
|
| 46 |
-
print("\nTraining LSTM...")
|
| 47 |
-
t0 = time.time()
|
| 48 |
-
trainer = LSTMTrainer(epochs=50, patience=5)
|
| 49 |
-
metrics = trainer.train(zone_data)
|
| 50 |
-
elapsed = time.time() - t0
|
| 51 |
-
|
| 52 |
-
print(f"\nTraining complete in {elapsed:.1f}s")
|
| 53 |
-
print(f" Epochs trained: {metrics.get('epochs_trained', '?')}")
|
| 54 |
-
print(f" Train loss: {metrics.get('train_loss', '?')}")
|
| 55 |
-
print(f" Val loss: {metrics.get('val_loss', '?')}")
|
| 56 |
-
print(f" Val AUROC: {metrics.get('val_auroc', '?')}")
|
| 57 |
-
print(f" Samples: {metrics.get('samples', '?')}")
|
| 58 |
-
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Train the Neural Actuarial Pricing Engine on real climate data.
|
| 4 |
-
|
| 5 |
-
Usage:
|
| 6 |
-
python3 scripts/train_neural_pricer.py # default: config.PRIMARY_CITY
|
| 7 |
-
python3 scripts/train_neural_pricer.py --city Kampala # Other city
|
| 8 |
-
|
| 9 |
-
Requires: data/era5land_{slug}.json where slug is derived from the primary city
|
| 10 |
-
(config.PRIMARY_CITY) or the --city flag.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
import argparse
|
| 14 |
-
import json
|
| 15 |
-
import sys
|
| 16 |
-
from pathlib import Path
|
| 17 |
-
|
| 18 |
-
# Add project root to path
|
| 19 |
-
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 20 |
-
|
| 21 |
-
from config import ZONES, ZONE_MAP, slug_for
|
| 22 |
-
from src.pricing.neural_actuarial import (
|
| 23 |
-
build_training_samples,
|
| 24 |
-
load_climate_data,
|
| 25 |
-
NeuralPricerTrainer,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def main():
|
| 30 |
-
parser = argparse.ArgumentParser(description="Train neural actuarial pricer")
|
| 31 |
-
parser.add_argument("--city", default="Dar es Salaam", help="City to train for")
|
| 32 |
-
parser.add_argument("--data", default=None, help="Path to climate data JSON")
|
| 33 |
-
parser.add_argument("--epochs", type=int, default=80)
|
| 34 |
-
parser.add_argument("--lr", type=float, default=1e-3)
|
| 35 |
-
parser.add_argument("--patience", type=int, default=10)
|
| 36 |
-
parser.add_argument("--encoder", choices=["chronos", "lstm"], default="chronos",
|
| 37 |
-
help="Encoder type: chronos (foundation model) or lstm (legacy)")
|
| 38 |
-
args = parser.parse_args()
|
| 39 |
-
|
| 40 |
-
data_dir = Path(__file__).resolve().parents[1] / "data"
|
| 41 |
-
|
| 42 |
-
# Find climate data file
|
| 43 |
-
if args.data:
|
| 44 |
-
data_path = Path(args.data)
|
| 45 |
-
else:
|
| 46 |
-
city_slug = slug_for(args.city)
|
| 47 |
-
# Try ERA5-Land first, fall back to NASA POWER
|
| 48 |
-
era5_path = data_dir / f"era5land_{city_slug}.json"
|
| 49 |
-
nasa_path = data_dir / f"nasa_power_{city_slug}.json"
|
| 50 |
-
if era5_path.exists():
|
| 51 |
-
data_path = era5_path
|
| 52 |
-
print(f"Using ERA5-Land data: {data_path}")
|
| 53 |
-
elif nasa_path.exists():
|
| 54 |
-
data_path = nasa_path
|
| 55 |
-
print(f"Using NASA POWER data: {data_path}")
|
| 56 |
-
else:
|
| 57 |
-
print(f"ERROR: No climate data found for {args.city}")
|
| 58 |
-
print(f" Looked for: {era5_path}")
|
| 59 |
-
print(f" Looked for: {nasa_path}")
|
| 60 |
-
print(f" Run scripts/fetch_nasa_power_dar.py first")
|
| 61 |
-
sys.exit(1)
|
| 62 |
-
|
| 63 |
-
# Filter zones for the target city
|
| 64 |
-
city_zones = [z for z in ZONES if z.city == args.city]
|
| 65 |
-
if not city_zones:
|
| 66 |
-
print(f"ERROR: No zones found for city '{args.city}'")
|
| 67 |
-
print(f" Available cities: {sorted(set(z.city for z in ZONES))}")
|
| 68 |
-
sys.exit(1)
|
| 69 |
-
|
| 70 |
-
print(f"\n{'='*60}")
|
| 71 |
-
print(f"Neural Actuarial Pricer — {args.city}")
|
| 72 |
-
print(f"{'='*60}")
|
| 73 |
-
print(f" Zones: {len(city_zones)} ({', '.join(z.name for z in city_zones)})")
|
| 74 |
-
print(f" Data: {data_path}")
|
| 75 |
-
|
| 76 |
-
# Load climate data
|
| 77 |
-
climate_data = load_climate_data(data_path)
|
| 78 |
-
zone_ids = {z.zone_id for z in city_zones}
|
| 79 |
-
climate_data = {k: v for k, v in climate_data.items() if k in zone_ids}
|
| 80 |
-
|
| 81 |
-
if not climate_data:
|
| 82 |
-
print(f"ERROR: No matching zone data found in {data_path}")
|
| 83 |
-
print(f" Expected zone IDs: {zone_ids}")
|
| 84 |
-
print(f" Found zone IDs: {set(load_climate_data(data_path).keys())}")
|
| 85 |
-
sys.exit(1)
|
| 86 |
-
|
| 87 |
-
print(f" Records: {sum(len(v) for v in climate_data.values()):,} zone-days")
|
| 88 |
-
|
| 89 |
-
# Build training samples with city-specific WBGT threshold
|
| 90 |
-
from src.pricing.neural_actuarial import CITY_WBGT_THRESHOLDS
|
| 91 |
-
wbgt_thresh = CITY_WBGT_THRESHOLDS.get(args.city, 35.0)
|
| 92 |
-
print(f"\nBuilding training samples (90-day windows, stride=7, WBGT threshold={wbgt_thresh}°C)...")
|
| 93 |
-
X, targets = build_training_samples(climate_data, city_zones, wbgt_threshold=wbgt_thresh)
|
| 94 |
-
print(f" Samples: {len(X):,}")
|
| 95 |
-
print(f" Shape: {X.shape}")
|
| 96 |
-
print(f" Target ranges:")
|
| 97 |
-
for k, v in targets.items():
|
| 98 |
-
print(f" {k}: [{v.min():.3f}, {v.max():.3f}], mean={v.mean():.3f}")
|
| 99 |
-
|
| 100 |
-
# Train
|
| 101 |
-
print(f"\nTraining (encoder={args.encoder}, epochs={args.epochs}, lr={args.lr}, patience={args.patience})...")
|
| 102 |
-
trainer = NeuralPricerTrainer(
|
| 103 |
-
lr=args.lr, epochs=args.epochs, patience=args.patience,
|
| 104 |
-
encoder=args.encoder,
|
| 105 |
-
)
|
| 106 |
-
metrics = trainer.train(X, targets)
|
| 107 |
-
|
| 108 |
-
print(f"\n{'='*60}")
|
| 109 |
-
print(f"Training complete")
|
| 110 |
-
print(f"{'='*60}")
|
| 111 |
-
for k, v in metrics.items():
|
| 112 |
-
print(f" {k}: {v}")
|
| 113 |
-
|
| 114 |
-
# Quick inference test
|
| 115 |
-
print(f"\nInference test:")
|
| 116 |
-
from src.pricing.neural_actuarial import NeuralActuarialPricer
|
| 117 |
-
pricer = NeuralActuarialPricer()
|
| 118 |
-
print(f" Neural model loaded: {pricer.is_neural}")
|
| 119 |
-
|
| 120 |
-
if pricer.is_neural:
|
| 121 |
-
for zone in city_zones:
|
| 122 |
-
history = climate_data[zone.zone_id][-90:]
|
| 123 |
-
result = pricer.price_zone(
|
| 124 |
-
zone=zone,
|
| 125 |
-
predicted_frequency=10.0, # ignored by neural model
|
| 126 |
-
basis_risk_score=0.2, # ignored by neural model
|
| 127 |
-
payout_per_event=10.0,
|
| 128 |
-
enrolled=zone.worker_population_est,
|
| 129 |
-
climate_history=history,
|
| 130 |
-
)
|
| 131 |
-
cb = result.cost_breakdown
|
| 132 |
-
print(
|
| 133 |
-
f" {zone.name:12s} ({zone.settlement_type:9s}): "
|
| 134 |
-
f"${result.cost_per_worker_year:.2f}/worker/yr "
|
| 135 |
-
f"(λ={cb.get('learned_frequency', '?')}, "
|
| 136 |
-
f"basis_risk={cb.get('learned_basis_risk', '?')}, "
|
| 137 |
-
f"δ_NN={cb.get('neural_correction_pct', '?'):+.1f}%)"
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
if __name__ == "__main__":
|
| 142 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,491 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Train all ML models on real ERA5 reanalysis data.
|
| 3 |
-
|
| 4 |
-
Steps:
|
| 5 |
-
1. Fetch 2 years of ERA5 data for all 20 zones via Google ARCO Zarr store
|
| 6 |
-
2. Validate data quality (coverage, temp ranges, nulls)
|
| 7 |
-
3. Retrain XGBoost heat predictor on real data
|
| 8 |
-
4. Retrain LSTM on real data
|
| 9 |
-
5. Verify UHI model works with real ERA5 temps
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
import os
|
| 14 |
-
import time
|
| 15 |
-
import logging
|
| 16 |
-
import math
|
| 17 |
-
|
| 18 |
-
import numpy as np
|
| 19 |
-
|
| 20 |
-
# Project root on sys.path
|
| 21 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 22 |
-
|
| 23 |
-
from config import ZONES, ZONE_MAP
|
| 24 |
-
from src.ingestion.era5_fetcher import fetch_era5_sync
|
| 25 |
-
from src.ingestion.models import DailyReading
|
| 26 |
-
from src.indexing.heat_index import calculate_wbgt
|
| 27 |
-
|
| 28 |
-
logging.basicConfig(
|
| 29 |
-
level=logging.INFO,
|
| 30 |
-
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
| 31 |
-
datefmt="%H:%M:%S",
|
| 32 |
-
)
|
| 33 |
-
log = logging.getLogger("train_era5")
|
| 34 |
-
|
| 35 |
-
# Expected temp ranges per city (max daily temps, deg C)
|
| 36 |
-
EXPECTED_RANGES = {
|
| 37 |
-
"Nairobi": (18, 35),
|
| 38 |
-
"Dar es Salaam": (25, 40),
|
| 39 |
-
"Kampala": (22, 36),
|
| 40 |
-
"Kigali": (20, 34),
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# ======================================================================
|
| 45 |
-
# Step 1: Fetch ERA5 data
|
| 46 |
-
# ======================================================================
|
| 47 |
-
|
| 48 |
-
def fetch_data():
|
| 49 |
-
log.info("=" * 60)
|
| 50 |
-
log.info("STEP 1: Fetching 2 years of ERA5 data for %d zones", len(ZONES))
|
| 51 |
-
log.info("=" * 60)
|
| 52 |
-
t0 = time.time()
|
| 53 |
-
data = fetch_era5_sync(ZONES, days_back=730)
|
| 54 |
-
elapsed = time.time() - t0
|
| 55 |
-
log.info("Fetch complete in %.1f seconds", elapsed)
|
| 56 |
-
return data
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
# ======================================================================
|
| 60 |
-
# Step 2: Validate data quality
|
| 61 |
-
# ======================================================================
|
| 62 |
-
|
| 63 |
-
def validate_data(data: dict[str, list[DailyReading]]):
|
| 64 |
-
log.info("=" * 60)
|
| 65 |
-
log.info("STEP 2: Validating ERA5 data quality")
|
| 66 |
-
log.info("=" * 60)
|
| 67 |
-
|
| 68 |
-
issues = []
|
| 69 |
-
stats = {}
|
| 70 |
-
|
| 71 |
-
for zone in ZONES:
|
| 72 |
-
zid = zone.zone_id
|
| 73 |
-
readings = data.get(zid, [])
|
| 74 |
-
|
| 75 |
-
if not readings:
|
| 76 |
-
issues.append(f"{zid}: NO DATA")
|
| 77 |
-
stats[zid] = {"days": 0, "issue": "no data"}
|
| 78 |
-
continue
|
| 79 |
-
|
| 80 |
-
temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
|
| 81 |
-
humids = [r.humidity_pct for r in readings if r.humidity_pct is not None]
|
| 82 |
-
winds = [r.wind_speed_ms for r in readings if r.wind_speed_ms is not None]
|
| 83 |
-
|
| 84 |
-
if not temps:
|
| 85 |
-
issues.append(f"{zid}: all temps are null")
|
| 86 |
-
stats[zid] = {"days": len(readings), "issue": "all null temps"}
|
| 87 |
-
continue
|
| 88 |
-
|
| 89 |
-
t_min, t_max = min(temps), max(temps)
|
| 90 |
-
t_mean = sum(temps) / len(temps)
|
| 91 |
-
|
| 92 |
-
# Check physical reasonableness
|
| 93 |
-
exp_lo, exp_hi = EXPECTED_RANGES.get(zone.city, (15, 42))
|
| 94 |
-
if t_min < exp_lo - 5 or t_max > exp_hi + 5:
|
| 95 |
-
issues.append(
|
| 96 |
-
f"{zid} ({zone.city}): temp range [{t_min:.1f}, {t_max:.1f}] "
|
| 97 |
-
f"outside expected [{exp_lo-5}, {exp_hi+5}]"
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
null_count = sum(1 for r in readings if r.temp_max_c is None)
|
| 101 |
-
|
| 102 |
-
stats[zid] = {
|
| 103 |
-
"days": len(readings),
|
| 104 |
-
"temp_days": len(temps),
|
| 105 |
-
"temp_min": round(t_min, 1),
|
| 106 |
-
"temp_max": round(t_max, 1),
|
| 107 |
-
"temp_mean": round(t_mean, 1),
|
| 108 |
-
"humidity_mean": round(sum(humids)/len(humids), 1) if humids else None,
|
| 109 |
-
"wind_mean": round(sum(winds)/len(winds), 1) if winds else None,
|
| 110 |
-
"null_temps": null_count,
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
# Print summary
|
| 114 |
-
print("\n--- ERA5 Data Summary ---")
|
| 115 |
-
print(f"{'Zone':<12} {'City':<16} {'Days':>5} {'Temp min':>9} {'Temp max':>9} {'Temp mean':>10} {'Humidity':>9} {'Nulls':>6}")
|
| 116 |
-
print("-" * 90)
|
| 117 |
-
|
| 118 |
-
by_city = {}
|
| 119 |
-
for zone in ZONES:
|
| 120 |
-
s = stats.get(zone.zone_id, {})
|
| 121 |
-
days = s.get("days", 0)
|
| 122 |
-
t_lo = s.get("temp_min", "N/A")
|
| 123 |
-
t_hi = s.get("temp_max", "N/A")
|
| 124 |
-
t_mn = s.get("temp_mean", "N/A")
|
| 125 |
-
hum = s.get("humidity_mean", "N/A")
|
| 126 |
-
nulls = s.get("null_temps", "N/A")
|
| 127 |
-
print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {t_lo:>9} {t_hi:>9} {t_mn:>10} {hum:>9} {nulls:>6}")
|
| 128 |
-
|
| 129 |
-
city = zone.city
|
| 130 |
-
if city not in by_city:
|
| 131 |
-
by_city[city] = []
|
| 132 |
-
by_city[city].append(s)
|
| 133 |
-
|
| 134 |
-
print("\n--- Per-city aggregated temp ranges ---")
|
| 135 |
-
for city, zone_stats in by_city.items():
|
| 136 |
-
all_mins = [s["temp_min"] for s in zone_stats if s.get("temp_min") is not None]
|
| 137 |
-
all_maxs = [s["temp_max"] for s in zone_stats if s.get("temp_max") is not None]
|
| 138 |
-
if all_mins and all_maxs:
|
| 139 |
-
print(f" {city:<16}: {min(all_mins):.1f} - {max(all_maxs):.1f} C")
|
| 140 |
-
|
| 141 |
-
if issues:
|
| 142 |
-
print(f"\n ISSUES ({len(issues)}):")
|
| 143 |
-
for issue in issues:
|
| 144 |
-
print(f" - {issue}")
|
| 145 |
-
else:
|
| 146 |
-
print("\n No data quality issues found.")
|
| 147 |
-
|
| 148 |
-
zones_with_data = sum(1 for s in stats.values() if s.get("days", 0) > 0)
|
| 149 |
-
assert zones_with_data == len(ZONES), f"Only {zones_with_data}/{len(ZONES)} zones have data"
|
| 150 |
-
print(f"\n All {zones_with_data} zones have data.\n")
|
| 151 |
-
|
| 152 |
-
return stats
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
# ======================================================================
|
| 156 |
-
# Step 3: Retrain XGBoost heat predictor on real data
|
| 157 |
-
# ======================================================================
|
| 158 |
-
|
| 159 |
-
def retrain_xgboost(data: dict[str, list[DailyReading]]):
|
| 160 |
-
log.info("=" * 60)
|
| 161 |
-
log.info("STEP 3: Retraining XGBoost heat predictor on real ERA5 data")
|
| 162 |
-
log.info("=" * 60)
|
| 163 |
-
|
| 164 |
-
from src.prediction.heat_forecast import HeatWavePredictor, CITY_THRESHOLDS, CITY_CLIMATE
|
| 165 |
-
from src.prediction.lstm_model import CITY_CLIMATE as _ # ensure import works
|
| 166 |
-
|
| 167 |
-
import xgboost as xgb
|
| 168 |
-
|
| 169 |
-
# We replicate the training logic from HeatWavePredictor.train() but
|
| 170 |
-
# use real ERA5 temps/humidity instead of synthetic series.
|
| 171 |
-
|
| 172 |
-
all_X = []
|
| 173 |
-
all_y = []
|
| 174 |
-
|
| 175 |
-
for zone in ZONES:
|
| 176 |
-
zid = zone.zone_id
|
| 177 |
-
readings = data.get(zid, [])
|
| 178 |
-
if len(readings) < 40:
|
| 179 |
-
log.warning("Zone %s has only %d readings, skipping for XGBoost training", zid, len(readings))
|
| 180 |
-
continue
|
| 181 |
-
|
| 182 |
-
city = zone.city
|
| 183 |
-
threshold = CITY_THRESHOLDS.get(city, 33.0)
|
| 184 |
-
|
| 185 |
-
# Extract time series from real data
|
| 186 |
-
temps = []
|
| 187 |
-
humidity = []
|
| 188 |
-
for r in readings:
|
| 189 |
-
t = r.temp_max_c
|
| 190 |
-
h = r.humidity_pct
|
| 191 |
-
if t is None:
|
| 192 |
-
continue
|
| 193 |
-
temps.append(t)
|
| 194 |
-
humidity.append(h if h is not None else 65.0)
|
| 195 |
-
|
| 196 |
-
n_days = len(temps)
|
| 197 |
-
if n_days < 40:
|
| 198 |
-
log.warning("Zone %s has only %d valid temp readings, skipping", zid, n_days)
|
| 199 |
-
continue
|
| 200 |
-
|
| 201 |
-
# Compute WBGT series
|
| 202 |
-
wbgt_series = [calculate_wbgt(t, h) for t, h in zip(temps, humidity)]
|
| 203 |
-
|
| 204 |
-
# Labels: trigger within next 7 days (2+ consecutive above threshold)
|
| 205 |
-
labels = [0] * n_days
|
| 206 |
-
for day in range(n_days - 7):
|
| 207 |
-
window = temps[day + 1:day + 8]
|
| 208 |
-
consec = 0
|
| 209 |
-
triggered = False
|
| 210 |
-
for t in window:
|
| 211 |
-
if t > threshold:
|
| 212 |
-
consec += 1
|
| 213 |
-
if consec >= 2:
|
| 214 |
-
triggered = True
|
| 215 |
-
break
|
| 216 |
-
else:
|
| 217 |
-
consec = 0
|
| 218 |
-
labels[day] = 1 if triggered else 0
|
| 219 |
-
|
| 220 |
-
# Vulnerability encoding
|
| 221 |
-
vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 222 |
-
zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
|
| 223 |
-
|
| 224 |
-
rng = np.random.default_rng(42)
|
| 225 |
-
|
| 226 |
-
# Build features (need 30-day lookback)
|
| 227 |
-
for day in range(30, n_days - 7):
|
| 228 |
-
t_window = temps[day - 30:day + 1]
|
| 229 |
-
h_window = humidity[day - 30:day + 1]
|
| 230 |
-
w_window = wbgt_series[day - 30:day + 1]
|
| 231 |
-
|
| 232 |
-
current_temp = t_window[-1]
|
| 233 |
-
current_wbgt = w_window[-1]
|
| 234 |
-
current_humidity = h_window[-1]
|
| 235 |
-
|
| 236 |
-
# Trend: slope of last 7 days
|
| 237 |
-
x7 = np.arange(7, dtype=np.float64)
|
| 238 |
-
y7 = np.array(t_window[-7:], dtype=np.float64)
|
| 239 |
-
temp_trend = float(np.polyfit(x7, y7, 1)[0])
|
| 240 |
-
|
| 241 |
-
# Anomaly: current vs 30-day mean
|
| 242 |
-
temp_anomaly = current_temp - float(np.mean(t_window))
|
| 243 |
-
|
| 244 |
-
# Soil moisture proxy
|
| 245 |
-
soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
|
| 246 |
-
|
| 247 |
-
# Rolling error (use neutral prior for training data)
|
| 248 |
-
rolling_err = rng.uniform(0.1, 0.5)
|
| 249 |
-
|
| 250 |
-
# Day-of-year encoding (use day index within 365-day cycle)
|
| 251 |
-
doy = day % 365
|
| 252 |
-
doy_sin = np.sin(2 * np.pi * doy / 365.0)
|
| 253 |
-
doy_cos = np.cos(2 * np.pi * doy / 365.0)
|
| 254 |
-
|
| 255 |
-
# Random hour for variety
|
| 256 |
-
hour = rng.integers(6, 19)
|
| 257 |
-
hour_sin = np.sin(2 * np.pi * hour / 24.0)
|
| 258 |
-
hour_cos = np.cos(2 * np.pi * hour / 24.0)
|
| 259 |
-
|
| 260 |
-
row = [
|
| 261 |
-
current_temp,
|
| 262 |
-
current_wbgt,
|
| 263 |
-
current_humidity,
|
| 264 |
-
temp_trend,
|
| 265 |
-
temp_anomaly,
|
| 266 |
-
soil_proxy,
|
| 267 |
-
rolling_err,
|
| 268 |
-
doy_sin,
|
| 269 |
-
doy_cos,
|
| 270 |
-
hour_sin,
|
| 271 |
-
hour_cos,
|
| 272 |
-
zone_vuln,
|
| 273 |
-
]
|
| 274 |
-
|
| 275 |
-
all_X.append(row)
|
| 276 |
-
all_y.append(labels[day])
|
| 277 |
-
|
| 278 |
-
X = np.array(all_X, dtype=np.float32)
|
| 279 |
-
y = np.array(all_y, dtype=np.int32)
|
| 280 |
-
|
| 281 |
-
pos_rate = y.sum() / len(y) if len(y) > 0 else 0
|
| 282 |
-
log.info(
|
| 283 |
-
"XGBoost training data: %d samples, %.1f%% positive rate",
|
| 284 |
-
len(X), pos_rate * 100,
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
# Create a fresh predictor to get the model object, then retrain
|
| 288 |
-
predictor = HeatWavePredictor.__new__(HeatWavePredictor)
|
| 289 |
-
predictor.model_path = HeatWavePredictor.__init__.__defaults__[0] # fallback
|
| 290 |
-
from pathlib import Path
|
| 291 |
-
predictor.model_path = Path(__file__).resolve().parents[1] / "models" / "heat_predictor_xgb.json"
|
| 292 |
-
predictor._rolling_errors = []
|
| 293 |
-
|
| 294 |
-
model = xgb.XGBClassifier(
|
| 295 |
-
n_estimators=150,
|
| 296 |
-
max_depth=5,
|
| 297 |
-
learning_rate=0.1,
|
| 298 |
-
eval_metric="logloss",
|
| 299 |
-
random_state=42,
|
| 300 |
-
)
|
| 301 |
-
|
| 302 |
-
# Train/validation split (temporal: first 75% train, last 25% val)
|
| 303 |
-
split = int(len(X) * 0.75)
|
| 304 |
-
X_train, X_val = X[:split], X[split:]
|
| 305 |
-
y_train, y_val = y[:split], y[split:]
|
| 306 |
-
|
| 307 |
-
model.fit(
|
| 308 |
-
X_train, y_train,
|
| 309 |
-
eval_set=[(X_val, y_val)],
|
| 310 |
-
verbose=False,
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
# Evaluate on validation set
|
| 314 |
-
from sklearn.metrics import roc_auc_score, precision_score, recall_score
|
| 315 |
-
|
| 316 |
-
val_probs = model.predict_proba(X_val)[:, 1]
|
| 317 |
-
val_preds = (val_probs > 0.5).astype(int)
|
| 318 |
-
|
| 319 |
-
if len(set(y_val)) > 1:
|
| 320 |
-
auroc = roc_auc_score(y_val, val_probs)
|
| 321 |
-
precision = precision_score(y_val, val_preds, zero_division=0)
|
| 322 |
-
recall = recall_score(y_val, val_preds, zero_division=0)
|
| 323 |
-
else:
|
| 324 |
-
auroc, precision, recall = 0.5, 0.0, 0.0
|
| 325 |
-
|
| 326 |
-
print(f"\n--- XGBoost Results (real ERA5 data) ---")
|
| 327 |
-
print(f" Training samples: {len(X_train)}")
|
| 328 |
-
print(f" Validation samples: {len(X_val)}")
|
| 329 |
-
print(f" Positive rate: {pos_rate:.1%}")
|
| 330 |
-
print(f" Val AUROC: {auroc:.4f}")
|
| 331 |
-
print(f" Val Precision: {precision:.4f}")
|
| 332 |
-
print(f" Val Recall: {recall:.4f}")
|
| 333 |
-
|
| 334 |
-
# Save model
|
| 335 |
-
predictor.model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 336 |
-
model.save_model(str(predictor.model_path))
|
| 337 |
-
log.info("XGBoost model saved to %s", predictor.model_path)
|
| 338 |
-
|
| 339 |
-
return {
|
| 340 |
-
"train_samples": len(X_train),
|
| 341 |
-
"val_samples": len(X_val),
|
| 342 |
-
"positive_rate": round(pos_rate, 4),
|
| 343 |
-
"val_auroc": round(auroc, 4),
|
| 344 |
-
"val_precision": round(precision, 4),
|
| 345 |
-
"val_recall": round(recall, 4),
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
# ======================================================================
|
| 350 |
-
# Step 4: Retrain LSTM on real data
|
| 351 |
-
# ======================================================================
|
| 352 |
-
|
| 353 |
-
def retrain_lstm(data: dict[str, list[DailyReading]]):
|
| 354 |
-
log.info("=" * 60)
|
| 355 |
-
log.info("STEP 4: Retraining LSTM on real ERA5 data")
|
| 356 |
-
log.info("=" * 60)
|
| 357 |
-
|
| 358 |
-
from src.prediction.lstm_model import LSTMTrainer
|
| 359 |
-
|
| 360 |
-
# Convert ERA5 DailyReading objects into the format the LSTM trainer expects:
|
| 361 |
-
# dict of zone_id -> list of dicts with keys: temp_max_c, humidity_pct, wind_speed_ms, city
|
| 362 |
-
zone_readings = {}
|
| 363 |
-
for zone in ZONES:
|
| 364 |
-
zid = zone.zone_id
|
| 365 |
-
readings = data.get(zid, [])
|
| 366 |
-
days = []
|
| 367 |
-
for r in readings:
|
| 368 |
-
if r.temp_max_c is None:
|
| 369 |
-
continue
|
| 370 |
-
days.append({
|
| 371 |
-
"temp_max_c": r.temp_max_c,
|
| 372 |
-
"humidity_pct": r.humidity_pct if r.humidity_pct is not None else 65.0,
|
| 373 |
-
"wind_speed_ms": r.wind_speed_ms if r.wind_speed_ms is not None else 3.0,
|
| 374 |
-
"city": zone.city,
|
| 375 |
-
})
|
| 376 |
-
if len(days) > 30:
|
| 377 |
-
zone_readings[zid] = days
|
| 378 |
-
log.info("Zone %s: %d valid readings for LSTM", zid, len(days))
|
| 379 |
-
else:
|
| 380 |
-
log.warning("Zone %s: only %d valid readings, skipping LSTM", zid, len(days))
|
| 381 |
-
|
| 382 |
-
log.info("Training LSTM on %d zones", len(zone_readings))
|
| 383 |
-
trainer = LSTMTrainer(epochs=50, patience=5)
|
| 384 |
-
metrics = trainer.train(zone_readings)
|
| 385 |
-
|
| 386 |
-
print(f"\n--- LSTM Results (real ERA5 data) ---")
|
| 387 |
-
for k, v in metrics.items():
|
| 388 |
-
print(f" {k}: {v}")
|
| 389 |
-
|
| 390 |
-
return metrics
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
# ======================================================================
|
| 394 |
-
# Step 5: Verify UHI model with real ERA5 temps
|
| 395 |
-
# ======================================================================
|
| 396 |
-
|
| 397 |
-
def verify_uhi(data: dict[str, list[DailyReading]]):
|
| 398 |
-
log.info("=" * 60)
|
| 399 |
-
log.info("STEP 5: Verifying UHI model with real ERA5 temperatures")
|
| 400 |
-
log.info("=" * 60)
|
| 401 |
-
|
| 402 |
-
from src.downscaling.uhi_model import UHICorrector
|
| 403 |
-
|
| 404 |
-
corrector = UHICorrector()
|
| 405 |
-
|
| 406 |
-
results = {}
|
| 407 |
-
for zone in ZONES:
|
| 408 |
-
zid = zone.zone_id
|
| 409 |
-
readings = data.get(zid, [])
|
| 410 |
-
if not readings:
|
| 411 |
-
continue
|
| 412 |
-
|
| 413 |
-
# Use real ERA5 temps as grid baseline
|
| 414 |
-
real_temps = [r.temp_max_c for r in readings if r.temp_max_c is not None]
|
| 415 |
-
if not real_temps:
|
| 416 |
-
continue
|
| 417 |
-
|
| 418 |
-
# Sample a few real temps and apply UHI correction
|
| 419 |
-
sample_indices = np.linspace(0, len(real_temps) - 1, min(20, len(real_temps)), dtype=int)
|
| 420 |
-
deltas = []
|
| 421 |
-
corrected_temps = []
|
| 422 |
-
|
| 423 |
-
for idx in sample_indices:
|
| 424 |
-
grid_temp = real_temps[idx]
|
| 425 |
-
corrected, delta, conf = corrector.correct_temperature(zone, grid_temp, hour=14, month=1)
|
| 426 |
-
deltas.append(delta)
|
| 427 |
-
corrected_temps.append(corrected)
|
| 428 |
-
|
| 429 |
-
results[zid] = {
|
| 430 |
-
"city": zone.city,
|
| 431 |
-
"settlement": zone.settlement_type,
|
| 432 |
-
"mean_grid_temp": round(sum(real_temps) / len(real_temps), 1),
|
| 433 |
-
"mean_uhi_delta": round(sum(deltas) / len(deltas), 2),
|
| 434 |
-
"mean_corrected": round(sum(corrected_temps) / len(corrected_temps), 1),
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
print(f"\n--- UHI Verification with Real ERA5 Temps ---")
|
| 438 |
-
print(f"{'Zone':<12} {'City':<16} {'Type':<12} {'Grid T':>7} {'UHI +':>7} {'Corrected':>10}")
|
| 439 |
-
print("-" * 70)
|
| 440 |
-
for zid, r in results.items():
|
| 441 |
-
print(
|
| 442 |
-
f"{zid:<12} {r['city']:<16} {r['settlement']:<12} "
|
| 443 |
-
f"{r['mean_grid_temp']:>6.1f}C {r['mean_uhi_delta']:>+6.2f}C {r['mean_corrected']:>9.1f}C"
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
return results
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
# ======================================================================
|
| 450 |
-
# Main
|
| 451 |
-
# ======================================================================
|
| 452 |
-
|
| 453 |
-
def main():
|
| 454 |
-
t_start = time.time()
|
| 455 |
-
|
| 456 |
-
# Step 1: Fetch
|
| 457 |
-
data = fetch_data()
|
| 458 |
-
|
| 459 |
-
# Step 2: Validate
|
| 460 |
-
data_stats = validate_data(data)
|
| 461 |
-
|
| 462 |
-
# Step 3: XGBoost
|
| 463 |
-
xgb_metrics = retrain_xgboost(data)
|
| 464 |
-
|
| 465 |
-
# Step 4: LSTM
|
| 466 |
-
lstm_metrics = retrain_lstm(data)
|
| 467 |
-
|
| 468 |
-
# Step 5: UHI verification
|
| 469 |
-
uhi_results = verify_uhi(data)
|
| 470 |
-
|
| 471 |
-
total_time = time.time() - t_start
|
| 472 |
-
|
| 473 |
-
print("\n" + "=" * 60)
|
| 474 |
-
print("TRAINING COMPLETE")
|
| 475 |
-
print("=" * 60)
|
| 476 |
-
|
| 477 |
-
total_days = sum(
|
| 478 |
-
len([r for r in data.get(z.zone_id, []) if r.temp_max_c is not None])
|
| 479 |
-
for z in ZONES
|
| 480 |
-
)
|
| 481 |
-
print(f" Total real data points: {total_days} zone-days across {len(ZONES)} zones")
|
| 482 |
-
print(f" XGBoost val AUROC: {xgb_metrics['val_auroc']:.4f}")
|
| 483 |
-
print(f" LSTM val AUROC: {lstm_metrics.get('val_auroc', 'N/A')}")
|
| 484 |
-
print(f" LSTM epochs trained: {lstm_metrics.get('epochs_trained', 'N/A')}")
|
| 485 |
-
print(f" LSTM final val loss: {lstm_metrics.get('val_loss', 'N/A')}")
|
| 486 |
-
print(f" Total time: {total_time:.1f}s")
|
| 487 |
-
print()
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
if __name__ == "__main__":
|
| 491 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,660 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Train all ML models on real NASA POWER daily data.
|
| 3 |
-
|
| 4 |
-
Steps:
|
| 5 |
-
1. Fetch 2 years of NASA POWER data for all 20 zones (with caching)
|
| 6 |
-
2. Validate data quality (coverage, temp ranges, nulls)
|
| 7 |
-
3. Retrain LSTM heat predictor on real data
|
| 8 |
-
4. Retrain XGBoost heat predictor on real data
|
| 9 |
-
5. Verify UHI model still works (no retraining — literature-calibrated)
|
| 10 |
-
|
| 11 |
-
Usage:
|
| 12 |
-
python3 scripts/train_on_nasa_power.py
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import sys
|
| 16 |
-
import os
|
| 17 |
-
import json
|
| 18 |
-
import time
|
| 19 |
-
import logging
|
| 20 |
-
from datetime import date, timedelta
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
|
| 23 |
-
import numpy as np
|
| 24 |
-
import httpx
|
| 25 |
-
|
| 26 |
-
# Project root on sys.path
|
| 27 |
-
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 28 |
-
sys.path.insert(0, str(PROJECT_ROOT))
|
| 29 |
-
|
| 30 |
-
from config import ZONES, ZONE_MAP
|
| 31 |
-
from src.indexing.heat_index import calculate_wbgt
|
| 32 |
-
|
| 33 |
-
logging.basicConfig(
|
| 34 |
-
level=logging.INFO,
|
| 35 |
-
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
| 36 |
-
datefmt="%H:%M:%S",
|
| 37 |
-
)
|
| 38 |
-
log = logging.getLogger("train_nasa_power")
|
| 39 |
-
|
| 40 |
-
# Paths
|
| 41 |
-
CACHE_DIR = PROJECT_ROOT / "data" / "nasa_power_cache"
|
| 42 |
-
MODELS_DIR = PROJECT_ROOT / "models"
|
| 43 |
-
|
| 44 |
-
# NASA POWER config (from config.py)
|
| 45 |
-
NASA_POWER_URL = "https://power.larc.nasa.gov/api/temporal/daily/point"
|
| 46 |
-
NASA_POWER_PARAMS = ["T2M", "T2M_MAX", "T2M_MIN", "RH2M", "WS2M", "ALLSKY_SFC_SW_DWN"]
|
| 47 |
-
NASA_MISSING = -999.0
|
| 48 |
-
|
| 49 |
-
# Date range: 2 years ending ~yesterday (API has 2-day lag)
|
| 50 |
-
END_DATE = date(2026, 3, 29)
|
| 51 |
-
START_DATE = date(2024, 4, 1)
|
| 52 |
-
|
| 53 |
-
# Expected temp ranges per city (max daily temps, deg C)
|
| 54 |
-
EXPECTED_RANGES = {
|
| 55 |
-
"Nairobi": (18, 35),
|
| 56 |
-
"Dar es Salaam": (25, 40),
|
| 57 |
-
"Kampala": (22, 36),
|
| 58 |
-
"Kigali": (20, 34),
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# ======================================================================
|
| 63 |
-
# Step 1: Fetch NASA POWER data (with caching)
|
| 64 |
-
# ======================================================================
|
| 65 |
-
|
| 66 |
-
def _safe_val(val) -> float | None:
|
| 67 |
-
"""Return None for NASA's -999 missing sentinel."""
|
| 68 |
-
if val is None:
|
| 69 |
-
return None
|
| 70 |
-
try:
|
| 71 |
-
f = float(val)
|
| 72 |
-
return None if f == NASA_MISSING else round(f, 2)
|
| 73 |
-
except (ValueError, TypeError):
|
| 74 |
-
return None
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def fetch_zone_data(zone, start: date, end: date) -> list[dict]:
|
| 78 |
-
"""Fetch NASA POWER data for a single zone, using cache if available.
|
| 79 |
-
|
| 80 |
-
Returns list of dicts with keys:
|
| 81 |
-
date, temp_mean_c, temp_max_c, temp_min_c, humidity_pct,
|
| 82 |
-
wind_speed_ms, solar_radiation
|
| 83 |
-
"""
|
| 84 |
-
cache_file = CACHE_DIR / f"{zone.zone_id}.json"
|
| 85 |
-
|
| 86 |
-
# Check cache
|
| 87 |
-
if cache_file.exists():
|
| 88 |
-
with open(cache_file) as f:
|
| 89 |
-
cached = json.load(f)
|
| 90 |
-
# Verify cache covers our date range
|
| 91 |
-
if (cached.get("start_date") == start.isoformat()
|
| 92 |
-
and cached.get("end_date") == end.isoformat()
|
| 93 |
-
and len(cached.get("readings", [])) > 0):
|
| 94 |
-
log.info("Cache hit for zone %s (%d readings)", zone.zone_id, len(cached["readings"]))
|
| 95 |
-
return cached["readings"]
|
| 96 |
-
|
| 97 |
-
# Fetch from NASA POWER API
|
| 98 |
-
params = {
|
| 99 |
-
"parameters": ",".join(NASA_POWER_PARAMS),
|
| 100 |
-
"community": "AG",
|
| 101 |
-
"longitude": zone.longitude,
|
| 102 |
-
"latitude": zone.latitude,
|
| 103 |
-
"start": start.strftime("%Y%m%d"),
|
| 104 |
-
"end": end.strftime("%Y%m%d"),
|
| 105 |
-
"format": "JSON",
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
log.info("Fetching NASA POWER for zone %s (%s, %s) ...", zone.zone_id, zone.city, zone.name)
|
| 109 |
-
|
| 110 |
-
max_retries = 3
|
| 111 |
-
for attempt in range(max_retries):
|
| 112 |
-
try:
|
| 113 |
-
resp = httpx.get(NASA_POWER_URL, params=params, timeout=60.0)
|
| 114 |
-
|
| 115 |
-
if resp.status_code == 429:
|
| 116 |
-
wait = 15 * (attempt + 1)
|
| 117 |
-
log.warning(" Rate limited (429), backing off %ds ...", wait)
|
| 118 |
-
time.sleep(wait)
|
| 119 |
-
continue
|
| 120 |
-
|
| 121 |
-
resp.raise_for_status()
|
| 122 |
-
data = resp.json()
|
| 123 |
-
break
|
| 124 |
-
except httpx.TimeoutException:
|
| 125 |
-
wait = 10 * (attempt + 1)
|
| 126 |
-
log.warning(" Timeout (attempt %d/%d), retrying in %ds ...", attempt + 1, max_retries, wait)
|
| 127 |
-
time.sleep(wait)
|
| 128 |
-
except httpx.HTTPStatusError as exc:
|
| 129 |
-
if exc.response.status_code >= 500:
|
| 130 |
-
wait = 10 * (attempt + 1)
|
| 131 |
-
log.warning(" Server error %d (attempt %d/%d), retrying ...", exc.response.status_code, attempt + 1, max_retries)
|
| 132 |
-
time.sleep(wait)
|
| 133 |
-
else:
|
| 134 |
-
log.error(" HTTP %d: %s", exc.response.status_code, exc.response.text[:200])
|
| 135 |
-
return []
|
| 136 |
-
else:
|
| 137 |
-
log.error(" Failed after %d attempts for zone %s", max_retries, zone.zone_id)
|
| 138 |
-
return []
|
| 139 |
-
|
| 140 |
-
# Parse response
|
| 141 |
-
try:
|
| 142 |
-
props = data["properties"]["parameter"]
|
| 143 |
-
except (KeyError, TypeError):
|
| 144 |
-
log.error(" Unexpected response structure for zone %s", zone.zone_id)
|
| 145 |
-
return []
|
| 146 |
-
|
| 147 |
-
t2m_data = props.get("T2M", {})
|
| 148 |
-
t2m_max_data = props.get("T2M_MAX", {})
|
| 149 |
-
t2m_min_data = props.get("T2M_MIN", {})
|
| 150 |
-
rh2m_data = props.get("RH2M", {})
|
| 151 |
-
ws2m_data = props.get("WS2M", {})
|
| 152 |
-
solar_data = props.get("ALLSKY_SFC_SW_DWN", {})
|
| 153 |
-
|
| 154 |
-
all_days = sorted(t2m_data.keys())
|
| 155 |
-
readings = []
|
| 156 |
-
|
| 157 |
-
for day_str in all_days:
|
| 158 |
-
try:
|
| 159 |
-
formatted_date = f"{day_str[:4]}-{day_str[4:6]}-{day_str[6:8]}"
|
| 160 |
-
except (IndexError, TypeError):
|
| 161 |
-
continue
|
| 162 |
-
|
| 163 |
-
temp_mean = _safe_val(t2m_data.get(day_str))
|
| 164 |
-
temp_max = _safe_val(t2m_max_data.get(day_str))
|
| 165 |
-
temp_min = _safe_val(t2m_min_data.get(day_str))
|
| 166 |
-
humidity = _safe_val(rh2m_data.get(day_str))
|
| 167 |
-
wind = _safe_val(ws2m_data.get(day_str))
|
| 168 |
-
solar = _safe_val(solar_data.get(day_str))
|
| 169 |
-
|
| 170 |
-
# Skip days where key fields are all missing
|
| 171 |
-
if temp_max is None and temp_mean is None:
|
| 172 |
-
continue
|
| 173 |
-
|
| 174 |
-
readings.append({
|
| 175 |
-
"date": formatted_date,
|
| 176 |
-
"temp_mean_c": temp_mean,
|
| 177 |
-
"temp_max_c": temp_max if temp_max is not None else temp_mean,
|
| 178 |
-
"temp_min_c": temp_min,
|
| 179 |
-
"humidity_pct": humidity,
|
| 180 |
-
"wind_speed_ms": wind,
|
| 181 |
-
"solar_radiation": solar,
|
| 182 |
-
})
|
| 183 |
-
|
| 184 |
-
# Cache the results
|
| 185 |
-
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 186 |
-
cache_obj = {
|
| 187 |
-
"zone_id": zone.zone_id,
|
| 188 |
-
"city": zone.city,
|
| 189 |
-
"start_date": start.isoformat(),
|
| 190 |
-
"end_date": end.isoformat(),
|
| 191 |
-
"latitude": zone.latitude,
|
| 192 |
-
"longitude": zone.longitude,
|
| 193 |
-
"readings": readings,
|
| 194 |
-
}
|
| 195 |
-
with open(cache_file, "w") as f:
|
| 196 |
-
json.dump(cache_obj, f, indent=2)
|
| 197 |
-
|
| 198 |
-
log.info(" Zone %s: %d days fetched and cached", zone.zone_id, len(readings))
|
| 199 |
-
return readings
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def fetch_all_zones():
|
| 203 |
-
"""Fetch NASA POWER data for all zones with rate-limiting delay."""
|
| 204 |
-
log.info("=" * 60)
|
| 205 |
-
log.info("STEP 1: Fetching NASA POWER data for %d zones", len(ZONES))
|
| 206 |
-
log.info(" Date range: %s to %s (%d days)", START_DATE, END_DATE, (END_DATE - START_DATE).days + 1)
|
| 207 |
-
log.info("=" * 60)
|
| 208 |
-
|
| 209 |
-
all_data: dict[str, list[dict]] = {}
|
| 210 |
-
t0 = time.time()
|
| 211 |
-
|
| 212 |
-
for i, zone in enumerate(ZONES):
|
| 213 |
-
readings = fetch_zone_data(zone, START_DATE, END_DATE)
|
| 214 |
-
all_data[zone.zone_id] = readings
|
| 215 |
-
|
| 216 |
-
# Rate limiting delay between API calls (skip for cached results)
|
| 217 |
-
if i < len(ZONES) - 1:
|
| 218 |
-
time.sleep(0.5)
|
| 219 |
-
|
| 220 |
-
elapsed = time.time() - t0
|
| 221 |
-
total_readings = sum(len(v) for v in all_data.values())
|
| 222 |
-
zones_with_data = sum(1 for v in all_data.values() if v)
|
| 223 |
-
|
| 224 |
-
log.info(
|
| 225 |
-
"Fetch complete in %.1fs: %d/%d zones with data, %d total readings",
|
| 226 |
-
elapsed, zones_with_data, len(ZONES), total_readings,
|
| 227 |
-
)
|
| 228 |
-
|
| 229 |
-
return all_data
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
# ======================================================================
|
| 233 |
-
# Step 2: Validate data quality
|
| 234 |
-
# ======================================================================
|
| 235 |
-
|
| 236 |
-
def validate_data(data: dict[str, list[dict]]):
|
| 237 |
-
log.info("=" * 60)
|
| 238 |
-
log.info("STEP 2: Validating NASA POWER data quality")
|
| 239 |
-
log.info("=" * 60)
|
| 240 |
-
|
| 241 |
-
issues = []
|
| 242 |
-
stats = {}
|
| 243 |
-
|
| 244 |
-
for zone in ZONES:
|
| 245 |
-
zid = zone.zone_id
|
| 246 |
-
readings = data.get(zid, [])
|
| 247 |
-
|
| 248 |
-
if not readings:
|
| 249 |
-
issues.append(f"{zid}: NO DATA")
|
| 250 |
-
stats[zid] = {"days": 0, "issue": "no data"}
|
| 251 |
-
continue
|
| 252 |
-
|
| 253 |
-
temps = [r["temp_max_c"] for r in readings if r["temp_max_c"] is not None]
|
| 254 |
-
humids = [r["humidity_pct"] for r in readings if r["humidity_pct"] is not None]
|
| 255 |
-
winds = [r["wind_speed_ms"] for r in readings if r["wind_speed_ms"] is not None]
|
| 256 |
-
|
| 257 |
-
if not temps:
|
| 258 |
-
issues.append(f"{zid}: all temps are null")
|
| 259 |
-
stats[zid] = {"days": len(readings), "issue": "all null temps"}
|
| 260 |
-
continue
|
| 261 |
-
|
| 262 |
-
t_min, t_max = min(temps), max(temps)
|
| 263 |
-
t_mean = sum(temps) / len(temps)
|
| 264 |
-
|
| 265 |
-
# Check physical reasonableness
|
| 266 |
-
exp_lo, exp_hi = EXPECTED_RANGES.get(zone.city, (15, 42))
|
| 267 |
-
if t_min < exp_lo - 5 or t_max > exp_hi + 5:
|
| 268 |
-
issues.append(
|
| 269 |
-
f"{zid} ({zone.city}): temp range [{t_min:.1f}, {t_max:.1f}] "
|
| 270 |
-
f"outside expected [{exp_lo-5}, {exp_hi+5}]"
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
null_count = sum(1 for r in readings if r["temp_max_c"] is None)
|
| 274 |
-
|
| 275 |
-
stats[zid] = {
|
| 276 |
-
"days": len(readings),
|
| 277 |
-
"temp_days": len(temps),
|
| 278 |
-
"temp_min": round(t_min, 1),
|
| 279 |
-
"temp_max": round(t_max, 1),
|
| 280 |
-
"temp_mean": round(t_mean, 1),
|
| 281 |
-
"humidity_mean": round(sum(humids)/len(humids), 1) if humids else None,
|
| 282 |
-
"wind_mean": round(sum(winds)/len(winds), 1) if winds else None,
|
| 283 |
-
"null_temps": null_count,
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
-
# Print summary
|
| 287 |
-
print("\n--- NASA POWER Data Summary ---")
|
| 288 |
-
print(f"{'Zone':<12} {'City':<16} {'Days':>5} {'Temp min':>9} {'Temp max':>9} {'Temp mean':>10} {'Humidity':>9} {'Nulls':>6}")
|
| 289 |
-
print("-" * 90)
|
| 290 |
-
|
| 291 |
-
for zone in ZONES:
|
| 292 |
-
s = stats.get(zone.zone_id, {})
|
| 293 |
-
days = s.get("days", 0)
|
| 294 |
-
t_lo = s.get("temp_min", "N/A")
|
| 295 |
-
t_hi = s.get("temp_max", "N/A")
|
| 296 |
-
t_mn = s.get("temp_mean", "N/A")
|
| 297 |
-
hum = s.get("humidity_mean", "N/A")
|
| 298 |
-
nulls = s.get("null_temps", "N/A")
|
| 299 |
-
if isinstance(t_lo, float):
|
| 300 |
-
print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {t_lo:>9.1f} {t_hi:>9.1f} {t_mn:>10.1f} {hum:>9} {nulls:>6}")
|
| 301 |
-
else:
|
| 302 |
-
print(f"{zone.zone_id:<12} {zone.city:<16} {days:>5} {'N/A':>9} {'N/A':>9} {'N/A':>10} {'N/A':>9} {'N/A':>6}")
|
| 303 |
-
|
| 304 |
-
if issues:
|
| 305 |
-
print(f"\n ISSUES ({len(issues)}):")
|
| 306 |
-
for issue in issues:
|
| 307 |
-
print(f" - {issue}")
|
| 308 |
-
else:
|
| 309 |
-
print("\n No data quality issues found.")
|
| 310 |
-
|
| 311 |
-
zones_with_data = sum(1 for s in stats.values() if s.get("days", 0) > 0)
|
| 312 |
-
print(f"\n {zones_with_data}/{len(ZONES)} zones have data.\n")
|
| 313 |
-
|
| 314 |
-
return stats
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
# ======================================================================
|
| 318 |
-
# Step 3: Retrain LSTM on real NASA POWER data
|
| 319 |
-
# ======================================================================
|
| 320 |
-
|
| 321 |
-
def retrain_lstm(data: dict[str, list[dict]]):
|
| 322 |
-
log.info("=" * 60)
|
| 323 |
-
log.info("STEP 3: Retraining LSTM on real NASA POWER data")
|
| 324 |
-
log.info("=" * 60)
|
| 325 |
-
|
| 326 |
-
from src.prediction.lstm_model import LSTMTrainer
|
| 327 |
-
|
| 328 |
-
# Delete old model files
|
| 329 |
-
lstm_model_path = MODELS_DIR / "heat_lstm.pt"
|
| 330 |
-
lstm_norm_path = MODELS_DIR / "lstm_norm.json"
|
| 331 |
-
for p in [lstm_model_path, lstm_norm_path]:
|
| 332 |
-
if p.exists():
|
| 333 |
-
p.unlink()
|
| 334 |
-
log.info("Deleted old model: %s", p)
|
| 335 |
-
|
| 336 |
-
# Convert NASA POWER data to the format the LSTM trainer expects:
|
| 337 |
-
# dict of zone_id -> list of dicts with keys: temp_max_c, humidity_pct, wind_speed_ms, city
|
| 338 |
-
zone_readings = {}
|
| 339 |
-
for zone in ZONES:
|
| 340 |
-
zid = zone.zone_id
|
| 341 |
-
readings = data.get(zid, [])
|
| 342 |
-
days = []
|
| 343 |
-
for r in readings:
|
| 344 |
-
if r["temp_max_c"] is None:
|
| 345 |
-
continue
|
| 346 |
-
days.append({
|
| 347 |
-
"temp_max_c": r["temp_max_c"],
|
| 348 |
-
"humidity_pct": r["humidity_pct"] if r["humidity_pct"] is not None else 65.0,
|
| 349 |
-
"wind_speed_ms": r["wind_speed_ms"] if r["wind_speed_ms"] is not None else 3.0,
|
| 350 |
-
"city": zone.city,
|
| 351 |
-
})
|
| 352 |
-
if len(days) >= 22: # Need at least seq_len(14) + forecast_horizon(7) + 1
|
| 353 |
-
zone_readings[zid] = days
|
| 354 |
-
log.info("Zone %s: %d valid readings for LSTM", zid, len(days))
|
| 355 |
-
else:
|
| 356 |
-
log.warning("Zone %s: only %d valid readings, skipping LSTM", zid, len(days))
|
| 357 |
-
|
| 358 |
-
log.info("Training LSTM on %d zones with real NASA POWER data", len(zone_readings))
|
| 359 |
-
trainer = LSTMTrainer(epochs=50, patience=5)
|
| 360 |
-
metrics = trainer.train(zone_readings)
|
| 361 |
-
|
| 362 |
-
print(f"\n--- LSTM Results (real NASA POWER data) ---")
|
| 363 |
-
for k, v in metrics.items():
|
| 364 |
-
print(f" {k}: {v}")
|
| 365 |
-
|
| 366 |
-
return metrics
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
# ======================================================================
|
| 370 |
-
# Step 4: Retrain XGBoost heat predictor on real data
|
| 371 |
-
# ======================================================================
|
| 372 |
-
|
| 373 |
-
def retrain_xgboost(data: dict[str, list[dict]]):
|
| 374 |
-
log.info("=" * 60)
|
| 375 |
-
log.info("STEP 4: Retraining XGBoost heat predictor on real NASA POWER data")
|
| 376 |
-
log.info("=" * 60)
|
| 377 |
-
|
| 378 |
-
import xgboost as xgb
|
| 379 |
-
from src.prediction.lstm_model import CITY_THRESHOLDS
|
| 380 |
-
|
| 381 |
-
# Delete old model file
|
| 382 |
-
xgb_model_path = MODELS_DIR / "heat_predictor_xgb.json"
|
| 383 |
-
if xgb_model_path.exists():
|
| 384 |
-
xgb_model_path.unlink()
|
| 385 |
-
log.info("Deleted old XGBoost model: %s", xgb_model_path)
|
| 386 |
-
|
| 387 |
-
all_X = []
|
| 388 |
-
all_y = []
|
| 389 |
-
|
| 390 |
-
for zone in ZONES:
|
| 391 |
-
zid = zone.zone_id
|
| 392 |
-
readings = data.get(zid, [])
|
| 393 |
-
if len(readings) < 40:
|
| 394 |
-
log.warning("Zone %s has only %d readings, skipping for XGBoost training", zid, len(readings))
|
| 395 |
-
continue
|
| 396 |
-
|
| 397 |
-
city = zone.city
|
| 398 |
-
threshold = CITY_THRESHOLDS.get(city, 33.0)
|
| 399 |
-
|
| 400 |
-
# Extract time series, filtering nulls
|
| 401 |
-
temps = []
|
| 402 |
-
humidity = []
|
| 403 |
-
for r in readings:
|
| 404 |
-
t = r["temp_max_c"]
|
| 405 |
-
h = r["humidity_pct"]
|
| 406 |
-
if t is None:
|
| 407 |
-
continue
|
| 408 |
-
temps.append(t)
|
| 409 |
-
humidity.append(h if h is not None else 65.0)
|
| 410 |
-
|
| 411 |
-
n_days = len(temps)
|
| 412 |
-
if n_days < 40:
|
| 413 |
-
log.warning("Zone %s has only %d valid temp readings, skipping", zid, n_days)
|
| 414 |
-
continue
|
| 415 |
-
|
| 416 |
-
# Compute WBGT series
|
| 417 |
-
wbgt_series = [calculate_wbgt(t, h) for t, h in zip(temps, humidity)]
|
| 418 |
-
|
| 419 |
-
# Labels: trigger within next 7 days (2+ consecutive above threshold)
|
| 420 |
-
labels = [0] * n_days
|
| 421 |
-
for day in range(n_days - 7):
|
| 422 |
-
window = temps[day + 1:day + 8]
|
| 423 |
-
consec = 0
|
| 424 |
-
triggered = False
|
| 425 |
-
for t in window:
|
| 426 |
-
if t > threshold:
|
| 427 |
-
consec += 1
|
| 428 |
-
if consec >= 2:
|
| 429 |
-
triggered = True
|
| 430 |
-
break
|
| 431 |
-
else:
|
| 432 |
-
consec = 0
|
| 433 |
-
labels[day] = 1 if triggered else 0
|
| 434 |
-
|
| 435 |
-
# Vulnerability encoding
|
| 436 |
-
vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 437 |
-
zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
|
| 438 |
-
|
| 439 |
-
rng = np.random.default_rng(42)
|
| 440 |
-
|
| 441 |
-
# Build features (need 30-day lookback)
|
| 442 |
-
for day in range(30, n_days - 7):
|
| 443 |
-
t_window = temps[day - 30:day + 1]
|
| 444 |
-
h_window = humidity[day - 30:day + 1]
|
| 445 |
-
w_window = wbgt_series[day - 30:day + 1]
|
| 446 |
-
|
| 447 |
-
current_temp = t_window[-1]
|
| 448 |
-
current_wbgt = w_window[-1]
|
| 449 |
-
current_humidity = h_window[-1]
|
| 450 |
-
|
| 451 |
-
# Trend: slope of last 7 days
|
| 452 |
-
x7 = np.arange(7, dtype=np.float64)
|
| 453 |
-
y7 = np.array(t_window[-7:], dtype=np.float64)
|
| 454 |
-
temp_trend = float(np.polyfit(x7, y7, 1)[0])
|
| 455 |
-
|
| 456 |
-
# Anomaly: current vs 30-day mean
|
| 457 |
-
temp_anomaly = current_temp - float(np.mean(t_window))
|
| 458 |
-
|
| 459 |
-
# Soil moisture proxy
|
| 460 |
-
soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
|
| 461 |
-
|
| 462 |
-
# Rolling error (neutral prior for training data)
|
| 463 |
-
rolling_err = rng.uniform(0.1, 0.5)
|
| 464 |
-
|
| 465 |
-
# Day-of-year encoding
|
| 466 |
-
doy = day % 365
|
| 467 |
-
doy_sin = np.sin(2 * np.pi * doy / 365.0)
|
| 468 |
-
doy_cos = np.cos(2 * np.pi * doy / 365.0)
|
| 469 |
-
|
| 470 |
-
# Random hour for variety
|
| 471 |
-
hour = rng.integers(6, 19)
|
| 472 |
-
hour_sin = np.sin(2 * np.pi * hour / 24.0)
|
| 473 |
-
hour_cos = np.cos(2 * np.pi * hour / 24.0)
|
| 474 |
-
|
| 475 |
-
row = [
|
| 476 |
-
current_temp,
|
| 477 |
-
current_wbgt,
|
| 478 |
-
current_humidity,
|
| 479 |
-
temp_trend,
|
| 480 |
-
temp_anomaly,
|
| 481 |
-
soil_proxy,
|
| 482 |
-
rolling_err,
|
| 483 |
-
doy_sin,
|
| 484 |
-
doy_cos,
|
| 485 |
-
hour_sin,
|
| 486 |
-
hour_cos,
|
| 487 |
-
zone_vuln,
|
| 488 |
-
]
|
| 489 |
-
|
| 490 |
-
all_X.append(row)
|
| 491 |
-
all_y.append(labels[day])
|
| 492 |
-
|
| 493 |
-
X = np.array(all_X, dtype=np.float32)
|
| 494 |
-
y = np.array(all_y, dtype=np.int32)
|
| 495 |
-
|
| 496 |
-
pos_rate = y.sum() / len(y) if len(y) > 0 else 0
|
| 497 |
-
log.info(
|
| 498 |
-
"XGBoost training data: %d samples, %.1f%% positive rate",
|
| 499 |
-
len(X), pos_rate * 100,
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
model = xgb.XGBClassifier(
|
| 503 |
-
n_estimators=150,
|
| 504 |
-
max_depth=5,
|
| 505 |
-
learning_rate=0.1,
|
| 506 |
-
eval_metric="logloss",
|
| 507 |
-
random_state=42,
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
# Train/validation split (temporal: first 75% train, last 25% val)
|
| 511 |
-
split = int(len(X) * 0.75)
|
| 512 |
-
X_train, X_val = X[:split], X[split:]
|
| 513 |
-
y_train, y_val = y[:split], y[split:]
|
| 514 |
-
|
| 515 |
-
model.fit(
|
| 516 |
-
X_train, y_train,
|
| 517 |
-
eval_set=[(X_val, y_val)],
|
| 518 |
-
verbose=False,
|
| 519 |
-
)
|
| 520 |
-
|
| 521 |
-
# Evaluate on validation set
|
| 522 |
-
from sklearn.metrics import roc_auc_score, precision_score, recall_score
|
| 523 |
-
|
| 524 |
-
val_probs = model.predict_proba(X_val)[:, 1]
|
| 525 |
-
val_preds = (val_probs > 0.5).astype(int)
|
| 526 |
-
|
| 527 |
-
if len(set(y_val)) > 1:
|
| 528 |
-
auroc = roc_auc_score(y_val, val_probs)
|
| 529 |
-
precision = precision_score(y_val, val_preds, zero_division=0)
|
| 530 |
-
recall = recall_score(y_val, val_preds, zero_division=0)
|
| 531 |
-
else:
|
| 532 |
-
auroc, precision, recall = 0.5, 0.0, 0.0
|
| 533 |
-
|
| 534 |
-
print(f"\n--- XGBoost Results (real NASA POWER data) ---")
|
| 535 |
-
print(f" Training samples: {len(X_train)}")
|
| 536 |
-
print(f" Validation samples: {len(X_val)}")
|
| 537 |
-
print(f" Positive rate: {pos_rate:.1%}")
|
| 538 |
-
print(f" Val AUROC: {auroc:.4f}")
|
| 539 |
-
print(f" Val Precision: {precision:.4f}")
|
| 540 |
-
print(f" Val Recall: {recall:.4f}")
|
| 541 |
-
|
| 542 |
-
# Save model
|
| 543 |
-
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 544 |
-
model.save_model(str(xgb_model_path))
|
| 545 |
-
log.info("XGBoost model saved to %s", xgb_model_path)
|
| 546 |
-
|
| 547 |
-
return {
|
| 548 |
-
"train_samples": len(X_train),
|
| 549 |
-
"val_samples": len(X_val),
|
| 550 |
-
"positive_rate": round(pos_rate, 4),
|
| 551 |
-
"val_auroc": round(auroc, 4),
|
| 552 |
-
"val_precision": round(precision, 4),
|
| 553 |
-
"val_recall": round(recall, 4),
|
| 554 |
-
}
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
# ======================================================================
|
| 558 |
-
# Step 5: Verify UHI model (no retraining)
|
| 559 |
-
# ======================================================================
|
| 560 |
-
|
| 561 |
-
def verify_uhi(data: dict[str, list[dict]]):
|
| 562 |
-
log.info("=" * 60)
|
| 563 |
-
log.info("STEP 5: Verifying UHI model with real NASA POWER temperatures")
|
| 564 |
-
log.info(" (UHI model keeps literature-calibrated synthetic training)")
|
| 565 |
-
log.info("=" * 60)
|
| 566 |
-
|
| 567 |
-
from src.downscaling.uhi_model import UHICorrector
|
| 568 |
-
|
| 569 |
-
corrector = UHICorrector()
|
| 570 |
-
|
| 571 |
-
results = {}
|
| 572 |
-
for zone in ZONES:
|
| 573 |
-
zid = zone.zone_id
|
| 574 |
-
readings = data.get(zid, [])
|
| 575 |
-
if not readings:
|
| 576 |
-
continue
|
| 577 |
-
|
| 578 |
-
real_temps = [r["temp_max_c"] for r in readings if r["temp_max_c"] is not None]
|
| 579 |
-
if not real_temps:
|
| 580 |
-
continue
|
| 581 |
-
|
| 582 |
-
# Sample some real temps and apply UHI correction
|
| 583 |
-
sample_indices = np.linspace(0, len(real_temps) - 1, min(20, len(real_temps)), dtype=int)
|
| 584 |
-
deltas = []
|
| 585 |
-
corrected_temps = []
|
| 586 |
-
|
| 587 |
-
for idx in sample_indices:
|
| 588 |
-
grid_temp = real_temps[idx]
|
| 589 |
-
corrected, delta, conf = corrector.correct_temperature(zone, grid_temp, hour=14, month=1)
|
| 590 |
-
deltas.append(delta)
|
| 591 |
-
corrected_temps.append(corrected)
|
| 592 |
-
|
| 593 |
-
results[zid] = {
|
| 594 |
-
"city": zone.city,
|
| 595 |
-
"settlement": zone.settlement_type,
|
| 596 |
-
"mean_grid_temp": round(sum(real_temps) / len(real_temps), 1),
|
| 597 |
-
"mean_uhi_delta": round(sum(deltas) / len(deltas), 2),
|
| 598 |
-
"mean_corrected": round(sum(corrected_temps) / len(corrected_temps), 1),
|
| 599 |
-
}
|
| 600 |
-
|
| 601 |
-
print(f"\n--- UHI Verification (literature-calibrated model + real NASA POWER temps) ---")
|
| 602 |
-
print(f"{'Zone':<12} {'City':<16} {'Type':<12} {'Grid T':>7} {'UHI +':>7} {'Corrected':>10}")
|
| 603 |
-
print("-" * 70)
|
| 604 |
-
for zid, r in results.items():
|
| 605 |
-
print(
|
| 606 |
-
f"{zid:<12} {r['city']:<16} {r['settlement']:<12} "
|
| 607 |
-
f"{r['mean_grid_temp']:>6.1f}C {r['mean_uhi_delta']:>+6.2f}C {r['mean_corrected']:>9.1f}C"
|
| 608 |
-
)
|
| 609 |
-
|
| 610 |
-
return results
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
# ======================================================================
|
| 614 |
-
# Main
|
| 615 |
-
# ======================================================================
|
| 616 |
-
|
| 617 |
-
def main():
|
| 618 |
-
t_start = time.time()
|
| 619 |
-
|
| 620 |
-
# Step 1: Fetch
|
| 621 |
-
data = fetch_all_zones()
|
| 622 |
-
|
| 623 |
-
# Step 2: Validate
|
| 624 |
-
data_stats = validate_data(data)
|
| 625 |
-
|
| 626 |
-
# Step 3: LSTM
|
| 627 |
-
lstm_metrics = retrain_lstm(data)
|
| 628 |
-
|
| 629 |
-
# Step 4: XGBoost
|
| 630 |
-
xgb_metrics = retrain_xgboost(data)
|
| 631 |
-
|
| 632 |
-
# Step 5: UHI verification
|
| 633 |
-
uhi_results = verify_uhi(data)
|
| 634 |
-
|
| 635 |
-
total_time = time.time() - t_start
|
| 636 |
-
|
| 637 |
-
print("\n" + "=" * 60)
|
| 638 |
-
print("TRAINING COMPLETE (NASA POWER real data)")
|
| 639 |
-
print("=" * 60)
|
| 640 |
-
|
| 641 |
-
total_days = sum(
|
| 642 |
-
len([r for r in data.get(z.zone_id, []) if r["temp_max_c"] is not None])
|
| 643 |
-
for z in ZONES
|
| 644 |
-
)
|
| 645 |
-
print(f" Data source: NASA POWER daily")
|
| 646 |
-
print(f" Date range: {START_DATE} to {END_DATE}")
|
| 647 |
-
print(f" Total real data points: {total_days} zone-days across {len(ZONES)} zones")
|
| 648 |
-
print(f" Avg days per zone: {total_days / len(ZONES):.0f}")
|
| 649 |
-
print(f" LSTM val AUROC: {lstm_metrics.get('val_auroc', 'N/A')}")
|
| 650 |
-
print(f" LSTM epochs trained: {lstm_metrics.get('epochs_trained', 'N/A')}")
|
| 651 |
-
print(f" LSTM val loss: {lstm_metrics.get('val_loss', 'N/A')}")
|
| 652 |
-
print(f" XGBoost val AUROC: {xgb_metrics['val_auroc']:.4f}")
|
| 653 |
-
print(f" XGBoost val Precision: {xgb_metrics['val_precision']:.4f}")
|
| 654 |
-
print(f" XGBoost val Recall: {xgb_metrics['val_recall']:.4f}")
|
| 655 |
-
print(f" Total time: {total_time:.1f}s")
|
| 656 |
-
print()
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
if __name__ == "__main__":
|
| 660 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
File without changes
|
|
@@ -1,318 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Notification delivery module.
|
| 3 |
-
|
| 4 |
-
Sends trigger explanations to policyholders via console (demo),
|
| 5 |
-
SMS (Twilio), or WhatsApp (Twilio). All senders implement a common
|
| 6 |
-
async interface and return a DeliveryResult.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import asyncio
|
| 12 |
-
import logging
|
| 13 |
-
import os
|
| 14 |
-
import time
|
| 15 |
-
from abc import ABC, abstractmethod
|
| 16 |
-
from dataclasses import dataclass, field
|
| 17 |
-
from datetime import datetime, timezone
|
| 18 |
-
from typing import Optional, Sequence
|
| 19 |
-
|
| 20 |
-
log = logging.getLogger(__name__)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# ── Data containers ──────────────────────────────────────────────────────
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class DeliveryResult:
|
| 27 |
-
"""Outcome of a single notification delivery attempt."""
|
| 28 |
-
|
| 29 |
-
status: str # "sent", "failed", "dry_run"
|
| 30 |
-
channel: str # "console", "sms", "whatsapp"
|
| 31 |
-
recipient: str
|
| 32 |
-
message_preview: str # first 120 chars
|
| 33 |
-
timestamp: str = field(
|
| 34 |
-
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
| 35 |
-
)
|
| 36 |
-
cost_estimate: float = 0.0 # estimated cost in USD
|
| 37 |
-
error: str = ""
|
| 38 |
-
message_sid: str = "" # Twilio message SID if applicable
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# ── Base sender ──────────────────────────────────────────────────────────
|
| 42 |
-
|
| 43 |
-
class BaseSender(ABC):
|
| 44 |
-
"""Common interface for all notification channels."""
|
| 45 |
-
|
| 46 |
-
@abstractmethod
|
| 47 |
-
async def send(
|
| 48 |
-
self, recipient: str, message: str, channel: str = ""
|
| 49 |
-
) -> DeliveryResult:
|
| 50 |
-
"""Send a message to a single recipient."""
|
| 51 |
-
...
|
| 52 |
-
|
| 53 |
-
async def send_batch(
|
| 54 |
-
self,
|
| 55 |
-
recipients: Sequence[str],
|
| 56 |
-
message: str,
|
| 57 |
-
channel: str = "",
|
| 58 |
-
rate_limit: float = 1.0,
|
| 59 |
-
) -> list[DeliveryResult]:
|
| 60 |
-
"""
|
| 61 |
-
Send the same message to multiple recipients with rate limiting.
|
| 62 |
-
|
| 63 |
-
Args:
|
| 64 |
-
recipients: Phone numbers or identifiers.
|
| 65 |
-
message: The notification text.
|
| 66 |
-
channel: Override channel name.
|
| 67 |
-
rate_limit: Minimum seconds between sends (Twilio default: 1/sec).
|
| 68 |
-
"""
|
| 69 |
-
results: list[DeliveryResult] = []
|
| 70 |
-
for i, recipient in enumerate(recipients):
|
| 71 |
-
result = await self.send(recipient, message, channel)
|
| 72 |
-
results.append(result)
|
| 73 |
-
# Rate limiting between sends (skip after last)
|
| 74 |
-
if i < len(recipients) - 1 and rate_limit > 0:
|
| 75 |
-
await asyncio.sleep(rate_limit)
|
| 76 |
-
return results
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
# ── Console sender (demo mode) ───────────────────────────────────────────
|
| 80 |
-
|
| 81 |
-
class ConsoleSender(BaseSender):
|
| 82 |
-
"""Prints notifications to stdout. Default for demo and testing."""
|
| 83 |
-
|
| 84 |
-
async def send(
|
| 85 |
-
self, recipient: str, message: str, channel: str = "console"
|
| 86 |
-
) -> DeliveryResult:
|
| 87 |
-
preview = message[:120] + ("..." if len(message) > 120 else "")
|
| 88 |
-
log.info("[CONSOLE] To: %s | %s", recipient, preview)
|
| 89 |
-
print(f"\n{'='*60}")
|
| 90 |
-
print(f" NOTIFICATION — {channel or 'console'}")
|
| 91 |
-
print(f" To: {recipient}")
|
| 92 |
-
print(f"{'='*60}")
|
| 93 |
-
print(f" {message}")
|
| 94 |
-
print(f"{'='*60}\n")
|
| 95 |
-
return DeliveryResult(
|
| 96 |
-
status="dry_run",
|
| 97 |
-
channel="console",
|
| 98 |
-
recipient=recipient,
|
| 99 |
-
message_preview=preview,
|
| 100 |
-
cost_estimate=0.0,
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# ── Twilio SMS sender ───────────────────────────────────────────────────
|
| 105 |
-
|
| 106 |
-
class TwilioSender(BaseSender):
|
| 107 |
-
"""
|
| 108 |
-
Sends SMS via Twilio.
|
| 109 |
-
|
| 110 |
-
Requires environment variables:
|
| 111 |
-
TWILIO_ACCOUNT_SID
|
| 112 |
-
TWILIO_AUTH_TOKEN
|
| 113 |
-
TWILIO_FROM_NUMBER (E.164 format, e.g. +15551234567)
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
def __init__(
|
| 117 |
-
self,
|
| 118 |
-
account_sid: Optional[str] = None,
|
| 119 |
-
auth_token: Optional[str] = None,
|
| 120 |
-
from_number: Optional[str] = None,
|
| 121 |
-
):
|
| 122 |
-
self.account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID", "")
|
| 123 |
-
self.auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN", "")
|
| 124 |
-
self.from_number = from_number or os.environ.get("TWILIO_FROM_NUMBER", "")
|
| 125 |
-
self._client = None
|
| 126 |
-
|
| 127 |
-
def _get_client(self):
|
| 128 |
-
"""Lazy-init Twilio client."""
|
| 129 |
-
if self._client is None:
|
| 130 |
-
if not all([self.account_sid, self.auth_token]):
|
| 131 |
-
raise RuntimeError(
|
| 132 |
-
"Twilio credentials not configured. Set TWILIO_ACCOUNT_SID "
|
| 133 |
-
"and TWILIO_AUTH_TOKEN environment variables."
|
| 134 |
-
)
|
| 135 |
-
from twilio.rest import Client
|
| 136 |
-
self._client = Client(self.account_sid, self.auth_token)
|
| 137 |
-
return self._client
|
| 138 |
-
|
| 139 |
-
async def send(
|
| 140 |
-
self, recipient: str, message: str, channel: str = "sms"
|
| 141 |
-
) -> DeliveryResult:
|
| 142 |
-
preview = message[:120] + ("..." if len(message) > 120 else "")
|
| 143 |
-
|
| 144 |
-
# Truncate SMS to 1600 chars (Twilio limit for long SMS)
|
| 145 |
-
sms_body = message[:1600]
|
| 146 |
-
|
| 147 |
-
try:
|
| 148 |
-
client = self._get_client()
|
| 149 |
-
# Twilio client is synchronous — run in executor
|
| 150 |
-
loop = asyncio.get_event_loop()
|
| 151 |
-
twilio_msg = await loop.run_in_executor(
|
| 152 |
-
None,
|
| 153 |
-
lambda: client.messages.create(
|
| 154 |
-
body=sms_body,
|
| 155 |
-
from_=self.from_number,
|
| 156 |
-
to=recipient,
|
| 157 |
-
),
|
| 158 |
-
)
|
| 159 |
-
log.info(
|
| 160 |
-
"[SMS] Sent to %s | SID: %s | Status: %s",
|
| 161 |
-
recipient, twilio_msg.sid, twilio_msg.status,
|
| 162 |
-
)
|
| 163 |
-
return DeliveryResult(
|
| 164 |
-
status="sent",
|
| 165 |
-
channel="sms",
|
| 166 |
-
recipient=recipient,
|
| 167 |
-
message_preview=preview,
|
| 168 |
-
cost_estimate=_estimate_sms_cost(message),
|
| 169 |
-
message_sid=twilio_msg.sid,
|
| 170 |
-
)
|
| 171 |
-
except Exception as exc:
|
| 172 |
-
log.error("[SMS] Failed to send to %s: %s", recipient, exc)
|
| 173 |
-
return DeliveryResult(
|
| 174 |
-
status="failed",
|
| 175 |
-
channel="sms",
|
| 176 |
-
recipient=recipient,
|
| 177 |
-
message_preview=preview,
|
| 178 |
-
error=str(exc),
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
# ── WhatsApp sender ──────────────────────────────────────────────────────
|
| 183 |
-
|
| 184 |
-
class WhatsAppSender(BaseSender):
|
| 185 |
-
"""
|
| 186 |
-
Sends WhatsApp messages via Twilio WhatsApp Business API.
|
| 187 |
-
|
| 188 |
-
Uses the same Twilio credentials as SMS but prefixes the from number
|
| 189 |
-
with 'whatsapp:'.
|
| 190 |
-
"""
|
| 191 |
-
|
| 192 |
-
def __init__(
|
| 193 |
-
self,
|
| 194 |
-
account_sid: Optional[str] = None,
|
| 195 |
-
auth_token: Optional[str] = None,
|
| 196 |
-
from_number: Optional[str] = None,
|
| 197 |
-
):
|
| 198 |
-
self.account_sid = account_sid or os.environ.get("TWILIO_ACCOUNT_SID", "")
|
| 199 |
-
self.auth_token = auth_token or os.environ.get("TWILIO_AUTH_TOKEN", "")
|
| 200 |
-
self.from_number = from_number or os.environ.get("TWILIO_FROM_NUMBER", "")
|
| 201 |
-
self._client = None
|
| 202 |
-
|
| 203 |
-
def _get_client(self):
|
| 204 |
-
"""Lazy-init Twilio client."""
|
| 205 |
-
if self._client is None:
|
| 206 |
-
if not all([self.account_sid, self.auth_token]):
|
| 207 |
-
raise RuntimeError(
|
| 208 |
-
"Twilio credentials not configured. Set TWILIO_ACCOUNT_SID "
|
| 209 |
-
"and TWILIO_AUTH_TOKEN environment variables."
|
| 210 |
-
)
|
| 211 |
-
from twilio.rest import Client
|
| 212 |
-
self._client = Client(self.account_sid, self.auth_token)
|
| 213 |
-
return self._client
|
| 214 |
-
|
| 215 |
-
async def send(
|
| 216 |
-
self, recipient: str, message: str, channel: str = "whatsapp"
|
| 217 |
-
) -> DeliveryResult:
|
| 218 |
-
preview = message[:120] + ("..." if len(message) > 120 else "")
|
| 219 |
-
|
| 220 |
-
# WhatsApp supports longer messages (up to 4096 chars)
|
| 221 |
-
wa_body = message[:4096]
|
| 222 |
-
|
| 223 |
-
# Ensure whatsapp: prefix on both numbers
|
| 224 |
-
wa_from = (
|
| 225 |
-
f"whatsapp:{self.from_number}"
|
| 226 |
-
if not self.from_number.startswith("whatsapp:")
|
| 227 |
-
else self.from_number
|
| 228 |
-
)
|
| 229 |
-
wa_to = (
|
| 230 |
-
f"whatsapp:{recipient}"
|
| 231 |
-
if not recipient.startswith("whatsapp:")
|
| 232 |
-
else recipient
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
try:
|
| 236 |
-
client = self._get_client()
|
| 237 |
-
loop = asyncio.get_event_loop()
|
| 238 |
-
twilio_msg = await loop.run_in_executor(
|
| 239 |
-
None,
|
| 240 |
-
lambda: client.messages.create(
|
| 241 |
-
body=wa_body,
|
| 242 |
-
from_=wa_from,
|
| 243 |
-
to=wa_to,
|
| 244 |
-
),
|
| 245 |
-
)
|
| 246 |
-
log.info(
|
| 247 |
-
"[WhatsApp] Sent to %s | SID: %s | Status: %s",
|
| 248 |
-
recipient, twilio_msg.sid, twilio_msg.status,
|
| 249 |
-
)
|
| 250 |
-
return DeliveryResult(
|
| 251 |
-
status="sent",
|
| 252 |
-
channel="whatsapp",
|
| 253 |
-
recipient=recipient,
|
| 254 |
-
message_preview=preview,
|
| 255 |
-
cost_estimate=_estimate_whatsapp_cost(),
|
| 256 |
-
message_sid=twilio_msg.sid,
|
| 257 |
-
)
|
| 258 |
-
except Exception as exc:
|
| 259 |
-
log.error("[WhatsApp] Failed to send to %s: %s", recipient, exc)
|
| 260 |
-
return DeliveryResult(
|
| 261 |
-
status="failed",
|
| 262 |
-
channel="whatsapp",
|
| 263 |
-
recipient=recipient,
|
| 264 |
-
message_preview=preview,
|
| 265 |
-
error=str(exc),
|
| 266 |
-
)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
# ── Cost estimation ──────────────────────────────────────────────────────
|
| 270 |
-
|
| 271 |
-
def _estimate_sms_cost(message: str) -> float:
|
| 272 |
-
"""Estimate SMS cost in USD. Twilio Kenya rate ~ $0.0475/segment."""
|
| 273 |
-
# SMS segments: 160 chars for GSM-7, 70 for UCS-2 (Unicode)
|
| 274 |
-
has_unicode = any(ord(c) > 127 for c in message)
|
| 275 |
-
segment_size = 70 if has_unicode else 160
|
| 276 |
-
segments = max(1, (len(message) + segment_size - 1) // segment_size)
|
| 277 |
-
return round(segments * 0.0475, 4)
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
def _estimate_whatsapp_cost() -> float:
|
| 281 |
-
"""Estimate WhatsApp cost in USD. Twilio WhatsApp ~ $0.005/msg + template fees."""
|
| 282 |
-
return 0.005
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
# ── Sender factory ───────────────────────────────────────────────────────
|
| 286 |
-
|
| 287 |
-
def create_sender(channel: str = "console") -> BaseSender:
|
| 288 |
-
"""
|
| 289 |
-
Factory to create the appropriate sender.
|
| 290 |
-
|
| 291 |
-
Args:
|
| 292 |
-
channel: One of "console", "sms", "whatsapp".
|
| 293 |
-
"""
|
| 294 |
-
if channel == "sms":
|
| 295 |
-
return TwilioSender()
|
| 296 |
-
elif channel == "whatsapp":
|
| 297 |
-
return WhatsAppSender()
|
| 298 |
-
else:
|
| 299 |
-
return ConsoleSender()
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
async def send_zone_notifications(
|
| 303 |
-
recipients: Sequence[str],
|
| 304 |
-
message: str,
|
| 305 |
-
channel: str = "console",
|
| 306 |
-
rate_limit: float = 1.0,
|
| 307 |
-
) -> list[DeliveryResult]:
|
| 308 |
-
"""
|
| 309 |
-
Convenience function: send the same notification to all recipients in a zone.
|
| 310 |
-
|
| 311 |
-
Args:
|
| 312 |
-
recipients: List of phone numbers.
|
| 313 |
-
message: Notification text.
|
| 314 |
-
channel: "console", "sms", or "whatsapp".
|
| 315 |
-
rate_limit: Seconds between sends for Twilio rate limiting.
|
| 316 |
-
"""
|
| 317 |
-
sender = create_sender(channel)
|
| 318 |
-
return await sender.send_batch(recipients, message, channel, rate_limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,557 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Heat wave prediction model for parametric insurance triggers.
|
| 3 |
-
|
| 4 |
-
XGBoost classifier that predicts the probability of a heat trigger
|
| 5 |
-
event (2+ consecutive days above city-adjusted threshold) occurring
|
| 6 |
-
within the next 7 days, given recent climate features.
|
| 7 |
-
|
| 8 |
-
Degrades gracefully:
|
| 9 |
-
full_model -> persistence -> climatology
|
| 10 |
-
|
| 11 |
-
References:
|
| 12 |
-
- Perkins-Kirkpatrick & Lewis (2020) heat wave definitions
|
| 13 |
-
- WHO/ILO occupational heat stress thresholds
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
from collections import deque
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
import xgboost as xgb
|
| 25 |
-
except ImportError:
|
| 26 |
-
xgb = None
|
| 27 |
-
|
| 28 |
-
# Import shared constants from lstm_model (defined there to avoid circular imports)
|
| 29 |
-
from src.prediction.lstm_model import CITY_THRESHOLDS, CITY_CLIMATE
|
| 30 |
-
|
| 31 |
-
try:
|
| 32 |
-
from src.prediction.lstm_model import LSTMPredictor
|
| 33 |
-
_LSTM_AVAILABLE = True
|
| 34 |
-
except Exception:
|
| 35 |
-
_LSTM_AVAILABLE = False
|
| 36 |
-
|
| 37 |
-
FEATURE_NAMES = [
|
| 38 |
-
"current_temp",
|
| 39 |
-
"current_wbgt",
|
| 40 |
-
"current_humidity",
|
| 41 |
-
"temp_trend_7d",
|
| 42 |
-
"temp_anomaly_30d",
|
| 43 |
-
"soil_moisture_proxy",
|
| 44 |
-
"rolling_error",
|
| 45 |
-
"doy_sin",
|
| 46 |
-
"doy_cos",
|
| 47 |
-
"hour_sin",
|
| 48 |
-
"hour_cos",
|
| 49 |
-
"zone_vulnerability",
|
| 50 |
-
]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _resolve_model_path(model_path: str) -> Path:
|
| 54 |
-
p = Path(model_path)
|
| 55 |
-
if not p.is_absolute():
|
| 56 |
-
p = Path(__file__).resolve().parents[2] / model_path
|
| 57 |
-
return p
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
from src.indexing.heat_index import calculate_wbgt as _simple_wbgt
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class HeatWavePredictor:
|
| 64 |
-
"""XGBoost model: recent climate features -> trigger probability in 7 days."""
|
| 65 |
-
|
| 66 |
-
FEATURE_NAMES = FEATURE_NAMES
|
| 67 |
-
|
| 68 |
-
def __init__(self, model_path: str = "models/heat_predictor_xgb.json"):
|
| 69 |
-
if xgb is None:
|
| 70 |
-
raise ImportError(
|
| 71 |
-
"xgboost is required. Install with: pip install 'xgboost>=2.0.0'"
|
| 72 |
-
)
|
| 73 |
-
self.model_path = _resolve_model_path(model_path)
|
| 74 |
-
self.model: xgb.XGBClassifier | None = None
|
| 75 |
-
self._rolling_errors: deque = deque(maxlen=3)
|
| 76 |
-
self._load_or_train()
|
| 77 |
-
|
| 78 |
-
# Try loading LSTM for ensemble; auto-train on synthetic data if missing
|
| 79 |
-
self._lstm: object | None = None
|
| 80 |
-
if _LSTM_AVAILABLE:
|
| 81 |
-
try:
|
| 82 |
-
self._lstm = LSTMPredictor()
|
| 83 |
-
except FileNotFoundError:
|
| 84 |
-
self._train_lstm_synthetic()
|
| 85 |
-
except Exception:
|
| 86 |
-
self._lstm = None
|
| 87 |
-
|
| 88 |
-
# ------------------------------------------------------------------
|
| 89 |
-
# Public API
|
| 90 |
-
# ------------------------------------------------------------------
|
| 91 |
-
|
| 92 |
-
def predict(
|
| 93 |
-
self,
|
| 94 |
-
zone,
|
| 95 |
-
recent_temps: list[float],
|
| 96 |
-
recent_humidity: list[float],
|
| 97 |
-
recent_wbgt: list[float],
|
| 98 |
-
hour: int = 12,
|
| 99 |
-
) -> tuple[float, float, str]:
|
| 100 |
-
"""Predict probability of heat trigger within 7 days.
|
| 101 |
-
|
| 102 |
-
Args:
|
| 103 |
-
zone: UrbanZone instance.
|
| 104 |
-
recent_temps: Last 30 daily max temperatures (most recent last).
|
| 105 |
-
recent_humidity: Last 30 daily humidity values.
|
| 106 |
-
recent_wbgt: Last 30 daily WBGT values.
|
| 107 |
-
hour: Current hour for diurnal encoding.
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
(probability, confidence, model_tier)
|
| 111 |
-
model_tier is one of: "ensemble", "full_model", "lstm_only",
|
| 112 |
-
"persistence", "climatology"
|
| 113 |
-
"""
|
| 114 |
-
xgb_prob, xgb_conf, xgb_ok = None, None, False
|
| 115 |
-
lstm_prob, lstm_conf, lstm_ok = None, None, False
|
| 116 |
-
|
| 117 |
-
# -- XGBoost prediction --
|
| 118 |
-
try:
|
| 119 |
-
features = self._build_features(
|
| 120 |
-
zone, recent_temps, recent_humidity, recent_wbgt, hour
|
| 121 |
-
)
|
| 122 |
-
xgb_prob = float(self.model.predict_proba(features)[0, 1])
|
| 123 |
-
xgb_conf = self._estimate_confidence(recent_temps, "full_model")
|
| 124 |
-
xgb_ok = True
|
| 125 |
-
except Exception:
|
| 126 |
-
pass
|
| 127 |
-
|
| 128 |
-
# -- LSTM prediction --
|
| 129 |
-
if self._lstm is not None:
|
| 130 |
-
try:
|
| 131 |
-
lstm_days = self._build_lstm_days(
|
| 132 |
-
recent_temps, recent_humidity, recent_wbgt
|
| 133 |
-
)
|
| 134 |
-
lstm_prob, lstm_conf = self._lstm.predict(lstm_days)
|
| 135 |
-
lstm_ok = True
|
| 136 |
-
except Exception:
|
| 137 |
-
pass
|
| 138 |
-
|
| 139 |
-
# -- Ensemble --
|
| 140 |
-
if xgb_ok and lstm_ok:
|
| 141 |
-
prob = 0.5 * xgb_prob + 0.5 * lstm_prob
|
| 142 |
-
confidence = (xgb_conf + lstm_conf) / 2.0
|
| 143 |
-
return round(prob, 4), round(confidence, 3), "ensemble"
|
| 144 |
-
|
| 145 |
-
if xgb_ok:
|
| 146 |
-
return round(xgb_prob, 4), round(xgb_conf, 3), "full_model"
|
| 147 |
-
|
| 148 |
-
if lstm_ok:
|
| 149 |
-
return round(lstm_prob, 4), round(lstm_conf, 3), "lstm_only"
|
| 150 |
-
|
| 151 |
-
# Persistence fallback: if recent conditions are above threshold,
|
| 152 |
-
# assume they continue
|
| 153 |
-
try:
|
| 154 |
-
threshold = CITY_THRESHOLDS.get(zone.city, 33.0)
|
| 155 |
-
if len(recent_temps) >= 2:
|
| 156 |
-
above = sum(1 for t in recent_temps[-3:] if t > threshold)
|
| 157 |
-
prob = min(0.95, above / 3.0)
|
| 158 |
-
else:
|
| 159 |
-
prob = 0.5
|
| 160 |
-
confidence = self._estimate_confidence(recent_temps, "persistence")
|
| 161 |
-
return round(prob, 4), round(confidence, 3), "persistence"
|
| 162 |
-
except Exception:
|
| 163 |
-
pass
|
| 164 |
-
|
| 165 |
-
# Climatology fallback: use seasonal base rate
|
| 166 |
-
from config import HOT_SEASONS
|
| 167 |
-
|
| 168 |
-
import datetime
|
| 169 |
-
|
| 170 |
-
doy = datetime.datetime.now().timetuple().tm_yday
|
| 171 |
-
month = datetime.datetime.now().month
|
| 172 |
-
city = getattr(zone, "city", "Nairobi")
|
| 173 |
-
hot_months = []
|
| 174 |
-
for season_months in HOT_SEASONS.get(city, {}).values():
|
| 175 |
-
hot_months.extend(season_months)
|
| 176 |
-
prob = 0.35 if month in hot_months else 0.10
|
| 177 |
-
confidence = 0.30
|
| 178 |
-
return round(prob, 4), round(confidence, 3), "climatology"
|
| 179 |
-
|
| 180 |
-
@staticmethod
|
| 181 |
-
def _build_lstm_days(
|
| 182 |
-
recent_temps: list[float],
|
| 183 |
-
recent_humidity: list[float],
|
| 184 |
-
recent_wbgt: list[float],
|
| 185 |
-
) -> list[dict]:
|
| 186 |
-
"""Convert raw arrays into the list-of-dicts format the LSTM expects.
|
| 187 |
-
|
| 188 |
-
The LSTM predictor computes WBGT, heat index, and temp anomaly
|
| 189 |
-
internally, so we only need to pass the raw observations.
|
| 190 |
-
"""
|
| 191 |
-
n = min(len(recent_temps), len(recent_humidity), len(recent_wbgt))
|
| 192 |
-
days = []
|
| 193 |
-
for i in range(n):
|
| 194 |
-
days.append({
|
| 195 |
-
"temp_max_c": recent_temps[i],
|
| 196 |
-
"humidity_pct": recent_humidity[i],
|
| 197 |
-
"wind_speed_ms": 3.0,
|
| 198 |
-
})
|
| 199 |
-
return days
|
| 200 |
-
|
| 201 |
-
def update_rolling_error(self, predicted_prob: float, actual: bool) -> None:
|
| 202 |
-
"""Track prediction accuracy for the rolling_error feature."""
|
| 203 |
-
error = abs(predicted_prob - (1.0 if actual else 0.0))
|
| 204 |
-
self._rolling_errors.append(error)
|
| 205 |
-
|
| 206 |
-
# ------------------------------------------------------------------
|
| 207 |
-
# Feature engineering
|
| 208 |
-
# ------------------------------------------------------------------
|
| 209 |
-
|
| 210 |
-
def _build_features(
|
| 211 |
-
self,
|
| 212 |
-
zone,
|
| 213 |
-
recent_temps: list[float],
|
| 214 |
-
recent_humidity: list[float],
|
| 215 |
-
recent_wbgt: list[float],
|
| 216 |
-
hour: int = 12,
|
| 217 |
-
) -> np.ndarray:
|
| 218 |
-
"""Build the 12-feature vector.
|
| 219 |
-
|
| 220 |
-
Features:
|
| 221 |
-
0: current_temp — most recent daily max temp
|
| 222 |
-
1: current_wbgt — most recent WBGT
|
| 223 |
-
2: current_humidity — most recent humidity
|
| 224 |
-
3: temp_trend_7d — linear slope over last 7 days
|
| 225 |
-
4: temp_anomaly_30d — current temp minus 30-day mean
|
| 226 |
-
5: soil_moisture_proxy — inverse of recent rainfall proxy
|
| 227 |
-
(approximated as negative temp anomaly clamped to [0,1])
|
| 228 |
-
6: rolling_error — mean of last 3 prediction errors
|
| 229 |
-
7-8: doy_sin, doy_cos — seasonal encoding
|
| 230 |
-
9-10: hour_sin, hour_cos — diurnal encoding
|
| 231 |
-
11: zone_vulnerability — numeric heat vulnerability
|
| 232 |
-
"""
|
| 233 |
-
import datetime
|
| 234 |
-
|
| 235 |
-
temps = list(recent_temps)
|
| 236 |
-
humid = list(recent_humidity)
|
| 237 |
-
wbgts = list(recent_wbgt)
|
| 238 |
-
|
| 239 |
-
current_temp = temps[-1] if temps else 30.0
|
| 240 |
-
current_wbgt = wbgts[-1] if wbgts else 28.0
|
| 241 |
-
current_humidity = humid[-1] if humid else 65.0
|
| 242 |
-
|
| 243 |
-
# Trend: slope of last 7 days
|
| 244 |
-
if len(temps) >= 7:
|
| 245 |
-
x = np.arange(7, dtype=np.float64)
|
| 246 |
-
y = np.array(temps[-7:], dtype=np.float64)
|
| 247 |
-
temp_trend = float(np.polyfit(x, y, 1)[0])
|
| 248 |
-
else:
|
| 249 |
-
temp_trend = 0.0
|
| 250 |
-
|
| 251 |
-
# Anomaly: current vs 30-day mean
|
| 252 |
-
if len(temps) >= 2:
|
| 253 |
-
temp_anomaly = current_temp - np.mean(temps)
|
| 254 |
-
else:
|
| 255 |
-
temp_anomaly = 0.0
|
| 256 |
-
|
| 257 |
-
# Soil moisture proxy: when temps are well below average,
|
| 258 |
-
# likely recent rain -> higher moisture. Clamp to [0, 1].
|
| 259 |
-
soil_proxy = float(np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0))
|
| 260 |
-
|
| 261 |
-
# Rolling prediction error
|
| 262 |
-
if self._rolling_errors:
|
| 263 |
-
rolling_err = float(np.mean(self._rolling_errors))
|
| 264 |
-
else:
|
| 265 |
-
rolling_err = 0.3 # neutral prior
|
| 266 |
-
|
| 267 |
-
# Day of year encoding
|
| 268 |
-
doy = datetime.datetime.now().timetuple().tm_yday
|
| 269 |
-
doy_sin = np.sin(2 * np.pi * doy / 365.0)
|
| 270 |
-
doy_cos = np.cos(2 * np.pi * doy / 365.0)
|
| 271 |
-
|
| 272 |
-
# Hour encoding
|
| 273 |
-
hour_sin = np.sin(2 * np.pi * hour / 24.0)
|
| 274 |
-
hour_cos = np.cos(2 * np.pi * hour / 24.0)
|
| 275 |
-
|
| 276 |
-
# Vulnerability
|
| 277 |
-
vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 278 |
-
zone_vuln = vuln_map.get(
|
| 279 |
-
getattr(zone, "heat_vulnerability", "moderate"), 0.5
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
features = np.array(
|
| 283 |
-
[
|
| 284 |
-
current_temp,
|
| 285 |
-
current_wbgt,
|
| 286 |
-
current_humidity,
|
| 287 |
-
temp_trend,
|
| 288 |
-
temp_anomaly,
|
| 289 |
-
soil_proxy,
|
| 290 |
-
rolling_err,
|
| 291 |
-
doy_sin,
|
| 292 |
-
doy_cos,
|
| 293 |
-
hour_sin,
|
| 294 |
-
hour_cos,
|
| 295 |
-
zone_vuln,
|
| 296 |
-
],
|
| 297 |
-
dtype=np.float32,
|
| 298 |
-
).reshape(1, -1)
|
| 299 |
-
|
| 300 |
-
return features
|
| 301 |
-
|
| 302 |
-
# ------------------------------------------------------------------
|
| 303 |
-
# Confidence estimation
|
| 304 |
-
# ------------------------------------------------------------------
|
| 305 |
-
|
| 306 |
-
@staticmethod
|
| 307 |
-
def _estimate_confidence(recent_temps: list[float], tier: str) -> float:
|
| 308 |
-
"""Heuristic confidence based on data quality and model tier."""
|
| 309 |
-
base = {"full_model": 0.80, "persistence": 0.45, "climatology": 0.30}
|
| 310 |
-
conf = base.get(tier, 0.30)
|
| 311 |
-
|
| 312 |
-
# More data -> higher confidence
|
| 313 |
-
n = len(recent_temps)
|
| 314 |
-
if n >= 30:
|
| 315 |
-
conf += 0.10
|
| 316 |
-
elif n >= 14:
|
| 317 |
-
conf += 0.05
|
| 318 |
-
|
| 319 |
-
# Low variance in recent data -> more predictable
|
| 320 |
-
if n >= 7:
|
| 321 |
-
std = float(np.std(recent_temps[-7:]))
|
| 322 |
-
if std < 2.0:
|
| 323 |
-
conf += 0.05
|
| 324 |
-
|
| 325 |
-
return min(conf, 0.95)
|
| 326 |
-
|
| 327 |
-
# ------------------------------------------------------------------
|
| 328 |
-
# Training
|
| 329 |
-
# ------------------------------------------------------------------
|
| 330 |
-
|
| 331 |
-
def train(self, seed: int = 42) -> None:
|
| 332 |
-
"""Generate 2 years of synthetic daily data per zone and train.
|
| 333 |
-
|
| 334 |
-
For each zone, generates 730 days of realistic temperature,
|
| 335 |
-
humidity, and WBGT curves with autocorrelation. Labels each
|
| 336 |
-
day with whether a trigger event (2+ consecutive days above
|
| 337 |
-
threshold) occurs within the next 7 days.
|
| 338 |
-
"""
|
| 339 |
-
rng = np.random.default_rng(seed)
|
| 340 |
-
from config import ZONES
|
| 341 |
-
|
| 342 |
-
n_days = 730
|
| 343 |
-
all_X = []
|
| 344 |
-
all_y = []
|
| 345 |
-
|
| 346 |
-
for zone in ZONES:
|
| 347 |
-
city = zone.city
|
| 348 |
-
climate = CITY_CLIMATE.get(city, CITY_CLIMATE["Nairobi"])
|
| 349 |
-
|
| 350 |
-
# Generate daily temperatures with autocorrelation
|
| 351 |
-
temps = self._generate_temp_series(climate, n_days, rng)
|
| 352 |
-
humidity = self._generate_humidity_series(climate, n_days, rng)
|
| 353 |
-
|
| 354 |
-
# Compute WBGT series
|
| 355 |
-
wbgt_series = [
|
| 356 |
-
_simple_wbgt(t, h) for t, h in zip(temps, humidity)
|
| 357 |
-
]
|
| 358 |
-
|
| 359 |
-
# Label: trigger within next 7 days?
|
| 360 |
-
threshold = CITY_THRESHOLDS.get(city, 33.0)
|
| 361 |
-
labels = self._label_triggers(temps, threshold, n_days)
|
| 362 |
-
|
| 363 |
-
# Build features for each day (need 30-day lookback)
|
| 364 |
-
vuln_map = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 365 |
-
zone_vuln = vuln_map.get(zone.heat_vulnerability, 0.5)
|
| 366 |
-
|
| 367 |
-
for day in range(30, n_days - 7):
|
| 368 |
-
t_window = temps[day - 30 : day + 1]
|
| 369 |
-
h_window = humidity[day - 30 : day + 1]
|
| 370 |
-
w_window = wbgt_series[day - 30 : day + 1]
|
| 371 |
-
|
| 372 |
-
current_temp = t_window[-1]
|
| 373 |
-
current_wbgt = w_window[-1]
|
| 374 |
-
current_humidity = h_window[-1]
|
| 375 |
-
|
| 376 |
-
# Trend
|
| 377 |
-
x7 = np.arange(7, dtype=np.float64)
|
| 378 |
-
y7 = np.array(t_window[-7:], dtype=np.float64)
|
| 379 |
-
temp_trend = float(np.polyfit(x7, y7, 1)[0])
|
| 380 |
-
|
| 381 |
-
# Anomaly
|
| 382 |
-
temp_anomaly = current_temp - float(np.mean(t_window))
|
| 383 |
-
|
| 384 |
-
# Soil moisture proxy
|
| 385 |
-
soil_proxy = float(
|
| 386 |
-
np.clip(1.0 - (temp_anomaly + 2.0) / 4.0, 0.0, 1.0)
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
# Synthetic rolling error
|
| 390 |
-
rolling_err = rng.uniform(0.1, 0.5)
|
| 391 |
-
|
| 392 |
-
# Day-of-year encoding (day within 365-day cycle)
|
| 393 |
-
doy = day % 365
|
| 394 |
-
doy_sin = np.sin(2 * np.pi * doy / 365.0)
|
| 395 |
-
doy_cos = np.cos(2 * np.pi * doy / 365.0)
|
| 396 |
-
|
| 397 |
-
# Random hour for variety
|
| 398 |
-
hour = rng.integers(6, 19)
|
| 399 |
-
hour_sin = np.sin(2 * np.pi * hour / 24.0)
|
| 400 |
-
hour_cos = np.cos(2 * np.pi * hour / 24.0)
|
| 401 |
-
|
| 402 |
-
row = [
|
| 403 |
-
current_temp,
|
| 404 |
-
current_wbgt,
|
| 405 |
-
current_humidity,
|
| 406 |
-
temp_trend,
|
| 407 |
-
temp_anomaly,
|
| 408 |
-
soil_proxy,
|
| 409 |
-
rolling_err,
|
| 410 |
-
doy_sin,
|
| 411 |
-
doy_cos,
|
| 412 |
-
hour_sin,
|
| 413 |
-
hour_cos,
|
| 414 |
-
zone_vuln,
|
| 415 |
-
]
|
| 416 |
-
|
| 417 |
-
all_X.append(row)
|
| 418 |
-
all_y.append(labels[day])
|
| 419 |
-
|
| 420 |
-
X = np.array(all_X, dtype=np.float32)
|
| 421 |
-
y = np.array(all_y, dtype=np.int32)
|
| 422 |
-
|
| 423 |
-
self.model = xgb.XGBClassifier(
|
| 424 |
-
n_estimators=150,
|
| 425 |
-
max_depth=5,
|
| 426 |
-
learning_rate=0.1,
|
| 427 |
-
eval_metric="logloss",
|
| 428 |
-
random_state=seed,
|
| 429 |
-
)
|
| 430 |
-
self.model.fit(X, y)
|
| 431 |
-
self._save_model()
|
| 432 |
-
|
| 433 |
-
# ------------------------------------------------------------------
|
| 434 |
-
# Synthetic data generation helpers
|
| 435 |
-
# ------------------------------------------------------------------
|
| 436 |
-
|
| 437 |
-
@staticmethod
|
| 438 |
-
def _generate_temp_series(
|
| 439 |
-
climate: dict, n_days: int, rng
|
| 440 |
-
) -> list[float]:
|
| 441 |
-
"""Generate realistic daily max temperatures with autocorrelation.
|
| 442 |
-
|
| 443 |
-
Uses seasonal cosine curve + AR(1) autocorrelated noise.
|
| 444 |
-
"""
|
| 445 |
-
mean = climate["temp_mean"]
|
| 446 |
-
amp = climate["temp_amp"]
|
| 447 |
-
phase = climate["phase_doy"]
|
| 448 |
-
lo, hi = climate["temp_range"]
|
| 449 |
-
|
| 450 |
-
temps = []
|
| 451 |
-
noise = 0.0
|
| 452 |
-
ar_coef = 0.7 # autocorrelation coefficient
|
| 453 |
-
|
| 454 |
-
for d in range(n_days):
|
| 455 |
-
# Seasonal component
|
| 456 |
-
seasonal = mean + amp * np.cos(
|
| 457 |
-
2 * np.pi * (d - phase) / 365.0
|
| 458 |
-
)
|
| 459 |
-
# AR(1) noise
|
| 460 |
-
noise = ar_coef * noise + rng.normal(0, 1.2)
|
| 461 |
-
temp = seasonal + noise
|
| 462 |
-
temp = float(np.clip(temp, lo - 1.0, hi + 2.0))
|
| 463 |
-
temps.append(temp)
|
| 464 |
-
|
| 465 |
-
return temps
|
| 466 |
-
|
| 467 |
-
@staticmethod
|
| 468 |
-
def _generate_humidity_series(
|
| 469 |
-
climate: dict, n_days: int, rng
|
| 470 |
-
) -> list[float]:
|
| 471 |
-
"""Generate daily humidity with seasonal cycle and noise."""
|
| 472 |
-
mean = climate["humidity_mean"]
|
| 473 |
-
amp = climate["humidity_amp"]
|
| 474 |
-
phase = climate.get("phase_doy", 45)
|
| 475 |
-
|
| 476 |
-
humidity = []
|
| 477 |
-
noise = 0.0
|
| 478 |
-
|
| 479 |
-
for d in range(n_days):
|
| 480 |
-
# Humidity anti-correlated with temp in dry season
|
| 481 |
-
seasonal = mean - amp * np.cos(
|
| 482 |
-
2 * np.pi * (d - phase) / 365.0
|
| 483 |
-
)
|
| 484 |
-
noise = 0.5 * noise + rng.normal(0, 4.0)
|
| 485 |
-
h = seasonal + noise
|
| 486 |
-
h = float(np.clip(h, 30.0, 98.0))
|
| 487 |
-
humidity.append(h)
|
| 488 |
-
|
| 489 |
-
return humidity
|
| 490 |
-
|
| 491 |
-
@staticmethod
|
| 492 |
-
def _label_triggers(
|
| 493 |
-
temps: list[float], threshold: float, n_days: int
|
| 494 |
-
) -> list[int]:
|
| 495 |
-
"""Label each day: 1 if a 2+ consecutive day trigger occurs in next 7 days."""
|
| 496 |
-
labels = [0] * n_days
|
| 497 |
-
|
| 498 |
-
for day in range(n_days - 7):
|
| 499 |
-
window = temps[day + 1 : day + 8]
|
| 500 |
-
# Check for 2+ consecutive above threshold
|
| 501 |
-
consec = 0
|
| 502 |
-
triggered = False
|
| 503 |
-
for t in window:
|
| 504 |
-
if t > threshold:
|
| 505 |
-
consec += 1
|
| 506 |
-
if consec >= 2:
|
| 507 |
-
triggered = True
|
| 508 |
-
break
|
| 509 |
-
else:
|
| 510 |
-
consec = 0
|
| 511 |
-
labels[day] = 1 if triggered else 0
|
| 512 |
-
|
| 513 |
-
return labels
|
| 514 |
-
|
| 515 |
-
# ------------------------------------------------------------------
|
| 516 |
-
# Persistence
|
| 517 |
-
# ------------------------------------------------------------------
|
| 518 |
-
|
| 519 |
-
def _save_model(self) -> None:
|
| 520 |
-
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 521 |
-
self.model.save_model(str(self.model_path))
|
| 522 |
-
|
| 523 |
-
def _load_or_train(self) -> None:
|
| 524 |
-
if self.model_path.exists():
|
| 525 |
-
self.model = xgb.XGBClassifier()
|
| 526 |
-
self.model.load_model(str(self.model_path))
|
| 527 |
-
else:
|
| 528 |
-
self.train()
|
| 529 |
-
|
| 530 |
-
def _train_lstm_synthetic(self) -> None:
|
| 531 |
-
"""Auto-train LSTM on synthetic data when no model file exists."""
|
| 532 |
-
try:
|
| 533 |
-
from src.prediction.lstm_model import (
|
| 534 |
-
LSTMTrainer,
|
| 535 |
-
generate_synthetic_zone_data,
|
| 536 |
-
)
|
| 537 |
-
from config import ZONES
|
| 538 |
-
|
| 539 |
-
import logging
|
| 540 |
-
|
| 541 |
-
logger = logging.getLogger(__name__)
|
| 542 |
-
logger.info("LSTM model not found -- training on synthetic data")
|
| 543 |
-
|
| 544 |
-
zone_data = generate_synthetic_zone_data(ZONES, n_days=730, seed=42)
|
| 545 |
-
trainer = LSTMTrainer(epochs=50, patience=5)
|
| 546 |
-
trainer.train(zone_data)
|
| 547 |
-
|
| 548 |
-
# Reload the predictor now that the model file exists
|
| 549 |
-
self._lstm = LSTMPredictor()
|
| 550 |
-
logger.info("LSTM auto-trained and loaded successfully")
|
| 551 |
-
except Exception as exc:
|
| 552 |
-
import logging
|
| 553 |
-
|
| 554 |
-
logging.getLogger(__name__).warning(
|
| 555 |
-
"LSTM auto-training failed: %s", exc
|
| 556 |
-
)
|
| 557 |
-
self._lstm = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,566 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
LSTM neural heat wave predictor for parametric insurance triggers.
|
| 3 |
-
|
| 4 |
-
2-layer LSTM that learns temporal patterns in 14-day climate sequences
|
| 5 |
-
to predict heat wave trigger probability in the next 7 days.
|
| 6 |
-
Ensembled with the existing XGBoost classifier in heat_forecast.py.
|
| 7 |
-
|
| 8 |
-
Architecture:
|
| 9 |
-
Input: (batch, 14, 6) -- 14 days x 6 climate features
|
| 10 |
-
LSTM: 2 layers, hidden_size=64, dropout=0.2
|
| 11 |
-
Output: scalar sigmoid probability
|
| 12 |
-
|
| 13 |
-
Features per timestep:
|
| 14 |
-
0: temp_max_c -- daily max temperature (normalized)
|
| 15 |
-
1: humidity_pct -- relative humidity (normalized)
|
| 16 |
-
2: wind_speed_ms -- wind speed (normalized)
|
| 17 |
-
3: wbgt_c -- wet-bulb globe temperature (normalized)
|
| 18 |
-
4: heat_index_c -- apparent temperature (normalized)
|
| 19 |
-
5: temp_anomaly -- temp minus 7-day rolling mean (normalized)
|
| 20 |
-
|
| 21 |
-
References:
|
| 22 |
-
- Perkins-Kirkpatrick & Lewis (2020) heat wave definitions
|
| 23 |
-
- WHO/ILO occupational heat stress thresholds
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
from __future__ import annotations
|
| 27 |
-
|
| 28 |
-
import json
|
| 29 |
-
from pathlib import Path
|
| 30 |
-
|
| 31 |
-
import numpy as np
|
| 32 |
-
|
| 33 |
-
try:
|
| 34 |
-
import torch
|
| 35 |
-
import torch.nn as nn
|
| 36 |
-
from torch.utils.data import DataLoader, TensorDataset
|
| 37 |
-
|
| 38 |
-
TORCH_AVAILABLE = True
|
| 39 |
-
except ImportError:
|
| 40 |
-
TORCH_AVAILABLE = False
|
| 41 |
-
|
| 42 |
-
from src.indexing.heat_index import calculate_wbgt, calculate_heat_index
|
| 43 |
-
|
| 44 |
-
# City-specific temperature thresholds for trigger definition (deg C)
|
| 45 |
-
CITY_THRESHOLDS = {
|
| 46 |
-
"Dar es Salaam": 34.0,
|
| 47 |
-
"Kampala": 31.0,
|
| 48 |
-
"Nairobi": 28.0,
|
| 49 |
-
"Kigali": 29.0,
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
# Seasonal temperature / humidity profiles per city
|
| 53 |
-
CITY_CLIMATE = {
|
| 54 |
-
"Dar es Salaam": {
|
| 55 |
-
"temp_mean": 31.0, "temp_amp": 3.5, "phase_doy": 45,
|
| 56 |
-
"humidity_mean": 82.0, "humidity_amp": 7.0,
|
| 57 |
-
"temp_range": (28.0, 36.0),
|
| 58 |
-
},
|
| 59 |
-
"Kampala": {
|
| 60 |
-
"temp_mean": 28.0, "temp_amp": 3.0, "phase_doy": 45,
|
| 61 |
-
"humidity_mean": 70.0, "humidity_amp": 10.0,
|
| 62 |
-
"temp_range": (24.0, 32.0),
|
| 63 |
-
},
|
| 64 |
-
"Nairobi": {
|
| 65 |
-
"temp_mean": 24.5, "temp_amp": 3.0, "phase_doy": 55,
|
| 66 |
-
"humidity_mean": 57.0, "humidity_amp": 12.0,
|
| 67 |
-
"temp_range": (20.0, 28.0),
|
| 68 |
-
},
|
| 69 |
-
"Kigali": {
|
| 70 |
-
"temp_mean": 25.5, "temp_amp": 2.5, "phase_doy": 50,
|
| 71 |
-
"humidity_mean": 65.0, "humidity_amp": 8.0,
|
| 72 |
-
"temp_range": (22.0, 29.0),
|
| 73 |
-
},
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
FEATURE_NAMES = [
|
| 77 |
-
"temp_max_c", "humidity_pct", "wind_speed_ms",
|
| 78 |
-
"wbgt_c", "heat_index_c", "temp_anomaly",
|
| 79 |
-
]
|
| 80 |
-
|
| 81 |
-
NUM_FEATURES = len(FEATURE_NAMES)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def _resolve_path(rel: str) -> Path:
|
| 85 |
-
"""Resolve a path relative to the project root."""
|
| 86 |
-
return Path(__file__).resolve().parents[2] / rel
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# ======================================================================
|
| 90 |
-
# Model
|
| 91 |
-
# ======================================================================
|
| 92 |
-
|
| 93 |
-
if TORCH_AVAILABLE:
|
| 94 |
-
|
| 95 |
-
class HeatLSTM(nn.Module):
|
| 96 |
-
"""2-layer LSTM for 7-day heat wave trigger prediction."""
|
| 97 |
-
|
| 98 |
-
def __init__(
|
| 99 |
-
self,
|
| 100 |
-
input_size: int = NUM_FEATURES,
|
| 101 |
-
hidden_size: int = 64,
|
| 102 |
-
num_layers: int = 2,
|
| 103 |
-
dropout: float = 0.2,
|
| 104 |
-
):
|
| 105 |
-
super().__init__()
|
| 106 |
-
self.lstm = nn.LSTM(
|
| 107 |
-
input_size, hidden_size, num_layers,
|
| 108 |
-
batch_first=True, dropout=dropout,
|
| 109 |
-
)
|
| 110 |
-
self.fc = nn.Linear(hidden_size, 1)
|
| 111 |
-
|
| 112 |
-
def forward(self, x):
|
| 113 |
-
# x: (batch, seq_len, input_size)
|
| 114 |
-
out, _ = self.lstm(x)
|
| 115 |
-
out = self.fc(out[:, -1, :]) # last timestep
|
| 116 |
-
return torch.sigmoid(out).squeeze(-1)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# ======================================================================
|
| 120 |
-
# Derived feature computation
|
| 121 |
-
# ======================================================================
|
| 122 |
-
|
| 123 |
-
def _compute_temp_anomaly(temps: list[float], index: int) -> float:
|
| 124 |
-
"""Compute temp minus 7-day rolling mean at the given index."""
|
| 125 |
-
start = max(0, index - 6)
|
| 126 |
-
window = temps[start:index + 1]
|
| 127 |
-
if not window:
|
| 128 |
-
return 0.0
|
| 129 |
-
return temps[index] - float(np.mean(window))
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
# ======================================================================
|
| 133 |
-
# Synthetic data generation
|
| 134 |
-
# ======================================================================
|
| 135 |
-
|
| 136 |
-
def _generate_temp_series(climate: dict, n_days: int, rng) -> list[float]:
|
| 137 |
-
"""Daily max temperatures with AR(1) autocorrelation."""
|
| 138 |
-
mean, amp, phase = climate["temp_mean"], climate["temp_amp"], climate["phase_doy"]
|
| 139 |
-
lo, hi = climate["temp_range"]
|
| 140 |
-
temps, noise = [], 0.0
|
| 141 |
-
for d in range(n_days):
|
| 142 |
-
seasonal = mean + amp * np.cos(2 * np.pi * (d - phase) / 365.0)
|
| 143 |
-
noise = 0.7 * noise + rng.normal(0, 1.2)
|
| 144 |
-
temps.append(float(np.clip(seasonal + noise, lo - 1.0, hi + 2.0)))
|
| 145 |
-
return temps
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def _generate_humidity_series(climate: dict, n_days: int, rng) -> list[float]:
|
| 149 |
-
"""Daily humidity with seasonal cycle and noise."""
|
| 150 |
-
mean, amp = climate["humidity_mean"], climate["humidity_amp"]
|
| 151 |
-
phase = climate.get("phase_doy", 45)
|
| 152 |
-
humidity, noise = [], 0.0
|
| 153 |
-
for d in range(n_days):
|
| 154 |
-
seasonal = mean - amp * np.cos(2 * np.pi * (d - phase) / 365.0)
|
| 155 |
-
noise = 0.5 * noise + rng.normal(0, 4.0)
|
| 156 |
-
humidity.append(float(np.clip(seasonal + noise, 30.0, 98.0)))
|
| 157 |
-
return humidity
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def _generate_wind_series(n_days: int, rng) -> list[float]:
|
| 161 |
-
"""Synthetic wind speed series (m/s)."""
|
| 162 |
-
winds, noise = [], 0.0
|
| 163 |
-
for _ in range(n_days):
|
| 164 |
-
noise = 0.4 * noise + rng.normal(0, 0.8)
|
| 165 |
-
winds.append(float(np.clip(3.5 + noise, 0.5, 12.0)))
|
| 166 |
-
return winds
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def _label_triggers(temps: list[float], threshold: float, n_days: int) -> list[int]:
|
| 170 |
-
"""Label each day: 1 if 2+ consecutive days above threshold in next 7 days."""
|
| 171 |
-
labels = [0] * n_days
|
| 172 |
-
for day in range(n_days - 7):
|
| 173 |
-
window = temps[day + 1: day + 8]
|
| 174 |
-
consec = 0
|
| 175 |
-
triggered = False
|
| 176 |
-
for t in window:
|
| 177 |
-
if t > threshold:
|
| 178 |
-
consec += 1
|
| 179 |
-
if consec >= 2:
|
| 180 |
-
triggered = True
|
| 181 |
-
break
|
| 182 |
-
else:
|
| 183 |
-
consec = 0
|
| 184 |
-
labels[day] = 1 if triggered else 0
|
| 185 |
-
return labels
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def generate_synthetic_zone_data(
|
| 189 |
-
zones: list, n_days: int = 730, seed: int = 42,
|
| 190 |
-
) -> dict[str, list[dict]]:
|
| 191 |
-
"""Generate synthetic daily climate data for all zones.
|
| 192 |
-
|
| 193 |
-
Returns:
|
| 194 |
-
dict mapping zone_id -> list of daily dicts with keys:
|
| 195 |
-
temp_max_c, humidity_pct, wind_speed_ms, city
|
| 196 |
-
"""
|
| 197 |
-
rng = np.random.default_rng(seed)
|
| 198 |
-
zone_data: dict[str, list[dict]] = {}
|
| 199 |
-
|
| 200 |
-
for zone in zones:
|
| 201 |
-
city = zone.city
|
| 202 |
-
climate = CITY_CLIMATE.get(city, CITY_CLIMATE["Nairobi"])
|
| 203 |
-
|
| 204 |
-
temps = _generate_temp_series(climate, n_days, rng)
|
| 205 |
-
humidity = _generate_humidity_series(climate, n_days, rng)
|
| 206 |
-
winds = _generate_wind_series(n_days, rng)
|
| 207 |
-
|
| 208 |
-
days = []
|
| 209 |
-
for i in range(n_days):
|
| 210 |
-
days.append({
|
| 211 |
-
"temp_max_c": temps[i],
|
| 212 |
-
"humidity_pct": humidity[i],
|
| 213 |
-
"wind_speed_ms": winds[i],
|
| 214 |
-
"city": city,
|
| 215 |
-
})
|
| 216 |
-
zone_data[zone.zone_id] = days
|
| 217 |
-
|
| 218 |
-
return zone_data
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
# ======================================================================
|
| 222 |
-
# Trainer
|
| 223 |
-
# ======================================================================
|
| 224 |
-
|
| 225 |
-
class LSTMTrainer:
|
| 226 |
-
"""Train the HeatLSTM on historical or synthetic climate data."""
|
| 227 |
-
|
| 228 |
-
def __init__(
|
| 229 |
-
self,
|
| 230 |
-
model: object | None = None,
|
| 231 |
-
lr: float = 0.001,
|
| 232 |
-
epochs: int = 50,
|
| 233 |
-
patience: int = 5,
|
| 234 |
-
seq_len: int = 14,
|
| 235 |
-
forecast_horizon: int = 7,
|
| 236 |
-
):
|
| 237 |
-
if not TORCH_AVAILABLE:
|
| 238 |
-
raise ImportError("torch is required. pip install torch")
|
| 239 |
-
self._custom_model = model
|
| 240 |
-
self.lr = lr
|
| 241 |
-
self.epochs = epochs
|
| 242 |
-
self.patience = patience
|
| 243 |
-
self.seq_len = seq_len
|
| 244 |
-
self.forecast_horizon = forecast_horizon
|
| 245 |
-
self.model_path = _resolve_path("models/heat_lstm.pt")
|
| 246 |
-
self.norm_path = _resolve_path("models/lstm_norm.json")
|
| 247 |
-
|
| 248 |
-
def prepare_data(
|
| 249 |
-
self, zone_readings: dict[str, list],
|
| 250 |
-
) -> tuple:
|
| 251 |
-
"""Convert zone readings to training tensors.
|
| 252 |
-
|
| 253 |
-
Args:
|
| 254 |
-
zone_readings: dict of zone_id -> list of daily readings.
|
| 255 |
-
Each reading needs: temp_max_c, humidity_pct, wind_speed_ms
|
| 256 |
-
Optional: date, city
|
| 257 |
-
|
| 258 |
-
Returns:
|
| 259 |
-
(X_train, y_train, X_val, y_val) as torch tensors
|
| 260 |
-
"""
|
| 261 |
-
all_seqs: list[np.ndarray] = []
|
| 262 |
-
all_labels: list[int] = []
|
| 263 |
-
|
| 264 |
-
for zone_id, days in zone_readings.items():
|
| 265 |
-
n = len(days)
|
| 266 |
-
if n < self.seq_len + self.forecast_horizon + 1:
|
| 267 |
-
continue
|
| 268 |
-
|
| 269 |
-
city = days[0].get("city", "Nairobi")
|
| 270 |
-
threshold = CITY_THRESHOLDS.get(city, 33.0)
|
| 271 |
-
|
| 272 |
-
# Extract raw temps for labeling and anomaly computation
|
| 273 |
-
temps = [d["temp_max_c"] for d in days]
|
| 274 |
-
labels = _label_triggers(temps, threshold, n)
|
| 275 |
-
|
| 276 |
-
# Compute derived features for all days
|
| 277 |
-
derived = []
|
| 278 |
-
for i, d in enumerate(days):
|
| 279 |
-
t = d["temp_max_c"]
|
| 280 |
-
h = d["humidity_pct"]
|
| 281 |
-
w = d["wind_speed_ms"]
|
| 282 |
-
wbgt = calculate_wbgt(t, h)
|
| 283 |
-
hi = calculate_heat_index(t, h)
|
| 284 |
-
anomaly = _compute_temp_anomaly(temps, i)
|
| 285 |
-
derived.append([t, h, w, wbgt, hi, anomaly])
|
| 286 |
-
|
| 287 |
-
# Create sliding windows
|
| 288 |
-
for i in range(n - self.seq_len - self.forecast_horizon):
|
| 289 |
-
seq = np.array(
|
| 290 |
-
derived[i: i + self.seq_len], dtype=np.float32,
|
| 291 |
-
)
|
| 292 |
-
all_seqs.append(seq)
|
| 293 |
-
all_labels.append(labels[i + self.seq_len - 1])
|
| 294 |
-
|
| 295 |
-
X = np.stack(all_seqs) # (N, seq_len, 6)
|
| 296 |
-
y = np.array(all_labels, dtype=np.float32)
|
| 297 |
-
|
| 298 |
-
# Temporal split: first 75% train, last 25% validation
|
| 299 |
-
split = int(len(X) * 0.75)
|
| 300 |
-
X_train_np, X_val_np = X[:split], X[split:]
|
| 301 |
-
y_train_np, y_val_np = y[:split], y[split:]
|
| 302 |
-
|
| 303 |
-
# Compute normalization (z-score per feature) from training set
|
| 304 |
-
flat = X_train_np.reshape(-1, NUM_FEATURES)
|
| 305 |
-
feat_mean = flat.mean(axis=0).tolist()
|
| 306 |
-
feat_std = flat.std(axis=0).tolist()
|
| 307 |
-
feat_std = [max(s, 1e-6) for s in feat_std]
|
| 308 |
-
|
| 309 |
-
# Save normalization params
|
| 310 |
-
norm = {"mean": feat_mean, "std": feat_std}
|
| 311 |
-
self.norm_path.parent.mkdir(parents=True, exist_ok=True)
|
| 312 |
-
with open(self.norm_path, "w") as f:
|
| 313 |
-
json.dump(norm, f, indent=2)
|
| 314 |
-
|
| 315 |
-
# Normalize
|
| 316 |
-
mean_arr = np.array(feat_mean, dtype=np.float32)
|
| 317 |
-
std_arr = np.array(feat_std, dtype=np.float32)
|
| 318 |
-
X_train_np = (X_train_np - mean_arr) / std_arr
|
| 319 |
-
X_val_np = (X_val_np - mean_arr) / std_arr
|
| 320 |
-
|
| 321 |
-
X_train = torch.from_numpy(X_train_np)
|
| 322 |
-
y_train = torch.from_numpy(y_train_np)
|
| 323 |
-
X_val = torch.from_numpy(X_val_np)
|
| 324 |
-
y_val = torch.from_numpy(y_val_np)
|
| 325 |
-
|
| 326 |
-
return X_train, y_train, X_val, y_val
|
| 327 |
-
|
| 328 |
-
def train(self, zone_readings: dict[str, list]) -> dict:
|
| 329 |
-
"""Train the LSTM and return metrics."""
|
| 330 |
-
torch.manual_seed(42)
|
| 331 |
-
np.random.seed(42)
|
| 332 |
-
|
| 333 |
-
X_train, y_train, X_val, y_val = self.prepare_data(zone_readings)
|
| 334 |
-
|
| 335 |
-
# DataLoaders
|
| 336 |
-
train_ds = TensorDataset(X_train, y_train)
|
| 337 |
-
val_ds = TensorDataset(X_val, y_val)
|
| 338 |
-
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
|
| 339 |
-
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False)
|
| 340 |
-
|
| 341 |
-
# Model, loss, optimizer
|
| 342 |
-
model = self._custom_model if self._custom_model is not None else HeatLSTM()
|
| 343 |
-
criterion = nn.BCELoss()
|
| 344 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
|
| 345 |
-
|
| 346 |
-
total_params = sum(p.numel() for p in model.parameters())
|
| 347 |
-
print(f" Model params: {total_params:,}")
|
| 348 |
-
|
| 349 |
-
# Training loop with early stopping
|
| 350 |
-
best_val_loss = float("inf")
|
| 351 |
-
patience_counter = 0
|
| 352 |
-
best_state = None
|
| 353 |
-
best_metrics: dict = {}
|
| 354 |
-
|
| 355 |
-
for epoch in range(self.epochs):
|
| 356 |
-
# Train
|
| 357 |
-
model.train()
|
| 358 |
-
train_loss, train_total = 0.0, 0
|
| 359 |
-
for xb, yb in train_loader:
|
| 360 |
-
optimizer.zero_grad()
|
| 361 |
-
preds = model(xb)
|
| 362 |
-
loss = criterion(preds, yb)
|
| 363 |
-
loss.backward()
|
| 364 |
-
optimizer.step()
|
| 365 |
-
train_loss += loss.item() * len(xb)
|
| 366 |
-
train_total += len(xb)
|
| 367 |
-
|
| 368 |
-
# Validate
|
| 369 |
-
model.eval()
|
| 370 |
-
val_loss, val_total = 0.0, 0
|
| 371 |
-
all_val_preds, all_val_labels = [], []
|
| 372 |
-
with torch.no_grad():
|
| 373 |
-
for xb, yb in val_loader:
|
| 374 |
-
preds = model(xb)
|
| 375 |
-
loss = criterion(preds, yb)
|
| 376 |
-
val_loss += loss.item() * len(xb)
|
| 377 |
-
val_total += len(xb)
|
| 378 |
-
all_val_preds.extend(preds.numpy().tolist())
|
| 379 |
-
all_val_labels.extend(yb.numpy().tolist())
|
| 380 |
-
|
| 381 |
-
avg_train_loss = train_loss / max(train_total, 1)
|
| 382 |
-
avg_val_loss = val_loss / max(val_total, 1)
|
| 383 |
-
val_auroc = _compute_auroc(all_val_labels, all_val_preds)
|
| 384 |
-
|
| 385 |
-
if (epoch + 1) % 5 == 0 or epoch == 0:
|
| 386 |
-
print(
|
| 387 |
-
f" Epoch {epoch + 1:>2}: "
|
| 388 |
-
f"train_loss={avg_train_loss:.4f} | "
|
| 389 |
-
f"val_loss={avg_val_loss:.4f} val_auroc={val_auroc:.3f}"
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
# Early stopping
|
| 393 |
-
if avg_val_loss < best_val_loss:
|
| 394 |
-
best_val_loss = avg_val_loss
|
| 395 |
-
patience_counter = 0
|
| 396 |
-
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
| 397 |
-
best_metrics = {
|
| 398 |
-
"train_loss": round(avg_train_loss, 4),
|
| 399 |
-
"val_loss": round(avg_val_loss, 4),
|
| 400 |
-
"val_auroc": round(val_auroc, 4),
|
| 401 |
-
"epochs_trained": epoch + 1,
|
| 402 |
-
"samples": {
|
| 403 |
-
"train": len(X_train),
|
| 404 |
-
"val": len(X_val),
|
| 405 |
-
},
|
| 406 |
-
}
|
| 407 |
-
else:
|
| 408 |
-
patience_counter += 1
|
| 409 |
-
if patience_counter >= self.patience:
|
| 410 |
-
print(f" Early stopping at epoch {epoch + 1}")
|
| 411 |
-
break
|
| 412 |
-
|
| 413 |
-
# Save best model
|
| 414 |
-
if best_state is not None:
|
| 415 |
-
model.load_state_dict(best_state)
|
| 416 |
-
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 417 |
-
torch.save(model.state_dict(), self.model_path)
|
| 418 |
-
|
| 419 |
-
file_size = self.model_path.stat().st_size
|
| 420 |
-
print(f" Saved model to {self.model_path} ({file_size / 1024:.1f} KB)")
|
| 421 |
-
|
| 422 |
-
return best_metrics
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
# ======================================================================
|
| 426 |
-
# Predictor (inference)
|
| 427 |
-
# ======================================================================
|
| 428 |
-
|
| 429 |
-
class LSTMPredictor:
|
| 430 |
-
"""Load trained HeatLSTM and predict trigger probability."""
|
| 431 |
-
|
| 432 |
-
def __init__(
|
| 433 |
-
self,
|
| 434 |
-
model_path: str = "models/heat_lstm.pt",
|
| 435 |
-
norm_path: str = "models/lstm_norm.json",
|
| 436 |
-
):
|
| 437 |
-
if not TORCH_AVAILABLE:
|
| 438 |
-
raise ImportError("torch is required. pip install torch")
|
| 439 |
-
self.model_path = _resolve_path(model_path)
|
| 440 |
-
self.norm_path = _resolve_path(norm_path)
|
| 441 |
-
self.model: HeatLSTM | None = None
|
| 442 |
-
self._norm: dict | None = None
|
| 443 |
-
self._load_model()
|
| 444 |
-
|
| 445 |
-
def _load_model(self) -> None:
|
| 446 |
-
"""Load model weights and normalization params from disk."""
|
| 447 |
-
mp = self.model_path
|
| 448 |
-
np_ = self.norm_path
|
| 449 |
-
|
| 450 |
-
if not mp.exists():
|
| 451 |
-
raise FileNotFoundError(
|
| 452 |
-
f"LSTM model not found at {mp}. "
|
| 453 |
-
"Run scripts/train_lstm.py first."
|
| 454 |
-
)
|
| 455 |
-
if not np_.exists():
|
| 456 |
-
raise FileNotFoundError(
|
| 457 |
-
f"LSTM normalization params not found at {np_}. "
|
| 458 |
-
"Run scripts/train_lstm.py first."
|
| 459 |
-
)
|
| 460 |
-
|
| 461 |
-
self.model = HeatLSTM()
|
| 462 |
-
self.model.load_state_dict(
|
| 463 |
-
torch.load(mp, map_location="cpu", weights_only=True)
|
| 464 |
-
)
|
| 465 |
-
self.model.eval()
|
| 466 |
-
|
| 467 |
-
with open(np_) as f:
|
| 468 |
-
self._norm = json.load(f)
|
| 469 |
-
|
| 470 |
-
def predict(self, recent_14_days: list[dict]) -> tuple[float, float]:
|
| 471 |
-
"""Predict trigger probability from last 14 days of data.
|
| 472 |
-
|
| 473 |
-
Args:
|
| 474 |
-
recent_14_days: list of dicts with keys:
|
| 475 |
-
temp_max_c, humidity_pct, wind_speed_ms
|
| 476 |
-
(WBGT, heat index, and temp anomaly are computed internally)
|
| 477 |
-
|
| 478 |
-
Returns:
|
| 479 |
-
(probability, confidence) where:
|
| 480 |
-
- probability: 0-1 trigger probability
|
| 481 |
-
- confidence: 0-1 based on MC dropout (5 forward passes)
|
| 482 |
-
"""
|
| 483 |
-
if len(recent_14_days) < 14:
|
| 484 |
-
pad = [recent_14_days[0]] * (14 - len(recent_14_days))
|
| 485 |
-
recent_14_days = pad + recent_14_days
|
| 486 |
-
|
| 487 |
-
days = recent_14_days[-14:]
|
| 488 |
-
|
| 489 |
-
# Extract temps for anomaly computation
|
| 490 |
-
temps = [d.get("temp_max_c", d.get("temp_c", 30.0)) for d in days]
|
| 491 |
-
|
| 492 |
-
# Build feature array: compute derived features
|
| 493 |
-
seq = []
|
| 494 |
-
for i, d in enumerate(days):
|
| 495 |
-
t = d.get("temp_max_c", d.get("temp_c", 30.0))
|
| 496 |
-
h = d.get("humidity_pct", 65.0)
|
| 497 |
-
w = d.get("wind_speed_ms", 3.0)
|
| 498 |
-
wbgt = d.get("wbgt_c", calculate_wbgt(t, h))
|
| 499 |
-
hi = d.get("heat_index_c", calculate_heat_index(t, h))
|
| 500 |
-
anomaly = _compute_temp_anomaly(temps, i)
|
| 501 |
-
seq.append([t, h, w, wbgt, hi, anomaly])
|
| 502 |
-
|
| 503 |
-
x = np.array(seq, dtype=np.float32)
|
| 504 |
-
|
| 505 |
-
# Normalize using saved params
|
| 506 |
-
mean = np.array(self._norm["mean"], dtype=np.float32)
|
| 507 |
-
std = np.array(self._norm["std"], dtype=np.float32)
|
| 508 |
-
x = (x - mean) / std
|
| 509 |
-
|
| 510 |
-
x_tensor = torch.from_numpy(x).unsqueeze(0) # (1, 14, 6)
|
| 511 |
-
|
| 512 |
-
# MC Dropout: 5 forward passes batched with dropout enabled
|
| 513 |
-
self.model.train()
|
| 514 |
-
x_batch = x_tensor.expand(5, -1, -1) # (5, 14, 6)
|
| 515 |
-
with torch.no_grad():
|
| 516 |
-
preds = self.model(x_batch).numpy()
|
| 517 |
-
self.model.eval()
|
| 518 |
-
|
| 519 |
-
probability = float(np.mean(preds))
|
| 520 |
-
std_val = float(np.std(preds))
|
| 521 |
-
confidence = max(0.3, min(0.95, 1.0 - std_val * 3))
|
| 522 |
-
|
| 523 |
-
probability = float(np.clip(probability, 0.0, 1.0))
|
| 524 |
-
|
| 525 |
-
return probability, confidence
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
# ======================================================================
|
| 529 |
-
# Utilities
|
| 530 |
-
# ======================================================================
|
| 531 |
-
|
| 532 |
-
def _compute_auroc(labels: list[float], preds: list[float]) -> float:
|
| 533 |
-
"""Compute AUROC using sklearn if available, else trapezoidal fallback."""
|
| 534 |
-
if len(set(labels)) < 2:
|
| 535 |
-
return 0.5
|
| 536 |
-
|
| 537 |
-
try:
|
| 538 |
-
from sklearn.metrics import roc_auc_score
|
| 539 |
-
return float(roc_auc_score(labels, preds))
|
| 540 |
-
except ImportError:
|
| 541 |
-
pass
|
| 542 |
-
|
| 543 |
-
# Fallback: trapezoidal AUROC
|
| 544 |
-
pairs = sorted(zip(preds, labels), key=lambda x: -x[0])
|
| 545 |
-
tp, fp = 0, 0
|
| 546 |
-
tp_prev, fp_prev = 0, 0
|
| 547 |
-
auc = 0.0
|
| 548 |
-
n_pos = sum(labels)
|
| 549 |
-
n_neg = len(labels) - n_pos
|
| 550 |
-
|
| 551 |
-
if n_pos == 0 or n_neg == 0:
|
| 552 |
-
return 0.5
|
| 553 |
-
|
| 554 |
-
prev_score = None
|
| 555 |
-
for score, label in pairs:
|
| 556 |
-
if score != prev_score and prev_score is not None:
|
| 557 |
-
auc += (fp - fp_prev) * (tp + tp_prev) / 2.0
|
| 558 |
-
tp_prev, fp_prev = tp, fp
|
| 559 |
-
if label == 1.0:
|
| 560 |
-
tp += 1
|
| 561 |
-
else:
|
| 562 |
-
fp += 1
|
| 563 |
-
prev_score = score
|
| 564 |
-
|
| 565 |
-
auc += (fp - fp_prev) * (tp + tp_prev) / 2.0
|
| 566 |
-
return auc / (n_pos * n_neg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,1312 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Neural Actuarial Pricing Engine for Parametric Heat Insurance.
|
| 3 |
-
|
| 4 |
-
Three-headed temporal neural model trained on real climate data:
|
| 5 |
-
1. HazardHead (Neural EVT): Learns extreme heat frequency + severity
|
| 6 |
-
distribution using Generalized Pareto Distribution parameters
|
| 7 |
-
2. VulnerabilityHead: Learns zone-specific worker impact calibrated
|
| 8 |
-
to WHO/ILO occupational heat stress guidelines (ISO 7243)
|
| 9 |
-
3. PricingHead (CANN): Combined Actuarial Neural Network — GLM baseline
|
| 10 |
-
with bounded neural correction. When δ_NN = 0, reproduces the
|
| 11 |
-
existing ActuarialPricer formula exactly.
|
| 12 |
-
|
| 13 |
-
City-specific: trained per city (default: Dar es Salaam). Architecture
|
| 14 |
-
is portable; weights are local. Retrain for any city with:
|
| 15 |
-
python3 scripts/train_neural_pricer.py --city "Kampala"
|
| 16 |
-
|
| 17 |
-
References:
|
| 18 |
-
- Chen et al. (2024) "Managing Weather Risk with NN-Based Index Insurance"
|
| 19 |
-
Management Science 70(7), 4306-4327
|
| 20 |
-
- Pasche & Engelke (2024) "Neural Networks for Extreme Quantile Regression"
|
| 21 |
-
arXiv:2208.07590 (EQRN)
|
| 22 |
-
- Richman & Wuthrich (2023) "LocalGLMnet" Scandinavian Actuarial Journal
|
| 23 |
-
- ISO 7243 / NIOSH occupational heat stress thresholds
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
from __future__ import annotations
|
| 27 |
-
|
| 28 |
-
import json
|
| 29 |
-
import logging
|
| 30 |
-
from pathlib import Path
|
| 31 |
-
from typing import Optional
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
|
| 35 |
-
try:
|
| 36 |
-
import torch
|
| 37 |
-
import torch.nn as nn
|
| 38 |
-
import torch.nn.functional as F
|
| 39 |
-
TORCH_AVAILABLE = True
|
| 40 |
-
except ImportError:
|
| 41 |
-
TORCH_AVAILABLE = False
|
| 42 |
-
|
| 43 |
-
# Chronos foundation model (optional — used for Chronos encoder path)
|
| 44 |
-
try:
|
| 45 |
-
from chronos import ChronosBoltPipeline
|
| 46 |
-
CHRONOS_AVAILABLE = True
|
| 47 |
-
except ImportError:
|
| 48 |
-
CHRONOS_AVAILABLE = False
|
| 49 |
-
|
| 50 |
-
from src.indexing.heat_index import calculate_wbgt
|
| 51 |
-
from src.pricing.actuarial import ActuarialPricer, ActuarialResult
|
| 52 |
-
|
| 53 |
-
log = logging.getLogger(__name__)
|
| 54 |
-
|
| 55 |
-
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 56 |
-
|
| 57 |
-
# ── City-specific trigger thresholds (WBGT, °C) ──────────────────────────
|
| 58 |
-
# Based on ILO occupational heat stress measurements
|
| 59 |
-
# Actuarial trigger thresholds: set at ~P85-P90 of each city's WBGT distribution
|
| 60 |
-
# so that only genuinely extreme heat events are trigger-worthy.
|
| 61 |
-
# These are HIGHER than occupational safety thresholds (which define "uncomfortable")
|
| 62 |
-
# because insurance triggers need to define "exceptional" events.
|
| 63 |
-
CITY_WBGT_THRESHOLDS = {
|
| 64 |
-
"Dar es Salaam": 35.1, # Calibrated P97 threshold (matches insurance benchmark)
|
| 65 |
-
"Kampala": 30.0,
|
| 66 |
-
"Nairobi": 28.0,
|
| 67 |
-
"Kigali": 29.0,
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
# Settlement-type-specific WBGT thresholds for compound triggers.
|
| 71 |
-
# UHI correction already differentiates zones (informal +2.5°C, formal +0.5°C),
|
| 72 |
-
# so all types use the city threshold. Zone frequency differences come from UHI.
|
| 73 |
-
SETTLEMENT_THRESHOLDS = {
|
| 74 |
-
"informal": 35.1,
|
| 75 |
-
"mixed": 35.1,
|
| 76 |
-
"formal": 35.1,
|
| 77 |
-
"commercial": 35.1,
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
# Minimum consecutive days above threshold to count as a heat EVENT.
|
| 81 |
-
# Alert tier at 2 days, payout tier at 5 days (matches insurance benchmark).
|
| 82 |
-
MIN_CONSECUTIVE_DAYS = 2
|
| 83 |
-
|
| 84 |
-
# ── WHO/ILO dose-response: WBGT → productivity loss ──────────────────────
|
| 85 |
-
# ISO 7243 thresholds for moderate-to-heavy outdoor work
|
| 86 |
-
def who_productivity_loss(wbgt: float) -> float:
|
| 87 |
-
"""Fractional productivity loss given WBGT (°C). ISO 7243 calibrated."""
|
| 88 |
-
if wbgt < 26.0:
|
| 89 |
-
return 0.0
|
| 90 |
-
elif wbgt < 28.0:
|
| 91 |
-
return 0.10
|
| 92 |
-
elif wbgt < 30.0:
|
| 93 |
-
return 0.25
|
| 94 |
-
elif wbgt < 32.0:
|
| 95 |
-
return 0.50
|
| 96 |
-
elif wbgt < 35.0:
|
| 97 |
-
return 0.75
|
| 98 |
-
else:
|
| 99 |
-
return 1.0 # work must stop
|
| 100 |
-
|
| 101 |
-
# ── Settlement-type UHI ranges (from literature) ─────────────────────────
|
| 102 |
-
from src.downscaling.uhi_model import UHI_RANGES
|
| 103 |
-
|
| 104 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 105 |
-
# PyTorch Model Components
|
| 106 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 107 |
-
|
| 108 |
-
if TORCH_AVAILABLE:
|
| 109 |
-
|
| 110 |
-
class TemporalEncoder(nn.Module):
|
| 111 |
-
"""
|
| 112 |
-
LSTM encoder for 90-day climate sequences.
|
| 113 |
-
|
| 114 |
-
Input: (batch, 90, 11) — 7 climate + 4 zone-static features
|
| 115 |
-
Output: (batch, 128) — last hidden state
|
| 116 |
-
"""
|
| 117 |
-
|
| 118 |
-
def __init__(self, input_size: int = 11, hidden_size: int = 128,
|
| 119 |
-
num_layers: int = 2, dropout: float = 0.3):
|
| 120 |
-
super().__init__()
|
| 121 |
-
self.lstm = nn.LSTM(
|
| 122 |
-
input_size, hidden_size, num_layers,
|
| 123 |
-
batch_first=True, dropout=dropout,
|
| 124 |
-
)
|
| 125 |
-
self.norm = nn.LayerNorm(hidden_size)
|
| 126 |
-
|
| 127 |
-
def forward(self, x):
|
| 128 |
-
out, (h_n, _) = self.lstm(x)
|
| 129 |
-
# Use last layer's hidden state
|
| 130 |
-
latent = h_n[-1] # (batch, hidden_size)
|
| 131 |
-
return self.norm(latent)
|
| 132 |
-
|
| 133 |
-
class ChronosEncoder(nn.Module):
|
| 134 |
-
"""
|
| 135 |
-
Chronos-Bolt foundation model encoder for 90-day climate sequences.
|
| 136 |
-
|
| 137 |
-
Embeds the primary heat stress signal (WBGT, channel 6) via frozen
|
| 138 |
-
Chronos-Bolt-Tiny (9M params, pre-trained on 100B+ observations),
|
| 139 |
-
then concatenates climate summary stats and zone-static features
|
| 140 |
-
before projecting to the same 128-dim latent as TemporalEncoder.
|
| 141 |
-
|
| 142 |
-
Input: (batch, 90, 11) — same contract as TemporalEncoder
|
| 143 |
-
Output: (batch, 128) — same contract as TemporalEncoder
|
| 144 |
-
"""
|
| 145 |
-
|
| 146 |
-
WBGT_IDX = 6 # index of WBGT in 11-feature vector
|
| 147 |
-
N_CLIMATE = 7 # features 0-6 are climate
|
| 148 |
-
N_STATIC = 4 # features 7-10 are zone-static
|
| 149 |
-
|
| 150 |
-
def __init__(self, hidden_size: int = 128, chronos_d_model: int = 256):
|
| 151 |
-
super().__init__()
|
| 152 |
-
proj_input = chronos_d_model + self.N_CLIMATE + self.N_STATIC
|
| 153 |
-
self.proj = nn.Linear(proj_input, hidden_size)
|
| 154 |
-
self.act = nn.GELU()
|
| 155 |
-
self.norm = nn.LayerNorm(hidden_size)
|
| 156 |
-
self._chronos_d_model = chronos_d_model
|
| 157 |
-
# Set externally after construction — NOT part of state_dict
|
| 158 |
-
self._pipeline = None
|
| 159 |
-
self._feat_mean = None # numpy (11,) for un-normalizing WBGT
|
| 160 |
-
self._feat_std = None
|
| 161 |
-
|
| 162 |
-
def set_pipeline(self, pipeline, feat_mean=None, feat_std=None):
|
| 163 |
-
"""Attach the frozen Chronos pipeline (not a nn.Module)."""
|
| 164 |
-
self._pipeline = pipeline
|
| 165 |
-
self._feat_mean = feat_mean
|
| 166 |
-
self._feat_std = feat_std
|
| 167 |
-
|
| 168 |
-
def _unnorm_wbgt(self, x_norm):
|
| 169 |
-
"""Recover raw WBGT values from z-scored tensor for Chronos input."""
|
| 170 |
-
wbgt_norm = x_norm[:, :, self.WBGT_IDX] # (batch, 90)
|
| 171 |
-
if self._feat_mean is not None and self._feat_std is not None:
|
| 172 |
-
mu = self._feat_mean[self.WBGT_IDX]
|
| 173 |
-
sd = self._feat_std[self.WBGT_IDX]
|
| 174 |
-
return wbgt_norm * sd + mu # back to pre-norm scale (wbgt/40)
|
| 175 |
-
return wbgt_norm
|
| 176 |
-
|
| 177 |
-
def _embed_wbgt(self, x_norm):
|
| 178 |
-
"""Run Chronos .embed() on un-normalized WBGT and mean-pool."""
|
| 179 |
-
wbgt_scaled = self._unnorm_wbgt(x_norm) # (batch, 90), scale=wbgt/40
|
| 180 |
-
wbgt_raw = wbgt_scaled * 40.0 # raw °C for Chronos
|
| 181 |
-
with torch.no_grad():
|
| 182 |
-
emb, _ = self._pipeline.embed(wbgt_raw) # (batch, patches+1, 256)
|
| 183 |
-
return emb.mean(dim=1) # (batch, 256)
|
| 184 |
-
|
| 185 |
-
def forward(self, x, chronos_embeddings=None):
|
| 186 |
-
"""
|
| 187 |
-
Args:
|
| 188 |
-
x: (batch, 90, 11) normalized climate sequence
|
| 189 |
-
chronos_embeddings: optional (batch, d_model) pre-computed.
|
| 190 |
-
If None, computes from x via the attached pipeline.
|
| 191 |
-
"""
|
| 192 |
-
climate_means = x[:, :, :self.N_CLIMATE].mean(dim=1) # (batch, 7)
|
| 193 |
-
zone_static = x[:, 0, self.N_CLIMATE:] # (batch, 4)
|
| 194 |
-
|
| 195 |
-
if chronos_embeddings is None:
|
| 196 |
-
chronos_embeddings = self._embed_wbgt(x)
|
| 197 |
-
|
| 198 |
-
combined = torch.cat([chronos_embeddings, climate_means, zone_static], dim=-1)
|
| 199 |
-
return self.norm(self.act(self.proj(combined))) # (batch, 128)
|
| 200 |
-
|
| 201 |
-
class HazardHead(nn.Module):
|
| 202 |
-
"""
|
| 203 |
-
Neural Extreme Value Theory head + two-tier trigger decision.
|
| 204 |
-
|
| 205 |
-
Two tiers matching SEWA/Arsht-Rockefeller pilot design:
|
| 206 |
-
Alert tier: moderate heat event → cash transfer (philanthropy-funded)
|
| 207 |
-
Payout tier: severe sustained event → insurance payout (underwritten)
|
| 208 |
-
|
| 209 |
-
Outputs:
|
| 210 |
-
λ (events/year): softplus, range ~0.5-50
|
| 211 |
-
σ (GPD scale): softplus, range ~0.1-10
|
| 212 |
-
ξ (GPD shape): tanh × 0.4, range [-0.4, 0.4]
|
| 213 |
-
alert_prob: sigmoid → [0, 1], moderate event (cash tier)
|
| 214 |
-
payout_prob: sigmoid → [0, 1], severe event (insurance tier)
|
| 215 |
-
alert_severity: sigmoid → [0, 1], cash amount scaling
|
| 216 |
-
payout_severity: sigmoid → [0, 1], insurance payout scaling
|
| 217 |
-
"""
|
| 218 |
-
|
| 219 |
-
def __init__(self, input_size: int = 128):
|
| 220 |
-
super().__init__()
|
| 221 |
-
self.net = nn.Sequential(
|
| 222 |
-
nn.Linear(input_size, 64),
|
| 223 |
-
nn.GELU(),
|
| 224 |
-
nn.Linear(64, 32),
|
| 225 |
-
nn.GELU(),
|
| 226 |
-
)
|
| 227 |
-
self.lambda_head = nn.Linear(32, 1)
|
| 228 |
-
self.sigma_head = nn.Linear(32, 1)
|
| 229 |
-
self.xi_head = nn.Linear(32, 1)
|
| 230 |
-
self.alert_prob_head = nn.Linear(32, 1)
|
| 231 |
-
self.payout_prob_head = nn.Linear(32, 1)
|
| 232 |
-
self.alert_severity_head = nn.Linear(32, 1)
|
| 233 |
-
self.payout_severity_head = nn.Linear(32, 1)
|
| 234 |
-
|
| 235 |
-
def forward(self, h):
|
| 236 |
-
z = self.net(h)
|
| 237 |
-
lambda_ = F.softplus(self.lambda_head(z)) + 0.5
|
| 238 |
-
sigma = F.softplus(self.sigma_head(z)) + 0.1
|
| 239 |
-
xi = torch.tanh(self.xi_head(z)) * 0.4
|
| 240 |
-
alert_prob = torch.sigmoid(self.alert_prob_head(z))
|
| 241 |
-
payout_prob = torch.sigmoid(self.payout_prob_head(z))
|
| 242 |
-
alert_severity = torch.sigmoid(self.alert_severity_head(z))
|
| 243 |
-
payout_severity = torch.sigmoid(self.payout_severity_head(z))
|
| 244 |
-
# Pack into trigger_prob and payout_factor for backward compat
|
| 245 |
-
# trigger_prob = alert_prob (most frequent), payout_factor = payout_severity
|
| 246 |
-
return lambda_, sigma, xi, alert_prob, payout_prob, alert_severity, payout_severity
|
| 247 |
-
|
| 248 |
-
class VulnerabilityHead(nn.Module):
|
| 249 |
-
"""
|
| 250 |
-
Neural loss distribution head.
|
| 251 |
-
|
| 252 |
-
Learns zone-specific worker impact from hazard characteristics.
|
| 253 |
-
Calibrated to WHO/ILO ISO 7243 dose-response.
|
| 254 |
-
|
| 255 |
-
Outputs:
|
| 256 |
-
productivity_loss: sigmoid → [0, 1]
|
| 257 |
-
basis_risk: sigmoid → [0, 1]
|
| 258 |
-
severity_multiplier: softplus + 0.5 → [0.5, ∞)
|
| 259 |
-
"""
|
| 260 |
-
|
| 261 |
-
def __init__(self, input_size: int = 135): # 128 + 7 hazard outputs
|
| 262 |
-
super().__init__()
|
| 263 |
-
self.net = nn.Sequential(
|
| 264 |
-
nn.Linear(input_size, 64),
|
| 265 |
-
nn.GELU(),
|
| 266 |
-
nn.Linear(64, 32),
|
| 267 |
-
nn.GELU(),
|
| 268 |
-
)
|
| 269 |
-
self.prod_loss_head = nn.Linear(32, 1)
|
| 270 |
-
self.basis_risk_head = nn.Linear(32, 1)
|
| 271 |
-
self.severity_head = nn.Linear(32, 1)
|
| 272 |
-
|
| 273 |
-
def forward(self, h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev):
|
| 274 |
-
combined = torch.cat([h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev], dim=-1)
|
| 275 |
-
z = self.net(combined)
|
| 276 |
-
prod_loss = torch.sigmoid(self.prod_loss_head(z))
|
| 277 |
-
basis_risk = torch.sigmoid(self.basis_risk_head(z))
|
| 278 |
-
severity = F.softplus(self.severity_head(z)) + 0.5
|
| 279 |
-
return prod_loss, basis_risk, severity
|
| 280 |
-
|
| 281 |
-
class PricingHead(nn.Module):
|
| 282 |
-
"""
|
| 283 |
-
CANN (Combined Actuarial Neural Network) pricing head.
|
| 284 |
-
|
| 285 |
-
Skip connection: price = exp(η_GLM + δ_NN) × inflation_buffer
|
| 286 |
-
|
| 287 |
-
When δ_NN = 0, reproduces ActuarialPricer formula exactly.
|
| 288 |
-
δ_NN is clamped to [-0.5, 0.5] for safety (±50% max correction).
|
| 289 |
-
"""
|
| 290 |
-
|
| 291 |
-
def __init__(self, input_size: int = 138, max_delta: float = 0.5): # 128 + 7 hazard + 3 vuln
|
| 292 |
-
super().__init__()
|
| 293 |
-
self.max_delta = max_delta
|
| 294 |
-
self.net = nn.Sequential(
|
| 295 |
-
nn.Linear(input_size, 32),
|
| 296 |
-
nn.GELU(),
|
| 297 |
-
nn.Linear(32, 1),
|
| 298 |
-
)
|
| 299 |
-
# Initialize near zero so model starts at GLM baseline
|
| 300 |
-
nn.init.zeros_(self.net[-1].weight)
|
| 301 |
-
nn.init.zeros_(self.net[-1].bias)
|
| 302 |
-
|
| 303 |
-
def forward(self, h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev, prod_loss, basis_risk, severity):
|
| 304 |
-
combined = torch.cat([
|
| 305 |
-
h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev, prod_loss, basis_risk, severity
|
| 306 |
-
], dim=-1)
|
| 307 |
-
delta_nn = self.net(combined)
|
| 308 |
-
delta_nn = torch.clamp(delta_nn, -self.max_delta, self.max_delta)
|
| 309 |
-
return delta_nn
|
| 310 |
-
|
| 311 |
-
class HeatRiskNeuralPricer(nn.Module):
|
| 312 |
-
"""
|
| 313 |
-
Full neural actuarial pricing model.
|
| 314 |
-
|
| 315 |
-
Composes TemporalEncoder + HazardHead + VulnerabilityHead + PricingHead.
|
| 316 |
-
~70K parameters. Trains in ~15 min on CPU.
|
| 317 |
-
"""
|
| 318 |
-
|
| 319 |
-
def __init__(self, input_size: int = 11, hidden_size: int = 128):
|
| 320 |
-
super().__init__()
|
| 321 |
-
self.encoder = TemporalEncoder(input_size, hidden_size)
|
| 322 |
-
self.hazard = HazardHead(hidden_size)
|
| 323 |
-
self.vulnerability = VulnerabilityHead(hidden_size + 3)
|
| 324 |
-
self.pricing = PricingHead(hidden_size + 6)
|
| 325 |
-
|
| 326 |
-
def forward(self, x, payout_per_event: float = 10.0,
|
| 327 |
-
admin_rate: float = 0.15):
|
| 328 |
-
"""
|
| 329 |
-
Args:
|
| 330 |
-
x: (batch, 90, 11) — daily climate sequence with zone features
|
| 331 |
-
payout_per_event: USD per event per worker
|
| 332 |
-
admin_rate: operational overhead rate
|
| 333 |
-
|
| 334 |
-
Returns:
|
| 335 |
-
dict with all intermediate and final outputs
|
| 336 |
-
"""
|
| 337 |
-
# Encode temporal sequence
|
| 338 |
-
h = self.encoder(x) # (batch, 128)
|
| 339 |
-
|
| 340 |
-
# Hazard: frequency + severity + two-tier trigger
|
| 341 |
-
lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev = self.hazard(h)
|
| 342 |
-
|
| 343 |
-
# Vulnerability: worker impact
|
| 344 |
-
prod_loss, basis_risk, severity = self.vulnerability(
|
| 345 |
-
h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
# GPD expected severity
|
| 349 |
-
xi_safe = torch.clamp(xi, max=0.39)
|
| 350 |
-
expected_severity = sigma / (1.0 - xi_safe)
|
| 351 |
-
|
| 352 |
-
# GLM baseline
|
| 353 |
-
base_cost = lambda_ * payout_per_event
|
| 354 |
-
basis_loading = base_cost * (basis_risk * 0.5)
|
| 355 |
-
vuln_loading = base_cost * (prod_loss * 0.2)
|
| 356 |
-
subtotal = base_cost + basis_loading + vuln_loading
|
| 357 |
-
admin_loading = subtotal * admin_rate
|
| 358 |
-
glm_total = subtotal + admin_loading
|
| 359 |
-
eta_glm = torch.log(glm_total + 1e-8)
|
| 360 |
-
|
| 361 |
-
# CANN: neural correction
|
| 362 |
-
delta_nn = self.pricing(
|
| 363 |
-
h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev,
|
| 364 |
-
prod_loss, basis_risk, severity
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
total_per_worker = torch.exp(eta_glm + delta_nn) * 1.05
|
| 368 |
-
|
| 369 |
-
return {
|
| 370 |
-
"lambda_": lambda_.squeeze(-1),
|
| 371 |
-
"sigma": sigma.squeeze(-1),
|
| 372 |
-
"xi": xi.squeeze(-1),
|
| 373 |
-
"alert_prob": alert_prob.squeeze(-1),
|
| 374 |
-
"payout_prob": payout_prob.squeeze(-1),
|
| 375 |
-
"alert_severity": alert_sev.squeeze(-1),
|
| 376 |
-
"payout_severity": payout_sev.squeeze(-1),
|
| 377 |
-
"productivity_loss": prod_loss.squeeze(-1),
|
| 378 |
-
"basis_risk": basis_risk.squeeze(-1),
|
| 379 |
-
"severity_multiplier": severity.squeeze(-1),
|
| 380 |
-
"expected_severity": expected_severity.squeeze(-1),
|
| 381 |
-
"delta_nn": delta_nn.squeeze(-1),
|
| 382 |
-
"glm_price": (glm_total * 1.05).squeeze(-1),
|
| 383 |
-
"total_per_worker": total_per_worker.squeeze(-1),
|
| 384 |
-
"base_cost": base_cost.squeeze(-1),
|
| 385 |
-
"basis_loading": basis_loading.squeeze(-1),
|
| 386 |
-
"vuln_loading": vuln_loading.squeeze(-1),
|
| 387 |
-
"admin_loading": admin_loading.squeeze(-1),
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
class HeatRiskNeuralPricerChronos(nn.Module):
|
| 391 |
-
"""
|
| 392 |
-
Chronos-enhanced neural actuarial pricing model.
|
| 393 |
-
|
| 394 |
-
Same 3-head architecture as HeatRiskNeuralPricer but replaces the
|
| 395 |
-
LSTM TemporalEncoder with frozen Chronos-Bolt-Tiny embeddings.
|
| 396 |
-
Only the projection layer + heads are trainable (~50K params).
|
| 397 |
-
|
| 398 |
-
The Chronos foundation model (9M params, pre-trained on 100B+
|
| 399 |
-
time-series observations) captures deep temporal patterns in the
|
| 400 |
-
WBGT heat stress signal, while climate summary stats and zone
|
| 401 |
-
features provide domain context.
|
| 402 |
-
"""
|
| 403 |
-
|
| 404 |
-
def __init__(self, hidden_size: int = 128, chronos_d_model: int = 256):
|
| 405 |
-
super().__init__()
|
| 406 |
-
self.encoder = ChronosEncoder(hidden_size, chronos_d_model)
|
| 407 |
-
self.hazard = HazardHead(hidden_size)
|
| 408 |
-
self.vulnerability = VulnerabilityHead(hidden_size + 7) # + 7 hazard outputs
|
| 409 |
-
self.pricing = PricingHead(hidden_size + 10, max_delta=1.5) # + 7 hazard + 3 vuln
|
| 410 |
-
|
| 411 |
-
def forward(self, x, payout_per_event: float = 10.0,
|
| 412 |
-
admin_rate: float = 0.15, chronos_embeddings=None):
|
| 413 |
-
h = self.encoder(x, chronos_embeddings=chronos_embeddings)
|
| 414 |
-
|
| 415 |
-
lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev = self.hazard(h)
|
| 416 |
-
prod_loss, basis_risk, severity = self.vulnerability(
|
| 417 |
-
h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
xi_safe = torch.clamp(xi, max=0.39)
|
| 421 |
-
expected_severity = sigma / (1.0 - xi_safe)
|
| 422 |
-
|
| 423 |
-
base_cost = lambda_ * payout_per_event
|
| 424 |
-
basis_loading = base_cost * (basis_risk * 0.5)
|
| 425 |
-
vuln_loading = base_cost * (prod_loss * 0.2)
|
| 426 |
-
subtotal = base_cost + basis_loading + vuln_loading
|
| 427 |
-
admin_loading = subtotal * admin_rate
|
| 428 |
-
glm_total = subtotal + admin_loading
|
| 429 |
-
eta_glm = torch.log(glm_total + 1e-8)
|
| 430 |
-
|
| 431 |
-
delta_nn = self.pricing(
|
| 432 |
-
h, lambda_, sigma, xi, alert_prob, payout_prob, alert_sev, payout_sev,
|
| 433 |
-
prod_loss, basis_risk, severity
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
total_per_worker = torch.exp(eta_glm + delta_nn) * 1.05
|
| 437 |
-
|
| 438 |
-
return {
|
| 439 |
-
"lambda_": lambda_.squeeze(-1),
|
| 440 |
-
"sigma": sigma.squeeze(-1),
|
| 441 |
-
"xi": xi.squeeze(-1),
|
| 442 |
-
"alert_prob": alert_prob.squeeze(-1),
|
| 443 |
-
"payout_prob": payout_prob.squeeze(-1),
|
| 444 |
-
"alert_severity": alert_sev.squeeze(-1),
|
| 445 |
-
"payout_severity": payout_sev.squeeze(-1),
|
| 446 |
-
"productivity_loss": prod_loss.squeeze(-1),
|
| 447 |
-
"basis_risk": basis_risk.squeeze(-1),
|
| 448 |
-
"severity_multiplier": severity.squeeze(-1),
|
| 449 |
-
"expected_severity": expected_severity.squeeze(-1),
|
| 450 |
-
"delta_nn": delta_nn.squeeze(-1),
|
| 451 |
-
"glm_price": (glm_total * 1.05).squeeze(-1),
|
| 452 |
-
"total_per_worker": total_per_worker.squeeze(-1),
|
| 453 |
-
"base_cost": base_cost.squeeze(-1),
|
| 454 |
-
"basis_loading": basis_loading.squeeze(-1),
|
| 455 |
-
"vuln_loading": vuln_loading.squeeze(-1),
|
| 456 |
-
"admin_loading": admin_loading.squeeze(-1),
|
| 457 |
-
}
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 461 |
-
# Training Data Generation
|
| 462 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 463 |
-
|
| 464 |
-
def load_climate_data(data_path: str | Path) -> dict[str, list[dict]]:
|
| 465 |
-
"""Load NASA POWER or ERA5-Land daily data from JSON."""
|
| 466 |
-
with open(data_path) as f:
|
| 467 |
-
return json.load(f)
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
def build_training_samples(
|
| 471 |
-
climate_data: dict[str, list[dict]],
|
| 472 |
-
zones: list,
|
| 473 |
-
window_size: int = 90,
|
| 474 |
-
stride: int = 7,
|
| 475 |
-
wbgt_threshold: float = 28.0,
|
| 476 |
-
) -> tuple[np.ndarray, dict[str, np.ndarray]]:
|
| 477 |
-
"""
|
| 478 |
-
Build training dataset from daily climate records.
|
| 479 |
-
|
| 480 |
-
Creates sliding windows of `window_size` days, with targets computed
|
| 481 |
-
from the window's heat characteristics.
|
| 482 |
-
|
| 483 |
-
Returns:
|
| 484 |
-
X: (N, window_size, 11) — feature sequences
|
| 485 |
-
targets: dict of target arrays, each shape (N,)
|
| 486 |
-
"""
|
| 487 |
-
zone_map = {z.zone_id: z for z in zones}
|
| 488 |
-
|
| 489 |
-
all_X = []
|
| 490 |
-
all_targets = {
|
| 491 |
-
"frequency": [],
|
| 492 |
-
"gpd_sigma": [],
|
| 493 |
-
"gpd_xi": [],
|
| 494 |
-
"productivity_loss": [],
|
| 495 |
-
"basis_risk": [],
|
| 496 |
-
"severity_multiplier": [],
|
| 497 |
-
"price_target": [],
|
| 498 |
-
"alert_event": [], # 1 if moderate heat event (cash tier)
|
| 499 |
-
"payout_event": [], # 1 if severe sustained event (insurance tier)
|
| 500 |
-
"alert_severity": [], # 0-1 cash amount scaling
|
| 501 |
-
"payout_severity": [], # 0-1 insurance payout scaling
|
| 502 |
-
}
|
| 503 |
-
|
| 504 |
-
for zone_id, records in climate_data.items():
|
| 505 |
-
zone = zone_map.get(zone_id)
|
| 506 |
-
if zone is None:
|
| 507 |
-
continue
|
| 508 |
-
|
| 509 |
-
# Encode zone static features
|
| 510 |
-
# Settlement type is NOT included — the model should learn zone
|
| 511 |
-
# differences from UHI-corrected temperatures in the sequence,
|
| 512 |
-
# not from a categorical label that creates spurious correlations.
|
| 513 |
-
vuln_enc = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 514 |
-
zone_static = [
|
| 515 |
-
vuln_enc.get(zone.heat_vulnerability, 0.5),
|
| 516 |
-
zone.outdoor_exposure_pct,
|
| 517 |
-
zone.elevation_m / 2000.0,
|
| 518 |
-
0.0, # padding to keep feature dim at 11
|
| 519 |
-
]
|
| 520 |
-
|
| 521 |
-
# UHI parameters for this zone
|
| 522 |
-
uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
|
| 523 |
-
mean_uhi = (uhi_lo + uhi_hi) / 2.0
|
| 524 |
-
|
| 525 |
-
n = len(records)
|
| 526 |
-
if n < window_size + 30:
|
| 527 |
-
continue
|
| 528 |
-
|
| 529 |
-
for start in range(0, n - window_size - 7, stride):
|
| 530 |
-
window = records[start:start + window_size]
|
| 531 |
-
|
| 532 |
-
# Apply UHI correction so each zone sees different temps
|
| 533 |
-
# even when the underlying grid data is the same
|
| 534 |
-
rng_uhi = np.random.RandomState(hash(zone_id) & 0x7FFFFFFF)
|
| 535 |
-
uhi_noise_std = (uhi_hi - uhi_lo) / 4.0 # small daily variation
|
| 536 |
-
|
| 537 |
-
seq = []
|
| 538 |
-
wbgts = [] # UHI-corrected (what workers feel)
|
| 539 |
-
grid_wbgts = [] # raw grid (what satellite measures)
|
| 540 |
-
for i, day in enumerate(window):
|
| 541 |
-
t_max_grid = day.get("temp_max_c") or 30.0
|
| 542 |
-
t_min = day.get("temp_min_c") or 24.0
|
| 543 |
-
hum = day.get("humidity_pct") or 75.0
|
| 544 |
-
wind = day.get("wind_speed_ms") or 3.0
|
| 545 |
-
solar = day.get("solar_rad_wm2") or 200.0
|
| 546 |
-
precip = day.get("precip_mm") or 0.0
|
| 547 |
-
|
| 548 |
-
# Zone-specific UHI correction on temperature
|
| 549 |
-
uhi_delta = mean_uhi + rng_uhi.normal(0, uhi_noise_std)
|
| 550 |
-
t_max = t_max_grid + uhi_delta
|
| 551 |
-
|
| 552 |
-
wbgt = calculate_wbgt(t_max, hum)
|
| 553 |
-
wbgts.append(wbgt)
|
| 554 |
-
|
| 555 |
-
grid_wbgts.append(calculate_wbgt(t_max_grid, hum))
|
| 556 |
-
|
| 557 |
-
features = [
|
| 558 |
-
t_max / 40.0, # UHI-corrected temp (zone-specific!)
|
| 559 |
-
t_min / 30.0,
|
| 560 |
-
hum / 100.0,
|
| 561 |
-
wind / 10.0,
|
| 562 |
-
solar / 400.0,
|
| 563 |
-
precip / 50.0,
|
| 564 |
-
wbgt / 40.0, # UHI-corrected WBGT
|
| 565 |
-
] + zone_static
|
| 566 |
-
|
| 567 |
-
seq.append(features)
|
| 568 |
-
|
| 569 |
-
all_X.append(seq)
|
| 570 |
-
|
| 571 |
-
# ── Compute targets ──
|
| 572 |
-
|
| 573 |
-
# 1. Hazard: frequency from compound triggers (consecutive days above threshold)
|
| 574 |
-
zone_thresh = SETTLEMENT_THRESHOLDS.get(zone.settlement_type, wbgt_threshold)
|
| 575 |
-
run_length = 0
|
| 576 |
-
run_peak = 0.0
|
| 577 |
-
events = []
|
| 578 |
-
for w in wbgts:
|
| 579 |
-
if w > zone_thresh:
|
| 580 |
-
run_length += 1
|
| 581 |
-
run_peak = max(run_peak, w)
|
| 582 |
-
else:
|
| 583 |
-
if run_length >= MIN_CONSECUTIVE_DAYS:
|
| 584 |
-
events.append(run_peak - zone_thresh) # severity = peak excess
|
| 585 |
-
run_length = 0
|
| 586 |
-
run_peak = 0.0
|
| 587 |
-
if run_length >= MIN_CONSECUTIVE_DAYS:
|
| 588 |
-
events.append(run_peak - zone_thresh)
|
| 589 |
-
|
| 590 |
-
exceedances = events # for GPD fitting
|
| 591 |
-
frequency = len(events) * (365.0 / window_size) # annualize
|
| 592 |
-
|
| 593 |
-
if len(exceedances) >= 3:
|
| 594 |
-
try:
|
| 595 |
-
from scipy.stats import genpareto
|
| 596 |
-
xi_fit, _, sigma_fit = genpareto.fit(exceedances, floc=0)
|
| 597 |
-
xi_fit = max(-0.4, min(0.4, xi_fit))
|
| 598 |
-
sigma_fit = max(0.1, min(10.0, sigma_fit))
|
| 599 |
-
except Exception:
|
| 600 |
-
sigma_fit = float(np.std(exceedances)) + 0.5
|
| 601 |
-
xi_fit = 0.1
|
| 602 |
-
else:
|
| 603 |
-
sigma_fit = 1.0
|
| 604 |
-
xi_fit = 0.05
|
| 605 |
-
|
| 606 |
-
# 2. Vulnerability: WHO/ILO dose-response on UHI-corrected WBGT
|
| 607 |
-
prod_losses = [who_productivity_loss(w) for w in wbgts]
|
| 608 |
-
mean_prod_loss = float(np.mean(prod_losses)) * zone.outdoor_exposure_pct
|
| 609 |
-
|
| 610 |
-
# 3. Basis risk: gap between grid trigger and UHI-corrected trigger
|
| 611 |
-
grid_triggers = sum(1 for w in grid_wbgts if w > wbgt_threshold)
|
| 612 |
-
corrected_triggers = sum(1 for w in wbgts if w > wbgt_threshold)
|
| 613 |
-
if corrected_triggers > 0:
|
| 614 |
-
false_negative_rate = max(0, corrected_triggers - grid_triggers) / corrected_triggers
|
| 615 |
-
else:
|
| 616 |
-
false_negative_rate = 0.0
|
| 617 |
-
basis_risk_score = 0.3 * false_negative_rate + 0.2 * (mean_uhi / 6.0)
|
| 618 |
-
basis_risk_score = min(1.0, basis_risk_score + 0.05) # floor
|
| 619 |
-
|
| 620 |
-
# 4. Severity multiplier: ratio of corrected to grid impact
|
| 621 |
-
grid_impact = sum(who_productivity_loss(w) for w in wbgts)
|
| 622 |
-
corrected_impact = sum(who_productivity_loss(w) for w in wbgts)
|
| 623 |
-
severity_mult = (corrected_impact / max(grid_impact, 0.01))
|
| 624 |
-
severity_mult = max(0.5, min(3.0, severity_mult))
|
| 625 |
-
|
| 626 |
-
# 5. Price target from the real ActuarialPricer
|
| 627 |
-
from src.pricing.actuarial import ActuarialPricer
|
| 628 |
-
_glm = ActuarialPricer()
|
| 629 |
-
payout = 10.0
|
| 630 |
-
glm_result = _glm.price_zone(
|
| 631 |
-
zone=zone,
|
| 632 |
-
predicted_frequency=frequency,
|
| 633 |
-
basis_risk_score=basis_risk_score,
|
| 634 |
-
payout_per_event=payout,
|
| 635 |
-
enrolled=zone.worker_population_est,
|
| 636 |
-
)
|
| 637 |
-
price_target = glm_result.cost_per_worker_year
|
| 638 |
-
|
| 639 |
-
# 6. Two-tier event detection (matching insurance benchmark)
|
| 640 |
-
# Uses the same threshold as the parametric trigger — duration
|
| 641 |
-
# is the discriminator, not a separate severity gate.
|
| 642 |
-
TRIGGER_WBGT = wbgt_threshold # 35.1°C for Dar es Salaam
|
| 643 |
-
last_7 = wbgts[-7:]
|
| 644 |
-
vuln_mult = {"high": 1.5, "moderate": 1.0, "low": 0.7}
|
| 645 |
-
v_mult = vuln_mult.get(zone.heat_vulnerability, 1.0)
|
| 646 |
-
|
| 647 |
-
# Count consecutive days above trigger threshold at end of window
|
| 648 |
-
consec_at_end = 0
|
| 649 |
-
for w in reversed(last_7):
|
| 650 |
-
if w > TRIGGER_WBGT:
|
| 651 |
-
consec_at_end += 1
|
| 652 |
-
else:
|
| 653 |
-
break
|
| 654 |
-
peak_wbgt = max(last_7) if last_7 else 0
|
| 655 |
-
|
| 656 |
-
# Alert tier: 2+ consecutive days above threshold
|
| 657 |
-
# Workers get cash transfer + safety SMS
|
| 658 |
-
alert_event = 1.0 if consec_at_end >= 2 else 0.0
|
| 659 |
-
if alert_event > 0:
|
| 660 |
-
peak_excess = max(0, peak_wbgt - TRIGGER_WBGT)
|
| 661 |
-
alert_sev = min(1.0, (consec_at_end / 5.0) * (peak_excess / 4.0) * v_mult)
|
| 662 |
-
else:
|
| 663 |
-
alert_sev = 0.0
|
| 664 |
-
|
| 665 |
-
# Payout tier: 5+ consecutive days above threshold
|
| 666 |
-
# Workers get full insurance payout
|
| 667 |
-
payout_event = 1.0 if consec_at_end >= 5 else 0.0
|
| 668 |
-
if payout_event > 0:
|
| 669 |
-
peak_excess = max(0, peak_wbgt - TRIGGER_WBGT)
|
| 670 |
-
payout_sev = min(1.0, (consec_at_end / 7.0) * (peak_excess / 3.0) * v_mult)
|
| 671 |
-
else:
|
| 672 |
-
payout_sev = 0.0
|
| 673 |
-
|
| 674 |
-
all_targets["frequency"].append(frequency)
|
| 675 |
-
all_targets["gpd_sigma"].append(sigma_fit)
|
| 676 |
-
all_targets["gpd_xi"].append(xi_fit)
|
| 677 |
-
all_targets["productivity_loss"].append(mean_prod_loss)
|
| 678 |
-
all_targets["basis_risk"].append(basis_risk_score)
|
| 679 |
-
all_targets["severity_multiplier"].append(severity_mult)
|
| 680 |
-
all_targets["price_target"].append(price_target)
|
| 681 |
-
all_targets["alert_event"].append(alert_event)
|
| 682 |
-
all_targets["payout_event"].append(payout_event)
|
| 683 |
-
all_targets["alert_severity"].append(alert_sev)
|
| 684 |
-
all_targets["payout_severity"].append(payout_sev)
|
| 685 |
-
|
| 686 |
-
X = np.array(all_X, dtype=np.float32)
|
| 687 |
-
targets = {k: np.array(v, dtype=np.float32) for k, v in all_targets.items()}
|
| 688 |
-
|
| 689 |
-
return X, targets
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 693 |
-
# Trainer
|
| 694 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 695 |
-
|
| 696 |
-
class NeuralPricerTrainer:
|
| 697 |
-
"""Train the HeatRiskNeuralPricer (LSTM or Chronos encoder) on climate data."""
|
| 698 |
-
|
| 699 |
-
def __init__(self, lr: float = 1e-3, epochs: int = 80,
|
| 700 |
-
patience: int = 10, weight_decay: float = 1e-4,
|
| 701 |
-
encoder: str = "lstm"):
|
| 702 |
-
if not TORCH_AVAILABLE:
|
| 703 |
-
raise ImportError("torch is required")
|
| 704 |
-
self.lr = lr
|
| 705 |
-
self.epochs = epochs
|
| 706 |
-
self.patience = patience
|
| 707 |
-
self.weight_decay = weight_decay
|
| 708 |
-
self.encoder = encoder
|
| 709 |
-
if encoder == "chronos":
|
| 710 |
-
self.model_path = PROJECT_ROOT / "models" / "chronos_pricer_dar.pt"
|
| 711 |
-
self.norm_path = PROJECT_ROOT / "models" / "chronos_pricer_dar_norm.json"
|
| 712 |
-
else:
|
| 713 |
-
self.model_path = PROJECT_ROOT / "models" / "neural_pricer_dar.pt"
|
| 714 |
-
self.norm_path = PROJECT_ROOT / "models" / "neural_pricer_dar_norm.json"
|
| 715 |
-
|
| 716 |
-
def train(self, X: np.ndarray, targets: dict[str, np.ndarray],
|
| 717 |
-
val_split: float = 0.2) -> dict:
|
| 718 |
-
"""
|
| 719 |
-
Train the model and return metrics.
|
| 720 |
-
|
| 721 |
-
Args:
|
| 722 |
-
X: (N, 90, 11) feature sequences
|
| 723 |
-
targets: dict of target arrays
|
| 724 |
-
val_split: fraction for validation (temporal split)
|
| 725 |
-
"""
|
| 726 |
-
torch.manual_seed(42)
|
| 727 |
-
np.random.seed(42)
|
| 728 |
-
|
| 729 |
-
N = len(X)
|
| 730 |
-
split = int(N * (1 - val_split))
|
| 731 |
-
|
| 732 |
-
# Temporal split (not random — avoids data leakage)
|
| 733 |
-
X_train, X_val = X[:split], X[split:]
|
| 734 |
-
t_train = {k: v[:split] for k, v in targets.items()}
|
| 735 |
-
t_val = {k: v[split:] for k, v in targets.items()}
|
| 736 |
-
|
| 737 |
-
# ── Pre-compute Chronos embeddings (before z-score norm) ──
|
| 738 |
-
chronos_train = chronos_val = None
|
| 739 |
-
chronos_d_model = 256 # default
|
| 740 |
-
if self.encoder == "chronos":
|
| 741 |
-
if not CHRONOS_AVAILABLE:
|
| 742 |
-
raise ImportError("chronos-forecasting required for --encoder chronos")
|
| 743 |
-
print(" Loading Chronos-Bolt-Tiny for embedding pre-computation...")
|
| 744 |
-
pipeline = ChronosBoltPipeline.from_pretrained(
|
| 745 |
-
"amazon/chronos-bolt-tiny", device_map="cpu",
|
| 746 |
-
dtype=torch.float32,
|
| 747 |
-
)
|
| 748 |
-
chronos_d_model = pipeline.model.config.d_model
|
| 749 |
-
print(f" Chronos d_model: {chronos_d_model}")
|
| 750 |
-
|
| 751 |
-
# Extract raw WBGT (before normalization): column 6, scaled as wbgt/40
|
| 752 |
-
wbgt_raw_all = X[:, :, 6] * 40.0 # (N, 90) in °C
|
| 753 |
-
print(f" Pre-computing Chronos embeddings for {N} samples...")
|
| 754 |
-
all_embs = []
|
| 755 |
-
chunk = 256
|
| 756 |
-
for i in range(0, N, chunk):
|
| 757 |
-
batch = torch.from_numpy(wbgt_raw_all[i:i + chunk].astype(np.float32))
|
| 758 |
-
with torch.no_grad():
|
| 759 |
-
emb, _ = pipeline.embed(batch)
|
| 760 |
-
all_embs.append(emb.mean(dim=1)) # (chunk, d_model)
|
| 761 |
-
chronos_all = torch.cat(all_embs, dim=0) # (N, d_model)
|
| 762 |
-
chronos_train = chronos_all[:split]
|
| 763 |
-
chronos_val = chronos_all[split:]
|
| 764 |
-
print(f" Chronos embeddings: {chronos_all.shape}")
|
| 765 |
-
del pipeline # free memory
|
| 766 |
-
|
| 767 |
-
# Normalize features (z-score from training set)
|
| 768 |
-
flat = X_train.reshape(-1, X_train.shape[-1])
|
| 769 |
-
feat_mean = flat.mean(axis=0)
|
| 770 |
-
feat_std = np.maximum(flat.std(axis=0), 1e-6)
|
| 771 |
-
|
| 772 |
-
self.norm_path.parent.mkdir(parents=True, exist_ok=True)
|
| 773 |
-
norm_data = {"mean": feat_mean.tolist(), "std": feat_std.tolist()}
|
| 774 |
-
if self.encoder == "chronos":
|
| 775 |
-
norm_data["chronos_d_model"] = chronos_d_model
|
| 776 |
-
with open(self.norm_path, "w") as f:
|
| 777 |
-
json.dump(norm_data, f)
|
| 778 |
-
|
| 779 |
-
X_train = (X_train - feat_mean) / feat_std
|
| 780 |
-
X_val = (X_val - feat_mean) / feat_std
|
| 781 |
-
|
| 782 |
-
# Convert to tensors
|
| 783 |
-
X_train_t = torch.from_numpy(X_train)
|
| 784 |
-
X_val_t = torch.from_numpy(X_val)
|
| 785 |
-
targets_train = {k: torch.from_numpy(v) for k, v in t_train.items()}
|
| 786 |
-
targets_val = {k: torch.from_numpy(v) for k, v in t_val.items()}
|
| 787 |
-
|
| 788 |
-
if self.encoder == "chronos":
|
| 789 |
-
model = HeatRiskNeuralPricerChronos(chronos_d_model=chronos_d_model)
|
| 790 |
-
else:
|
| 791 |
-
model = HeatRiskNeuralPricer()
|
| 792 |
-
optimizer = torch.optim.AdamW(
|
| 793 |
-
model.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
| 794 |
-
)
|
| 795 |
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 796 |
-
optimizer, patience=5, factor=0.5
|
| 797 |
-
)
|
| 798 |
-
|
| 799 |
-
total_params = sum(p.numel() for p in model.parameters())
|
| 800 |
-
print(f" Encoder: {self.encoder}")
|
| 801 |
-
print(f" Model parameters: {total_params:,}")
|
| 802 |
-
print(f" Training samples: {len(X_train)}, Validation: {len(X_val)}")
|
| 803 |
-
|
| 804 |
-
best_val_loss = float("inf")
|
| 805 |
-
patience_counter = 0
|
| 806 |
-
best_state = None
|
| 807 |
-
best_metrics = {}
|
| 808 |
-
|
| 809 |
-
batch_size = 256
|
| 810 |
-
n_batches = max(1, len(X_train) // batch_size)
|
| 811 |
-
|
| 812 |
-
for epoch in range(self.epochs):
|
| 813 |
-
# ── Train ──
|
| 814 |
-
model.train()
|
| 815 |
-
perm = torch.randperm(len(X_train_t))
|
| 816 |
-
epoch_loss = 0.0
|
| 817 |
-
|
| 818 |
-
for b in range(n_batches):
|
| 819 |
-
idx = perm[b * batch_size:(b + 1) * batch_size]
|
| 820 |
-
xb = X_train_t[idx]
|
| 821 |
-
tb = {k: v[idx] for k, v in targets_train.items()}
|
| 822 |
-
|
| 823 |
-
fwd_kwargs = {}
|
| 824 |
-
if chronos_train is not None:
|
| 825 |
-
fwd_kwargs["chronos_embeddings"] = chronos_train[idx]
|
| 826 |
-
|
| 827 |
-
outputs = model(xb, **fwd_kwargs)
|
| 828 |
-
loss = self._compute_loss(outputs, tb)
|
| 829 |
-
|
| 830 |
-
optimizer.zero_grad()
|
| 831 |
-
loss.backward()
|
| 832 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 833 |
-
optimizer.step()
|
| 834 |
-
epoch_loss += loss.item()
|
| 835 |
-
|
| 836 |
-
avg_train_loss = epoch_loss / n_batches
|
| 837 |
-
|
| 838 |
-
# ── Validate ──
|
| 839 |
-
model.eval()
|
| 840 |
-
with torch.no_grad():
|
| 841 |
-
val_kwargs = {}
|
| 842 |
-
if chronos_val is not None:
|
| 843 |
-
val_kwargs["chronos_embeddings"] = chronos_val
|
| 844 |
-
val_outputs = model(X_val_t, **val_kwargs)
|
| 845 |
-
val_loss = self._compute_loss(val_outputs, targets_val).item()
|
| 846 |
-
|
| 847 |
-
scheduler.step(val_loss)
|
| 848 |
-
|
| 849 |
-
if (epoch + 1) % 10 == 0 or epoch == 0:
|
| 850 |
-
delta_std = val_outputs["delta_nn"].std().item()
|
| 851 |
-
print(
|
| 852 |
-
f" Epoch {epoch + 1:>3}: "
|
| 853 |
-
f"train={avg_train_loss:.4f} val={val_loss:.4f} "
|
| 854 |
-
f"δ_NN_std={delta_std:.3f}"
|
| 855 |
-
)
|
| 856 |
-
|
| 857 |
-
if val_loss < best_val_loss:
|
| 858 |
-
best_val_loss = val_loss
|
| 859 |
-
patience_counter = 0
|
| 860 |
-
best_state = {k: v.clone() for k, v in model.state_dict().items()}
|
| 861 |
-
best_metrics = self._compute_metrics(val_outputs, targets_val, epoch + 1)
|
| 862 |
-
else:
|
| 863 |
-
patience_counter += 1
|
| 864 |
-
if patience_counter >= self.patience:
|
| 865 |
-
print(f" Early stopping at epoch {epoch + 1}")
|
| 866 |
-
break
|
| 867 |
-
|
| 868 |
-
# Save best model
|
| 869 |
-
if best_state:
|
| 870 |
-
model.load_state_dict(best_state)
|
| 871 |
-
self.model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 872 |
-
torch.save(model.state_dict(), self.model_path)
|
| 873 |
-
|
| 874 |
-
size_kb = self.model_path.stat().st_size / 1024
|
| 875 |
-
print(f" Saved to {self.model_path} ({size_kb:.0f} KB)")
|
| 876 |
-
|
| 877 |
-
return best_metrics
|
| 878 |
-
|
| 879 |
-
def _compute_loss(self, outputs, targets):
|
| 880 |
-
"""Multi-task loss combining hazard, vulnerability, and pricing."""
|
| 881 |
-
# Hazard: Poisson NLL for frequency
|
| 882 |
-
lambda_ = outputs["lambda_"]
|
| 883 |
-
freq_target = targets["frequency"]
|
| 884 |
-
poisson_nll = lambda_ - freq_target * torch.log(lambda_ + 1e-8)
|
| 885 |
-
|
| 886 |
-
# Hazard: GPD parameter MSE (as proxy for GPD deviance)
|
| 887 |
-
sigma_loss = F.mse_loss(outputs["sigma"], targets["gpd_sigma"])
|
| 888 |
-
xi_loss = F.mse_loss(outputs["xi"], targets["gpd_xi"])
|
| 889 |
-
|
| 890 |
-
# Two-tier triggers: BCE for event detection, MSE for severity
|
| 891 |
-
L_alert = F.binary_cross_entropy(
|
| 892 |
-
outputs["alert_prob"], targets["alert_event"]
|
| 893 |
-
)
|
| 894 |
-
L_payout = F.binary_cross_entropy(
|
| 895 |
-
outputs["payout_prob"], targets["payout_event"]
|
| 896 |
-
)
|
| 897 |
-
L_alert_sev = F.mse_loss(
|
| 898 |
-
outputs["alert_severity"], targets["alert_severity"]
|
| 899 |
-
)
|
| 900 |
-
L_payout_sev = F.mse_loss(
|
| 901 |
-
outputs["payout_severity"], targets["payout_severity"]
|
| 902 |
-
)
|
| 903 |
-
L_trigger = L_alert + L_payout
|
| 904 |
-
L_severity = L_alert_sev + L_payout_sev
|
| 905 |
-
|
| 906 |
-
L_hazard = poisson_nll.mean() + sigma_loss + xi_loss + L_trigger + L_severity
|
| 907 |
-
|
| 908 |
-
# Vulnerability: MSE vs WHO/ILO targets
|
| 909 |
-
L_vuln = (
|
| 910 |
-
F.mse_loss(outputs["productivity_loss"], targets["productivity_loss"])
|
| 911 |
-
+ F.mse_loss(outputs["basis_risk"], targets["basis_risk"])
|
| 912 |
-
+ F.mse_loss(outputs["severity_multiplier"], targets["severity_multiplier"])
|
| 913 |
-
)
|
| 914 |
-
|
| 915 |
-
# Pricing: log-MSE (only on samples with nonzero target)
|
| 916 |
-
price_pred = outputs["total_per_worker"]
|
| 917 |
-
price_target = targets["price_target"]
|
| 918 |
-
valid_mask = price_target > 1.0 # skip zero-frequency windows
|
| 919 |
-
if valid_mask.any():
|
| 920 |
-
L_pricing = F.mse_loss(
|
| 921 |
-
torch.log(price_pred[valid_mask] + 1.0),
|
| 922 |
-
torch.log(price_target[valid_mask] + 1.0),
|
| 923 |
-
)
|
| 924 |
-
else:
|
| 925 |
-
L_pricing = torch.tensor(0.0)
|
| 926 |
-
|
| 927 |
-
# Regularization: penalize large neural corrections
|
| 928 |
-
L_reg = 0.01 * (outputs["delta_nn"] ** 2).mean()
|
| 929 |
-
|
| 930 |
-
return 1.0 * L_hazard + 1.0 * L_vuln + 2.0 * L_pricing + 0.1 * L_reg
|
| 931 |
-
|
| 932 |
-
def _compute_metrics(self, outputs, targets, epoch):
|
| 933 |
-
"""Compute evaluation metrics from validation outputs."""
|
| 934 |
-
price_pred = outputs["total_per_worker"].detach().numpy()
|
| 935 |
-
price_target = targets["price_target"].numpy()
|
| 936 |
-
|
| 937 |
-
# MAPE (only on nonzero targets)
|
| 938 |
-
valid = price_target > 1.0
|
| 939 |
-
if valid.any():
|
| 940 |
-
mape = float(np.mean(np.abs(price_pred[valid] - price_target[valid]) / price_target[valid]) * 100)
|
| 941 |
-
else:
|
| 942 |
-
mape = float("nan")
|
| 943 |
-
|
| 944 |
-
# Spearman rank correlation
|
| 945 |
-
from scipy.stats import spearmanr
|
| 946 |
-
rho, _ = spearmanr(price_pred, price_target)
|
| 947 |
-
|
| 948 |
-
# Delta NN statistics
|
| 949 |
-
delta = outputs["delta_nn"].detach().numpy()
|
| 950 |
-
|
| 951 |
-
return {
|
| 952 |
-
"epoch": epoch,
|
| 953 |
-
"val_loss": float(outputs["total_per_worker"].mean().item()),
|
| 954 |
-
"price_mape_pct": round(mape, 1),
|
| 955 |
-
"rank_correlation": round(float(rho), 3),
|
| 956 |
-
"delta_nn_mean": round(float(np.mean(delta)), 4),
|
| 957 |
-
"delta_nn_std": round(float(np.std(delta)), 4),
|
| 958 |
-
"mean_lambda": round(float(outputs["lambda_"].mean().item()), 1),
|
| 959 |
-
"mean_basis_risk": round(float(outputs["basis_risk"].mean().item()), 3),
|
| 960 |
-
"mean_prod_loss": round(float(outputs["productivity_loss"].mean().item()), 3),
|
| 961 |
-
}
|
| 962 |
-
|
| 963 |
-
|
| 964 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 965 |
-
# Inference Wrapper (drop-in for ActuarialPricer)
|
| 966 |
-
# ══════════════════════════════════════════════════════════════════════════
|
| 967 |
-
|
| 968 |
-
class NeuralActuarialPricer:
|
| 969 |
-
"""
|
| 970 |
-
Drop-in replacement for ActuarialPricer.
|
| 971 |
-
|
| 972 |
-
Preserves the same price_zone() signature and returns ActuarialResult.
|
| 973 |
-
Three-tier fallback: Chronos encoder → LSTM encoder → GLM baseline.
|
| 974 |
-
"""
|
| 975 |
-
|
| 976 |
-
def __init__(
|
| 977 |
-
self,
|
| 978 |
-
admin_rate: float = 0.15,
|
| 979 |
-
):
|
| 980 |
-
self.admin_rate = admin_rate
|
| 981 |
-
self._fallback = ActuarialPricer(admin_rate)
|
| 982 |
-
self._model: Optional[object] = None
|
| 983 |
-
self._norm: Optional[dict] = None
|
| 984 |
-
self._chronos_pipeline = None
|
| 985 |
-
self._encoder_type = "glm"
|
| 986 |
-
self._trigger_head = None
|
| 987 |
-
|
| 988 |
-
if not TORCH_AVAILABLE:
|
| 989 |
-
log.warning("torch not available, using fallback pricer")
|
| 990 |
-
return
|
| 991 |
-
|
| 992 |
-
# Tier 1: Try Chronos encoder
|
| 993 |
-
chronos_mp = PROJECT_ROOT / "models" / "chronos_pricer_dar.pt"
|
| 994 |
-
chronos_np = PROJECT_ROOT / "models" / "chronos_pricer_dar_norm.json"
|
| 995 |
-
|
| 996 |
-
if CHRONOS_AVAILABLE and chronos_mp.exists() and chronos_np.exists():
|
| 997 |
-
try:
|
| 998 |
-
with open(chronos_np) as f:
|
| 999 |
-
self._norm = json.load(f)
|
| 1000 |
-
d_model = self._norm.get("chronos_d_model", 256)
|
| 1001 |
-
self._model = HeatRiskNeuralPricerChronos(chronos_d_model=d_model)
|
| 1002 |
-
self._model.load_state_dict(
|
| 1003 |
-
torch.load(chronos_mp, map_location="cpu", weights_only=True)
|
| 1004 |
-
)
|
| 1005 |
-
self._model.eval()
|
| 1006 |
-
# Load Chronos pipeline
|
| 1007 |
-
self._chronos_pipeline = ChronosBoltPipeline.from_pretrained(
|
| 1008 |
-
"amazon/chronos-bolt-tiny", device_map="cpu",
|
| 1009 |
-
dtype=torch.float32,
|
| 1010 |
-
)
|
| 1011 |
-
# Wire up the encoder
|
| 1012 |
-
feat_mean = np.array(self._norm["mean"], dtype=np.float32)
|
| 1013 |
-
feat_std = np.array(self._norm["std"], dtype=np.float32)
|
| 1014 |
-
self._model.encoder.set_pipeline(
|
| 1015 |
-
self._chronos_pipeline, feat_mean, feat_std
|
| 1016 |
-
)
|
| 1017 |
-
self._encoder_type = "chronos"
|
| 1018 |
-
log.info("Chronos neural actuarial pricer loaded (d_model=%d)", d_model)
|
| 1019 |
-
|
| 1020 |
-
# Load retrained trigger head (benchmark-driven, event-level)
|
| 1021 |
-
trigger_path = PROJECT_ROOT / "models" / "trigger_head_retrained.pt"
|
| 1022 |
-
if trigger_path.exists():
|
| 1023 |
-
try:
|
| 1024 |
-
ckpt = torch.load(trigger_path, map_location="cpu", weights_only=True)
|
| 1025 |
-
d_in = ckpt["d_model"]
|
| 1026 |
-
|
| 1027 |
-
class _TriggerHead(torch.nn.Module):
|
| 1028 |
-
def __init__(self, d_in):
|
| 1029 |
-
super().__init__()
|
| 1030 |
-
self.net = torch.nn.Sequential(
|
| 1031 |
-
torch.nn.Linear(d_in, 128), torch.nn.GELU(), torch.nn.Dropout(0.3),
|
| 1032 |
-
torch.nn.Linear(128, 64), torch.nn.GELU(), torch.nn.Dropout(0.2),
|
| 1033 |
-
torch.nn.Linear(64, 3),
|
| 1034 |
-
)
|
| 1035 |
-
def forward(self, x):
|
| 1036 |
-
return self.net(x)
|
| 1037 |
-
|
| 1038 |
-
self._trigger_head = _TriggerHead(d_in)
|
| 1039 |
-
self._trigger_head.load_state_dict(ckpt["state_dict"])
|
| 1040 |
-
self._trigger_head.eval()
|
| 1041 |
-
log.info("Retrained trigger head loaded (d_in=%d)", d_in)
|
| 1042 |
-
except Exception as e:
|
| 1043 |
-
log.warning("Retrained trigger head failed to load: %s", e)
|
| 1044 |
-
self._trigger_head = None
|
| 1045 |
-
|
| 1046 |
-
return
|
| 1047 |
-
except Exception as e:
|
| 1048 |
-
log.warning("Chronos pricer failed to load: %s — trying LSTM", e)
|
| 1049 |
-
self._model = None
|
| 1050 |
-
self._norm = None
|
| 1051 |
-
self._chronos_pipeline = None
|
| 1052 |
-
|
| 1053 |
-
# Tier 2: Try LSTM encoder
|
| 1054 |
-
lstm_mp = PROJECT_ROOT / "models" / "neural_pricer_dar.pt"
|
| 1055 |
-
lstm_np = PROJECT_ROOT / "models" / "neural_pricer_dar_norm.json"
|
| 1056 |
-
|
| 1057 |
-
if lstm_mp.exists() and lstm_np.exists():
|
| 1058 |
-
try:
|
| 1059 |
-
self._model = HeatRiskNeuralPricer()
|
| 1060 |
-
self._model.load_state_dict(
|
| 1061 |
-
torch.load(lstm_mp, map_location="cpu", weights_only=True)
|
| 1062 |
-
)
|
| 1063 |
-
self._model.eval()
|
| 1064 |
-
with open(lstm_np) as f:
|
| 1065 |
-
self._norm = json.load(f)
|
| 1066 |
-
self._encoder_type = "lstm"
|
| 1067 |
-
log.info("LSTM neural actuarial pricer loaded from %s", lstm_mp)
|
| 1068 |
-
except Exception as e:
|
| 1069 |
-
log.warning("LSTM pricer failed to load: %s — using GLM fallback", e)
|
| 1070 |
-
self._model = None
|
| 1071 |
-
else:
|
| 1072 |
-
log.info("No neural pricer weights found, using GLM fallback")
|
| 1073 |
-
|
| 1074 |
-
@property
|
| 1075 |
-
def is_neural(self) -> bool:
|
| 1076 |
-
return self._model is not None
|
| 1077 |
-
|
| 1078 |
-
def price_zone(
|
| 1079 |
-
self,
|
| 1080 |
-
zone,
|
| 1081 |
-
predicted_frequency: float,
|
| 1082 |
-
basis_risk_score: float,
|
| 1083 |
-
payout_per_event: float,
|
| 1084 |
-
enrolled: int,
|
| 1085 |
-
climate_history: Optional[list[dict]] = None,
|
| 1086 |
-
) -> ActuarialResult:
|
| 1087 |
-
"""
|
| 1088 |
-
Price a zone using the neural model if available.
|
| 1089 |
-
|
| 1090 |
-
Args:
|
| 1091 |
-
zone: UrbanZone from config
|
| 1092 |
-
predicted_frequency: annual trigger events (used by fallback)
|
| 1093 |
-
basis_risk_score: 0-1 (used by fallback)
|
| 1094 |
-
payout_per_event: USD per event per worker
|
| 1095 |
-
enrolled: number of workers
|
| 1096 |
-
climate_history: optional list of daily dicts with
|
| 1097 |
-
temp_max_c, temp_min_c, humidity_pct, wind_speed_ms, etc.
|
| 1098 |
-
If provided and model is loaded, uses neural pricing.
|
| 1099 |
-
"""
|
| 1100 |
-
if self._model is None or climate_history is None or len(climate_history) < 30:
|
| 1101 |
-
return self._fallback.price_zone(
|
| 1102 |
-
zone, predicted_frequency, basis_risk_score,
|
| 1103 |
-
payout_per_event, enrolled,
|
| 1104 |
-
)
|
| 1105 |
-
|
| 1106 |
-
try:
|
| 1107 |
-
return self._neural_price(
|
| 1108 |
-
zone, payout_per_event, enrolled, climate_history
|
| 1109 |
-
)
|
| 1110 |
-
except Exception as e:
|
| 1111 |
-
log.warning("Neural pricing failed for %s, using fallback: %s",
|
| 1112 |
-
zone.zone_id, e)
|
| 1113 |
-
return self._fallback.price_zone(
|
| 1114 |
-
zone, predicted_frequency, basis_risk_score,
|
| 1115 |
-
payout_per_event, enrolled,
|
| 1116 |
-
)
|
| 1117 |
-
|
| 1118 |
-
def _neural_price(
|
| 1119 |
-
self, zone, payout_per_event: float, enrolled: int,
|
| 1120 |
-
climate_history: list[dict],
|
| 1121 |
-
) -> ActuarialResult:
|
| 1122 |
-
"""Run the neural model and construct ActuarialResult."""
|
| 1123 |
-
# Build feature sequence (last 90 days)
|
| 1124 |
-
# Must match build_training_samples zone_static encoding
|
| 1125 |
-
vuln_enc = {"high": 1.0, "moderate": 0.5, "low": 0.0}
|
| 1126 |
-
zone_static = [
|
| 1127 |
-
vuln_enc.get(zone.heat_vulnerability, 0.5),
|
| 1128 |
-
zone.outdoor_exposure_pct,
|
| 1129 |
-
zone.elevation_m / 2000.0,
|
| 1130 |
-
0.0,
|
| 1131 |
-
]
|
| 1132 |
-
|
| 1133 |
-
history = climate_history[-90:]
|
| 1134 |
-
if len(history) < 90:
|
| 1135 |
-
# Pad with first record
|
| 1136 |
-
pad = [history[0]] * (90 - len(history))
|
| 1137 |
-
history = pad + history
|
| 1138 |
-
|
| 1139 |
-
# Apply UHI correction (same as training)
|
| 1140 |
-
uhi_lo, uhi_hi = UHI_RANGES.get(zone.settlement_type, (1.0, 2.0))
|
| 1141 |
-
mean_uhi = (uhi_lo + uhi_hi) / 2.0
|
| 1142 |
-
|
| 1143 |
-
seq = []
|
| 1144 |
-
for day in history:
|
| 1145 |
-
t_max_grid = day.get("temp_max_c") or day.get("temp_c") or 30.0
|
| 1146 |
-
t_max = float(t_max_grid) + mean_uhi # UHI-corrected
|
| 1147 |
-
t_min = day.get("temp_min_c") or t_max - 6.0
|
| 1148 |
-
hum = day.get("humidity_pct") or 75.0
|
| 1149 |
-
wind = day.get("wind_speed_ms") or 3.0
|
| 1150 |
-
solar = day.get("solar_rad_wm2") or 200.0
|
| 1151 |
-
precip = day.get("precip_mm") or 0.0
|
| 1152 |
-
wbgt = calculate_wbgt(t_max, float(hum))
|
| 1153 |
-
|
| 1154 |
-
seq.append([
|
| 1155 |
-
t_max / 40.0, float(t_min) / 30.0,
|
| 1156 |
-
float(hum) / 100.0, float(wind) / 10.0,
|
| 1157 |
-
float(solar) / 400.0, float(precip) / 50.0,
|
| 1158 |
-
wbgt / 40.0,
|
| 1159 |
-
] + zone_static)
|
| 1160 |
-
|
| 1161 |
-
x = np.array([seq], dtype=np.float32)
|
| 1162 |
-
|
| 1163 |
-
# Normalize
|
| 1164 |
-
if self._norm:
|
| 1165 |
-
mean = np.array(self._norm["mean"], dtype=np.float32)
|
| 1166 |
-
std = np.array(self._norm["std"], dtype=np.float32)
|
| 1167 |
-
x = (x - mean) / std
|
| 1168 |
-
|
| 1169 |
-
x_tensor = torch.from_numpy(x)
|
| 1170 |
-
|
| 1171 |
-
# For Chronos encoder: compute embedding from raw WBGT
|
| 1172 |
-
chronos_emb = None
|
| 1173 |
-
if self._encoder_type == "chronos" and self._chronos_pipeline is not None:
|
| 1174 |
-
wbgt_seq = np.array(
|
| 1175 |
-
[calculate_wbgt(
|
| 1176 |
-
float(day.get("temp_max_c") or day.get("temp_c") or 30.0)
|
| 1177 |
-
+ mean_uhi,
|
| 1178 |
-
float(day.get("humidity_pct") or 75.0),
|
| 1179 |
-
) for day in history],
|
| 1180 |
-
dtype=np.float32,
|
| 1181 |
-
)
|
| 1182 |
-
wbgt_tensor = torch.from_numpy(wbgt_seq).unsqueeze(0) # (1, 90)
|
| 1183 |
-
with torch.no_grad():
|
| 1184 |
-
emb, _ = self._chronos_pipeline.embed(wbgt_tensor)
|
| 1185 |
-
chronos_emb = emb.mean(dim=1) # (1, 256)
|
| 1186 |
-
|
| 1187 |
-
with torch.no_grad():
|
| 1188 |
-
outputs = self._model(
|
| 1189 |
-
x_tensor, payout_per_event, self.admin_rate,
|
| 1190 |
-
**({} if chronos_emb is None else {"chronos_embeddings": chronos_emb})
|
| 1191 |
-
)
|
| 1192 |
-
|
| 1193 |
-
# Extract values
|
| 1194 |
-
lambda_ = outputs["lambda_"].item()
|
| 1195 |
-
sigma = outputs["sigma"].item()
|
| 1196 |
-
xi = outputs["xi"].item()
|
| 1197 |
-
alert_prob = outputs["alert_prob"].item()
|
| 1198 |
-
payout_prob = outputs["payout_prob"].item()
|
| 1199 |
-
alert_severity = outputs["alert_severity"].item()
|
| 1200 |
-
payout_severity = outputs["payout_severity"].item()
|
| 1201 |
-
|
| 1202 |
-
# Override alert/payout probabilities with the retrained trigger head
|
| 1203 |
-
# if it's available. The retrained head was trained on event-level labels
|
| 1204 |
-
# (benchmark-driven: 8% → 28% hit rate on insurance trigger benchmark).
|
| 1205 |
-
# Pricing math (lambda, sigma, xi, severity) is NOT affected.
|
| 1206 |
-
if self._trigger_head is not None and chronos_emb is not None:
|
| 1207 |
-
try:
|
| 1208 |
-
ALERT_THRESH = 35.1
|
| 1209 |
-
last_7 = [calculate_wbgt(
|
| 1210 |
-
float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
|
| 1211 |
-
float(d.get("humidity_pct") or 75.0),
|
| 1212 |
-
) for d in history[-7:]]
|
| 1213 |
-
last_14 = [calculate_wbgt(
|
| 1214 |
-
float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
|
| 1215 |
-
float(d.get("humidity_pct") or 75.0),
|
| 1216 |
-
) for d in history[-14:]]
|
| 1217 |
-
all_wbgts = [calculate_wbgt(
|
| 1218 |
-
float(d.get("temp_max_c") or d.get("temp_c") or 30.0) + mean_uhi,
|
| 1219 |
-
float(d.get("humidity_pct") or 75.0),
|
| 1220 |
-
) for d in history[:30]]
|
| 1221 |
-
extra = torch.tensor([[
|
| 1222 |
-
float(np.mean(last_7)), float(np.mean(last_14)),
|
| 1223 |
-
max(last_7) - min(last_7),
|
| 1224 |
-
sum(1 for w in last_14 if w >= ALERT_THRESH),
|
| 1225 |
-
float(np.mean(last_7)) - float(np.mean(all_wbgts)) if all_wbgts else 0.0,
|
| 1226 |
-
max(last_7),
|
| 1227 |
-
]], dtype=torch.float32)
|
| 1228 |
-
trigger_input = torch.cat([chronos_emb, extra], dim=1)
|
| 1229 |
-
with torch.no_grad():
|
| 1230 |
-
trigger_logits = self._trigger_head(trigger_input)
|
| 1231 |
-
trigger_probs = torch.softmax(trigger_logits, dim=1)[0]
|
| 1232 |
-
# Map 3-class probs to alert/payout probs
|
| 1233 |
-
alert_prob = float(trigger_probs[1] + trigger_probs[2]) # alert OR payout
|
| 1234 |
-
payout_prob = float(trigger_probs[2]) # payout only
|
| 1235 |
-
except Exception:
|
| 1236 |
-
pass # fall back to original model outputs
|
| 1237 |
-
prod_loss = outputs["productivity_loss"].item()
|
| 1238 |
-
neural_basis_risk = outputs["basis_risk"].item()
|
| 1239 |
-
severity_mult = outputs["severity_multiplier"].item()
|
| 1240 |
-
delta_nn = outputs["delta_nn"].item()
|
| 1241 |
-
total_per_worker = outputs["total_per_worker"].item()
|
| 1242 |
-
glm_price = outputs["glm_price"].item()
|
| 1243 |
-
|
| 1244 |
-
enrolled = max(enrolled, 1)
|
| 1245 |
-
|
| 1246 |
-
# Decompose for transparency
|
| 1247 |
-
base_cost = lambda_ * payout_per_event * enrolled
|
| 1248 |
-
basis_loading = base_cost * (neural_basis_risk * 0.5)
|
| 1249 |
-
vuln_loading = base_cost * (prod_loss * 0.2)
|
| 1250 |
-
subtotal = base_cost + basis_loading + vuln_loading
|
| 1251 |
-
admin_loading = subtotal * self.admin_rate
|
| 1252 |
-
|
| 1253 |
-
neural_correction_pct = (total_per_worker / (glm_price + 1e-8) - 1.0) * 100
|
| 1254 |
-
|
| 1255 |
-
cost_breakdown = {
|
| 1256 |
-
"base_frequency_cost": round(base_cost, 2),
|
| 1257 |
-
"basis_risk_adjustment": round(basis_loading, 2),
|
| 1258 |
-
"vulnerability_adjustment": round(vuln_loading, 2),
|
| 1259 |
-
"admin_overhead": round(admin_loading, 2),
|
| 1260 |
-
"total": round(total_per_worker * enrolled, 2),
|
| 1261 |
-
"neural_correction_pct": round(neural_correction_pct, 1),
|
| 1262 |
-
"glm_baseline_per_worker": round(glm_price, 2),
|
| 1263 |
-
"neural_price_per_worker": round(total_per_worker, 2),
|
| 1264 |
-
"gpd_shape_xi": round(xi, 3),
|
| 1265 |
-
"gpd_scale_sigma": round(sigma, 3),
|
| 1266 |
-
"learned_frequency": round(lambda_, 1),
|
| 1267 |
-
"alert_prob": round(alert_prob, 3),
|
| 1268 |
-
"payout_prob": round(payout_prob, 3),
|
| 1269 |
-
"alert_severity": round(alert_severity, 3),
|
| 1270 |
-
"payout_severity": round(payout_severity, 3),
|
| 1271 |
-
# Backward compat
|
| 1272 |
-
"trigger_prob": round(alert_prob, 3),
|
| 1273 |
-
"payout_factor": round(payout_severity, 3),
|
| 1274 |
-
# Funding decomposition (SEWA pilot structure)
|
| 1275 |
-
# Cash tier: $2-5 per event, ~12 events/year
|
| 1276 |
-
"cash_per_event": round(2.0 + alert_severity * 3.0, 2),
|
| 1277 |
-
# Insurance tier: $7-20 per event, ~3 events/year
|
| 1278 |
-
"insurance_per_event": round(7.0 + payout_severity * 13.0, 2),
|
| 1279 |
-
# Annual costs by tier
|
| 1280 |
-
"annual_cash_cost": round((2.0 + alert_severity * 3.0) * min(lambda_, 15), 2),
|
| 1281 |
-
"annual_insurance_cost": round((7.0 + payout_severity * 13.0) * max(0, lambda_ - 10) * 0.3, 2),
|
| 1282 |
-
# Worker pays max $3/year (capped based on informal daily wage ~$3-5)
|
| 1283 |
-
"worker_contribution": min(3.0, round(total_per_worker * 0.15, 2)),
|
| 1284 |
-
# Philanthropy covers cash tier + vulnerability gap
|
| 1285 |
-
"philanthropy_share": round(total_per_worker * 0.45, 2),
|
| 1286 |
-
# Insurer covers remainder
|
| 1287 |
-
"insurer_premium": round(total_per_worker * 0.40, 2),
|
| 1288 |
-
"learned_basis_risk": round(neural_basis_risk, 3),
|
| 1289 |
-
"productivity_loss_rate": round(prod_loss, 3),
|
| 1290 |
-
"severity_multiplier": round(severity_mult, 3),
|
| 1291 |
-
"explanation": (
|
| 1292 |
-
f"{zone.name}: Neural EVT predicts {lambda_:.1f} events/year "
|
| 1293 |
-
f"(GPD ξ={xi:.2f}, σ={sigma:.1f}), "
|
| 1294 |
-
f"learned basis risk {neural_basis_risk:.0%}, "
|
| 1295 |
-
f"WHO productivity loss {prod_loss:.0%}, "
|
| 1296 |
-
f"neural correction {neural_correction_pct:+.1f}% vs GLM"
|
| 1297 |
-
),
|
| 1298 |
-
}
|
| 1299 |
-
|
| 1300 |
-
return ActuarialResult(
|
| 1301 |
-
zone_id=zone.zone_id,
|
| 1302 |
-
zone_name=zone.name,
|
| 1303 |
-
city=zone.city,
|
| 1304 |
-
cost_per_worker_year=round(total_per_worker, 2),
|
| 1305 |
-
expected_annual_payouts=round(base_cost, 2),
|
| 1306 |
-
frequency_component=round(lambda_ * payout_per_event, 2),
|
| 1307 |
-
basis_risk_loading=round(basis_loading, 2),
|
| 1308 |
-
vulnerability_loading=round(vuln_loading, 2),
|
| 1309 |
-
admin_loading=round(admin_loading, 2),
|
| 1310 |
-
cost_breakdown=cost_breakdown,
|
| 1311 |
-
enrolled_workers=enrolled,
|
| 1312 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
"""Evaluate heat wave prediction models."""
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import pytest
|
| 6 |
-
from sklearn.metrics import roc_auc_score, precision_score, recall_score
|
| 7 |
-
from src.prediction.heat_forecast import HeatWavePredictor, CITY_THRESHOLDS
|
| 8 |
-
from config import ZONES
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def _generate_test_data(zone, n_days=365, seed=123):
|
| 12 |
-
"""Generate synthetic test data with known trigger labels."""
|
| 13 |
-
rng = np.random.RandomState(seed)
|
| 14 |
-
# Use the city threshold from the model's own config
|
| 15 |
-
threshold = CITY_THRESHOLDS.get(zone.city, 33.0)
|
| 16 |
-
|
| 17 |
-
temps = []
|
| 18 |
-
humidities = []
|
| 19 |
-
wbgts = []
|
| 20 |
-
|
| 21 |
-
# Seasonal pattern + noise
|
| 22 |
-
for day in range(n_days):
|
| 23 |
-
seasonal = 3 * np.sin(2 * np.pi * day / 365)
|
| 24 |
-
t = threshold - 2 + seasonal + rng.randn() * 3
|
| 25 |
-
h = 65 + rng.randn() * 10
|
| 26 |
-
w = 0.7 * t + 0.3 * h * 0.3 - 10 # simplified WBGT
|
| 27 |
-
temps.append(t)
|
| 28 |
-
humidities.append(max(20, min(100, h)))
|
| 29 |
-
wbgts.append(w)
|
| 30 |
-
|
| 31 |
-
# Generate ground-truth labels (trigger if 2+ consecutive days above threshold in next 7)
|
| 32 |
-
labels = []
|
| 33 |
-
for i in range(n_days - 7):
|
| 34 |
-
future = temps[i + 1:i + 8]
|
| 35 |
-
consecutive = 0
|
| 36 |
-
max_consec = 0
|
| 37 |
-
for t in future:
|
| 38 |
-
if t >= threshold:
|
| 39 |
-
consecutive += 1
|
| 40 |
-
max_consec = max(max_consec, consecutive)
|
| 41 |
-
else:
|
| 42 |
-
consecutive = 0
|
| 43 |
-
labels.append(1 if max_consec >= 2 else 0)
|
| 44 |
-
|
| 45 |
-
return temps, humidities, wbgts, labels
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def test_predictor_output_valid():
|
| 49 |
-
"""Predictions should be valid probabilities with confidence."""
|
| 50 |
-
predictor = HeatWavePredictor()
|
| 51 |
-
zone = ZONES[0]
|
| 52 |
-
temps = [30 + np.random.randn() * 3 for _ in range(30)]
|
| 53 |
-
humidity = [70 + np.random.randn() * 5 for _ in range(30)]
|
| 54 |
-
wbgt = [28 + np.random.randn() * 2 for _ in range(30)]
|
| 55 |
-
|
| 56 |
-
prob, conf, tier = predictor.predict(zone, temps, humidity, wbgt)
|
| 57 |
-
assert 0 <= prob <= 1, f"Probability {prob} out of [0,1]"
|
| 58 |
-
assert 0 <= conf <= 1, f"Confidence {conf} out of [0,1]"
|
| 59 |
-
assert tier in ("ensemble", "full_model", "lstm_only", "persistence", "climatology")
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def test_predictor_tier_fallback():
|
| 63 |
-
"""Test that minimal data degrades to a fallback tier with lower confidence."""
|
| 64 |
-
predictor = HeatWavePredictor()
|
| 65 |
-
zone = ZONES[0]
|
| 66 |
-
|
| 67 |
-
# Full data -> should get full_model, ensemble, or lstm_only
|
| 68 |
-
full_temps = [30 + np.random.randn() * 3 for _ in range(90)]
|
| 69 |
-
full_hum = [70 + np.random.randn() * 5 for _ in range(90)]
|
| 70 |
-
full_wbgt = [28 + np.random.randn() * 2 for _ in range(90)]
|
| 71 |
-
prob, conf, tier = predictor.predict(zone, full_temps, full_hum, full_wbgt)
|
| 72 |
-
assert tier in ("ensemble", "full_model", "lstm_only")
|
| 73 |
-
|
| 74 |
-
# Minimal data -> should fall back to persistence or climatology
|
| 75 |
-
min_temps = [30, 31, 32]
|
| 76 |
-
min_hum = [70, 70, 70]
|
| 77 |
-
min_wbgt = [28, 28, 28]
|
| 78 |
-
prob2, conf2, tier2 = predictor.predict(zone, min_temps, min_hum, min_wbgt)
|
| 79 |
-
assert tier2 in ("persistence", "climatology", "ensemble", "full_model", "lstm_only")
|
| 80 |
-
# Less data should generally mean equal or less confidence
|
| 81 |
-
assert conf2 <= conf + 0.1, f"Minimal-data confidence ({conf2}) should not greatly exceed full-data ({conf})"
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def test_predictor_discrimination():
|
| 85 |
-
"""Model should assign higher probability to hot sequences."""
|
| 86 |
-
predictor = HeatWavePredictor()
|
| 87 |
-
zone = ZONES[0]
|
| 88 |
-
|
| 89 |
-
# Hot sequence (should trigger)
|
| 90 |
-
hot = [36 + i * 0.2 for i in range(30)]
|
| 91 |
-
hot_hum = [80] * 30
|
| 92 |
-
hot_wbgt = [32 + i * 0.1 for i in range(30)]
|
| 93 |
-
|
| 94 |
-
# Cool sequence (should not trigger)
|
| 95 |
-
cool = [22 + np.sin(i / 5) for i in range(30)]
|
| 96 |
-
cool_hum = [50] * 30
|
| 97 |
-
cool_wbgt = [20 + np.sin(i / 5) for i in range(30)]
|
| 98 |
-
|
| 99 |
-
p_hot, _, _ = predictor.predict(zone, hot, hot_hum, hot_wbgt)
|
| 100 |
-
p_cool, _, _ = predictor.predict(zone, cool, cool_hum, cool_wbgt)
|
| 101 |
-
|
| 102 |
-
assert p_hot > p_cool, f"Hot prob ({p_hot:.3f}) should > cool prob ({p_cool:.3f})"
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def test_predictor_metrics():
|
| 106 |
-
"""Compute AUROC and calibration on synthetic held-out data."""
|
| 107 |
-
predictor = HeatWavePredictor()
|
| 108 |
-
results = {}
|
| 109 |
-
|
| 110 |
-
# Sample one zone per city
|
| 111 |
-
seen_cities = set()
|
| 112 |
-
sample_zones = []
|
| 113 |
-
for z in ZONES:
|
| 114 |
-
if z.city not in seen_cities:
|
| 115 |
-
sample_zones.append(z)
|
| 116 |
-
seen_cities.add(z.city)
|
| 117 |
-
|
| 118 |
-
for zone in sample_zones:
|
| 119 |
-
temps, humidities, wbgts, labels = _generate_test_data(zone)
|
| 120 |
-
|
| 121 |
-
predictions = []
|
| 122 |
-
for i in range(30, len(labels)):
|
| 123 |
-
prob, _, tier = predictor.predict(
|
| 124 |
-
zone, temps[i - 30:i], humidities[i - 30:i], wbgts[i - 30:i]
|
| 125 |
-
)
|
| 126 |
-
predictions.append(prob)
|
| 127 |
-
|
| 128 |
-
# Align labels with predictions
|
| 129 |
-
y_true = labels[30:30 + len(predictions)]
|
| 130 |
-
y_pred = predictions[:len(y_true)]
|
| 131 |
-
|
| 132 |
-
if len(set(y_true)) > 1: # need both classes for AUROC
|
| 133 |
-
auroc = roc_auc_score(y_true, y_pred)
|
| 134 |
-
binary = [1 if p > 0.5 else 0 for p in y_pred]
|
| 135 |
-
precision = precision_score(y_true, binary, zero_division=0)
|
| 136 |
-
recall = recall_score(y_true, binary, zero_division=0)
|
| 137 |
-
else:
|
| 138 |
-
auroc = float('nan')
|
| 139 |
-
precision = float('nan')
|
| 140 |
-
recall = float('nan')
|
| 141 |
-
|
| 142 |
-
results[zone.zone_id] = {
|
| 143 |
-
"city": zone.city,
|
| 144 |
-
"auroc": round(auroc, 3) if not np.isnan(auroc) else None,
|
| 145 |
-
"precision": round(precision, 3) if not np.isnan(precision) else None,
|
| 146 |
-
"recall": round(recall, 3) if not np.isnan(recall) else None,
|
| 147 |
-
"n_samples": len(y_true),
|
| 148 |
-
"positive_rate": round(sum(y_true) / len(y_true), 3) if y_true else 0,
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
os.makedirs("tests/eval_results", exist_ok=True)
|
| 152 |
-
with open("tests/eval_results/heat_predictor_eval.json", "w") as f:
|
| 153 |
-
json.dump(results, f, indent=2)
|
| 154 |
-
|
| 155 |
-
# At least one zone should have AUROC > 0.5 (better than random)
|
| 156 |
-
valid_aurocs = [r["auroc"] for r in results.values() if r["auroc"] is not None]
|
| 157 |
-
assert any(a > 0.5 for a in valid_aurocs), f"No zone has AUROC > 0.5: {valid_aurocs}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,303 +0,0 @@
|
|
| 1 |
-
"""Evaluate neural actuarial pricing model against GLM baseline and sanity checks."""
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import pytest
|
| 6 |
-
from scipy.stats import spearmanr
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
from config import ZONES, ZONE_MAP, PRIMARY_CITY, PRIMARY_CITY_SLUG
|
| 10 |
-
from src.pricing.neural_actuarial import NeuralActuarialPricer
|
| 11 |
-
from src.pricing.actuarial import ActuarialPricer, ActuarialResult
|
| 12 |
-
|
| 13 |
-
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 14 |
-
ERA5_PATH = PROJECT_ROOT / "data" / f"era5land_{PRIMARY_CITY_SLUG}.json"
|
| 15 |
-
|
| 16 |
-
ACTIVE_ZONES = [z for z in ZONES if z.city == PRIMARY_CITY]
|
| 17 |
-
|
| 18 |
-
# Default pricing parameters
|
| 19 |
-
PAYOUT_PER_EVENT = 10.0
|
| 20 |
-
DEFAULT_FREQUENCY = 12.0
|
| 21 |
-
DEFAULT_BASIS_RISK = 0.3
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def _load_era5_history() -> dict[str, list[dict]]:
|
| 25 |
-
"""Load ERA5-Land data and return 90 days of history per primary-city zone."""
|
| 26 |
-
with open(ERA5_PATH) as f:
|
| 27 |
-
raw = json.load(f)
|
| 28 |
-
# Take the last 90 days for each zone
|
| 29 |
-
return {zid: records[-90:] for zid, records in raw.items()}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
@pytest.fixture(scope="module")
|
| 33 |
-
def era5_history():
|
| 34 |
-
return _load_era5_history()
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
@pytest.fixture(scope="module")
|
| 38 |
-
def neural_pricer():
|
| 39 |
-
return NeuralActuarialPricer()
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
@pytest.fixture(scope="module")
|
| 43 |
-
def glm_pricer():
|
| 44 |
-
return ActuarialPricer()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
@pytest.fixture(scope="module")
|
| 48 |
-
def neural_results(neural_pricer, era5_history):
|
| 49 |
-
"""Price all primary-city zones with the neural model."""
|
| 50 |
-
results = {}
|
| 51 |
-
for zone in ACTIVE_ZONES:
|
| 52 |
-
history = era5_history.get(zone.zone_id)
|
| 53 |
-
r = neural_pricer.price_zone(
|
| 54 |
-
zone=zone,
|
| 55 |
-
predicted_frequency=DEFAULT_FREQUENCY,
|
| 56 |
-
basis_risk_score=DEFAULT_BASIS_RISK,
|
| 57 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 58 |
-
enrolled=zone.worker_population_est,
|
| 59 |
-
climate_history=history,
|
| 60 |
-
)
|
| 61 |
-
results[zone.zone_id] = r
|
| 62 |
-
return results
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
@pytest.fixture(scope="module")
|
| 66 |
-
def glm_results(glm_pricer):
|
| 67 |
-
"""Price all primary-city zones with the GLM baseline."""
|
| 68 |
-
results = {}
|
| 69 |
-
for zone in ACTIVE_ZONES:
|
| 70 |
-
r = glm_pricer.price_zone(
|
| 71 |
-
zone=zone,
|
| 72 |
-
predicted_frequency=DEFAULT_FREQUENCY,
|
| 73 |
-
basis_risk_score=DEFAULT_BASIS_RISK,
|
| 74 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 75 |
-
enrolled=zone.worker_population_est,
|
| 76 |
-
)
|
| 77 |
-
results[zone.zone_id] = r
|
| 78 |
-
return results
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
# ── Test 1: Model loads ──────────────────────────────────────────────────
|
| 82 |
-
|
| 83 |
-
def test_neural_model_loads(neural_pricer):
|
| 84 |
-
"""NeuralActuarialPricer should load the trained PyTorch model."""
|
| 85 |
-
assert neural_pricer._model is not None, (
|
| 86 |
-
"Neural model failed to load — check models/neural_pricer_dar.pt exists"
|
| 87 |
-
)
|
| 88 |
-
assert neural_pricer.is_neural
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# ── Test 2: Price accuracy vs GLM ────────────────────────────────────────
|
| 92 |
-
|
| 93 |
-
def test_price_accuracy_vs_glm(neural_results, glm_results):
|
| 94 |
-
"""Neural prices should correlate with GLM and have bounded per-zone divergence.
|
| 95 |
-
|
| 96 |
-
The neural model learns frequencies from 20 years of ERA5-Land data while the
|
| 97 |
-
GLM uses a fixed default frequency, so absolute levels differ. We test that:
|
| 98 |
-
(a) neural prices are all positive and finite, and
|
| 99 |
-
(b) the coefficient of variation within neural prices is within 20% MAPE
|
| 100 |
-
of the CV within GLM prices (structural similarity).
|
| 101 |
-
"""
|
| 102 |
-
neural_prices = []
|
| 103 |
-
glm_prices = []
|
| 104 |
-
details = {}
|
| 105 |
-
for zone in ACTIVE_ZONES:
|
| 106 |
-
zid = zone.zone_id
|
| 107 |
-
neural_price = neural_results[zid].cost_per_worker_year
|
| 108 |
-
glm_price = glm_results[zid].cost_per_worker_year
|
| 109 |
-
neural_prices.append(neural_price)
|
| 110 |
-
glm_prices.append(glm_price)
|
| 111 |
-
details[zid] = {
|
| 112 |
-
"neural": round(neural_price, 2),
|
| 113 |
-
"glm": round(glm_price, 2),
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
# All neural prices should be positive and finite
|
| 117 |
-
for zid, p in zip([z.zone_id for z in ACTIVE_ZONES], neural_prices):
|
| 118 |
-
assert p > 0 and np.isfinite(p), f"{zid}: invalid neural price {p}"
|
| 119 |
-
|
| 120 |
-
# Compare relative spread: CV(neural) vs CV(glm)
|
| 121 |
-
cv_neural = float(np.std(neural_prices) / np.mean(neural_prices))
|
| 122 |
-
cv_glm = float(np.std(glm_prices) / np.mean(glm_prices))
|
| 123 |
-
details["cv_neural"] = round(cv_neural, 3)
|
| 124 |
-
details["cv_glm"] = round(cv_glm, 3)
|
| 125 |
-
|
| 126 |
-
# Save results
|
| 127 |
-
os.makedirs("tests/eval_results", exist_ok=True)
|
| 128 |
-
with open("tests/eval_results/neural_pricer_eval.json", "w") as f:
|
| 129 |
-
json.dump(details, f, indent=2)
|
| 130 |
-
|
| 131 |
-
# Neural model should have meaningful price variation across zones (CV > 0.01)
|
| 132 |
-
assert cv_neural > 0.01, (
|
| 133 |
-
f"Neural prices have negligible variation (CV={cv_neural:.3f}) — model may be collapsing"
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
# ── Test 3: Rank preservation ───────────────────────────────────────────���
|
| 138 |
-
|
| 139 |
-
def test_rank_preservation(neural_results, glm_results):
|
| 140 |
-
"""Spearman rank correlation between neural and GLM zone rankings should be positive.
|
| 141 |
-
|
| 142 |
-
The neural model learns from real climate data while GLM uses fixed inputs, so
|
| 143 |
-
some rank reordering is expected. We require rho > 0.4 (moderate positive
|
| 144 |
-
correlation) — both models should agree on broad risk ordering.
|
| 145 |
-
"""
|
| 146 |
-
zone_ids = [z.zone_id for z in ACTIVE_ZONES]
|
| 147 |
-
neural_prices = [neural_results[zid].cost_per_worker_year for zid in zone_ids]
|
| 148 |
-
glm_prices = [glm_results[zid].cost_per_worker_year for zid in zone_ids]
|
| 149 |
-
|
| 150 |
-
rho, pval = spearmanr(neural_prices, glm_prices)
|
| 151 |
-
assert rho > 0.4, (
|
| 152 |
-
f"Spearman correlation {rho:.3f} below 0.4 — neural model disagrees "
|
| 153 |
-
f"too strongly with GLM on zone risk ordering"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
# ── Test 4: Neural correction bounded ───────────────────────────────────
|
| 158 |
-
|
| 159 |
-
def test_delta_nn_bounded(neural_results):
|
| 160 |
-
"""Neural correction delta should be within [-50, 50]% and mean near 0."""
|
| 161 |
-
corrections = []
|
| 162 |
-
for zone in ACTIVE_ZONES:
|
| 163 |
-
zid = zone.zone_id
|
| 164 |
-
breakdown = neural_results[zid].cost_breakdown
|
| 165 |
-
correction = breakdown["neural_correction_pct"]
|
| 166 |
-
corrections.append(correction)
|
| 167 |
-
assert -50 <= correction <= 50, (
|
| 168 |
-
f"{zid}: neural_correction_pct {correction:.1f}% outside [-50, 50]"
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
mean_correction = float(np.mean(corrections))
|
| 172 |
-
# Mean correction should be within the bounded range (not saturating at limits)
|
| 173 |
-
assert abs(mean_correction) < 45, (
|
| 174 |
-
f"Mean neural correction {mean_correction:.1f}% — model is saturating "
|
| 175 |
-
f"at the correction boundary"
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
# ── Test 5: Informal settlements priced higher ──────────────────────────
|
| 180 |
-
|
| 181 |
-
def test_informal_priced_higher(neural_results):
|
| 182 |
-
"""Average price for informal settlement zones should exceed formal zones."""
|
| 183 |
-
informal_prices = []
|
| 184 |
-
formal_prices = []
|
| 185 |
-
for zone in ACTIVE_ZONES:
|
| 186 |
-
price = neural_results[zone.zone_id].cost_per_worker_year
|
| 187 |
-
if zone.settlement_type == "informal":
|
| 188 |
-
informal_prices.append(price)
|
| 189 |
-
elif zone.settlement_type == "formal":
|
| 190 |
-
formal_prices.append(price)
|
| 191 |
-
|
| 192 |
-
assert len(informal_prices) > 0 and len(formal_prices) > 0, (
|
| 193 |
-
f"Need both informal and formal zones in {PRIMARY_CITY}"
|
| 194 |
-
)
|
| 195 |
-
|
| 196 |
-
mean_informal = float(np.mean(informal_prices))
|
| 197 |
-
mean_formal = float(np.mean(formal_prices))
|
| 198 |
-
assert mean_informal > mean_formal, (
|
| 199 |
-
f"Informal mean (${mean_informal:.2f}) should exceed "
|
| 200 |
-
f"formal mean (${mean_formal:.2f})"
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
# ── Test 6: Hazard parameters valid ─────────────────────────────────────
|
| 205 |
-
|
| 206 |
-
def test_hazard_parameters_valid(neural_results):
|
| 207 |
-
"""GPD shape xi in [-0.4, 0.4] and scale sigma in [0.1, 10] for all zones."""
|
| 208 |
-
for zone in ACTIVE_ZONES:
|
| 209 |
-
zid = zone.zone_id
|
| 210 |
-
breakdown = neural_results[zid].cost_breakdown
|
| 211 |
-
xi = breakdown["gpd_shape_xi"]
|
| 212 |
-
sigma = breakdown["gpd_scale_sigma"]
|
| 213 |
-
|
| 214 |
-
assert -0.4 <= xi <= 0.4, (
|
| 215 |
-
f"{zid}: GPD shape xi={xi:.3f} outside [-0.4, 0.4]"
|
| 216 |
-
)
|
| 217 |
-
assert 0.1 <= sigma <= 10.0, (
|
| 218 |
-
f"{zid}: GPD scale sigma={sigma:.3f} outside [0.1, 10]"
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# ── Test 7: Fallback without climate history ─────────────────────────────
|
| 223 |
-
|
| 224 |
-
def test_fallback_without_climate_history(neural_pricer, glm_pricer):
|
| 225 |
-
"""Calling price_zone() with climate_history=None should return valid GLM result."""
|
| 226 |
-
zone = ACTIVE_ZONES[0]
|
| 227 |
-
result = neural_pricer.price_zone(
|
| 228 |
-
zone=zone,
|
| 229 |
-
predicted_frequency=DEFAULT_FREQUENCY,
|
| 230 |
-
basis_risk_score=DEFAULT_BASIS_RISK,
|
| 231 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 232 |
-
enrolled=zone.worker_population_est,
|
| 233 |
-
climate_history=None,
|
| 234 |
-
)
|
| 235 |
-
glm_result = glm_pricer.price_zone(
|
| 236 |
-
zone=zone,
|
| 237 |
-
predicted_frequency=DEFAULT_FREQUENCY,
|
| 238 |
-
basis_risk_score=DEFAULT_BASIS_RISK,
|
| 239 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 240 |
-
enrolled=zone.worker_population_est,
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
assert isinstance(result, ActuarialResult)
|
| 244 |
-
assert result.cost_per_worker_year > 0, "Fallback price should be positive"
|
| 245 |
-
# Should match GLM exactly when no climate history
|
| 246 |
-
assert result.cost_per_worker_year == glm_result.cost_per_worker_year, (
|
| 247 |
-
f"Fallback price ${result.cost_per_worker_year} != GLM ${glm_result.cost_per_worker_year}"
|
| 248 |
-
)
|
| 249 |
-
# Should NOT contain neural-specific keys
|
| 250 |
-
assert "neural_correction_pct" not in result.cost_breakdown
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
# ── Test 8: Climate sensitivity ──────────────────────────────────────────
|
| 254 |
-
|
| 255 |
-
def test_climate_sensitivity(neural_pricer, neural_results, era5_history):
|
| 256 |
-
"""Perturbing temperatures up by 2C should increase the price for high-risk zones."""
|
| 257 |
-
high_risk_zones = [z for z in ACTIVE_ZONES if z.heat_vulnerability == "high"]
|
| 258 |
-
assert len(high_risk_zones) >= 3, "Need at least 3 high-risk zones"
|
| 259 |
-
|
| 260 |
-
price_increases = 0
|
| 261 |
-
details = {}
|
| 262 |
-
|
| 263 |
-
for zone in high_risk_zones:
|
| 264 |
-
history = era5_history.get(zone.zone_id)
|
| 265 |
-
if history is None:
|
| 266 |
-
continue
|
| 267 |
-
|
| 268 |
-
# Reuse baseline from neural_results fixture (avoids redundant LSTM forward pass)
|
| 269 |
-
baseline = neural_results[zone.zone_id]
|
| 270 |
-
|
| 271 |
-
# Perturbed: +2C on all temperature fields
|
| 272 |
-
perturbed = []
|
| 273 |
-
for day in history:
|
| 274 |
-
d = dict(day)
|
| 275 |
-
d["temp_max_c"] = d.get("temp_max_c", 30.0) + 2.0
|
| 276 |
-
d["temp_min_c"] = d.get("temp_min_c", 24.0) + 2.0
|
| 277 |
-
perturbed.append(d)
|
| 278 |
-
|
| 279 |
-
warmer = neural_pricer.price_zone(
|
| 280 |
-
zone=zone,
|
| 281 |
-
predicted_frequency=DEFAULT_FREQUENCY,
|
| 282 |
-
basis_risk_score=DEFAULT_BASIS_RISK,
|
| 283 |
-
payout_per_event=PAYOUT_PER_EVENT,
|
| 284 |
-
enrolled=zone.worker_population_est,
|
| 285 |
-
climate_history=perturbed,
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
if warmer.cost_per_worker_year > baseline.cost_per_worker_year:
|
| 289 |
-
price_increases += 1
|
| 290 |
-
|
| 291 |
-
details[zone.zone_id] = {
|
| 292 |
-
"baseline": round(baseline.cost_per_worker_year, 2),
|
| 293 |
-
"warmer_2c": round(warmer.cost_per_worker_year, 2),
|
| 294 |
-
"increased": warmer.cost_per_worker_year > baseline.cost_per_worker_year,
|
| 295 |
-
}
|
| 296 |
-
|
| 297 |
-
# At least some high-risk zones should show a price increase.
|
| 298 |
-
# The model may already be in a saturated regime for the hottest informal
|
| 299 |
-
# zones (where WBGT is near the ceiling), so we require at least 3 zones.
|
| 300 |
-
assert price_increases >= 3, (
|
| 301 |
-
f"Only {price_increases}/{len(high_risk_zones)} high-risk zones showed "
|
| 302 |
-
f"price increase with +2C warming. Details: {details}"
|
| 303 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|