MuleGuard / src /models /scoring.py
Aryan Singh
Improve mule classifier: native-NaN + missingness (CV PR-AUC 0.88->0.91, recall 13->15/16)
67eae2d
Raw
History Blame Contribute Delete
3.79 kB
"""Shared scoring + explanation logic used by the API, simulator, and dashboard.
Guarantees train/serve parity: everything loads the SAME persisted FeatureBuilder,
calibrated model, and tuned threshold.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from functools import lru_cache
import joblib
import numpy as np
import pandas as pd
import shap
from src import config
@dataclass
class Artifacts:
builder: object
model: object # CalibratedClassifierCV(LightGBM)
threshold: float
features: list
metadata: dict
base_estimators: list = field(default_factory=list)
explainers: list = field(default_factory=list)
@lru_cache(maxsize=1)
def load_artifacts() -> Artifacts:
builder = joblib.load(config.PIPELINE_PATH)
model = joblib.load(config.MODEL_PATH)
threshold = json.loads(config.THRESHOLD_PATH.read_text())["threshold"]
features = json.loads(config.FEATURE_LIST_PATH.read_text())
metadata = json.loads(config.METADATA_PATH.read_text())
# Pull the fitted LightGBM base estimators out of the calibrated model so SHAP
# explains the actual shipped model. Average SHAP across the CV folds.
base = []
for cc in getattr(model, "calibrated_classifiers_", []):
est = getattr(cc, "estimator", None) or getattr(cc, "base_estimator", None)
if est is not None:
base.append(est)
explainers = [shap.TreeExplainer(e) for e in base]
return Artifacts(builder, model, threshold, features, metadata, base, explainers)
def score_frame(raw_df: pd.DataFrame) -> pd.DataFrame:
"""Score raw account rows. Returns prob, risk_score (0-100), tier, decision."""
art = load_artifacts()
X = art.builder.transform(raw_df)
prob = art.model.predict_proba(X)[:, 1]
out = pd.DataFrame(index=raw_df.index)
out["probability"] = prob
out["risk_score"] = [config.prob_to_risk(p, art.threshold) for p in prob]
out["decision"] = np.where(out["risk_score"] >= 50, "FLAG_MULE", "CLEAR")
out["risk_tier"] = [config.risk_tier(s) for s in out["risk_score"]]
return out
def explain_row(raw_row: pd.DataFrame, top_n: int = 5) -> list[dict]:
"""Top-N features driving an account's risk, via averaged SHAP over CV folds."""
art = load_artifacts()
X = art.builder.transform(raw_row)
vals = np.zeros(X.shape[1])
for expl in art.explainers:
sv = expl.shap_values(X)
if isinstance(sv, list): # older shap returns [class0, class1]
sv = sv[1]
vals += np.asarray(sv)[0]
vals /= max(len(art.explainers), 1)
order = np.argsort(-np.abs(vals))[:top_n]
feats = list(X.columns)
reasons = []
for i in order:
reasons.append({
"feature": feats[i],
"value": float(X.iloc[0, i]),
"shap": float(vals[i]),
"direction": "increases risk" if vals[i] > 0 else "lowers risk",
})
return reasons
def narrative(risk_score: float, reasons: list[dict]) -> str:
"""Short human-readable alert narrative from reason codes.
Direction-aware: flagged accounts describe what *raised* risk; cleared
accounts describe what *kept it low* — so the wording always matches the
decision instead of always naming risk-increasing features.
"""
if risk_score >= 50:
feats = [r["feature"] for r in reasons if r["shap"] > 0][:3]
drivers = ", ".join(feats) if feats else "diffuse elevated signals"
return f"Risk {risk_score:.0f}/100 — elevated mainly by {drivers}."
feats = [r["feature"] for r in reasons if r["shap"] < 0][:3]
drivers = ", ".join(feats) if feats else "no material risk signals"
return f"Risk {risk_score:.0f}/100 — kept low mainly by {drivers}."