MuleGuard / src /models /train.py
Aryan Singh
Improve mule classifier: native-NaN + missingness (CV PR-AUC 0.88->0.91, recall 13->15/16)
67eae2d
Raw
History Blame Contribute Delete
9.56 kB
"""Train, evaluate, calibrate, and persist the MuleGuard classifier.
Discipline (see .claude/skills/imbalanced-fraud-classification):
- metrics: PR-AUC, ROC-AUC, recall@precision, F2 (never accuracy)
- repeated stratified CV for honest variance
- threshold tuned on out-of-fold predictions, not 0.5
- probabilities calibrated so the 0-100 risk score is meaningful
"""
from __future__ import annotations
import json
import joblib
import numpy as np
import pandas as pd
from lightgbm import LGBMClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (average_precision_score, confusion_matrix, fbeta_score,
precision_recall_curve, roc_auc_score)
from sklearn.model_selection import (RepeatedStratifiedKFold, StratifiedKFold,
cross_val_predict, cross_val_score)
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from src import config
from src.models.figures import make_figures
def base_lgbm(scale_pos_weight: float) -> LGBMClassifier:
return LGBMClassifier(
n_estimators=400, num_leaves=31, max_depth=5, learning_rate=0.03,
subsample=0.8, colsample_bytree=0.6, reg_lambda=2.0, min_child_samples=20,
scale_pos_weight=scale_pos_weight, random_state=config.SEED, n_jobs=-1, verbose=-1,
)
def tune_threshold(y_true, prob) -> dict:
"""Pick the threshold maximizing F2 (recall-weighted) on OOF predictions."""
prec, rec, thr = precision_recall_curve(y_true, prob)
# thr has len-1 vs prec/rec; align.
f2 = (1 + 4) * (prec * rec) / np.where((4 * prec + rec) == 0, 1, (4 * prec + rec))
best_i = int(np.nanargmax(f2[:-1])) if len(thr) else 0
best_thr = float(thr[best_i]) if len(thr) else 0.5
def recall_at_precision(target_p):
ok = prec[:-1] >= target_p
return float(rec[:-1][ok].max()) if ok.any() else 0.0
return {
"threshold": best_thr,
"f2_at_threshold": float(f2[best_i]),
"precision_at_threshold": float(prec[best_i]),
"recall_at_threshold": float(rec[best_i]),
"recall_at_precision_0.30": recall_at_precision(0.30),
"recall_at_precision_0.50": recall_at_precision(0.50),
"recall_at_precision_0.70": recall_at_precision(0.70),
}
def main() -> None:
config.ensure_dirs()
builder = joblib.load(config.PIPELINE_PATH)
train_raw = pd.read_parquet(config.ARTIFACTS_DIR / "train_holdout.parquet")
test_raw = pd.read_parquet(config.TEST_SPLIT_PATH)
y_tr = train_raw[config.TARGET].astype(int)
X_tr = builder.transform(train_raw.drop(columns=[config.TARGET]))
y_te = test_raw[config.TARGET].astype(int)
X_te = builder.transform(test_raw.drop(columns=[config.TARGET]))
pos, neg = int(y_tr.sum()), int((y_tr == 0).sum())
spw = neg / max(pos, 1)
print(f"Train: {len(y_tr)} rows, {pos} positives | Test: {len(y_te)} rows, {int(y_te.sum())} positives")
# ---- LightGBM: honest out-of-fold predictions for metrics + threshold ----
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=config.SEED)
lgbm = base_lgbm(spw)
oof_prob = cross_val_predict(lgbm, X_tr, y_tr, cv=skf, method="predict_proba", n_jobs=-1)[:, 1]
# Repeated CV for variance on the headline metrics.
rcv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=config.SEED)
ap_scores = cross_val_score(base_lgbm(spw), X_tr, y_tr, cv=rcv, scoring="average_precision", n_jobs=-1)
auc_scores = cross_val_score(base_lgbm(spw), X_tr, y_tr, cv=rcv, scoring="roc_auc", n_jobs=-1)
cv_metrics = {
"oof_pr_auc": float(average_precision_score(y_tr, oof_prob)),
"oof_roc_auc": float(roc_auc_score(y_tr, oof_prob)),
"cv_pr_auc_mean": float(ap_scores.mean()), "cv_pr_auc_std": float(ap_scores.std()),
"cv_roc_auc_mean": float(auc_scores.mean()), "cv_roc_auc_std": float(auc_scores.std()),
}
thr_info = tune_threshold(y_tr.values, oof_prob)
threshold = thr_info["threshold"]
# ---- Logistic baseline (sanity) ----
logit = Pipeline([("impute", SimpleImputer(strategy="median")),
("scale", StandardScaler()),
("clf", LogisticRegression(max_iter=2000, class_weight="balanced", C=0.1))])
logit_ap = cross_val_score(logit, X_tr, y_tr, cv=skf, scoring="average_precision", n_jobs=-1).mean()
logit_auc = cross_val_score(logit, X_tr, y_tr, cv=skf, scoring="roc_auc", n_jobs=-1).mean()
# ---- Final calibrated model on all training data ----
final = CalibratedClassifierCV(base_lgbm(spw), method="sigmoid", cv=5)
final.fit(X_tr, y_tr)
joblib.dump(final, config.MODEL_PATH)
# ---- Test-holdout evaluation at the tuned threshold ----
test_prob = final.predict_proba(X_te)[:, 1]
test_pred = (test_prob >= threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(y_te, test_pred).ravel()
test_metrics = {
"pr_auc": float(average_precision_score(y_te, test_prob)),
"roc_auc": float(roc_auc_score(y_te, test_prob)),
"precision": float(tp / (tp + fp)) if (tp + fp) else 0.0,
"recall": float(tp / (tp + fn)) if (tp + fn) else 0.0,
"f2": float(fbeta_score(y_te, test_pred, beta=2, zero_division=0)),
"confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)},
"alerts_raised": int(test_pred.sum()),
"alert_rate": float(test_pred.mean()),
}
metadata = {
"model": "CalibratedClassifierCV(LightGBM, sigmoid) fused with IsolationForest anomaly score",
"n_features": len(builder.selected_features_),
"scale_pos_weight": spw,
"seed": config.SEED,
"leakage_excluded": config.LEAKAGE_EXCLUDE,
"cv_metrics": cv_metrics,
"threshold_info": thr_info,
"logistic_baseline": {"cv_pr_auc": float(logit_ap), "cv_roc_auc": float(logit_auc)},
"test_metrics": test_metrics,
}
config.METADATA_PATH.write_text(json.dumps(metadata, indent=2))
config.THRESHOLD_PATH.write_text(json.dumps(
{"threshold": threshold, "operating_point": "max F2 on OOF predictions"}, indent=2))
make_figures(y_te.values, test_prob, test_pred, threshold)
_write_report(metadata)
print("\n=== HEADLINE ===")
print(f"CV PR-AUC : {cv_metrics['cv_pr_auc_mean']:.3f} ± {cv_metrics['cv_pr_auc_std']:.3f}")
print(f"CV ROC-AUC: {cv_metrics['cv_roc_auc_mean']:.3f} ± {cv_metrics['cv_roc_auc_std']:.3f}")
print(f"TEST PR-AUC {test_metrics['pr_auc']:.3f} | ROC-AUC {test_metrics['roc_auc']:.3f} "
f"| recall {test_metrics['recall']:.2f} | precision {test_metrics['precision']:.2f}")
print(f"TEST confusion: {test_metrics['confusion_matrix']} | alerts {test_metrics['alerts_raised']}")
print(f"Logistic baseline CV PR-AUC {logit_ap:.3f} (vs LGBM {cv_metrics['cv_pr_auc_mean']:.3f})")
def _write_report(m: dict) -> None:
cv, t, tm, base = m["cv_metrics"], m["threshold_info"], m["test_metrics"], m["logistic_baseline"]
cm = tm["confusion_matrix"]
L = ["# Model Report (Model Card) — MuleGuard\n"]
L.append(f"**Model:** {m['model']}\n")
L.append(f"**Features:** {m['n_features']} (incl. fused anomaly score). "
f"**Leakage excluded:** {', '.join(m['leakage_excluded'])} (label-adjacent flag).\n")
L.append("## Cross-validated performance (training, repeated stratified CV)\n")
L.append(f"- **PR-AUC:** {cv['cv_pr_auc_mean']:.3f} ± {cv['cv_pr_auc_std']:.3f}")
L.append(f"- **ROC-AUC:** {cv['cv_roc_auc_mean']:.3f} ± {cv['cv_roc_auc_std']:.3f}")
L.append(f"- OOF PR-AUC {cv['oof_pr_auc']:.3f}, OOF ROC-AUC {cv['oof_roc_auc']:.3f}\n")
L.append("## Baseline comparison\n")
L.append(f"- Logistic regression CV PR-AUC **{base['cv_pr_auc']:.3f}** / ROC-AUC {base['cv_roc_auc']:.3f}")
L.append(f"- LightGBM (ours) CV PR-AUC **{cv['cv_pr_auc_mean']:.3f}** — the gain justifies trees.\n")
L.append(f"## Operating point (threshold = {t['threshold']:.4f}, max-F2 on OOF)\n")
L.append(f"- Recall@P≥0.30: {t['recall_at_precision_0.30']:.2f} · "
f"Recall@P≥0.50: {t['recall_at_precision_0.50']:.2f} · "
f"Recall@P≥0.70: {t['recall_at_precision_0.70']:.2f}\n")
L.append("## Held-out test performance\n")
L.append(f"- **PR-AUC {tm['pr_auc']:.3f} · ROC-AUC {tm['roc_auc']:.3f}**")
L.append(f"- Recall **{tm['recall']:.2f}** · Precision **{tm['precision']:.2f}** · F2 **{tm['f2']:.2f}**")
L.append(f"- Confusion: TP={cm['tp']} FP={cm['fp']} FN={cm['fn']} TN={cm['tn']}")
L.append(f"- Alerts raised: {tm['alerts_raised']} of {sum(cm.values())} accounts "
f"({tm['alert_rate']:.1%}) — caught {cm['tp']}/{cm['tp']+cm['fn']} mules.\n")
L.append("## Figures\n")
L.append("![PR curve](figures/pr_curve.png)\n![ROC curve](figures/roc_curve.png)\n"
"![Confusion](figures/confusion_matrix.png)\n![Risk distribution](figures/risk_distribution.png)\n")
L.append("## Limitations\n")
L.append("- Only 81 positives → metrics carry variance; we report CV mean ± std for honesty.")
L.append("- Features are anonymized; reason codes describe direction/magnitude, not semantics.")
L.append("- Productionization to live regulatory/cross-channel feeds is a roadmap item.\n")
(config.REPORTS_DIR / "model_report.md").write_text("\n".join(L))
if __name__ == "__main__":
main()