"""Train, evaluate, calibrate, and persist the MuleGuard classifier. Discipline (see .claude/skills/imbalanced-fraud-classification): - metrics: PR-AUC, ROC-AUC, recall@precision, F2 (never accuracy) - repeated stratified CV for honest variance - threshold tuned on out-of-fold predictions, not 0.5 - probabilities calibrated so the 0-100 risk score is meaningful """ from __future__ import annotations import json import joblib import numpy as np import pandas as pd from lightgbm import LGBMClassifier from sklearn.calibration import CalibratedClassifierCV from sklearn.linear_model import LogisticRegression from sklearn.metrics import (average_precision_score, confusion_matrix, fbeta_score, precision_recall_curve, roc_auc_score) from sklearn.model_selection import (RepeatedStratifiedKFold, StratifiedKFold, cross_val_predict, cross_val_score) from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from src import config from src.models.figures import make_figures def base_lgbm(scale_pos_weight: float) -> LGBMClassifier: return LGBMClassifier( n_estimators=400, num_leaves=31, max_depth=5, learning_rate=0.03, subsample=0.8, colsample_bytree=0.6, reg_lambda=2.0, min_child_samples=20, scale_pos_weight=scale_pos_weight, random_state=config.SEED, n_jobs=-1, verbose=-1, ) def tune_threshold(y_true, prob) -> dict: """Pick the threshold maximizing F2 (recall-weighted) on OOF predictions.""" prec, rec, thr = precision_recall_curve(y_true, prob) # thr has len-1 vs prec/rec; align. f2 = (1 + 4) * (prec * rec) / np.where((4 * prec + rec) == 0, 1, (4 * prec + rec)) best_i = int(np.nanargmax(f2[:-1])) if len(thr) else 0 best_thr = float(thr[best_i]) if len(thr) else 0.5 def recall_at_precision(target_p): ok = prec[:-1] >= target_p return float(rec[:-1][ok].max()) if ok.any() else 0.0 return { "threshold": best_thr, "f2_at_threshold": float(f2[best_i]), "precision_at_threshold": float(prec[best_i]), "recall_at_threshold": float(rec[best_i]), "recall_at_precision_0.30": recall_at_precision(0.30), "recall_at_precision_0.50": recall_at_precision(0.50), "recall_at_precision_0.70": recall_at_precision(0.70), } def main() -> None: config.ensure_dirs() builder = joblib.load(config.PIPELINE_PATH) train_raw = pd.read_parquet(config.ARTIFACTS_DIR / "train_holdout.parquet") test_raw = pd.read_parquet(config.TEST_SPLIT_PATH) y_tr = train_raw[config.TARGET].astype(int) X_tr = builder.transform(train_raw.drop(columns=[config.TARGET])) y_te = test_raw[config.TARGET].astype(int) X_te = builder.transform(test_raw.drop(columns=[config.TARGET])) pos, neg = int(y_tr.sum()), int((y_tr == 0).sum()) spw = neg / max(pos, 1) print(f"Train: {len(y_tr)} rows, {pos} positives | Test: {len(y_te)} rows, {int(y_te.sum())} positives") # ---- LightGBM: honest out-of-fold predictions for metrics + threshold ---- skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=config.SEED) lgbm = base_lgbm(spw) oof_prob = cross_val_predict(lgbm, X_tr, y_tr, cv=skf, method="predict_proba", n_jobs=-1)[:, 1] # Repeated CV for variance on the headline metrics. rcv = RepeatedStratifiedKFold(n_splits=5, n_repeats=4, random_state=config.SEED) ap_scores = cross_val_score(base_lgbm(spw), X_tr, y_tr, cv=rcv, scoring="average_precision", n_jobs=-1) auc_scores = cross_val_score(base_lgbm(spw), X_tr, y_tr, cv=rcv, scoring="roc_auc", n_jobs=-1) cv_metrics = { "oof_pr_auc": float(average_precision_score(y_tr, oof_prob)), "oof_roc_auc": float(roc_auc_score(y_tr, oof_prob)), "cv_pr_auc_mean": float(ap_scores.mean()), "cv_pr_auc_std": float(ap_scores.std()), "cv_roc_auc_mean": float(auc_scores.mean()), "cv_roc_auc_std": float(auc_scores.std()), } thr_info = tune_threshold(y_tr.values, oof_prob) threshold = thr_info["threshold"] # ---- Logistic baseline (sanity) ---- logit = Pipeline([("impute", SimpleImputer(strategy="median")), ("scale", StandardScaler()), ("clf", LogisticRegression(max_iter=2000, class_weight="balanced", C=0.1))]) logit_ap = cross_val_score(logit, X_tr, y_tr, cv=skf, scoring="average_precision", n_jobs=-1).mean() logit_auc = cross_val_score(logit, X_tr, y_tr, cv=skf, scoring="roc_auc", n_jobs=-1).mean() # ---- Final calibrated model on all training data ---- final = CalibratedClassifierCV(base_lgbm(spw), method="sigmoid", cv=5) final.fit(X_tr, y_tr) joblib.dump(final, config.MODEL_PATH) # ---- Test-holdout evaluation at the tuned threshold ---- test_prob = final.predict_proba(X_te)[:, 1] test_pred = (test_prob >= threshold).astype(int) tn, fp, fn, tp = confusion_matrix(y_te, test_pred).ravel() test_metrics = { "pr_auc": float(average_precision_score(y_te, test_prob)), "roc_auc": float(roc_auc_score(y_te, test_prob)), "precision": float(tp / (tp + fp)) if (tp + fp) else 0.0, "recall": float(tp / (tp + fn)) if (tp + fn) else 0.0, "f2": float(fbeta_score(y_te, test_pred, beta=2, zero_division=0)), "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}, "alerts_raised": int(test_pred.sum()), "alert_rate": float(test_pred.mean()), } metadata = { "model": "CalibratedClassifierCV(LightGBM, sigmoid) fused with IsolationForest anomaly score", "n_features": len(builder.selected_features_), "scale_pos_weight": spw, "seed": config.SEED, "leakage_excluded": config.LEAKAGE_EXCLUDE, "cv_metrics": cv_metrics, "threshold_info": thr_info, "logistic_baseline": {"cv_pr_auc": float(logit_ap), "cv_roc_auc": float(logit_auc)}, "test_metrics": test_metrics, } config.METADATA_PATH.write_text(json.dumps(metadata, indent=2)) config.THRESHOLD_PATH.write_text(json.dumps( {"threshold": threshold, "operating_point": "max F2 on OOF predictions"}, indent=2)) make_figures(y_te.values, test_prob, test_pred, threshold) _write_report(metadata) print("\n=== HEADLINE ===") print(f"CV PR-AUC : {cv_metrics['cv_pr_auc_mean']:.3f} ± {cv_metrics['cv_pr_auc_std']:.3f}") print(f"CV ROC-AUC: {cv_metrics['cv_roc_auc_mean']:.3f} ± {cv_metrics['cv_roc_auc_std']:.3f}") print(f"TEST PR-AUC {test_metrics['pr_auc']:.3f} | ROC-AUC {test_metrics['roc_auc']:.3f} " f"| recall {test_metrics['recall']:.2f} | precision {test_metrics['precision']:.2f}") print(f"TEST confusion: {test_metrics['confusion_matrix']} | alerts {test_metrics['alerts_raised']}") print(f"Logistic baseline CV PR-AUC {logit_ap:.3f} (vs LGBM {cv_metrics['cv_pr_auc_mean']:.3f})") def _write_report(m: dict) -> None: cv, t, tm, base = m["cv_metrics"], m["threshold_info"], m["test_metrics"], m["logistic_baseline"] cm = tm["confusion_matrix"] L = ["# Model Report (Model Card) — MuleGuard\n"] L.append(f"**Model:** {m['model']}\n") L.append(f"**Features:** {m['n_features']} (incl. fused anomaly score). " f"**Leakage excluded:** {', '.join(m['leakage_excluded'])} (label-adjacent flag).\n") L.append("## Cross-validated performance (training, repeated stratified CV)\n") L.append(f"- **PR-AUC:** {cv['cv_pr_auc_mean']:.3f} ± {cv['cv_pr_auc_std']:.3f}") L.append(f"- **ROC-AUC:** {cv['cv_roc_auc_mean']:.3f} ± {cv['cv_roc_auc_std']:.3f}") L.append(f"- OOF PR-AUC {cv['oof_pr_auc']:.3f}, OOF ROC-AUC {cv['oof_roc_auc']:.3f}\n") L.append("## Baseline comparison\n") L.append(f"- Logistic regression CV PR-AUC **{base['cv_pr_auc']:.3f}** / ROC-AUC {base['cv_roc_auc']:.3f}") L.append(f"- LightGBM (ours) CV PR-AUC **{cv['cv_pr_auc_mean']:.3f}** — the gain justifies trees.\n") L.append(f"## Operating point (threshold = {t['threshold']:.4f}, max-F2 on OOF)\n") L.append(f"- Recall@P≥0.30: {t['recall_at_precision_0.30']:.2f} · " f"Recall@P≥0.50: {t['recall_at_precision_0.50']:.2f} · " f"Recall@P≥0.70: {t['recall_at_precision_0.70']:.2f}\n") L.append("## Held-out test performance\n") L.append(f"- **PR-AUC {tm['pr_auc']:.3f} · ROC-AUC {tm['roc_auc']:.3f}**") L.append(f"- Recall **{tm['recall']:.2f}** · Precision **{tm['precision']:.2f}** · F2 **{tm['f2']:.2f}**") L.append(f"- Confusion: TP={cm['tp']} FP={cm['fp']} FN={cm['fn']} TN={cm['tn']}") L.append(f"- Alerts raised: {tm['alerts_raised']} of {sum(cm.values())} accounts " f"({tm['alert_rate']:.1%}) — caught {cm['tp']}/{cm['tp']+cm['fn']} mules.\n") L.append("## Figures\n") L.append("![PR curve](figures/pr_curve.png)\n![ROC curve](figures/roc_curve.png)\n" "![Confusion](figures/confusion_matrix.png)\n![Risk distribution](figures/risk_distribution.png)\n") L.append("## Limitations\n") L.append("- Only 81 positives → metrics carry variance; we report CV mean ± std for honesty.") L.append("- Features are anonymized; reason codes describe direction/magnitude, not semantics.") L.append("- Productionization to live regulatory/cross-channel feeds is a roadmap item.\n") (config.REPORTS_DIR / "model_report.md").write_text("\n".join(L)) if __name__ == "__main__": main()