Update app.py
Browse files
app.py
CHANGED
|
@@ -360,26 +360,44 @@ def train_and_save(
|
|
| 360 |
# ----- METRICS BLOCK (MISSING) -----
|
| 361 |
roc_auc = float(roc_auc_score(y_test, proba))
|
| 362 |
fpr, tpr, roc_thresholds = roc_curve(y_test, proba)
|
| 363 |
-
|
| 364 |
-
|
| 365 |
|
| 366 |
metrics = {
|
| 367 |
"roc_auc": roc_auc,
|
| 368 |
"n_train": int(len(X_train)),
|
| 369 |
"n_test": int(len(X_test)),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
"best_threshold_by": "f1",
|
| 371 |
"best_threshold": float(best_thr),
|
| 372 |
-
"
|
| 373 |
-
"
|
| 374 |
-
"
|
| 375 |
-
"
|
| 376 |
-
"
|
| 377 |
-
"
|
| 378 |
-
"
|
| 379 |
-
"
|
| 380 |
-
|
| 381 |
-
"
|
|
|
|
| 382 |
},
|
|
|
|
| 383 |
"roc_curve": {
|
| 384 |
"fpr": [float(x) for x in fpr],
|
| 385 |
"tpr": [float(x) for x in tpr],
|
|
@@ -389,6 +407,7 @@ def train_and_save(
|
|
| 389 |
"calibration": compute_calibration(y_test, proba, n_bins, cal_strategy),
|
| 390 |
"decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)),
|
| 391 |
}
|
|
|
|
| 392 |
|
| 393 |
|
| 394 |
joblib.dump(pipe, "model.joblib")
|
|
@@ -409,7 +428,7 @@ def train_and_save(
|
|
| 409 |
"svd_components": int(svd_components) if use_dimred else None,
|
| 410 |
"use_feature_selection": bool(use_feature_selection),
|
| 411 |
"l1_C": float(l1_C) if use_feature_selection else None,
|
| 412 |
-
"selection_method": "SelectFromModel(L1 saga, threshold=
|
| 413 |
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
|
| 414 |
},
|
| 415 |
"positive_class": str(pos_class),
|
|
@@ -868,9 +887,11 @@ with tab_train:
|
|
| 868 |
# Show key metrics at threshold 0.5
|
| 869 |
c1, c2, c3, c4 = st.columns(4)
|
| 870 |
c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
|
| 871 |
-
c2.metric("Sensitivity (
|
| 872 |
-
c3.metric("Specificity", f"{m['specificity@
|
| 873 |
-
c4.metric("F1", f"{m['f1@
|
|
|
|
|
|
|
| 874 |
|
| 875 |
c5, c6, c7, c8 = st.columns(4)
|
| 876 |
c5.metric("Precision", f"{m['precision@0.5']:.3f}")
|
|
|
|
| 360 |
# ----- METRICS BLOCK (MISSING) -----
|
| 361 |
roc_auc = float(roc_auc_score(y_test, proba))
|
| 362 |
fpr, tpr, roc_thresholds = roc_curve(y_test, proba)
|
| 363 |
+
cls_05 = compute_classification_metrics(y_test, proba, threshold=0.5)
|
| 364 |
+
best_thr, best_val, cls_best = find_best_threshold(y_test, proba, metric="f1")
|
| 365 |
|
| 366 |
metrics = {
|
| 367 |
"roc_auc": roc_auc,
|
| 368 |
"n_train": int(len(X_train)),
|
| 369 |
"n_test": int(len(X_test)),
|
| 370 |
+
|
| 371 |
+
# reference @0.5
|
| 372 |
+
"threshold@0.5": 0.5,
|
| 373 |
+
"accuracy@0.5": cls_05["accuracy"],
|
| 374 |
+
"balanced_accuracy@0.5": cls_05["balanced_accuracy"],
|
| 375 |
+
"precision@0.5": cls_05["precision"],
|
| 376 |
+
"recall@0.5": cls_05["recall"],
|
| 377 |
+
"f1@0.5": cls_05["f1"],
|
| 378 |
+
"sensitivity@0.5": cls_05["sensitivity"],
|
| 379 |
+
"specificity@0.5": cls_05["specificity"],
|
| 380 |
+
"confusion_matrix@0.5": {
|
| 381 |
+
"tn": cls_05["tn"], "fp": cls_05["fp"],
|
| 382 |
+
"fn": cls_05["fn"], "tp": cls_05["tp"],
|
| 383 |
+
},
|
| 384 |
+
|
| 385 |
+
# primary: best F1 threshold
|
| 386 |
"best_threshold_by": "f1",
|
| 387 |
"best_threshold": float(best_thr),
|
| 388 |
+
"best_f1": float(cls_best["f1"]),
|
| 389 |
+
"accuracy@best": cls_best["accuracy"],
|
| 390 |
+
"balanced_accuracy@best": cls_best["balanced_accuracy"],
|
| 391 |
+
"precision@best": cls_best["precision"],
|
| 392 |
+
"recall@best": cls_best["recall"],
|
| 393 |
+
"f1@best": cls_best["f1"],
|
| 394 |
+
"sensitivity@best": cls_best["sensitivity"],
|
| 395 |
+
"specificity@best": cls_best["specificity"],
|
| 396 |
+
"confusion_matrix@best": {
|
| 397 |
+
"tn": cls_best["tn"], "fp": cls_best["fp"],
|
| 398 |
+
"fn": cls_best["fn"], "tp": cls_best["tp"],
|
| 399 |
},
|
| 400 |
+
|
| 401 |
"roc_curve": {
|
| 402 |
"fpr": [float(x) for x in fpr],
|
| 403 |
"tpr": [float(x) for x in tpr],
|
|
|
|
| 407 |
"calibration": compute_calibration(y_test, proba, n_bins, cal_strategy),
|
| 408 |
"decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)),
|
| 409 |
}
|
| 410 |
+
|
| 411 |
|
| 412 |
|
| 413 |
joblib.dump(pipe, "model.joblib")
|
|
|
|
| 428 |
"svd_components": int(svd_components) if use_dimred else None,
|
| 429 |
"use_feature_selection": bool(use_feature_selection),
|
| 430 |
"l1_C": float(l1_C) if use_feature_selection else None,
|
| 431 |
+
"selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
|
| 432 |
"note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
|
| 433 |
},
|
| 434 |
"positive_class": str(pos_class),
|
|
|
|
| 887 |
# Show key metrics at threshold 0.5
|
| 888 |
c1, c2, c3, c4 = st.columns(4)
|
| 889 |
c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
|
| 890 |
+
c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
|
| 891 |
+
c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
|
| 892 |
+
c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
|
| 893 |
+
st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
|
| 894 |
+
|
| 895 |
|
| 896 |
c5, c6, c7, c8 = st.columns(4)
|
| 897 |
c5.metric("Precision", f"{m['precision@0.5']:.3f}")
|