AA-EPPS-Data-Challenge / src /stress_test.py
itaykadosh's picture
Initial upload: AA EPPS Data Challenge app
bef09da verified
"""
Stress test & real-world verification for the crew sequence risk model.
Test 1 — 2024 holdout backtest
Reload model, score 2024 rows (never seen during training), check:
- ROC-AUC, Average Precision
- Decile calibration: do higher-score deciles have proportionally higher actual bad rates?
- Monotonicity: is each decile worse than the previous?
Test 2 — Historic storm event replay
For known catastrophic weather events, download IEM METAR for affected airports,
compute per-day weather severity, and check whether the model + weather penalty
would have correctly flagged those sequences as high-risk.
Events:
A. Winter Storm Elliott — Dec 23-26 2022 (ORD, BUF, CLT, DEN, MDW)
B. DFW summer convective — Jul 10 2023 (DFW-centric, peak convective season)
C. Southeast ice storm — Jan 16 2024 (ATL, CLT, RDU, MIA)
"""
import os, sys
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
sys.path.insert(0, os.path.dirname(__file__))
from weather import download_historical_weather, _compute_severity
PROC = os.path.join(os.path.dirname(__file__), "..", "data", "processed")
PLOTS = os.path.join(PROC, "plots")
os.makedirs(PLOTS, exist_ok=True)
FEATURE_COLS = [
"A_weather_delay_rate", "A_weather_cancel_rate", "A_avg_weather_delay_min",
"A_p75_weather_delay_min", "A_p95_weather_delay_min", "A_nas_delay_rate",
"A_overall_weather_delay_rate", "A_overall_avg_weather_delay_min",
"B_weather_delay_rate", "B_weather_cancel_rate", "B_avg_weather_delay_min",
"B_p75_weather_delay_min", "B_p95_weather_delay_min", "B_nas_delay_rate",
"B_overall_weather_delay_rate", "B_overall_avg_weather_delay_min",
"pair_combined_weather_rate", "pair_max_weather_rate", "pair_min_weather_rate",
"pair_weather_rate_sum", "pair_avg_weather_delay_min", "both_high_risk",
"Month", "is_spring_summer", "median_turnaround_min",
]
STORM_EVENTS = {
"Winter Storm Elliott (Dec 23-26 2022)": {
"airports": ["ORD", "BUF", "CLT", "DEN", "MDW", "LGA", "DFW"],
"year": 2022,
"dates": pd.date_range("2022-12-23", "2022-12-26"),
"expected": "CRITICAL — massive cancellation event across US",
},
"DFW Convective Outbreak (Jul 10 2023)": {
"airports": ["DFW", "DAL", "AUS", "SAT", "HOU", "MSY", "ORD"],
"year": 2023,
"dates": pd.date_range("2023-07-08", "2023-07-12"),
"expected": "HIGH — peak summer convective at DFW hub",
},
"Southeast Ice Storm (Jan 16-17 2024)": {
"airports": ["ATL", "CLT", "RDU", "GSO", "BNA", "MIA", "DFW"],
"year": 2024,
"dates": pd.date_range("2024-01-15", "2024-01-18"),
"expected": "HIGH — ice/freezing rain across Southeast",
},
}
# ---------------------------------------------------------------------------
# Test 1: 2024 holdout backtest
# ---------------------------------------------------------------------------
def test1_holdout_backtest():
print("\n" + "="*65)
print("TEST 1 — 2024 Holdout Backtest")
print("="*65)
df = pd.read_parquet(os.path.join(PROC, "sequence_features.parquet"))
season_cols = [c for c in df.columns if c.startswith("season_")]
feat_cols = [c for c in FEATURE_COLS + season_cols if c in df.columns]
val = df[df["Year"] == 2024].copy()
X_val = val[feat_cols].astype(float)
y_val = val["target"].astype(int)
model = xgb.XGBClassifier()
model.load_model(os.path.join(PROC, "xgb_model.json"))
proba = model.predict_proba(X_val)[:, 1]
roc = roc_auc_score(y_val, proba)
ap = average_precision_score(y_val, proba)
print(f"\nROC-AUC (2024): {roc:.4f}")
print(f"Average Precision (2024):{ap:.4f}")
print(f"Baseline bad rate: {y_val.mean():.3f} ({y_val.mean():.1%})")
# Decile calibration
val = val.copy()
val["proba"] = proba
val["decile"] = pd.qcut(val["proba"], q=10, labels=False, duplicates="drop")
decile_tbl = (
val.groupby("decile")
.agg(
n=("target", "count"),
actual_bad_rate=("target", "mean"),
mean_predicted=("proba", "mean"),
mean_observed=("observed_bad_rate", "mean"),
)
.reset_index()
)
print("\nDecile Calibration (0=lowest predicted risk, 9=highest):")
print(decile_tbl.to_string(index=False, float_format="{:.3f}".format))
# Monotonicity check
rates = decile_tbl["actual_bad_rate"].values
monotone = all(rates[i] <= rates[i+1] for i in range(len(rates)-1))
print(f"\nMonotonicity (each decile worse than prev): {'PASS' if monotone else 'FAIL'}")
# Lift: top decile vs bottom decile
top_rate = decile_tbl.iloc[-1]["actual_bad_rate"]
bottom_rate = decile_tbl.iloc[0]["actual_bad_rate"]
print(f"Lift (top/bottom decile): {top_rate/bottom_rate:.1f}x")
# Plot calibration
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
ax = axes[0]
ax.plot(decile_tbl["mean_predicted"], decile_tbl["actual_bad_rate"],
"o-", color="steelblue", label="Actual bad rate")
ax.plot([0,1],[0,1], "k--", alpha=0.4, label="Perfect calibration")
ax.set_xlabel("Mean predicted probability")
ax.set_ylabel("Actual bad rate")
ax.set_title("Calibration — 2024 holdout\n(by predicted probability decile)")
ax.legend()
ax.grid(alpha=0.3)
ax = axes[1]
bars = ax.bar(decile_tbl["decile"], decile_tbl["actual_bad_rate"],
color=plt.cm.RdYlGn_r(decile_tbl["actual_bad_rate"]))
ax.axhline(y_val.mean(), color="navy", linestyle="--", label=f"Baseline {y_val.mean():.1%}")
ax.set_xlabel("Risk score decile (0=lowest, 9=highest)")
ax.set_ylabel("Actual bad rate")
ax.set_title(f"Actual bad rate by decile — 2024 holdout\nROC-AUC={roc:.3f} AP={ap:.3f}")
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
out = os.path.join(PLOTS, "stress_test1_calibration.png")
plt.savefig(out, dpi=150)
plt.close()
print(f"\nCalibration plot → {out}")
return decile_tbl
# ---------------------------------------------------------------------------
# Test 2: Historic storm event replay
# ---------------------------------------------------------------------------
def test2_storm_replay():
print("\n" + "="*65)
print("TEST 2 — Historic Storm Event Replay")
print("="*65)
all_results = []
for event_name, ev in STORM_EVENTS.items():
print(f"\n--- {event_name} ---")
print(f"Expected: {ev['expected']}")
airports = ev["airports"]
year = ev["year"]
dates = ev["dates"]
# Download/cache IEM METAR for these airports × year
print(f"Fetching IEM METAR for {airports} in {year}...")
wx = download_historical_weather(airports, [year])
if wx.empty:
print(" WARNING: No weather data returned")
continue
# Filter to event dates
wx["date"] = pd.to_datetime(wx["date"])
ev_wx = wx[wx["date"].isin(dates)].copy()
if ev_wx.empty:
print(" WARNING: No data for event dates")
continue
print(f"\n Airport weather during event:")
cols = ["iata", "date", "weather_severity", "has_thunderstorm",
"has_snow_ice", "has_fog", "min_visibility_mi",
"max_wind_kt", "ceiling_ft"]
cols_present = [c for c in cols if c in ev_wx.columns]
print(ev_wx[cols_present].sort_values(["date","iata"]).to_string(index=False, float_format="{:.2f}".format))
# Summary stats per airport
summary = ev_wx.groupby("iata")["weather_severity"].agg(["mean","max"])
summary.columns = ["avg_severity", "max_severity"]
print(f"\n Severity summary:")
print(summary.sort_values("max_severity", ascending=False).to_string(float_format="{:.3f}".format))
# For each airport pair involving DFW, compute what live_score would have been
# using peak weather day severity
print(f"\n Pair-level impact (A→DFW→B, worst day):")
dfwrow = ev_wx[ev_wx["iata"] == "DFW"]
dfw_max_sev = dfwrow["weather_severity"].max() if not dfwrow.empty else 0.0
for iata in airports:
if iata == "DFW":
continue
rows = ev_wx[ev_wx["iata"] == iata]
if rows.empty:
continue
ap_max_sev = rows["weather_severity"].max()
# Simulate weather penalty formula from predict.py
penalty = max(dfw_max_sev * 0.6, ap_max_sev * 0.4)
all_results.append({
"event": event_name,
"airport": iata,
"dfw_severity":dfw_max_sev,
"ap_severity": ap_max_sev,
"wx_penalty": penalty,
"flag": "HIGH" if penalty > 0.3 else ("MOD" if penalty > 0.15 else "LOW"),
})
print(f" {iata}: airport_sev={ap_max_sev:.2f} dfw_sev={dfw_max_sev:.2f} "
f"penalty={penalty:.2f}{all_results[-1]['flag']}")
# Summary table
if all_results:
res_df = pd.DataFrame(all_results)
print("\n\n" + "="*65)
print("STORM REPLAY SUMMARY")
print("="*65)
print(res_df.groupby(["event","flag"]).size().unstack(fill_value=0).to_string())
# Plot
fig, ax = plt.subplots(figsize=(12, 6))
events_short = {e: e.split("(")[0].strip() for e in res_df["event"].unique()}
res_df["event_short"] = res_df["event"].map(events_short)
colors = {"HIGH": "#d62728", "MOD": "#ff7f0e", "LOW": "#2ca02c"}
for i, (ev_short, grp) in enumerate(res_df.groupby("event_short")):
x = range(len(grp))
bars = ax.bar(
[j + i*0.28 for j in range(len(grp))],
grp["wx_penalty"].values,
width=0.25,
label=ev_short,
color=[colors[f] for f in grp["flag"]],
alpha=0.8,
)
ax.axhline(0.30, color="red", linestyle="--", linewidth=1, label="HIGH threshold (0.30)")
ax.axhline(0.15, color="orange", linestyle="--", linewidth=1, label="MOD threshold (0.15)")
ax.set_ylabel("Weather penalty")
ax.set_title("Storm Event Replay — Weather Penalty per Airport Pair\n(higher = model would have flagged pair as high-risk)")
ax.legend(loc="upper right", fontsize=8)
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
out = os.path.join(PLOTS, "stress_test2_storm_replay.png")
plt.savefig(out, dpi=150)
plt.close()
print(f"\nStorm replay plot → {out}")
return all_results
if __name__ == "__main__":
t1 = test1_holdout_backtest()
t2 = test2_storm_replay()
print("\n" + "="*65)
print("STRESS TEST COMPLETE")
print("="*65)
print("Plots saved to data/processed/plots/")
print(" stress_test1_calibration.png")
print(" stress_test2_storm_replay.png")