Joblib
File size: 7,430 Bytes
d06775d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import json
import argparse
import numpy as np
import pandas as pd
import xgboost as xgb
from scipy.stats import spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from datasets import load_from_disk, DatasetDict

def safe_spearmanr(y_true, y_pred):
    rho = spearmanr(y_true, y_pred).correlation
    return 0.0 if (rho is None or np.isnan(rho)) else float(rho)

def eval_regression(y_true, y_pred):
    try:
        from sklearn.metrics import root_mean_squared_error
        rmse = float(root_mean_squared_error(y_true, y_pred))
    except Exception:
        rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    return {
        "spearman_rho": safe_spearmanr(y_true, y_pred),
        "rmse": rmse,
        "mae":  float(mean_absolute_error(y_true, y_pred)),
        "r2":   float(r2_score(y_true, y_pred)),
    }

# ======================== Bootstrap CI =========================================

def bootstrap_ci_reg(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    n_bootstrap: int = 2000,
    ci: float = 0.95,
    seed: int = 1986,
) -> dict:
    """
    Percentile bootstrap CI for regression metrics.
    Uses percentile method (not t-CI) because:
      - Spearman rho is bounded [-1, 1] - t-CI can produce impossible values near extremes
      - RMSE is strictly positive - symmetric t-CI is inappropriate near 0
      - Percentile bootstrap makes no distributional assumptions

    Fisher z-transform CI for rho is also computed as a cross-check.
    """
    rng = np.random.default_rng(seed=seed)
    n   = len(y_true)
    alpha = 1 - ci
    lo, hi = alpha / 2, 1 - alpha / 2

    boot_metrics = {k: [] for k in ["spearman_rho", "rmse", "mae", "r2"]}

    for _ in range(n_bootstrap):
        idx = rng.integers(0, n, size=n)
        yt, yp = y_true[idx], y_pred[idx]
        if len(np.unique(yt)) < 2:
            continue
        m = eval_regression(yt, yp)
        for k in boot_metrics:
            boot_metrics[k].append(m[k])

    results = {}
    for name, arr in boot_metrics.items():
        arr = np.array(arr)
        results[name] = {
            "mean":    float(arr.mean()),
            "std":     float(arr.std()),
            "ci_low":  float(np.quantile(arr, lo)),
            "ci_high": float(np.quantile(arr, hi)),
            "report":  f"{arr.mean():.4f} [{np.quantile(arr, lo):.4f}, {np.quantile(arr, hi):.4f}]",
            "n_bootstrap": len(arr),
        }

    # Fisher z-transform CI for Spearman rho (cross-check, more accurate near ±1)
    rho_vals = np.array(boot_metrics["spearman_rho"])
    rho_obs  = safe_spearmanr(y_true, y_pred)
    # z-transform: arctanh(rho), SE = 1/sqrt(n-3)
    z     = np.arctanh(np.clip(rho_obs, -0.9999, 0.9999))
    se_z  = 1.0 / np.sqrt(max(n - 3, 1))
    z_lo  = z - 1.96 * se_z
    z_hi  = z + 1.96 * se_z
    results["spearman_rho"]["fisher_z_ci"] = {
        "ci_low":  float(np.tanh(z_lo)),
        "ci_high": float(np.tanh(z_hi)),
        "report":  f"[{np.tanh(z_lo):.4f}, {np.tanh(z_hi):.4f}]",
        "note": "Fisher z-transform CI - more accurate when rho > 0.9",
    }

    results["n_samples"] = int(n)
    return results


