File size: 20,043 Bytes
3a61277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
"""
Model wrapper: loads XGBoost model + all feature data, provides prediction + SHAP explanation.

At startup, builds a merged feature table (same joins as model.py's load_features()),
then keeps only the most-recent-year row per (airport_A, airport_B, Month) to reduce memory.
Cached at the Streamlit session level via @st.cache_resource.
"""
from __future__ import annotations
import os
import numpy as np
import pandas as pd
import xgboost as xgb

PROCESSED = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "processed"))
RAW       = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "raw"))

RISK_THRESHOLD = 0.25

# ── Isotonic calibration (model score β†’ observed bad rate scale) ─────────────
import json as _json
_cal_path = os.path.join(PROCESSED, "calibration_isotonic.json")
if os.path.exists(_cal_path):
    with open(_cal_path) as _f:
        _cal = _json.load(_f)
    _CAL_X = np.array(_cal["x"])
    _CAL_Y = np.array(_cal["y"])
    def _calibrate(p: float) -> float:
        return float(np.interp(p, _CAL_X, _CAL_Y))
else:
    def _calibrate(p: float) -> float:
        return p

# Calibrated risk thresholds (on observed-bad-rate scale)
HIGH_THRESHOLD = 0.30   # β‰₯30% of sequences historically disrupted
MOD_THRESHOLD  = 0.20   # β‰₯20% of sequences historically disrupted

# ── Open-Meteo per-airport weather lookup (replaces GSOM median imputation) ──
# Old approach: fill NaN GSOM features with month-level POPULATION medians.
#   Problem: the model was trained with NaN for ~55% of airports; at inference
#   replacing NaN with medians creates a distribution shift.
# New approach: load actual Open-Meteo ERA5 climate per airportΓ—month.
#   Coverage: 100% of airports (ERA5 is global reanalysis, no station gaps).
#   Column mapping keeps XGBoost feature names unchanged (no retraining needed).
#
# Fallback: if openmeteo file not yet generated, falls back to old GSOM medians.

_OM_TO_GSOM = {
    "avg_wind_mph":    "avg_wind_speed",
    "max_gust_mph":    "max_wind_gust",
    "total_precip_in": "total_precip",
    "precip_days":     "precip_days",
    "severe_wx_days":  "extreme_precip",
}
_GSOM_WEATHER_COLS = list(_OM_TO_GSOM.values())   # model-facing names

_om_path = os.path.join(PROCESSED, "openmeteo_airport_monthly.parquet")
_gsom_med_path = os.path.join(PROCESSED, "gsom_month_medians.json")

if os.path.exists(_om_path):
    # Load Open-Meteo climatological normals (mean across 2015-2024 per airportΓ—month)
    _om_raw = pd.read_parquet(_om_path)
    _OM_LOOKUP: pd.DataFrame = (
        _om_raw
        .groupby(["iata", "month"])[list(_OM_TO_GSOM.keys())]
        .mean()
        .rename(columns=_OM_TO_GSOM)
        .reset_index()
        .rename(columns={"month": "Month"})
    )
    _USE_OPENMETEO = True
else:
    _USE_OPENMETEO = False
    # Legacy fallback: global monthly medians
    if os.path.exists(_gsom_med_path):
        with open(_gsom_med_path) as _gf:
            _GSOM_MEDIANS: dict[str, dict[int, float]] = {
                k: {int(mk): mv for mk, mv in v.items()}
                for k, v in _json.load(_gf).items()
            }
    else:
        _GSOM_MEDIANS = {}

_GSOM_COLS_A    = [f"A_{c}" for c in _GSOM_WEATHER_COLS]
_GSOM_COLS_B    = [f"B_{c}" for c in _GSOM_WEATHER_COLS]
_GSOM_COLS_PAIR = [f"pair_max_{c}" for c in _GSOM_WEATHER_COLS]
_ALL_GSOM_COLS  = _GSOM_COLS_A + _GSOM_COLS_B + _GSOM_COLS_PAIR


