Joblib
File size: 5,412 Bytes
d06775d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import json
import argparse
import numpy as np
import pandas as pd
import xgboost as xgb
from scipy import stats
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve
from datasets import load_from_disk, DatasetDict

def best_f1_threshold(y_true, y_prob):
    p, r, thr = precision_recall_curve(y_true, y_prob)
    f1s = (2 * p[:-1] * r[:-1]) / (p[:-1] + r[:-1] + 1e-12)
    i = int(np.nanargmax(f1s))
    return float(thr[i]), float(f1s[i])


def bootstrap_ci(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    n_bootstrap: int = 2000,
    ci: float = 0.95,
    seed: int = 1986,
) -> dict:
    """
    Non-parametric bootstrap CI for F1 (at val-optimal threshold) and AUC.
    Resamples (y_true, y_prob) pairs
    """
    rng = np.random.default_rng(seed=seed)
    n = len(y_true)

    # Threshold picked on the full val set
    thr, _ = best_f1_threshold(y_true, y_prob)

    f1_scores, auc_scores = [], []

    for _ in range(n_bootstrap):
        idx = rng.integers(0, n, size=n)
        yt, yp = y_true[idx], y_prob[idx]

        # Skip degenerate bootstraps (only one class)
        if len(np.unique(yt)) < 2:
            continue

        f1_scores.append(f1_score(yt, (yp >= thr).astype(int), zero_division=0))
        auc_scores.append(roc_auc_score(yt, yp))

    alpha = 1 - ci
    lo, hi = alpha / 2, 1 - alpha / 2

    results = {}
    for name, arr in [("f1", f1_scores), ("auc", auc_scores)]:
        arr = np.array(arr)
        results[name] = {
            "mean":    float(arr.mean()),
            "std":     float(arr.std()),
            "ci_low":  float(np.quantile(arr, lo)),
            "ci_high": float(np.quantile(arr, hi)),
            "report":  f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
            "n_bootstrap": len(arr),
        }

    results["threshold_used"] = float(thr)
    results["n_samples"] = int(n)
    return results

def prob_margin_uncertainty(val_preds_df: pd.DataFrame) -> pd.DataFrame:
    """
    Uncertainty = distance from the decision boundary in probability space.

    |prob - 0.5| if = 0.0 means maximally uncertain, 0.5 means maximally confident.
    Normalized to [0, 1]: confidence = 2 * |prob - 0.5|
    This reflecting how far the model is from a coin-flip on given sequence.
    """
    df = val_preds_df.copy()
    df["uncertainty"] = 1 - 2 * (df["y_prob"] - 0.5).abs()   # 0=confident, 1=uncertain
    df["confidence"]  = 1 - df["uncertainty"]                # 0=uncertain, 1=confident
    return df

def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, "bootstrap_ci.json")
    with open(path, "w") as f:
        json.dump(ci_results, f, indent=2)

    print(f"\n=== Bootstrap 95% CI ({model_name}) ===")
    print(f"  F1  : {ci_results['f1']['report']}")
    print(f"  AUC : {ci_results['auc']['report']}")
    print(f"  (threshold={ci_results['threshold_used']:.4f}, "
          f"n_bootstrap={ci_results['f1']['n_bootstrap']}, "
          f"n_val={ci_results['n_samples']})")
    print(f"Saved to {path}")


def save_uncertainty_csv(df: pd.DataFrame, out_dir: str, fname: str = "val_uncertainty.csv"):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, fname)
    df.to_csv(path, index=False)
    print(f"\n=== Per-molecule uncertainty ===")
    print(f"  Mean uncertainty : {df['uncertainty'].mean():.4f}")
    print(f"  Mean confidence  : {df['confidence'].mean():.4f}")
    print(f"  Saved to {path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", choices=["ci", "uncertainty_xgb", "uncertainty_prob"],
                        required=True,
                        help=(
                            "ci               : bootstrap CI from val_predictions.csv (all models)\n"
                            "uncertainty_prob : margin uncertainty for SVM/ElasticNet/XGB"
                        ))
    parser.add_argument("--val_preds",    type=str, help="Path to val_predictions.csv")
    parser.add_argument("--model_path",   type=str, help="Path to best_model.json (XGB only)")
    parser.add_argument("--dataset_path", type=str, help="HuggingFace dataset path (XGB uncertainty only)")
    parser.add_argument("--out_dir",      type=str, required=True)
    parser.add_argument("--model_name",   type=str, default="", help="Label for report (xgb_smiles)")
    parser.add_argument("--n_bootstrap",  type=int, default=2000)
    args = parser.parse_args()

    if args.mode == "ci":
        assert args.val_preds, "--val_preds required for ci mode"
        df  = pd.read_csv(args.val_preds)
        ci  = bootstrap_ci(df["y_true"].values, df["y_prob"].values,
                           n_bootstrap=args.n_bootstrap)
        save_ci_report(ci, args.out_dir, args.model_name)
    elif args.mode == "uncertainty_prob":
        assert args.val_preds, "--val_preds required for uncertainty_prob"
        df_preds = pd.read_csv(args.val_preds)
        # CI
        ci = bootstrap_ci(df_preds["y_true"].values, df_preds["y_prob"].values,
                          n_bootstrap=args.n_bootstrap)
        save_ci_report(ci, args.out_dir, args.model_name)
        # Uncertainty from margin
        df_unc = prob_margin_uncertainty(df_preds)
        save_uncertainty_csv(df_unc, args.out_dir, "val_uncertainty_prob.csv")