def residual_uncertainty(val_preds_df: pd.DataFrame, coverage: float = 0.95) -> pd.DataFrame:
    """
      - Assume residuals ~ N(0, sigma) where sigma = std(residuals)
      - 95% prediction interval for molecule i: y_pred_i ± z * sigma
      - Uncertainty score = sigma (constant across all molecules for linear models)
      - Dataset-level uncertainty
    """
    df = val_preds_df.copy()

    residuals  = df["y_true"] - df["y_pred"]
    sigma      = float(residuals.std(ddof=1))
    z          = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}.get(coverage, 1.960)
    half_width = z * sigma

    df["pred_interval_low"]  = df["y_pred"] - half_width
    df["pred_interval_high"] = df["y_pred"] + half_width
    df["pred_interval_width"] = 2 * half_width   # constant for linear models
    df["abs_error"]           = residuals.abs()

    # what fraction of y_true actually falls inside the interval
    empirical_coverage = float(
        ((df["y_true"] >= df["pred_interval_low"]) &
         (df["y_true"] <= df["pred_interval_high"])).mean()
    )

    meta = {
        "residual_std":       round(sigma, 6),
        "interval_halfwidth": round(half_width, 6),
        f"nominal_coverage":  coverage,
        "empirical_coverage": round(empirical_coverage, 4),
        "note": (
            "Prediction interval assumes N(0, sigma) residuals."
            "Interval width is constant across molecules for linear models. "
        ),
    }
    return df, meta

def save_ci_report(ci_results: dict, out_dir: str, model_name: str = ""):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, "bootstrap_ci_reg.json")
    with open(path, "w") as f:
        json.dump(ci_results, f, indent=2)

    print(f"\n=== Bootstrap 95% CI - Regression ({model_name}) ===")
    for metric in ["spearman_rho", "rmse", "mae", "r2"]:
        r = ci_results[metric]
        print(f"  {metric:15s}: {r['report']}")
        if metric == "spearman_rho" and "fisher_z_ci" in r:
            fz = r["fisher_z_ci"]
            print(f"    Fisher z CI  : {fz['report']}  ← use this if rho > 0.9")
    print(f"  n_val={ci_results['n_samples']}, n_bootstrap={ci_results['spearman_rho']['n_bootstrap']}")
    print(f"Saved to {path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True,
                        choices=["ci", "uncertainty_residual"],
                        help=(
                            "ci                  : bootstrap CI from val_predictions.csv\n"
                            "uncertainty_residual: residual interval for ElasticNet/SVR"
                        ))
    parser.add_argument("--val_preds",    type=str, help="Path to val_predictions.csv")
    parser.add_argument("--out_dir",      type=str, required=True)
    parser.add_argument("--model_name",   type=str, default="")
    parser.add_argument("--n_bootstrap",  type=int, default=2000)
    args = parser.parse_args()

    if args.mode == "ci":
        assert args.val_preds, "--val_preds required"
        df = pd.read_csv(args.val_preds)
        ci = bootstrap_ci_reg(df["y_true"].values, df["y_pred"].values,
                              n_bootstrap=args.n_bootstrap)
        save_ci_report(ci, args.out_dir, args.model_name)
    elif args.mode == "uncertainty_residual":
        assert args.val_preds
        df_preds = pd.read_csv(args.val_preds)
        ci = bootstrap_ci_reg(df_preds["y_true"].values, df_preds["y_pred"].values,
                              n_bootstrap=args.n_bootstrap)
        save_ci_report(ci, args.out_dir, args.model_name)
        df_unc, meta = residual_uncertainty(df_preds)
        path = os.path.join(args.out_dir, "val_uncertainty_residual.csv")
        df_unc.to_csv(path, index=False)
        meta_path = os.path.join(args.out_dir, "residual_interval_meta.json")
        with open(meta_path, "w") as f:
            json.dump(meta, f, indent=2)
        print(f"\nResidual interval summary:")
        print(f"  Residual std       : {meta['residual_std']:.4f}")
        print(f"  95% interval ± {meta['interval_halfwidth']:.4f}")
        print(f"  Empirical coverage : {meta['empirical_coverage']:.4f}  (nominal={meta['nominal_coverage']})")
        print(f"  Saved to {path}")