def _apply_gsom_imputation(X: pd.DataFrame, month: int,
                           airport_a: str = "", airport_b: str = "") -> tuple[pd.DataFrame, set[str]]:
    """
    Fill weather climate features with real Open-Meteo ERA5 values per airport.
    Falls back to legacy global medians if Open-Meteo data not available.
    Returns (filled_df, set_of_columns_that_were_filled).
    """
    X = X.copy()
    filled: set[str] = set()

    if _USE_OPENMETEO and airport_a and airport_b:
        # Look up real climate values for each airport
        def _om_row(iata: str) -> dict:
            mask = (_OM_LOOKUP["iata"] == iata) & (_OM_LOOKUP["Month"] == month)
            rows = _OM_LOOKUP[mask]
            return rows.iloc[0].to_dict() if not rows.empty else {}

        row_a = _om_row(airport_a)
        row_b = _om_row(airport_b)

        for col in _GSOM_WEATHER_COLS:
            a_col, b_col = f"A_{col}", f"B_{col}"
            if a_col in X.columns and X[a_col].isna().any() and col in row_a:
                X[a_col] = X[a_col].fillna(row_a[col])
                filled.add(a_col)
            if b_col in X.columns and X[b_col].isna().any() and col in row_b:
                X[b_col] = X[b_col].fillna(row_b[col])
                filled.add(b_col)

    else:
        # Legacy fallback: global monthly medians
        for col in _ALL_GSOM_COLS:
            if col not in X.columns:
                continue
            if X[col].isna().any() and col in _GSOM_MEDIANS:
                fill_val = _GSOM_MEDIANS[col].get(month, np.nan)
                if not np.isnan(fill_val):
                    X[col] = X[col].fillna(fill_val)
                    filled.add(col)

    # Always re-derive pair-level max features from A/B values
    for col in _GSOM_WEATHER_COLS:
        a_col, b_col, pair_col = f"A_{col}", f"B_{col}", f"pair_max_{col}"
        if a_col in X.columns and b_col in X.columns and pair_col in X.columns:
            X[pair_col] = X[[a_col, b_col]].max(axis=1)

    return X, filled


FEATURE_COLS = [
    # Airport A BTS weather stats
    "A_weather_delay_rate", "A_weather_cancel_rate", "A_avg_weather_delay_min",
    "A_p75_weather_delay_min", "A_p95_weather_delay_min", "A_nas_delay_rate",
    "A_overall_weather_delay_rate", "A_overall_avg_weather_delay_min",
    # Airport B BTS weather stats
    "B_weather_delay_rate", "B_weather_cancel_rate", "B_avg_weather_delay_min",
    "B_p75_weather_delay_min", "B_p95_weather_delay_min", "B_nas_delay_rate",
    "B_overall_weather_delay_rate", "B_overall_avg_weather_delay_min",
    # Pair-level BTS features
    "pair_combined_weather_rate", "pair_max_weather_rate", "pair_min_weather_rate",
    "pair_weather_rate_sum", "pair_avg_weather_delay_min", "both_high_risk",
    # Temporal
    "Month", "is_spring_summer", "median_turnaround_min",
    # GSOM weather features
    "A_avg_wind_speed", "A_precip_days", "A_extreme_precip",
    "A_total_precip", "A_max_wind_gust",
    "B_avg_wind_speed", "B_precip_days", "B_extreme_precip",
    "B_total_precip", "B_max_wind_gust",
    "pair_max_avg_wind_speed", "pair_max_precip_days",
    "pair_max_extreme_precip", "pair_max_total_precip", "pair_max_max_wind_gust",
    # DFW hub weather
    "DFW_weather_delay_rate", "DFW_weather_cancel_rate",
    "DFW_avg_weather_delay_min", "DFW_p95_weather_delay_min",
    # Tail-chain crew duty features
    "tc_legs_before_mean", "tc_block_before_mean", "tc_duty_start_hour",
    "tc_total_duty_mean", "tc_total_duty_p75",
    "tc_fdp_util_mean", "tc_fdp_util_p75", "tc_fdp_overrun_rate",
    "tc_wocl_rate", "tc_legs_after_mean", "tc_legs_in_day_mean",
    "tc_downstream_rate", "tc_cascade_late_rate",
    "tc_cascade_late_min", "tc_cascade_amplif_mean",
    # Airport-level cascade propagation
    "A_ap_cascade_rate", "A_ap_cascade_given_late",
    "B_ap_cascade_rate", "B_ap_cascade_given_late",
    "pair_cascade_product", "pair_max_cascade_rate",
    # Multi-hop DFW cascade
    "mhc_n_hops_mean", "mhc_n_hops_max",
    "mhc_total_late_min_mean", "mhc_total_late_min_p75",
    "mhc_cascade_hop_rate", "mhc_cascade_depth_mean",
    "mhc_unique_airports_mean", "mhc_recovery_rate",
]

