Synav commited on
Commit
746ab4a
·
verified ·
1 Parent(s): 35419fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -16
app.py CHANGED
@@ -360,26 +360,44 @@ def train_and_save(
360
  # ----- METRICS BLOCK (MISSING) -----
361
  roc_auc = float(roc_auc_score(y_test, proba))
362
  fpr, tpr, roc_thresholds = roc_curve(y_test, proba)
363
- best_thr, best_val, cls = find_best_threshold(y_test, proba, metric="f1")
364
-
365
 
366
  metrics = {
367
  "roc_auc": roc_auc,
368
  "n_train": int(len(X_train)),
369
  "n_test": int(len(X_test)),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  "best_threshold_by": "f1",
371
  "best_threshold": float(best_thr),
372
- "accuracy@0.5": cls["accuracy"],
373
- "balanced_accuracy@0.5": cls["balanced_accuracy"],
374
- "precision@0.5": cls["precision"],
375
- "recall@0.5": cls["recall"],
376
- "f1@0.5": cls["f1"],
377
- "sensitivity@0.5": cls["sensitivity"],
378
- "specificity@0.5": cls["specificity"],
379
- "confusion_matrix@0.5": {
380
- "tn": cls["tn"], "fp": cls["fp"],
381
- "fn": cls["fn"], "tp": cls["tp"],
 
382
  },
 
383
  "roc_curve": {
384
  "fpr": [float(x) for x in fpr],
385
  "tpr": [float(x) for x in tpr],
@@ -389,6 +407,7 @@ def train_and_save(
389
  "calibration": compute_calibration(y_test, proba, n_bins, cal_strategy),
390
  "decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)),
391
  }
 
392
 
393
 
394
  joblib.dump(pipe, "model.joblib")
@@ -409,7 +428,7 @@ def train_and_save(
409
  "svd_components": int(svd_components) if use_dimred else None,
410
  "use_feature_selection": bool(use_feature_selection),
411
  "l1_C": float(l1_C) if use_feature_selection else None,
412
- "selection_method": "SelectFromModel(L1 saga, threshold=mean)" if use_feature_selection else None,
413
  "note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
414
  },
415
  "positive_class": str(pos_class),
@@ -868,9 +887,11 @@ with tab_train:
868
  # Show key metrics at threshold 0.5
869
  c1, c2, c3, c4 = st.columns(4)
870
  c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
871
- c2.metric("Sensitivity (Recall)", f"{m['sensitivity@0.5']:.3f}")
872
- c3.metric("Specificity", f"{m['specificity@0.5']:.3f}")
873
- c4.metric("F1", f"{m['f1@0.5']:.3f}")
 
 
874
 
875
  c5, c6, c7, c8 = st.columns(4)
876
  c5.metric("Precision", f"{m['precision@0.5']:.3f}")
 
360
  # ----- METRICS BLOCK (MISSING) -----
361
  roc_auc = float(roc_auc_score(y_test, proba))
362
  fpr, tpr, roc_thresholds = roc_curve(y_test, proba)
363
+ cls_05 = compute_classification_metrics(y_test, proba, threshold=0.5)
364
+ best_thr, best_val, cls_best = find_best_threshold(y_test, proba, metric="f1")
365
 
366
  metrics = {
367
  "roc_auc": roc_auc,
368
  "n_train": int(len(X_train)),
369
  "n_test": int(len(X_test)),
370
+
371
+ # reference @0.5
372
+ "threshold@0.5": 0.5,
373
+ "accuracy@0.5": cls_05["accuracy"],
374
+ "balanced_accuracy@0.5": cls_05["balanced_accuracy"],
375
+ "precision@0.5": cls_05["precision"],
376
+ "recall@0.5": cls_05["recall"],
377
+ "f1@0.5": cls_05["f1"],
378
+ "sensitivity@0.5": cls_05["sensitivity"],
379
+ "specificity@0.5": cls_05["specificity"],
380
+ "confusion_matrix@0.5": {
381
+ "tn": cls_05["tn"], "fp": cls_05["fp"],
382
+ "fn": cls_05["fn"], "tp": cls_05["tp"],
383
+ },
384
+
385
+ # primary: best F1 threshold
386
  "best_threshold_by": "f1",
387
  "best_threshold": float(best_thr),
388
+ "best_f1": float(cls_best["f1"]),
389
+ "accuracy@best": cls_best["accuracy"],
390
+ "balanced_accuracy@best": cls_best["balanced_accuracy"],
391
+ "precision@best": cls_best["precision"],
392
+ "recall@best": cls_best["recall"],
393
+ "f1@best": cls_best["f1"],
394
+ "sensitivity@best": cls_best["sensitivity"],
395
+ "specificity@best": cls_best["specificity"],
396
+ "confusion_matrix@best": {
397
+ "tn": cls_best["tn"], "fp": cls_best["fp"],
398
+ "fn": cls_best["fn"], "tp": cls_best["tp"],
399
  },
400
+
401
  "roc_curve": {
402
  "fpr": [float(x) for x in fpr],
403
  "tpr": [float(x) for x in tpr],
 
407
  "calibration": compute_calibration(y_test, proba, n_bins, cal_strategy),
408
  "decision_curve": decision_curve_analysis(y_test, proba, np.linspace(0.01, 0.99, dca_points)),
409
  }
410
+
411
 
412
 
413
  joblib.dump(pipe, "model.joblib")
 
428
  "svd_components": int(svd_components) if use_dimred else None,
429
  "use_feature_selection": bool(use_feature_selection),
430
  "l1_C": float(l1_C) if use_feature_selection else None,
431
+ "selection_method": "SelectFromModel(L1 saga, threshold=median)" if use_feature_selection else None,
432
  "note": "If SVD is enabled, SHAP becomes component-level (less interpretable)."
433
  },
434
  "positive_class": str(pos_class),
 
887
  # Show key metrics at threshold 0.5
888
  c1, c2, c3, c4 = st.columns(4)
889
  c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
890
+ c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
891
+ c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
892
+ c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
893
+ st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
894
+
895
 
896
  c5, c6, c7, c8 = st.columns(4)
897
  c5.metric("Precision", f"{m['precision@0.5']:.3f}")