MuleGuard / src /models /explain.py
MuleGuard
MuleGuard: end-to-end mule-account detection + HF Space deploy
af879c2
Raw
History Blame Contribute Delete
2.13 kB
"""Generate global SHAP explainability: beeswarm + bar importance over the test
holdout, and append an explainability section to the model report."""
from __future__ import annotations
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from src import config
from src.models.scoring import load_artifacts
def main() -> None:
art = load_artifacts()
test_raw = pd.read_parquet(config.TEST_SPLIT_PATH).drop(columns=[config.TARGET])
X = art.builder.transform(test_raw)
# Average SHAP across the CV-fold base estimators.
total = np.zeros(X.shape)
for expl in art.explainers:
sv = expl.shap_values(X)
if isinstance(sv, list):
sv = sv[1]
total += np.asarray(sv)
shap_vals = total / max(len(art.explainers), 1)
# Beeswarm
plt.figure()
shap.summary_plot(shap_vals, X, show=False, max_display=15)
plt.tight_layout()
plt.savefig(config.FIGURES_DIR / "shap_beeswarm.png", dpi=130, bbox_inches="tight")
plt.close()
# Bar importance
mean_abs = np.abs(shap_vals).mean(axis=0)
order = np.argsort(-mean_abs)[:15]
feats = [X.columns[i] for i in order]
plt.figure(figsize=(6, 5))
plt.barh(range(len(order))[::-1], mean_abs[order], color="#1f4e79")
plt.yticks(range(len(order))[::-1], feats)
plt.xlabel("Mean |SHAP|"); plt.title("Global feature importance (SHAP)")
plt.tight_layout()
plt.savefig(config.FIGURES_DIR / "shap_importance.png", dpi=130)
plt.close()
# Append to model report
report = config.REPORTS_DIR / "model_report.md"
extra = ["\n## Explainability (SHAP)\n",
"Every alert ships with reason codes. Global drivers (test set):\n",
"![SHAP importance](figures/shap_importance.png)\n",
"![SHAP beeswarm](figures/shap_beeswarm.png)\n",
"Top global drivers: " + ", ".join(feats[:8]) + ".\n"]
with report.open("a") as f:
f.write("\n".join(extra))
print("SHAP figures written. Top drivers:", ", ".join(feats[:8]))
if __name__ == "__main__":
main()