FEATURE_LABELS = {
    "A_weather_delay_rate": "Origin: Weather Delay Rate",
    "A_weather_cancel_rate": "Origin: Weather Cancel Rate",
    "A_avg_weather_delay_min": "Origin: Avg Weather Delay (min)",
    "A_p75_weather_delay_min": "Origin: P75 Weather Delay (min)",
    "A_p95_weather_delay_min": "Origin: P95 Weather Delay (min)",
    "A_nas_delay_rate": "Origin: NAS Delay Rate",
    "A_overall_weather_delay_rate": "Origin: Overall Weather Delay Rate",
    "A_overall_avg_weather_delay_min": "Origin: Overall Avg Weather Delay (min)",
    "B_weather_delay_rate": "Dest: Weather Delay Rate",
    "B_weather_cancel_rate": "Dest: Weather Cancel Rate",
    "B_avg_weather_delay_min": "Dest: Avg Weather Delay (min)",
    "B_p75_weather_delay_min": "Dest: P75 Weather Delay (min)",
    "B_p95_weather_delay_min": "Dest: P95 Weather Delay (min)",
    "B_nas_delay_rate": "Dest: NAS Delay Rate",
    "B_overall_weather_delay_rate": "Dest: Overall Weather Delay Rate",
    "B_overall_avg_weather_delay_min": "Dest: Overall Avg Weather Delay (min)",
    "pair_combined_weather_rate": "Pair: Combined Weather Rate",
    "pair_max_weather_rate": "Pair: Max Weather Rate",
    "pair_min_weather_rate": "Pair: Min Weather Rate",
    "pair_weather_rate_sum": "Pair: Weather Rate Sum",
    "pair_avg_weather_delay_min": "Pair: Avg Weather Delay (min)",
    "both_high_risk": "Both Airports High Risk",
    "Month": "Month",
    "is_spring_summer": "Spring/Summer Season",
    "median_turnaround_min": "Median Turnaround (min)",
    "A_avg_wind_speed": "Origin: Avg Wind Speed",
    "A_precip_days": "Origin: Precipitation Days",
    "A_extreme_precip": "Origin: Extreme Precip Events",
    "A_total_precip": "Origin: Total Precipitation",
    "A_max_wind_gust": "Origin: Max Wind Gust",
    "B_avg_wind_speed": "Dest: Avg Wind Speed",
    "B_precip_days": "Dest: Precipitation Days",
    "B_extreme_precip": "Dest: Extreme Precip Events",
    "B_total_precip": "Dest: Total Precipitation",
    "B_max_wind_gust": "Dest: Max Wind Gust",
    "pair_max_avg_wind_speed": "Pair: Max Avg Wind Speed",
    "pair_max_precip_days": "Pair: Max Precip Days",
    "pair_max_extreme_precip": "Pair: Max Extreme Precip",
    "pair_max_total_precip": "Pair: Max Total Precip",
    "pair_max_max_wind_gust": "Pair: Max Wind Gust",
    "DFW_weather_delay_rate": "DFW Hub: Weather Delay Rate",
    "DFW_weather_cancel_rate": "DFW Hub: Weather Cancel Rate",
    "DFW_avg_weather_delay_min": "DFW Hub: Avg Weather Delay (min)",
    "DFW_p95_weather_delay_min": "DFW Hub: P95 Weather Delay (min)",
    "tc_legs_before_mean": "Crew: Legs Before DFW (avg)",
    "tc_block_before_mean": "Crew: Block Time Before (avg min)",
    "tc_duty_start_hour": "Crew: Duty Start Hour",
    "tc_total_duty_mean": "Crew: Total Duty Time (avg min)",
    "tc_total_duty_p75": "Crew: Total Duty Time (P75 min)",
    "tc_fdp_util_mean": "Crew: FDP Utilization (avg)",
    "tc_fdp_util_p75": "Crew: FDP Utilization (P75)",
    "tc_fdp_overrun_rate": "Crew: FDP Overrun Rate",
    "tc_wocl_rate": "Crew: WOCL Overlap Rate",
    "tc_legs_after_mean": "Crew: Legs After DFW (avg)",
    "tc_legs_in_day_mean": "Crew: Total Legs in Day (avg)",
    "tc_downstream_rate": "Cascade: Downstream Late Rate",
    "tc_cascade_late_rate": "Cascade: B→DFW Late Rate",
    "tc_cascade_late_min": "Cascade: B→DFW Avg Late (min)",
    "tc_cascade_amplif_mean": "Cascade: Delay Amplification",
    "A_ap_cascade_rate": "Origin: Airport Cascade Rate",
    "A_ap_cascade_given_late": "Origin: Cascade Rate Given Late",
    "B_ap_cascade_rate": "Dest: Airport Cascade Rate",
    "B_ap_cascade_given_late": "Dest: Cascade Rate Given Late",
    "pair_cascade_product": "Pair: Cascade Rate Product",
    "pair_max_cascade_rate": "Pair: Max Cascade Rate",
    "mhc_n_hops_mean": "Multi-Hop: Avg Downstream Hops",
    "mhc_n_hops_max": "Multi-Hop: Max Downstream Hops",
    "mhc_total_late_min_mean": "Multi-Hop: Avg Total Late (min)",
    "mhc_total_late_min_p75": "Multi-Hop: P75 Total Late (min)",
    "mhc_cascade_hop_rate": "Multi-Hop: Cascade Rate",
    "mhc_cascade_depth_mean": "Multi-Hop: Avg Cascade Depth",
    "mhc_unique_airports_mean": "Multi-Hop: Avg Airports Affected",
    "mhc_recovery_rate": "Multi-Hop: Recovery Rate",
}


def _get_dfw_weather() -> pd.DataFrame:
    cache = os.path.join(PROCESSED, "dfw_weather_monthly.parquet")
    if os.path.exists(cache):
        return pd.read_parquet(cache)
    return pd.DataFrame()


_APP_CACHE = os.path.join(PROCESSED, "app_features_cache.parquet")


def build_features_df(force_rebuild: bool = False) -> pd.DataFrame:
    """
    Load merged feature table for the app.

    Build path (priority order):
      1. Load existing app_features_cache.parquet if present (fast, ~1s)
      2. If sequence_features.parquet exists: rebuild via full join pipeline
      3. Fallback: rebuild from committed processed parquets + Open-Meteo
         (works without sequence_features.parquet β€” uses enrich_openmeteo logic)

    Set force_rebuild=True or delete app_features_cache.parquet to regenerate.
    """
    if not force_rebuild and os.path.exists(_APP_CACHE):
        return pd.read_parquet(_APP_CACHE)

    seq_path = os.path.join(PROCESSED, "sequence_features.parquet")
    if os.path.exists(seq_path):
        # ── Full pipeline rebuild from sequence_features ──────────────────
        print("Building feature cache from sequence_features (~60s)...")
        df = pd.read_parquet(seq_path)
        df["target"] = (df["observed_bad_rate"] > RISK_THRESHOLD).astype(int)

        if "pair_max_weather_rate" not in df.columns and "A_weather_delay_rate" in df.columns:
            df["pair_max_weather_rate"] = df[["A_weather_delay_rate",
                                              "B_weather_delay_rate"]].max(axis=1)

        dfw = _get_dfw_weather()
        if not dfw.empty:
            df = df.merge(dfw, on="Month", how="left")

        tc_path = os.path.join(PROCESSED, "tail_chain_features.parquet")
        if os.path.exists(tc_path):
            tc = pd.read_parquet(tc_path)
            tc_meta = ["airport_A", "airport_B", "Month", "Year"]
            df = df.merge(tc[tc_meta + [c for c in tc.columns if c not in tc_meta]],
                          on=tc_meta, how="left")

        ap_path = os.path.join(PROCESSED, "airport_cascade_features.parquet")
        if os.path.exists(ap_path):
            ap = pd.read_parquet(ap_path)
            ap_feat = [c for c in ap.columns if c not in ("airport", "Month")]
            for side in ("A", "B"):
                rename = {c: f"{side}_ap_{c}" for c in ap_feat}
                merged = ap.rename(columns={"airport": f"airport_{side}", **rename})
                df = df.merge(merged[[f"airport_{side}", "Month"] + list(rename.values())],
                              on=[f"airport_{side}", "Month"], how="left")
            if "A_ap_cascade_rate" in df.columns and "B_ap_cascade_rate" in df.columns:
                df["pair_cascade_product"]  = df["A_ap_cascade_rate"] * df["B_ap_cascade_rate"]
                df["pair_max_cascade_rate"] = df[["A_ap_cascade_rate",
                                                   "B_ap_cascade_rate"]].max(axis=1)

        mhc_path = os.path.join(PROCESSED, "multihop_cascade_features.parquet")
        if os.path.exists(mhc_path):
            mhc = pd.read_parquet(mhc_path)
            mhc_meta = ["airport_A", "airport_B", "Month", "Year"]
            df = df.merge(mhc[mhc_meta + [c for c in mhc.columns if c not in mhc_meta]],
                          on=mhc_meta, how="left")

        df = (
            df.sort_values("Year")
            .groupby(["airport_A", "airport_B", "Month"], as_index=False)
            .last()
        )

    else:
        # ── Fallback: rebuild from committed parquets + Open-Meteo ────────
        print("sequence_features.parquet not found β€” rebuilding from processed files...")
        import sys, importlib
        sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
        enrich = importlib.import_module("enrich_openmeteo")
        df = enrich.build_pair_features()

    df.to_parquet(_APP_CACHE, index=False)
    print(f"Feature cache saved β†’ {_APP_CACHE} ({len(df):,} rows)")
    return df


class RiskPredictor:
    def __init__(self, features_df: pd.DataFrame):
        self.df = features_df.set_index(["airport_A", "airport_B", "Month"])
        model_path = os.path.join(PROCESSED, "xgb_model.json")
        self.model = xgb.XGBClassifier()
        self.model.load_model(model_path)
        # Use exact feature names the model was trained with (authoritative)
        self.feature_cols = self.model.get_booster().feature_names
        self._explainer = None

    @property
    def explainer(self):
        if self._explainer is None:
            import shap
            self._explainer = shap.TreeExplainer(self.model)
        return self._explainer

    def predict_pair(self, airport_a: str, airport_b: str, month: int) -> dict | None:
        """Return prediction dict or None if pair not in dataset."""
        try:
            row = self.df.loc[(airport_a, airport_b, month)]
        except KeyError:
            return None

        if isinstance(row, pd.DataFrame):
            row = row.iloc[0]

        # Index consumed airport_A, airport_B, Month β€” add Month back for model
        row = row.copy()
        row["Month"] = month

        X_raw = row[self.feature_cols].to_frame().T.astype(float)
        X, gsom_imputed = _apply_gsom_imputation(X_raw, month, airport_a, airport_b)
        prob_raw = float(self.model.predict_proba(X)[0, 1])
        prob = _calibrate(prob_raw)   # map to observed-bad-rate scale

        return {
            "risk_score": prob,
            "label": _risk_label(prob),
            "color": _risk_color(prob),
            "observed_bad_rate": float(row.get("observed_bad_rate", np.nan)),
            "n_sequences": int(row.get("n_sequences", 0)),
            "X": X,
            "X_raw": X_raw,
            "gsom_imputed": gsom_imputed,   # set of column names that were filled
            "row": row,
        }

    def explain_pair(self, X: pd.DataFrame, top_n: int = 15,
                     gsom_imputed: set[str] | None = None) -> pd.DataFrame:
        """Return DataFrame of feature contributions sorted by |SHAP value|.
        gsom_imputed: set of column names that were filled via median imputation.
        """
        import shap
        shap_vals = self.explainer.shap_values(X)
        if isinstance(shap_vals, list):
            shap_vals = shap_vals[1]
        vals = shap_vals[0]
        feat_names = X.columns.tolist()
        imputed_set = gsom_imputed or set()
        result = pd.DataFrame({
            "feature": feat_names,
            "shap_value": vals,
            "feature_value": X.iloc[0].values,
            "label": [
                (FEATURE_LABELS.get(f, f) + " β˜…")   # star = imputed
                if f in imputed_set else FEATURE_LABELS.get(f, f)
                for f in feat_names
            ],
            "imputed": [f in imputed_set for f in feat_names],
        })
        result["abs_shap"] = result["shap_value"].abs()
        return result.sort_values("abs_shap", ascending=False).head(top_n).reset_index(drop=True)

    def predict_all_months(self, airport_a: str, airport_b: str) -> pd.DataFrame:
        """Risk score for every month for a given pair."""
        rows = []
        for m in range(1, 13):
            res = self.predict_pair(airport_a, airport_b, m)
            rows.append({
                "Month": m,
                "risk_score": res["risk_score"] if res else np.nan,
                "label": res["label"] if res else "No data",
            })
        return pd.DataFrame(rows)

    @property
    def airports_a(self) -> list[str]:
        return sorted(self.df.index.get_level_values("airport_A").unique())

    @property
    def airports_b(self) -> list[str]:
        return sorted(self.df.index.get_level_values("airport_B").unique())


def _risk_label(score: float) -> str:
    if score >= HIGH_THRESHOLD:
        return "HIGH RISK"
    if score >= MOD_THRESHOLD:
        return "MODERATE RISK"
    return "LOW RISK"


def _risk_color(score: float) -> str:
    if score >= HIGH_THRESHOLD:
        return "#d62728"
    if score >= MOD_THRESHOLD:
        return "#ff7f0e"
    return "#2ca02c"