ynuozhang commited on
Commit ·
3e669de
1
Parent(s): b90bb8d
clean up legacy _smiles folders, stray diagnostic files, and half_life non-best models
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py +0 -132
- training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py +0 -847
- training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py +0 -414
- training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash +0 -31
- training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py +0 -508
- training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py +0 -309
- training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt +0 -234
- training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py +0 -417
- training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py +0 -468
- training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py +0 -410
- training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py +0 -426
- training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py +0 -420
- training_classifiers/binding_affinity/val_smiles_pooled.csv +0 -3
- training_classifiers/binding_affinity/val_smiles_unpooled.csv +0 -3
- training_classifiers/binding_affinity/val_wt_pooled.csv +0 -3
- training_classifiers/binding_affinity/val_wt_unpooled.csv +0 -3
- training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt +0 -3
- training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt +0 -3
- training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv +0 -3
- training_classifiers/half_life/cnn_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/cnn_unpooled_peptideclm/best_model.pt +0 -3
- training_classifiers/half_life/cnn_unpooled_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/enet_gpu_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/enet_peptideclm/smiles_halflife_best_enet.joblib +0 -3
- training_classifiers/half_life/mlp_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/mlp_unpooled_peptideclm/best_model.pt +0 -3
- training_classifiers/half_life/mlp_unpooled_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/svr_gpu_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/svr_peptideclm/smiles_halflife_best_svr.joblib +0 -3
- training_classifiers/half_life/transformer_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/transformer_unpooled_peptideclm/best_model.pt +0 -3
- training_classifiers/half_life/transformer_unpooled_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/transformer_wt_log/oof_pred_vs_true.png +0 -0
- training_classifiers/half_life/transformer_wt_log/oof_predictions.csv +0 -3
- training_classifiers/half_life/transformer_wt_log/oof_residual_hist.png +0 -0
- training_classifiers/half_life/transformer_wt_log/oof_residual_vs_pred.png +0 -0
- training_classifiers/half_life/transformer_wt_log/optimization_summary.txt +0 -33
- training_classifiers/half_life/transformer_wt_log/study_trials.csv +0 -3
- training_classifiers/half_life/transformer_wt_raw/oof_pred_vs_true.png +0 -0
- training_classifiers/half_life/transformer_wt_raw/oof_predictions.csv +0 -3
- training_classifiers/half_life/transformer_wt_raw/oof_residual_hist.png +0 -0
- training_classifiers/half_life/transformer_wt_raw/oof_residual_vs_pred.png +0 -0
- training_classifiers/half_life/transformer_wt_raw/optimization_summary.txt +0 -33
- training_classifiers/half_life/transformer_wt_raw/study_trials.csv +0 -3
- training_classifiers/half_life/xgb_smiles/cv_oof_predictions.csv +0 -3
- training_classifiers/half_life/xgb_wt_log/oof_pred_vs_true.png +0 -0
- training_classifiers/half_life/xgb_wt_log/oof_predictions.csv +0 -3
- training_classifiers/half_life/xgb_wt_log/oof_residual_hist.png +0 -0
- training_classifiers/half_life/xgb_wt_log/oof_residual_vs_pred.png +0 -0
- training_classifiers/half_life/xgb_wt_log/optimization_summary.txt +0 -27
training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
extract_iptm_affinity_csv_all.py
|
| 4 |
-
|
| 5 |
-
Writes:
|
| 6 |
-
- out_dir/wt_iptm_affinity_all.csv
|
| 7 |
-
- out_dir/smiles_iptm_affinity_all.csv
|
| 8 |
-
|
| 9 |
-
Also prints:
|
| 10 |
-
- N
|
| 11 |
-
- Spearman rho (affinity vs iptm)
|
| 12 |
-
- Pearson r (affinity vs iptm)
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
import numpy as np
|
| 17 |
-
import pandas as pd
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def corr_stats(df: pd.DataFrame, x: str, y: str):
|
| 21 |
-
# pandas handles NaNs if we already dropped them; still be safe
|
| 22 |
-
xx = pd.to_numeric(df[x], errors="coerce")
|
| 23 |
-
yy = pd.to_numeric(df[y], errors="coerce")
|
| 24 |
-
m = xx.notna() & yy.notna()
|
| 25 |
-
xx = xx[m]
|
| 26 |
-
yy = yy[m]
|
| 27 |
-
n = int(m.sum())
|
| 28 |
-
|
| 29 |
-
# Pearson r
|
| 30 |
-
pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
|
| 31 |
-
# Spearman rho
|
| 32 |
-
spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
|
| 33 |
-
|
| 34 |
-
return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def clean_one(
|
| 38 |
-
in_csv: Path,
|
| 39 |
-
out_csv: Path,
|
| 40 |
-
iptm_col: str,
|
| 41 |
-
affinity_col: str = "affinity",
|
| 42 |
-
keep_cols=(),
|
| 43 |
-
):
|
| 44 |
-
df = pd.read_csv(in_csv)
|
| 45 |
-
|
| 46 |
-
# affinity + iptm must exist
|
| 47 |
-
need = [affinity_col, iptm_col]
|
| 48 |
-
missing = [c for c in need if c not in df.columns]
|
| 49 |
-
if missing:
|
| 50 |
-
raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
|
| 51 |
-
|
| 52 |
-
# coerce numeric
|
| 53 |
-
df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
|
| 54 |
-
df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
|
| 55 |
-
|
| 56 |
-
# drop NaNs in either
|
| 57 |
-
df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
|
| 58 |
-
|
| 59 |
-
# output cols (standardize names)
|
| 60 |
-
out = pd.DataFrame({
|
| 61 |
-
"affinity": df[affinity_col].astype(float),
|
| 62 |
-
"iptm": df[iptm_col].astype(float),
|
| 63 |
-
})
|
| 64 |
-
|
| 65 |
-
# keep split if present (handy for coloring later, but not used for corr)
|
| 66 |
-
if "split" in df.columns:
|
| 67 |
-
out.insert(0, "split", df["split"].astype(str))
|
| 68 |
-
|
| 69 |
-
# optional extras for labeling/debug
|
| 70 |
-
for c in keep_cols:
|
| 71 |
-
if c in df.columns:
|
| 72 |
-
out[c] = df[c]
|
| 73 |
-
|
| 74 |
-
out_csv.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
-
out.to_csv(out_csv, index=False)
|
| 76 |
-
|
| 77 |
-
stats = corr_stats(out, "iptm", "affinity")
|
| 78 |
-
print(f"[write] {out_csv}")
|
| 79 |
-
print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
|
| 80 |
-
|
| 81 |
-
# also save stats json next to csv
|
| 82 |
-
stats_path = out_csv.with_suffix(".stats.json")
|
| 83 |
-
with open(stats_path, "w") as f:
|
| 84 |
-
import json
|
| 85 |
-
json.dump(
|
| 86 |
-
{
|
| 87 |
-
"input_csv": str(in_csv),
|
| 88 |
-
"output_csv": str(out_csv),
|
| 89 |
-
"iptm_col": iptm_col,
|
| 90 |
-
"affinity_col": affinity_col,
|
| 91 |
-
**stats,
|
| 92 |
-
},
|
| 93 |
-
f,
|
| 94 |
-
indent=2,
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def main():
|
| 99 |
-
import argparse
|
| 100 |
-
ap = argparse.ArgumentParser()
|
| 101 |
-
ap.add_argument("--wt_meta_csv", type=str, required=True)
|
| 102 |
-
ap.add_argument("--smiles_meta_csv", type=str, required=True)
|
| 103 |
-
ap.add_argument("--out_dir", type=str, required=True)
|
| 104 |
-
|
| 105 |
-
ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
|
| 106 |
-
ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
|
| 107 |
-
ap.add_argument("--affinity_col", type=str, default="affinity")
|
| 108 |
-
args = ap.parse_args()
|
| 109 |
-
|
| 110 |
-
out_dir = Path(args.out_dir)
|
| 111 |
-
|
| 112 |
-
clean_one(
|
| 113 |
-
Path(args.wt_meta_csv),
|
| 114 |
-
out_dir / "wt_iptm_affinity_all.csv",
|
| 115 |
-
iptm_col=args.wt_iptm_col,
|
| 116 |
-
affinity_col=args.affinity_col,
|
| 117 |
-
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
clean_one(
|
| 121 |
-
Path(args.smiles_meta_csv),
|
| 122 |
-
out_dir / "smiles_iptm_affinity_all.csv",
|
| 123 |
-
iptm_col=args.smiles_iptm_col,
|
| 124 |
-
affinity_col=args.affinity_col,
|
| 125 |
-
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
if __name__ == "__main__":
|
| 132 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py
DELETED
|
@@ -1,847 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import os
|
| 3 |
-
import math
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
import sys
|
| 6 |
-
from contextlib import contextmanager
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pandas as pd
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
# tqdm is optional; we’ll disable it by default in notebooks
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
|
| 15 |
-
sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
|
| 16 |
-
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 17 |
-
|
| 18 |
-
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
|
| 19 |
-
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
|
| 20 |
-
|
| 21 |
-
# -------------------------
|
| 22 |
-
# Config
|
| 23 |
-
# -------------------------
|
| 24 |
-
CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
|
| 25 |
-
|
| 26 |
-
OUT_ROOT = Path(
|
| 27 |
-
"/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# WT (seq) embedding model
|
| 31 |
-
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
|
| 32 |
-
WT_MAX_LEN = 1022
|
| 33 |
-
WT_BATCH = 32
|
| 34 |
-
|
| 35 |
-
# SMILES embedding model + tokenizer
|
| 36 |
-
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
|
| 37 |
-
TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
|
| 38 |
-
TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
|
| 39 |
-
SMI_MAX_LEN = 768
|
| 40 |
-
SMI_BATCH = 128
|
| 41 |
-
|
| 42 |
-
# Split config
|
| 43 |
-
TRAIN_FRAC = 0.80
|
| 44 |
-
RANDOM_SEED = 1986
|
| 45 |
-
AFFINITY_Q_BINS = 30
|
| 46 |
-
|
| 47 |
-
# Columns expected in CSV
|
| 48 |
-
COL_SEQ1 = "seq1"
|
| 49 |
-
COL_SEQ2 = "seq2"
|
| 50 |
-
COL_AFF = "affinity"
|
| 51 |
-
COL_F2S = "Fasta2SMILES"
|
| 52 |
-
COL_REACT = "REACT_SMILES"
|
| 53 |
-
COL_WT_IPTM = "wt_iptm_score"
|
| 54 |
-
COL_SMI_IPTM = "smiles_iptm_score"
|
| 55 |
-
|
| 56 |
-
# Device
|
| 57 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 58 |
-
|
| 59 |
-
# -------------------------
|
| 60 |
-
# Quiet / notebook-safe output controls
|
| 61 |
-
# -------------------------
|
| 62 |
-
QUIET = True # suppress most prints
|
| 63 |
-
USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
|
| 64 |
-
LOG_FILE = None # optionally: OUT_ROOT / "build.log"
|
| 65 |
-
|
| 66 |
-
def log(msg: str):
|
| 67 |
-
if LOG_FILE is not None:
|
| 68 |
-
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
| 69 |
-
with open(LOG_FILE, "a") as f:
|
| 70 |
-
f.write(msg.rstrip() + "\n")
|
| 71 |
-
if not QUIET:
|
| 72 |
-
print(msg)
|
| 73 |
-
|
| 74 |
-
def pbar(it, **kwargs):
|
| 75 |
-
return tqdm(it, **kwargs) if USE_TQDM else it
|
| 76 |
-
|
| 77 |
-
@contextmanager
|
| 78 |
-
def section(title: str):
|
| 79 |
-
log(f"\n=== {title} ===")
|
| 80 |
-
yield
|
| 81 |
-
log(f"=== done: {title} ===")
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# -------------------------
|
| 85 |
-
# Helpers
|
| 86 |
-
# -------------------------
|
| 87 |
-
def has_uaa(seq: str) -> bool:
|
| 88 |
-
return "X" in str(seq).upper()
|
| 89 |
-
|
| 90 |
-
def affinity_to_class(a: float) -> str:
|
| 91 |
-
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
|
| 92 |
-
if a >= 9.0:
|
| 93 |
-
return "High"
|
| 94 |
-
elif a >= 7.0:
|
| 95 |
-
return "Moderate"
|
| 96 |
-
else:
|
| 97 |
-
return "Low"
|
| 98 |
-
|
| 99 |
-
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 100 |
-
df = df.copy()
|
| 101 |
-
|
| 102 |
-
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 103 |
-
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 104 |
-
|
| 105 |
-
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
|
| 109 |
-
strat_col = "aff_bin"
|
| 110 |
-
except Exception:
|
| 111 |
-
df["aff_bin"] = df["affinity_class"]
|
| 112 |
-
strat_col = "aff_bin"
|
| 113 |
-
|
| 114 |
-
rng = np.random.RandomState(RANDOM_SEED)
|
| 115 |
-
|
| 116 |
-
df["split"] = None
|
| 117 |
-
for _, g in df.groupby(strat_col, observed=True):
|
| 118 |
-
idx = g.index.to_numpy()
|
| 119 |
-
rng.shuffle(idx)
|
| 120 |
-
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 121 |
-
df.loc[idx[:n_train], "split"] = "train"
|
| 122 |
-
df.loc[idx[n_train:], "split"] = "val"
|
| 123 |
-
|
| 124 |
-
df["split"] = df["split"].fillna("train")
|
| 125 |
-
return df
|
| 126 |
-
|
| 127 |
-
def _summ(x):
|
| 128 |
-
x = np.asarray(x, dtype=float)
|
| 129 |
-
x = x[~np.isnan(x)]
|
| 130 |
-
if len(x) == 0:
|
| 131 |
-
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 132 |
-
return {
|
| 133 |
-
"n": int(len(x)),
|
| 134 |
-
"mean": float(np.mean(x)),
|
| 135 |
-
"std": float(np.std(x)),
|
| 136 |
-
"p50": float(np.quantile(x, 0.50)),
|
| 137 |
-
"p95": float(np.quantile(x, 0.95)),
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
def _len_stats(seqs):
|
| 141 |
-
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
|
| 142 |
-
if len(lens) == 0:
|
| 143 |
-
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 144 |
-
return {
|
| 145 |
-
"n": int(len(lens)),
|
| 146 |
-
"mean": float(lens.mean()),
|
| 147 |
-
"std": float(lens.std()),
|
| 148 |
-
"p50": float(np.quantile(lens, 0.50)),
|
| 149 |
-
"p95": float(np.quantile(lens, 0.95)),
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
def verify_split_before_embedding(
|
| 153 |
-
df2: pd.DataFrame,
|
| 154 |
-
affinity_col: str,
|
| 155 |
-
split_col: str,
|
| 156 |
-
seq_col: str,
|
| 157 |
-
iptm_col: str,
|
| 158 |
-
aff_class_col: str = "affinity_class",
|
| 159 |
-
aff_bins: int = 30,
|
| 160 |
-
save_report_prefix: str | None = None,
|
| 161 |
-
verbose: bool = False,
|
| 162 |
-
):
|
| 163 |
-
"""
|
| 164 |
-
Notebook-safe: by default prints only ONE line via `log()`.
|
| 165 |
-
Optionally writes CSV reports (stats + class proportions).
|
| 166 |
-
"""
|
| 167 |
-
df2 = df2.copy()
|
| 168 |
-
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
|
| 169 |
-
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
|
| 170 |
-
|
| 171 |
-
assert split_col in df2.columns, f"Missing split col: {split_col}"
|
| 172 |
-
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
|
| 173 |
-
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
|
| 174 |
-
|
| 175 |
-
try:
|
| 176 |
-
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
|
| 177 |
-
except Exception:
|
| 178 |
-
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
|
| 179 |
-
|
| 180 |
-
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
|
| 181 |
-
va = df2[df2[split_col] == "val"].reset_index(drop=True)
|
| 182 |
-
|
| 183 |
-
tr_aff = _summ(tr[affinity_col].to_numpy())
|
| 184 |
-
va_aff = _summ(va[affinity_col].to_numpy())
|
| 185 |
-
tr_len = _len_stats(tr[seq_col].tolist())
|
| 186 |
-
va_len = _len_stats(va[seq_col].tolist())
|
| 187 |
-
|
| 188 |
-
# bin drift
|
| 189 |
-
bin_ct = (
|
| 190 |
-
df2.groupby([split_col, "_aff_bin_dbg"])
|
| 191 |
-
.size()
|
| 192 |
-
.groupby(level=0)
|
| 193 |
-
.apply(lambda s: s / s.sum())
|
| 194 |
-
)
|
| 195 |
-
tr_bins = bin_ct.loc["train"]
|
| 196 |
-
va_bins = bin_ct.loc["val"]
|
| 197 |
-
all_bins = tr_bins.index.union(va_bins.index)
|
| 198 |
-
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
|
| 199 |
-
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
|
| 200 |
-
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
|
| 201 |
-
|
| 202 |
-
msg = (
|
| 203 |
-
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
|
| 204 |
-
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
|
| 205 |
-
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
|
| 206 |
-
f"max_bin_diff={max_bin_diff:.4f}"
|
| 207 |
-
)
|
| 208 |
-
log(msg)
|
| 209 |
-
|
| 210 |
-
if verbose and (not QUIET):
|
| 211 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 212 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
|
| 213 |
-
print("\n[verbose] affinity_class counts:\n", class_ct)
|
| 214 |
-
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
|
| 215 |
-
|
| 216 |
-
if save_report_prefix is not None:
|
| 217 |
-
out = Path(save_report_prefix)
|
| 218 |
-
out.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
-
|
| 220 |
-
stats_df = pd.DataFrame([
|
| 221 |
-
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
|
| 222 |
-
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
|
| 223 |
-
])
|
| 224 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 225 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
|
| 226 |
-
|
| 227 |
-
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
|
| 228 |
-
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# -------------------------
|
| 232 |
-
# WT pooled (ESM2)
|
| 233 |
-
# -------------------------
|
| 234 |
-
@torch.no_grad()
|
| 235 |
-
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
|
| 236 |
-
embs = []
|
| 237 |
-
for i in pbar(range(0, len(seqs), batch_size)):
|
| 238 |
-
batch = seqs[i:i + batch_size]
|
| 239 |
-
inputs = tokenizer(
|
| 240 |
-
batch,
|
| 241 |
-
padding=True,
|
| 242 |
-
truncation=True,
|
| 243 |
-
max_length=max_length,
|
| 244 |
-
return_tensors="pt",
|
| 245 |
-
)
|
| 246 |
-
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 247 |
-
out = model(**inputs)
|
| 248 |
-
h = out.last_hidden_state # (B, L, H)
|
| 249 |
-
|
| 250 |
-
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
|
| 251 |
-
summed = (h * attn).sum(dim=1) # (B, H)
|
| 252 |
-
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
|
| 253 |
-
pooled = (summed / denom).detach().cpu().numpy()
|
| 254 |
-
embs.append(pooled)
|
| 255 |
-
|
| 256 |
-
return np.vstack(embs)
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
# -------------------------
|
| 260 |
-
# WT unpooled (ESM2)
|
| 261 |
-
# -------------------------
|
| 262 |
-
@torch.no_grad()
|
| 263 |
-
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
|
| 264 |
-
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
|
| 265 |
-
tok = {k: v.to(DEVICE) for k, v in tok.items()}
|
| 266 |
-
out = model(**tok)
|
| 267 |
-
h = out.last_hidden_state[0] # (L, H)
|
| 268 |
-
attn = tok["attention_mask"][0].bool() # (L,)
|
| 269 |
-
ids = tok["input_ids"][0]
|
| 270 |
-
|
| 271 |
-
keep = attn.clone()
|
| 272 |
-
if cls_id is not None:
|
| 273 |
-
keep &= (ids != cls_id)
|
| 274 |
-
if eos_id is not None:
|
| 275 |
-
keep &= (ids != eos_id)
|
| 276 |
-
|
| 277 |
-
return h[keep].detach().cpu().to(torch.float16).numpy()
|
| 278 |
-
|
| 279 |
-
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
|
| 280 |
-
"""
|
| 281 |
-
Expects df_split to have:
|
| 282 |
-
- target_sequence (seq1)
|
| 283 |
-
- sequence (binder seq2; WT binder)
|
| 284 |
-
- label, affinity_class, COL_AFF, COL_WT_IPTM
|
| 285 |
-
Saves a dataset where each row contains BOTH:
|
| 286 |
-
- target_embedding (Lt,H), target_attention_mask, target_length
|
| 287 |
-
- binder_embedding (Lb,H), binder_attention_mask, binder_length
|
| 288 |
-
"""
|
| 289 |
-
cls_id = tokenizer.cls_token_id
|
| 290 |
-
eos_id = tokenizer.eos_token_id
|
| 291 |
-
H = model.config.hidden_size
|
| 292 |
-
|
| 293 |
-
features = Features({
|
| 294 |
-
"target_sequence": Value("string"),
|
| 295 |
-
"sequence": Value("string"),
|
| 296 |
-
"label": Value("float32"),
|
| 297 |
-
"affinity": Value("float32"),
|
| 298 |
-
"affinity_class": Value("string"),
|
| 299 |
-
|
| 300 |
-
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 301 |
-
"target_attention_mask": HFSequence(Value("int8")),
|
| 302 |
-
"target_length": Value("int64"),
|
| 303 |
-
|
| 304 |
-
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 305 |
-
"binder_attention_mask": HFSequence(Value("int8")),
|
| 306 |
-
"binder_length": Value("int64"),
|
| 307 |
-
|
| 308 |
-
COL_WT_IPTM: Value("float32"),
|
| 309 |
-
COL_AFF: Value("float32"),
|
| 310 |
-
})
|
| 311 |
-
|
| 312 |
-
def gen_rows(df: pd.DataFrame):
|
| 313 |
-
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 314 |
-
tgt = str(getattr(r, "target_sequence")).strip()
|
| 315 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 316 |
-
|
| 317 |
-
y = float(getattr(r, "label"))
|
| 318 |
-
aff = float(getattr(r, COL_AFF))
|
| 319 |
-
acls = str(getattr(r, "affinity_class"))
|
| 320 |
-
|
| 321 |
-
iptm = getattr(r, COL_WT_IPTM)
|
| 322 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 323 |
-
|
| 324 |
-
# token embeddings for target + binder (both ESM)
|
| 325 |
-
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
|
| 326 |
-
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
|
| 327 |
-
|
| 328 |
-
t_list = t_emb.tolist()
|
| 329 |
-
b_list = b_emb.tolist()
|
| 330 |
-
Lt = len(t_list)
|
| 331 |
-
Lb = len(b_list)
|
| 332 |
-
|
| 333 |
-
yield {
|
| 334 |
-
"target_sequence": tgt,
|
| 335 |
-
"sequence": bnd,
|
| 336 |
-
"label": np.float32(y),
|
| 337 |
-
"affinity": np.float32(aff),
|
| 338 |
-
"affinity_class": acls,
|
| 339 |
-
|
| 340 |
-
"target_embedding": t_list,
|
| 341 |
-
"target_attention_mask": [1] * Lt,
|
| 342 |
-
"target_length": int(Lt),
|
| 343 |
-
|
| 344 |
-
"binder_embedding": b_list,
|
| 345 |
-
"binder_attention_mask": [1] * Lb,
|
| 346 |
-
"binder_length": int(Lb),
|
| 347 |
-
|
| 348 |
-
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 349 |
-
COL_AFF: np.float32(aff),
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 353 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 354 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 355 |
-
return ds
|
| 356 |
-
|
| 357 |
-
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
|
| 358 |
-
smi_tok, smi_roformer):
|
| 359 |
-
"""
|
| 360 |
-
df_split must have:
|
| 361 |
-
- target_sequence (seq1)
|
| 362 |
-
- sequence (binder smiles string)
|
| 363 |
-
- label, affinity_class, COL_AFF, COL_SMI_IPTM
|
| 364 |
-
Saves rows with:
|
| 365 |
-
target_embedding (Lt,Ht) from ESM
|
| 366 |
-
binder_embedding (Lb,Hb) from PeptideCLM
|
| 367 |
-
"""
|
| 368 |
-
cls_id = wt_tokenizer.cls_token_id
|
| 369 |
-
eos_id = wt_tokenizer.eos_token_id
|
| 370 |
-
Ht = wt_model_unpooled.config.hidden_size
|
| 371 |
-
|
| 372 |
-
# Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
|
| 373 |
-
# Here: we’ll infer from model config if available.
|
| 374 |
-
Hb = getattr(smi_roformer.config, "hidden_size", None)
|
| 375 |
-
if Hb is None:
|
| 376 |
-
Hb = getattr(smi_roformer.config, "dim", None)
|
| 377 |
-
if Hb is None:
|
| 378 |
-
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
|
| 379 |
-
|
| 380 |
-
features = Features({
|
| 381 |
-
"target_sequence": Value("string"),
|
| 382 |
-
"sequence": Value("string"),
|
| 383 |
-
"label": Value("float32"),
|
| 384 |
-
"affinity": Value("float32"),
|
| 385 |
-
"affinity_class": Value("string"),
|
| 386 |
-
|
| 387 |
-
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
|
| 388 |
-
"target_attention_mask": HFSequence(Value("int8")),
|
| 389 |
-
"target_length": Value("int64"),
|
| 390 |
-
|
| 391 |
-
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
|
| 392 |
-
"binder_attention_mask": HFSequence(Value("int8")),
|
| 393 |
-
"binder_length": Value("int64"),
|
| 394 |
-
|
| 395 |
-
COL_SMI_IPTM: Value("float32"),
|
| 396 |
-
COL_AFF: Value("float32"),
|
| 397 |
-
})
|
| 398 |
-
|
| 399 |
-
def gen_rows(df: pd.DataFrame):
|
| 400 |
-
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 401 |
-
tgt = str(getattr(r, "target_sequence")).strip()
|
| 402 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 403 |
-
|
| 404 |
-
y = float(getattr(r, "label"))
|
| 405 |
-
aff = float(getattr(r, COL_AFF))
|
| 406 |
-
acls = str(getattr(r, "affinity_class"))
|
| 407 |
-
|
| 408 |
-
iptm = getattr(r, COL_SMI_IPTM)
|
| 409 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 410 |
-
|
| 411 |
-
# target token embeddings (ESM)
|
| 412 |
-
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
|
| 413 |
-
t_list = t_emb.tolist()
|
| 414 |
-
Lt = len(t_list)
|
| 415 |
-
|
| 416 |
-
# binder token embeddings (PeptideCLM) — single-item batch
|
| 417 |
-
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
|
| 418 |
-
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
|
| 419 |
-
)
|
| 420 |
-
b_emb = tok_list[0] # np.float16 (Lb, Hb)
|
| 421 |
-
b_list = b_emb.tolist()
|
| 422 |
-
Lb = int(lengths[0])
|
| 423 |
-
b_mask = mask_list[0].astype(np.int8).tolist()
|
| 424 |
-
|
| 425 |
-
yield {
|
| 426 |
-
"target_sequence": tgt,
|
| 427 |
-
"sequence": bnd,
|
| 428 |
-
"label": np.float32(y),
|
| 429 |
-
"affinity": np.float32(aff),
|
| 430 |
-
"affinity_class": acls,
|
| 431 |
-
|
| 432 |
-
"target_embedding": t_list,
|
| 433 |
-
"target_attention_mask": [1] * Lt,
|
| 434 |
-
"target_length": int(Lt),
|
| 435 |
-
|
| 436 |
-
"binder_embedding": b_list,
|
| 437 |
-
"binder_attention_mask": [int(x) for x in b_mask],
|
| 438 |
-
"binder_length": int(Lb),
|
| 439 |
-
|
| 440 |
-
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 441 |
-
COL_AFF: np.float32(aff),
|
| 442 |
-
}
|
| 443 |
-
|
| 444 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 445 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 446 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 447 |
-
return ds
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
# -------------------------
|
| 451 |
-
# SMILES pooled + unpooled (PeptideCLM)
|
| 452 |
-
# -------------------------
|
| 453 |
-
def get_special_ids(tokenizer_obj):
|
| 454 |
-
cand = [
|
| 455 |
-
getattr(tokenizer_obj, "pad_token_id", None),
|
| 456 |
-
getattr(tokenizer_obj, "cls_token_id", None),
|
| 457 |
-
getattr(tokenizer_obj, "sep_token_id", None),
|
| 458 |
-
getattr(tokenizer_obj, "bos_token_id", None),
|
| 459 |
-
getattr(tokenizer_obj, "eos_token_id", None),
|
| 460 |
-
getattr(tokenizer_obj, "mask_token_id", None),
|
| 461 |
-
]
|
| 462 |
-
return sorted({x for x in cand if x is not None})
|
| 463 |
-
|
| 464 |
-
@torch.no_grad()
|
| 465 |
-
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
|
| 466 |
-
tok = tokenizer_obj(
|
| 467 |
-
batch_sequences,
|
| 468 |
-
return_tensors="pt",
|
| 469 |
-
padding=True,
|
| 470 |
-
truncation=True,
|
| 471 |
-
max_length=max_length,
|
| 472 |
-
)
|
| 473 |
-
input_ids = tok["input_ids"].to(DEVICE)
|
| 474 |
-
attention_mask = tok["attention_mask"].to(DEVICE)
|
| 475 |
-
|
| 476 |
-
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
|
| 477 |
-
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 478 |
-
|
| 479 |
-
special_ids = get_special_ids(tokenizer_obj)
|
| 480 |
-
valid = attention_mask.bool()
|
| 481 |
-
if len(special_ids) > 0:
|
| 482 |
-
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
|
| 483 |
-
if hasattr(torch, "isin"):
|
| 484 |
-
valid = valid & (~torch.isin(input_ids, sid))
|
| 485 |
-
else:
|
| 486 |
-
m = torch.zeros_like(valid)
|
| 487 |
-
for s in special_ids:
|
| 488 |
-
m |= (input_ids == s)
|
| 489 |
-
valid = valid & (~m)
|
| 490 |
-
|
| 491 |
-
valid_f = valid.unsqueeze(-1).float()
|
| 492 |
-
summed = torch.sum(last_hidden * valid_f, dim=1)
|
| 493 |
-
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 494 |
-
pooled = (summed / denom).detach().cpu().numpy()
|
| 495 |
-
|
| 496 |
-
token_emb_list, mask_list, lengths = [], [], []
|
| 497 |
-
for b in range(last_hidden.shape[0]):
|
| 498 |
-
emb = last_hidden[b, valid[b]] # (Li, H)
|
| 499 |
-
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
|
| 500 |
-
li = emb.shape[0]
|
| 501 |
-
lengths.append(int(li))
|
| 502 |
-
mask_list.append(np.ones((li,), dtype=np.int8))
|
| 503 |
-
|
| 504 |
-
return pooled, token_emb_list, mask_list, lengths
|
| 505 |
-
|
| 506 |
-
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
|
| 507 |
-
pooled_all = []
|
| 508 |
-
token_emb_all = []
|
| 509 |
-
mask_all = []
|
| 510 |
-
lengths_all = []
|
| 511 |
-
|
| 512 |
-
for i in pbar(range(0, len(seqs), batch_size)):
|
| 513 |
-
batch = seqs[i:i + batch_size]
|
| 514 |
-
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
|
| 515 |
-
batch, tokenizer_obj, model_roformer, max_length
|
| 516 |
-
)
|
| 517 |
-
pooled_all.append(pooled)
|
| 518 |
-
token_emb_all.extend(tok_list)
|
| 519 |
-
mask_all.extend(m_list)
|
| 520 |
-
lengths_all.extend(lens)
|
| 521 |
-
|
| 522 |
-
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
|
| 523 |
-
|
| 524 |
-
# -------------------------
|
| 525 |
-
# Target embedding cache (NO extra ESM runs)
|
| 526 |
-
# We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
|
| 527 |
-
# -------------------------
|
| 528 |
-
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
|
| 529 |
-
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 530 |
-
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 531 |
-
|
| 532 |
-
# compute target pooled embeddings once
|
| 533 |
-
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
|
| 534 |
-
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
|
| 535 |
-
|
| 536 |
-
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 537 |
-
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 538 |
-
)
|
| 539 |
-
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 540 |
-
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
# build dict: target_sequence -> embedding (float32 array)
|
| 544 |
-
# if duplicates exist, last wins; you can add checks if needed
|
| 545 |
-
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
|
| 546 |
-
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
|
| 547 |
-
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
|
| 548 |
-
# -------------------------
|
| 549 |
-
# Main
|
| 550 |
-
# -------------------------
|
| 551 |
-
def main():
|
| 552 |
-
log(f"[INFO] DEVICE: {DEVICE}")
|
| 553 |
-
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 554 |
-
|
| 555 |
-
# 1) Load
|
| 556 |
-
with section("load csv + dedup"):
|
| 557 |
-
df = pd.read_csv(CSV_PATH)
|
| 558 |
-
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
|
| 559 |
-
if c in df.columns:
|
| 560 |
-
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 561 |
-
|
| 562 |
-
# Dedup on the full identity tuple you want
|
| 563 |
-
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
|
| 564 |
-
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
|
| 565 |
-
|
| 566 |
-
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
|
| 567 |
-
|
| 568 |
-
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
|
| 569 |
-
missing = [c for c in need if c not in df.columns]
|
| 570 |
-
if missing:
|
| 571 |
-
raise ValueError(f"Missing required columns: {missing}")
|
| 572 |
-
|
| 573 |
-
# numeric affinity for both branches
|
| 574 |
-
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 575 |
-
|
| 576 |
-
# 2) Build WT subset + SMILES subset separately (NO global dropping)
|
| 577 |
-
with section("prepare wt/smiles subsets"):
|
| 578 |
-
# WT: requires a canonical peptide sequence (no X) + affinity
|
| 579 |
-
df_wt = df.copy()
|
| 580 |
-
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 581 |
-
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 582 |
-
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
|
| 583 |
-
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
|
| 584 |
-
|
| 585 |
-
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
|
| 586 |
-
df_smi = df.copy()
|
| 587 |
-
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 588 |
-
df_smi = df_smi[
|
| 589 |
-
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 590 |
-
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
|
| 591 |
-
|
| 592 |
-
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
|
| 593 |
-
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
|
| 594 |
-
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
|
| 595 |
-
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
|
| 596 |
-
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
|
| 597 |
-
|
| 598 |
-
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
|
| 599 |
-
|
| 600 |
-
# 3) Split separately (different sizes and memberships are expected)
|
| 601 |
-
with section("split wt and smiles separately"):
|
| 602 |
-
df_wt2 = make_distribution_matched_split(df_wt)
|
| 603 |
-
df_smi2 = make_distribution_matched_split(df_smi)
|
| 604 |
-
|
| 605 |
-
# save split tables
|
| 606 |
-
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
|
| 607 |
-
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
|
| 608 |
-
df_wt2.to_csv(wt_split_csv, index=False)
|
| 609 |
-
df_smi2.to_csv(smi_split_csv, index=False)
|
| 610 |
-
log(f"Saved WT split meta: {wt_split_csv}")
|
| 611 |
-
log(f"Saved SMILES split meta: {smi_split_csv}")
|
| 612 |
-
|
| 613 |
-
# lightweight double-check (one-line)
|
| 614 |
-
verify_split_before_embedding(
|
| 615 |
-
df2=df_wt2,
|
| 616 |
-
affinity_col=COL_AFF,
|
| 617 |
-
split_col="split",
|
| 618 |
-
seq_col="wt_sequence",
|
| 619 |
-
iptm_col=COL_WT_IPTM,
|
| 620 |
-
aff_class_col="affinity_class",
|
| 621 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 622 |
-
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
|
| 623 |
-
verbose=False,
|
| 624 |
-
)
|
| 625 |
-
verify_split_before_embedding(
|
| 626 |
-
df2=df_smi2,
|
| 627 |
-
affinity_col=COL_AFF,
|
| 628 |
-
split_col="split",
|
| 629 |
-
seq_col="smiles_sequence",
|
| 630 |
-
iptm_col=COL_SMI_IPTM,
|
| 631 |
-
aff_class_col="affinity_class",
|
| 632 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 633 |
-
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
|
| 634 |
-
verbose=False,
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
# Prepare split views
|
| 638 |
-
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
|
| 639 |
-
out = df_in.copy()
|
| 640 |
-
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
|
| 641 |
-
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
|
| 642 |
-
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 643 |
-
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
|
| 644 |
-
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 645 |
-
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 646 |
-
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
|
| 647 |
-
|
| 648 |
-
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 649 |
-
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 650 |
-
|
| 651 |
-
# -------------------------
|
| 652 |
-
# Split views
|
| 653 |
-
# -------------------------
|
| 654 |
-
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 655 |
-
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
| 656 |
-
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 657 |
-
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
# =========================
|
| 661 |
-
# TARGET pooled embeddings (ESM) — SEPARATE per branch
|
| 662 |
-
# =========================
|
| 663 |
-
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
|
| 664 |
-
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 665 |
-
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 666 |
-
|
| 667 |
-
# ---- WT targets ----
|
| 668 |
-
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 669 |
-
wt_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 670 |
-
wt_tok, wt_esm,
|
| 671 |
-
batch_size=WT_BATCH,
|
| 672 |
-
max_length=WT_MAX_LEN,
|
| 673 |
-
).astype(np.float32)
|
| 674 |
-
|
| 675 |
-
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 676 |
-
wt_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 677 |
-
wt_tok, wt_esm,
|
| 678 |
-
batch_size=WT_BATCH,
|
| 679 |
-
max_length=WT_MAX_LEN,
|
| 680 |
-
).astype(np.float32)
|
| 681 |
-
|
| 682 |
-
# ---- SMILES targets (independent; may include UAA-only targets) ----
|
| 683 |
-
smi_train_tgt_emb = wt_pooled_embeddings(
|
| 684 |
-
smi_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 685 |
-
wt_tok, wt_esm,
|
| 686 |
-
batch_size=WT_BATCH,
|
| 687 |
-
max_length=WT_MAX_LEN,
|
| 688 |
-
).astype(np.float32)
|
| 689 |
-
|
| 690 |
-
smi_val_tgt_emb = wt_pooled_embeddings(
|
| 691 |
-
smi_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 692 |
-
wt_tok, wt_esm,
|
| 693 |
-
batch_size=WT_BATCH,
|
| 694 |
-
max_length=WT_MAX_LEN,
|
| 695 |
-
).astype(np.float32)
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
# =========================
|
| 699 |
-
# WT pooled binder embeddings (binder = WT peptide)
|
| 700 |
-
# =========================
|
| 701 |
-
with section("WT pooled binder embeddings + save"):
|
| 702 |
-
wt_train_emb = wt_pooled_embeddings(
|
| 703 |
-
wt_train["sequence"].astype(str).str.strip().tolist(),
|
| 704 |
-
wt_tok, wt_esm,
|
| 705 |
-
batch_size=WT_BATCH,
|
| 706 |
-
max_length=WT_MAX_LEN,
|
| 707 |
-
).astype(np.float32)
|
| 708 |
-
|
| 709 |
-
wt_val_emb = wt_pooled_embeddings(
|
| 710 |
-
wt_val["sequence"].astype(str).str.strip().tolist(),
|
| 711 |
-
wt_tok, wt_esm,
|
| 712 |
-
batch_size=WT_BATCH,
|
| 713 |
-
max_length=WT_MAX_LEN,
|
| 714 |
-
).astype(np.float32)
|
| 715 |
-
|
| 716 |
-
wt_train_ds = Dataset.from_dict({
|
| 717 |
-
"target_sequence": wt_train["target_sequence"].tolist(),
|
| 718 |
-
"sequence": wt_train["sequence"].tolist(),
|
| 719 |
-
"label": wt_train["label"].astype(float).tolist(),
|
| 720 |
-
"target_embedding": wt_train_tgt_emb,
|
| 721 |
-
"embedding": wt_train_emb,
|
| 722 |
-
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
|
| 723 |
-
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
|
| 724 |
-
"affinity_class": wt_train["affinity_class"].tolist(),
|
| 725 |
-
})
|
| 726 |
-
|
| 727 |
-
wt_val_ds = Dataset.from_dict({
|
| 728 |
-
"target_sequence": wt_val["target_sequence"].tolist(),
|
| 729 |
-
"sequence": wt_val["sequence"].tolist(),
|
| 730 |
-
"label": wt_val["label"].astype(float).tolist(),
|
| 731 |
-
"target_embedding": wt_val_tgt_emb,
|
| 732 |
-
"embedding": wt_val_emb,
|
| 733 |
-
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
|
| 734 |
-
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
|
| 735 |
-
"affinity_class": wt_val["affinity_class"].tolist(),
|
| 736 |
-
})
|
| 737 |
-
|
| 738 |
-
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
|
| 739 |
-
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 740 |
-
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
|
| 741 |
-
log(f"Saved WT pooled -> {wt_pooled_out}")
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
# =========================
|
| 745 |
-
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
|
| 746 |
-
# =========================
|
| 747 |
-
with section("SMILES pooled binder embeddings + save"):
|
| 748 |
-
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 749 |
-
smi_roformer = (
|
| 750 |
-
AutoModelForMaskedLM
|
| 751 |
-
.from_pretrained(SMI_MODEL_NAME)
|
| 752 |
-
.roformer
|
| 753 |
-
.to(DEVICE)
|
| 754 |
-
.eval()
|
| 755 |
-
)
|
| 756 |
-
|
| 757 |
-
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 758 |
-
smi_train["sequence"].astype(str).str.strip().tolist(),
|
| 759 |
-
smi_tok, smi_roformer,
|
| 760 |
-
batch_size=SMI_BATCH,
|
| 761 |
-
max_length=SMI_MAX_LEN,
|
| 762 |
-
)
|
| 763 |
-
|
| 764 |
-
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 765 |
-
smi_val["sequence"].astype(str).str.strip().tolist(),
|
| 766 |
-
smi_tok, smi_roformer,
|
| 767 |
-
batch_size=SMI_BATCH,
|
| 768 |
-
max_length=SMI_MAX_LEN,
|
| 769 |
-
)
|
| 770 |
-
|
| 771 |
-
smi_train_ds = Dataset.from_dict({
|
| 772 |
-
"target_sequence": smi_train["target_sequence"].tolist(),
|
| 773 |
-
"sequence": smi_train["sequence"].tolist(),
|
| 774 |
-
"label": smi_train["label"].astype(float).tolist(),
|
| 775 |
-
"target_embedding": smi_train_tgt_emb,
|
| 776 |
-
"embedding": smi_train_pooled.astype(np.float32),
|
| 777 |
-
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
|
| 778 |
-
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
|
| 779 |
-
"affinity_class": smi_train["affinity_class"].tolist(),
|
| 780 |
-
})
|
| 781 |
-
|
| 782 |
-
smi_val_ds = Dataset.from_dict({
|
| 783 |
-
"target_sequence": smi_val["target_sequence"].tolist(),
|
| 784 |
-
"sequence": smi_val["sequence"].tolist(),
|
| 785 |
-
"label": smi_val["label"].astype(float).tolist(),
|
| 786 |
-
"target_embedding": smi_val_tgt_emb,
|
| 787 |
-
"embedding": smi_val_pooled.astype(np.float32),
|
| 788 |
-
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
|
| 789 |
-
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
|
| 790 |
-
"affinity_class": smi_val["affinity_class"].tolist(),
|
| 791 |
-
})
|
| 792 |
-
|
| 793 |
-
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
|
| 794 |
-
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
|
| 795 |
-
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
|
| 796 |
-
log(f"Saved SMILES pooled -> {smi_pooled_out}")
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
# =========================
|
| 800 |
-
# WT unpooled paired (ESM target + ESM binder) + save
|
| 801 |
-
# =========================
|
| 802 |
-
with section("WT unpooled paired embeddings + save"):
|
| 803 |
-
wt_tok_unpooled = wt_tok # reuse tokenizer
|
| 804 |
-
wt_esm_unpooled = wt_esm # reuse model
|
| 805 |
-
|
| 806 |
-
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 807 |
-
wt_unpooled_dd = DatasetDict({
|
| 808 |
-
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
|
| 809 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 810 |
-
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
|
| 811 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 812 |
-
})
|
| 813 |
-
# (Optional) also save as DatasetDict root if you want a single load_from_disk path:
|
| 814 |
-
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
|
| 815 |
-
log(f"Saved WT unpooled -> {wt_unpooled_out}")
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
# =========================
|
| 819 |
-
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
|
| 820 |
-
# =========================
|
| 821 |
-
with section("SMILES unpooled paired embeddings + save"):
|
| 822 |
-
# reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
|
| 823 |
-
# otherwise re-init here:
|
| 824 |
-
# smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 825 |
-
# smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
|
| 826 |
-
|
| 827 |
-
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
|
| 828 |
-
smi_unpooled_dd = DatasetDict({
|
| 829 |
-
"train": build_smiles_unpooled_paired_dataset(
|
| 830 |
-
smi_train, smi_unpooled_out / "train",
|
| 831 |
-
wt_tok, wt_esm,
|
| 832 |
-
smi_tok, smi_roformer
|
| 833 |
-
),
|
| 834 |
-
"val": build_smiles_unpooled_paired_dataset(
|
| 835 |
-
smi_val, smi_unpooled_out / "val",
|
| 836 |
-
wt_tok, wt_esm,
|
| 837 |
-
smi_tok, smi_roformer
|
| 838 |
-
),
|
| 839 |
-
})
|
| 840 |
-
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
|
| 841 |
-
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
|
| 842 |
-
|
| 843 |
-
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
if __name__ == "__main__":
|
| 847 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py
DELETED
|
@@ -1,414 +0,0 @@
|
|
| 1 |
-
import os, json
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
-
from torch.utils.data import DataLoader
|
| 7 |
-
import optuna
|
| 8 |
-
from datasets import load_from_disk, DatasetDict
|
| 9 |
-
from scipy.stats import spearmanr
|
| 10 |
-
from lightning.pytorch import seed_everything
|
| 11 |
-
seed_everything(1986)
|
| 12 |
-
|
| 13 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 17 |
-
rho = spearmanr(y_true, y_pred).correlation
|
| 18 |
-
if rho is None or np.isnan(rho):
|
| 19 |
-
return 0.0
|
| 20 |
-
return float(rho)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# -----------------------------
|
| 24 |
-
# Affinity class thresholds (final spec)
|
| 25 |
-
# High >= 9 ; Moderate 7-9 ; Low < 7
|
| 26 |
-
# 0=High, 1=Moderate, 2=Low
|
| 27 |
-
# -----------------------------
|
| 28 |
-
def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
|
| 29 |
-
high = y >= 9.0
|
| 30 |
-
low = y < 7.0
|
| 31 |
-
mid = ~(high | low)
|
| 32 |
-
cls = torch.zeros_like(y, dtype=torch.long)
|
| 33 |
-
cls[mid] = 1
|
| 34 |
-
cls[low] = 2
|
| 35 |
-
return cls
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# -----------------------------
|
| 39 |
-
# Load paired DatasetDict
|
| 40 |
-
# -----------------------------
|
| 41 |
-
def load_split_paired(path: str):
|
| 42 |
-
dd = load_from_disk(path)
|
| 43 |
-
if not isinstance(dd, DatasetDict):
|
| 44 |
-
raise ValueError(f"Expected DatasetDict at {path}")
|
| 45 |
-
if "train" not in dd or "val" not in dd:
|
| 46 |
-
raise ValueError(f"DatasetDict missing train/val at {path}")
|
| 47 |
-
return dd["train"], dd["val"]
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# -----------------------------
|
| 51 |
-
# Collate: pooled paired
|
| 52 |
-
# -----------------------------
|
| 53 |
-
def collate_pair_pooled(batch):
|
| 54 |
-
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
|
| 55 |
-
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
|
| 56 |
-
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 57 |
-
return Pt, Pb, y
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# -----------------------------
|
| 61 |
-
# Collate: unpooled paired
|
| 62 |
-
# -----------------------------
|
| 63 |
-
def collate_pair_unpooled(batch):
|
| 64 |
-
B = len(batch)
|
| 65 |
-
Ht = len(batch[0]["target_embedding"][0])
|
| 66 |
-
Hb = len(batch[0]["binder_embedding"][0])
|
| 67 |
-
Lt_max = max(int(x["target_length"]) for x in batch)
|
| 68 |
-
Lb_max = max(int(x["binder_length"]) for x in batch)
|
| 69 |
-
|
| 70 |
-
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
|
| 71 |
-
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
|
| 72 |
-
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
|
| 73 |
-
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
|
| 74 |
-
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 75 |
-
|
| 76 |
-
for i, x in enumerate(batch):
|
| 77 |
-
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
|
| 78 |
-
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
|
| 79 |
-
lt, lb = t.shape[0], b.shape[0]
|
| 80 |
-
Pt[i, :lt] = t
|
| 81 |
-
Pb[i, :lb] = b
|
| 82 |
-
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
|
| 83 |
-
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
|
| 84 |
-
|
| 85 |
-
return Pt, Mt, Pb, Mb, y
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# -----------------------------
|
| 89 |
-
# Cross-attention models
|
| 90 |
-
# -----------------------------
|
| 91 |
-
class CrossAttnPooled(nn.Module):
|
| 92 |
-
"""
|
| 93 |
-
pooled vectors -> treat as single-token sequences for cross attention
|
| 94 |
-
"""
|
| 95 |
-
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 98 |
-
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 99 |
-
|
| 100 |
-
self.layers = nn.ModuleList([])
|
| 101 |
-
for _ in range(n_layers):
|
| 102 |
-
self.layers.append(nn.ModuleDict({
|
| 103 |
-
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 104 |
-
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 105 |
-
"n1t": nn.LayerNorm(hidden),
|
| 106 |
-
"n2t": nn.LayerNorm(hidden),
|
| 107 |
-
"n1b": nn.LayerNorm(hidden),
|
| 108 |
-
"n2b": nn.LayerNorm(hidden),
|
| 109 |
-
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 110 |
-
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 111 |
-
}))
|
| 112 |
-
|
| 113 |
-
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 114 |
-
self.reg = nn.Linear(hidden, 1)
|
| 115 |
-
self.cls = nn.Linear(hidden, 3)
|
| 116 |
-
|
| 117 |
-
def forward(self, t_vec, b_vec):
|
| 118 |
-
# (B,Ht),(B,Hb)
|
| 119 |
-
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
|
| 120 |
-
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
|
| 121 |
-
|
| 122 |
-
for L in self.layers:
|
| 123 |
-
t_attn, _ = L["attn_tb"](t, b, b)
|
| 124 |
-
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 125 |
-
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 126 |
-
|
| 127 |
-
b_attn, _ = L["attn_bt"](b, t, t)
|
| 128 |
-
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 129 |
-
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 130 |
-
|
| 131 |
-
t0 = t[0]
|
| 132 |
-
b0 = b[0]
|
| 133 |
-
z = torch.cat([t0, b0], dim=-1)
|
| 134 |
-
h = self.shared(z)
|
| 135 |
-
return self.reg(h).squeeze(-1), self.cls(h)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class CrossAttnUnpooled(nn.Module):
|
| 139 |
-
"""
|
| 140 |
-
token sequences with masks; alternating cross attention.
|
| 141 |
-
"""
|
| 142 |
-
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 143 |
-
super().__init__()
|
| 144 |
-
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 145 |
-
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 146 |
-
|
| 147 |
-
self.layers = nn.ModuleList([])
|
| 148 |
-
for _ in range(n_layers):
|
| 149 |
-
self.layers.append(nn.ModuleDict({
|
| 150 |
-
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 151 |
-
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 152 |
-
"n1t": nn.LayerNorm(hidden),
|
| 153 |
-
"n2t": nn.LayerNorm(hidden),
|
| 154 |
-
"n1b": nn.LayerNorm(hidden),
|
| 155 |
-
"n2b": nn.LayerNorm(hidden),
|
| 156 |
-
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 157 |
-
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 158 |
-
}))
|
| 159 |
-
|
| 160 |
-
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 161 |
-
self.reg = nn.Linear(hidden, 1)
|
| 162 |
-
self.cls = nn.Linear(hidden, 3)
|
| 163 |
-
|
| 164 |
-
def masked_mean(self, X, M):
|
| 165 |
-
Mf = M.unsqueeze(-1).float()
|
| 166 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 167 |
-
return (X * Mf).sum(dim=1) / denom
|
| 168 |
-
|
| 169 |
-
def forward(self, T, Mt, B, Mb):
|
| 170 |
-
# T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
|
| 171 |
-
T = self.t_proj(T)
|
| 172 |
-
Bx = self.b_proj(B)
|
| 173 |
-
|
| 174 |
-
kp_t = ~Mt # key_padding_mask True = pad
|
| 175 |
-
kp_b = ~Mb
|
| 176 |
-
|
| 177 |
-
for L in self.layers:
|
| 178 |
-
# T attends to B
|
| 179 |
-
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 180 |
-
T = L["n1t"](T + T_attn)
|
| 181 |
-
T = L["n2t"](T + L["fft"](T))
|
| 182 |
-
|
| 183 |
-
# B attends to T
|
| 184 |
-
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 185 |
-
Bx = L["n1b"](Bx + B_attn)
|
| 186 |
-
Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 187 |
-
|
| 188 |
-
t_pool = self.masked_mean(T, Mt)
|
| 189 |
-
b_pool = self.masked_mean(Bx, Mb)
|
| 190 |
-
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 191 |
-
h = self.shared(z)
|
| 192 |
-
return self.reg(h).squeeze(-1), self.cls(h)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# -----------------------------
|
| 196 |
-
# Train/eval
|
| 197 |
-
# -----------------------------
|
| 198 |
-
@torch.no_grad()
|
| 199 |
-
def eval_spearman_pooled(model, loader):
|
| 200 |
-
model.eval()
|
| 201 |
-
ys, ps = [], []
|
| 202 |
-
for t, b, y in loader:
|
| 203 |
-
t = t.to(DEVICE, non_blocking=True)
|
| 204 |
-
b = b.to(DEVICE, non_blocking=True)
|
| 205 |
-
pred, _ = model(t, b)
|
| 206 |
-
ys.append(y.numpy())
|
| 207 |
-
ps.append(pred.detach().cpu().numpy())
|
| 208 |
-
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 209 |
-
|
| 210 |
-
@torch.no_grad()
|
| 211 |
-
def eval_spearman_unpooled(model, loader):
|
| 212 |
-
model.eval()
|
| 213 |
-
ys, ps = [], []
|
| 214 |
-
for T, Mt, B, Mb, y in loader:
|
| 215 |
-
T = T.to(DEVICE, non_blocking=True)
|
| 216 |
-
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 217 |
-
B = B.to(DEVICE, non_blocking=True)
|
| 218 |
-
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 219 |
-
pred, _ = model(T, Mt, B, Mb)
|
| 220 |
-
ys.append(y.numpy())
|
| 221 |
-
ps.append(pred.detach().cpu().numpy())
|
| 222 |
-
return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
|
| 223 |
-
|
| 224 |
-
def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 225 |
-
model.train()
|
| 226 |
-
for t, b, y in loader:
|
| 227 |
-
t = t.to(DEVICE, non_blocking=True)
|
| 228 |
-
b = b.to(DEVICE, non_blocking=True)
|
| 229 |
-
y = y.to(DEVICE, non_blocking=True)
|
| 230 |
-
y_cls = affinity_to_class_tensor(y)
|
| 231 |
-
|
| 232 |
-
opt.zero_grad(set_to_none=True)
|
| 233 |
-
pred, logits = model(t, b)
|
| 234 |
-
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 235 |
-
L.backward()
|
| 236 |
-
if clip is not None:
|
| 237 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 238 |
-
opt.step()
|
| 239 |
-
|
| 240 |
-
def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
|
| 241 |
-
model.train()
|
| 242 |
-
for T, Mt, B, Mb, y in loader:
|
| 243 |
-
T = T.to(DEVICE, non_blocking=True)
|
| 244 |
-
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 245 |
-
B = B.to(DEVICE, non_blocking=True)
|
| 246 |
-
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 247 |
-
y = y.to(DEVICE, non_blocking=True)
|
| 248 |
-
y_cls = affinity_to_class_tensor(y)
|
| 249 |
-
|
| 250 |
-
opt.zero_grad(set_to_none=True)
|
| 251 |
-
pred, logits = model(T, Mt, B, Mb)
|
| 252 |
-
L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
|
| 253 |
-
L.backward()
|
| 254 |
-
if clip is not None:
|
| 255 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
|
| 256 |
-
opt.step()
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
# -----------------------------
|
| 260 |
-
# Optuna objective
|
| 261 |
-
# -----------------------------
|
| 262 |
-
def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
|
| 263 |
-
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 264 |
-
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
|
| 265 |
-
dropout = trial.suggest_float("dropout", 0.0, 0.4)
|
| 266 |
-
hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
|
| 267 |
-
n_heads = trial.suggest_categorical("n_heads", [4, 8])
|
| 268 |
-
n_layers = trial.suggest_int("n_layers", 1, 4)
|
| 269 |
-
cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
|
| 270 |
-
batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
|
| 271 |
-
|
| 272 |
-
# infer dims from first row
|
| 273 |
-
if mode == "pooled":
|
| 274 |
-
Ht = len(train_ds[0]["target_embedding"])
|
| 275 |
-
Hb = len(train_ds[0]["binder_embedding"])
|
| 276 |
-
collate = collate_pair_pooled
|
| 277 |
-
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 278 |
-
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 279 |
-
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 280 |
-
eval_fn = eval_spearman_pooled
|
| 281 |
-
train_fn = train_one_epoch_pooled
|
| 282 |
-
|
| 283 |
-
else:
|
| 284 |
-
Ht = len(train_ds[0]["target_embedding"][0])
|
| 285 |
-
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 286 |
-
collate = collate_pair_unpooled
|
| 287 |
-
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 288 |
-
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 289 |
-
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 290 |
-
eval_fn = eval_spearman_unpooled
|
| 291 |
-
train_fn = train_one_epoch_unpooled
|
| 292 |
-
|
| 293 |
-
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 294 |
-
loss_reg = nn.MSELoss()
|
| 295 |
-
loss_cls = nn.CrossEntropyLoss()
|
| 296 |
-
|
| 297 |
-
best = -1e9
|
| 298 |
-
bad = 0
|
| 299 |
-
patience = 10
|
| 300 |
-
|
| 301 |
-
for ep in range(1, 61):
|
| 302 |
-
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 303 |
-
rho = eval_fn(model, val_loader)
|
| 304 |
-
|
| 305 |
-
trial.report(rho, ep)
|
| 306 |
-
if trial.should_prune():
|
| 307 |
-
raise optuna.TrialPruned()
|
| 308 |
-
|
| 309 |
-
if rho > best + 1e-6:
|
| 310 |
-
best = rho
|
| 311 |
-
bad = 0
|
| 312 |
-
else:
|
| 313 |
-
bad += 1
|
| 314 |
-
if bad >= patience:
|
| 315 |
-
break
|
| 316 |
-
|
| 317 |
-
return float(best)
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
# -----------------------------
|
| 321 |
-
# Run: optuna + refit best
|
| 322 |
-
# -----------------------------
|
| 323 |
-
def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
|
| 324 |
-
out_dir = Path(out_dir)
|
| 325 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 326 |
-
|
| 327 |
-
train_ds, val_ds = load_split_paired(dataset_path)
|
| 328 |
-
print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
|
| 329 |
-
|
| 330 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 331 |
-
study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
|
| 332 |
-
|
| 333 |
-
study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
|
| 334 |
-
best = study.best_trial
|
| 335 |
-
best_params = dict(best.params)
|
| 336 |
-
|
| 337 |
-
# refit longer
|
| 338 |
-
lr = float(best_params["lr"])
|
| 339 |
-
wd = float(best_params["weight_decay"])
|
| 340 |
-
dropout = float(best_params["dropout"])
|
| 341 |
-
hidden = int(best_params["hidden_dim"])
|
| 342 |
-
n_heads = int(best_params["n_heads"])
|
| 343 |
-
n_layers = int(best_params["n_layers"])
|
| 344 |
-
cls_w = float(best_params["cls_weight"])
|
| 345 |
-
batch = int(best_params["batch_size"])
|
| 346 |
-
|
| 347 |
-
loss_reg = nn.MSELoss()
|
| 348 |
-
loss_cls = nn.CrossEntropyLoss()
|
| 349 |
-
|
| 350 |
-
if mode == "pooled":
|
| 351 |
-
Ht = len(train_ds[0]["target_embedding"])
|
| 352 |
-
Hb = len(train_ds[0]["binder_embedding"])
|
| 353 |
-
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 354 |
-
collate = collate_pair_pooled
|
| 355 |
-
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 356 |
-
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 357 |
-
eval_fn = eval_spearman_pooled
|
| 358 |
-
train_fn = train_one_epoch_pooled
|
| 359 |
-
else:
|
| 360 |
-
Ht = len(train_ds[0]["target_embedding"][0])
|
| 361 |
-
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 362 |
-
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
|
| 363 |
-
collate = collate_pair_unpooled
|
| 364 |
-
train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 365 |
-
val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
|
| 366 |
-
eval_fn = eval_spearman_unpooled
|
| 367 |
-
train_fn = train_one_epoch_unpooled
|
| 368 |
-
|
| 369 |
-
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 370 |
-
|
| 371 |
-
best_rho = -1e9
|
| 372 |
-
bad = 0
|
| 373 |
-
patience = 20
|
| 374 |
-
best_state = None
|
| 375 |
-
|
| 376 |
-
for ep in range(1, 201):
|
| 377 |
-
train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
|
| 378 |
-
rho = eval_fn(model, val_loader)
|
| 379 |
-
|
| 380 |
-
if rho > best_rho + 1e-6:
|
| 381 |
-
best_rho = rho
|
| 382 |
-
bad = 0
|
| 383 |
-
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 384 |
-
else:
|
| 385 |
-
bad += 1
|
| 386 |
-
if bad >= patience:
|
| 387 |
-
break
|
| 388 |
-
|
| 389 |
-
if best_state is not None:
|
| 390 |
-
model.load_state_dict(best_state)
|
| 391 |
-
|
| 392 |
-
# save
|
| 393 |
-
torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
|
| 394 |
-
with open(out_dir / "best_params.json", "w") as f:
|
| 395 |
-
json.dump(best_params, f, indent=2)
|
| 396 |
-
|
| 397 |
-
print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
if __name__ == "__main__":
|
| 401 |
-
import argparse
|
| 402 |
-
ap = argparse.ArgumentParser()
|
| 403 |
-
ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
|
| 404 |
-
ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
|
| 405 |
-
ap.add_argument("--out_dir", type=str, required=True)
|
| 406 |
-
ap.add_argument("--n_trials", type=int, default=50)
|
| 407 |
-
args = ap.parse_args()
|
| 408 |
-
|
| 409 |
-
run(
|
| 410 |
-
dataset_path=args.dataset_path,
|
| 411 |
-
out_dir=args.out_dir,
|
| 412 |
-
mode=args.mode,
|
| 413 |
-
n_trials=args.n_trials,
|
| 414 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --job-name=b-data
|
| 3 |
-
#SBATCH --partition=dgx-b200
|
| 4 |
-
#SBATCH --gpus=1
|
| 5 |
-
#SBATCH --cpus-per-task=10
|
| 6 |
-
#SBATCH --mem=100G
|
| 7 |
-
#SBATCH --time=48:00:00
|
| 8 |
-
#SBATCH --output=%x_%j.out
|
| 9 |
-
|
| 10 |
-
HOME_LOC=/vast/projects/pranam/lab/yz927
|
| 11 |
-
SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
|
| 12 |
-
DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
|
| 13 |
-
OBJECTIVE='binding_affinity'
|
| 14 |
-
WT='smiles' #wt/smiles
|
| 15 |
-
STATUS='pooled' #pooled/unpooled
|
| 16 |
-
DATA_FILE="pair_wt_${WT}_${STATUS}"
|
| 17 |
-
LOG_LOC=$SCRIPT_LOC
|
| 18 |
-
DATE=$(date +%m_%d)
|
| 19 |
-
SPECIAL_PREFIX="binding_affinity_data_generation"
|
| 20 |
-
|
| 21 |
-
# Create log directory if it doesn't exist
|
| 22 |
-
mkdir -p $LOG_LOC
|
| 23 |
-
|
| 24 |
-
cd $SCRIPT_LOC
|
| 25 |
-
source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
|
| 26 |
-
conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
|
| 27 |
-
|
| 28 |
-
python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 29 |
-
|
| 30 |
-
echo "Script completed at $(date)"
|
| 31 |
-
conda deactivate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py
DELETED
|
@@ -1,508 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# finetune_xgb_halflife_cv_optuna.py
|
| 3 |
-
|
| 4 |
-
import os
|
| 5 |
-
import json
|
| 6 |
-
import math
|
| 7 |
-
import hashlib
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from typing import Dict, Any, Optional, Tuple, List
|
| 10 |
-
|
| 11 |
-
import numpy as np
|
| 12 |
-
import pandas as pd
|
| 13 |
-
import optuna
|
| 14 |
-
|
| 15 |
-
from sklearn.model_selection import KFold
|
| 16 |
-
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 17 |
-
from scipy.stats import spearmanr
|
| 18 |
-
|
| 19 |
-
import torch
|
| 20 |
-
from transformers import AutoTokenizer, AutoModel
|
| 21 |
-
|
| 22 |
-
import xgboost as xgb
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# -----------------------------
|
| 26 |
-
# Repro
|
| 27 |
-
# -----------------------------
|
| 28 |
-
SEED = 1986
|
| 29 |
-
np.random.seed(SEED)
|
| 30 |
-
torch.manual_seed(SEED)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
# -----------------------------
|
| 34 |
-
# Metrics (mirrors your stability script style)
|
| 35 |
-
# -----------------------------
|
| 36 |
-
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 37 |
-
rho = spearmanr(y_true, y_pred).correlation
|
| 38 |
-
if rho is None or np.isnan(rho):
|
| 39 |
-
return 0.0
|
| 40 |
-
return float(rho)
|
| 41 |
-
|
| 42 |
-
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 43 |
-
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 44 |
-
mae = float(mean_absolute_error(y_true, y_pred))
|
| 45 |
-
r2 = float(r2_score(y_true, y_pred))
|
| 46 |
-
rho = float(safe_spearmanr(y_true, y_pred))
|
| 47 |
-
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# -----------------------------
|
| 51 |
-
# ESM-2 embeddings (cached)
|
| 52 |
-
# -----------------------------
|
| 53 |
-
@dataclass
|
| 54 |
-
class ESMEmbedderConfig:
|
| 55 |
-
model_name: str = "facebook/esm2_t33_650M_UR50D"
|
| 56 |
-
batch_size: int = 8
|
| 57 |
-
max_length: int = 1024 # truncate very long proteins
|
| 58 |
-
fp16: bool = True
|
| 59 |
-
|
| 60 |
-
class ESM2Embedder:
|
| 61 |
-
"""
|
| 62 |
-
Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence.
|
| 63 |
-
"""
|
| 64 |
-
def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"):
|
| 65 |
-
self.cfg = cfg
|
| 66 |
-
self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 67 |
-
self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False)
|
| 68 |
-
self.model = AutoModel.from_pretrained(cfg.model_name)
|
| 69 |
-
self.model.eval()
|
| 70 |
-
self.model.to(self.device)
|
| 71 |
-
|
| 72 |
-
# Turn off gradients
|
| 73 |
-
for p in self.model.parameters():
|
| 74 |
-
p.requires_grad = False
|
| 75 |
-
|
| 76 |
-
@torch.inference_mode()
|
| 77 |
-
def embed(self, seqs: List[str]) -> np.ndarray:
|
| 78 |
-
out = []
|
| 79 |
-
bs = self.cfg.batch_size
|
| 80 |
-
|
| 81 |
-
use_amp = (self.cfg.fp16 and self.device == "cuda")
|
| 82 |
-
autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast # safe fallback
|
| 83 |
-
|
| 84 |
-
for i in range(0, len(seqs), bs):
|
| 85 |
-
batch = [s.strip().upper() for s in seqs[i:i+bs]]
|
| 86 |
-
toks = self.tokenizer(
|
| 87 |
-
batch,
|
| 88 |
-
return_tensors="pt",
|
| 89 |
-
padding=True,
|
| 90 |
-
truncation=True,
|
| 91 |
-
max_length=self.cfg.max_length,
|
| 92 |
-
add_special_tokens=True,
|
| 93 |
-
)
|
| 94 |
-
toks = {k: v.to(self.device) for k, v in toks.items()}
|
| 95 |
-
attn = toks["attention_mask"] # (B, L)
|
| 96 |
-
|
| 97 |
-
with autocast(enabled=use_amp):
|
| 98 |
-
h = self.model(**toks).last_hidden_state # (B, L, H)
|
| 99 |
-
|
| 100 |
-
# mask out special tokens: first token is <cls>; last non-pad token is usually <eos>
|
| 101 |
-
mask = attn.clone()
|
| 102 |
-
mask[:, 0] = 0
|
| 103 |
-
lengths = attn.sum(dim=1) # includes special tokens
|
| 104 |
-
# zero out last real token position per sequence
|
| 105 |
-
eos_pos = (lengths - 1).clamp(min=0)
|
| 106 |
-
mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0
|
| 107 |
-
|
| 108 |
-
denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B,1)
|
| 109 |
-
pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom # (B,H)
|
| 110 |
-
out.append(pooled.float().detach().cpu().numpy())
|
| 111 |
-
|
| 112 |
-
return np.concatenate(out, axis=0).astype(np.float32)
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str:
|
| 116 |
-
h = hashlib.sha256()
|
| 117 |
-
for s in seqs:
|
| 118 |
-
h.update(s.encode("utf-8"))
|
| 119 |
-
h.update(b"\n")
|
| 120 |
-
h.update(np.asarray(y, dtype=np.float32).tobytes())
|
| 121 |
-
h.update(extra.encode("utf-8"))
|
| 122 |
-
return h.hexdigest()[:16]
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def load_or_compute_embeddings(
|
| 126 |
-
df: pd.DataFrame,
|
| 127 |
-
out_dir: str,
|
| 128 |
-
embed_cfg: ESMEmbedderConfig,
|
| 129 |
-
device: str,
|
| 130 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 131 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 132 |
-
|
| 133 |
-
seqs = df["sequence"].astype(str).tolist()
|
| 134 |
-
y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32)
|
| 135 |
-
|
| 136 |
-
fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}")
|
| 137 |
-
emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy")
|
| 138 |
-
meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json")
|
| 139 |
-
|
| 140 |
-
if os.path.exists(emb_path) and os.path.exists(meta_path):
|
| 141 |
-
X = np.load(emb_path).astype(np.float32)
|
| 142 |
-
return X, y, np.asarray(seqs)
|
| 143 |
-
|
| 144 |
-
embedder = ESM2Embedder(embed_cfg, device=device)
|
| 145 |
-
X = embedder.embed(seqs) # (N,H)
|
| 146 |
-
|
| 147 |
-
np.save(emb_path, X)
|
| 148 |
-
with open(meta_path, "w") as f:
|
| 149 |
-
json.dump(
|
| 150 |
-
{
|
| 151 |
-
"fingerprint": fp,
|
| 152 |
-
"model_name": embed_cfg.model_name,
|
| 153 |
-
"max_length": embed_cfg.max_length,
|
| 154 |
-
"n": len(seqs),
|
| 155 |
-
"dim": int(X.shape[1]),
|
| 156 |
-
},
|
| 157 |
-
f,
|
| 158 |
-
indent=2,
|
| 159 |
-
)
|
| 160 |
-
return X, y, np.asarray(seqs)
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
# -----------------------------
|
| 164 |
-
# XGBoost training (supports "finetune" via xgb_model)
|
| 165 |
-
# -----------------------------
|
| 166 |
-
def train_xgb_reg(
|
| 167 |
-
X_train: np.ndarray,
|
| 168 |
-
y_train: np.ndarray,
|
| 169 |
-
X_val: np.ndarray,
|
| 170 |
-
y_val: np.ndarray,
|
| 171 |
-
params: Dict[str, Any],
|
| 172 |
-
base_model_json: Optional[str] = None,
|
| 173 |
-
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]:
|
| 174 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 175 |
-
dval = xgb.DMatrix(X_val, label=y_val)
|
| 176 |
-
|
| 177 |
-
num_boost_round = int(params.pop("num_boost_round"))
|
| 178 |
-
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 179 |
-
|
| 180 |
-
# Important: load a fresh base model each fold (avoid leakage)
|
| 181 |
-
xgb_model = None
|
| 182 |
-
if base_model_json is not None:
|
| 183 |
-
booster0 = xgb.Booster()
|
| 184 |
-
booster0.load_model(base_model_json)
|
| 185 |
-
xgb_model = booster0
|
| 186 |
-
|
| 187 |
-
booster = xgb.train(
|
| 188 |
-
params=params,
|
| 189 |
-
dtrain=dtrain,
|
| 190 |
-
num_boost_round=num_boost_round,
|
| 191 |
-
evals=[(dval, "val")],
|
| 192 |
-
early_stopping_rounds=early_stopping_rounds,
|
| 193 |
-
verbose_eval=False,
|
| 194 |
-
xgb_model=xgb_model, # <-- "finetune": continue boosting from base model
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
p_train = booster.predict(dtrain)
|
| 198 |
-
p_val = booster.predict(dval)
|
| 199 |
-
best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1))
|
| 200 |
-
return booster, p_train, p_val, best_iter
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
# -----------------------------
|
| 204 |
-
# Optuna objective: 5-fold mean Spearman rho
|
| 205 |
-
# -----------------------------
|
| 206 |
-
def make_cv_objective(
|
| 207 |
-
X: np.ndarray,
|
| 208 |
-
y: np.ndarray,
|
| 209 |
-
n_splits: int,
|
| 210 |
-
device: str,
|
| 211 |
-
base_model_json: Optional[str],
|
| 212 |
-
target_transform: str,
|
| 213 |
-
):
|
| 214 |
-
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
|
| 215 |
-
|
| 216 |
-
# Optional target transform (sometimes helps with heavy-tailed half-life)
|
| 217 |
-
if target_transform == "log1p":
|
| 218 |
-
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
|
| 219 |
-
elif target_transform == "none":
|
| 220 |
-
y_used = y.astype(np.float32)
|
| 221 |
-
else:
|
| 222 |
-
raise ValueError(f"Unknown target_transform: {target_transform}")
|
| 223 |
-
|
| 224 |
-
def objective(trial: optuna.Trial) -> float:
|
| 225 |
-
# Hyperparam ranges patterned after your stability script :contentReference[oaicite:1]{index=1}
|
| 226 |
-
params = {
|
| 227 |
-
"objective": "reg:squarederror",
|
| 228 |
-
"eval_metric": "rmse",
|
| 229 |
-
|
| 230 |
-
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
|
| 231 |
-
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
|
| 232 |
-
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 233 |
-
|
| 234 |
-
"max_depth": trial.suggest_int("max_depth", 2, 12),
|
| 235 |
-
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True),
|
| 236 |
-
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 237 |
-
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 238 |
-
|
| 239 |
-
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
|
| 240 |
-
|
| 241 |
-
"tree_method": "hist",
|
| 242 |
-
"device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu",
|
| 243 |
-
}
|
| 244 |
-
params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500)
|
| 245 |
-
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150)
|
| 246 |
-
|
| 247 |
-
fold_metrics = []
|
| 248 |
-
fold_best_iters = []
|
| 249 |
-
|
| 250 |
-
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
|
| 251 |
-
Xtr, ytr = X[tr_idx], y_used[tr_idx]
|
| 252 |
-
Xva, yva = X[va_idx], y_used[va_idx]
|
| 253 |
-
|
| 254 |
-
_, _, p_va, best_iter = train_xgb_reg(
|
| 255 |
-
Xtr, ytr, Xva, yva, params.copy(),
|
| 256 |
-
base_model_json=base_model_json,
|
| 257 |
-
)
|
| 258 |
-
|
| 259 |
-
m = eval_regression(yva, p_va)
|
| 260 |
-
fold_metrics.append(m)
|
| 261 |
-
fold_best_iters.append(best_iter)
|
| 262 |
-
|
| 263 |
-
mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics]))
|
| 264 |
-
mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics]))
|
| 265 |
-
mean_mae = float(np.mean([m["mae"] for m in fold_metrics]))
|
| 266 |
-
mean_r2 = float(np.mean([m["r2"] for m in fold_metrics]))
|
| 267 |
-
mean_best_iter = float(np.mean(fold_best_iters))
|
| 268 |
-
|
| 269 |
-
trial.set_user_attr("cv_spearman_rho", mean_rho)
|
| 270 |
-
trial.set_user_attr("cv_rmse", mean_rmse)
|
| 271 |
-
trial.set_user_attr("cv_mae", mean_mae)
|
| 272 |
-
trial.set_user_attr("cv_r2", mean_r2)
|
| 273 |
-
trial.set_user_attr("cv_mean_best_iter", mean_best_iter)
|
| 274 |
-
|
| 275 |
-
# maximize Spearman rho (same as your stability workflow :contentReference[oaicite:2]{index=2})
|
| 276 |
-
return mean_rho
|
| 277 |
-
|
| 278 |
-
return objective
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def refit_and_save(
|
| 282 |
-
X: np.ndarray,
|
| 283 |
-
y: np.ndarray,
|
| 284 |
-
seqs: np.ndarray,
|
| 285 |
-
out_dir: str,
|
| 286 |
-
best_params: Dict[str, Any],
|
| 287 |
-
n_splits: int,
|
| 288 |
-
device: str,
|
| 289 |
-
base_model_json: Optional[str],
|
| 290 |
-
target_transform: str,
|
| 291 |
-
):
|
| 292 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 293 |
-
|
| 294 |
-
# Transform target consistently
|
| 295 |
-
if target_transform == "log1p":
|
| 296 |
-
y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
|
| 297 |
-
else:
|
| 298 |
-
y_used = y.astype(np.float32)
|
| 299 |
-
|
| 300 |
-
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
|
| 301 |
-
|
| 302 |
-
# 1) get OOF preds + average best_iteration
|
| 303 |
-
oof_pred = np.zeros_like(y_used, dtype=np.float32)
|
| 304 |
-
best_iters = []
|
| 305 |
-
fold_rows = []
|
| 306 |
-
|
| 307 |
-
for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
|
| 308 |
-
Xtr, ytr = X[tr_idx], y_used[tr_idx]
|
| 309 |
-
Xva, yva = X[va_idx], y_used[va_idx]
|
| 310 |
-
|
| 311 |
-
_, _, p_va, best_iter = train_xgb_reg(
|
| 312 |
-
Xtr, ytr, Xva, yva, best_params.copy(),
|
| 313 |
-
base_model_json=base_model_json,
|
| 314 |
-
)
|
| 315 |
-
oof_pred[va_idx] = p_va.astype(np.float32)
|
| 316 |
-
best_iters.append(best_iter)
|
| 317 |
-
|
| 318 |
-
m = eval_regression(yva, p_va)
|
| 319 |
-
fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)})
|
| 320 |
-
|
| 321 |
-
fold_df = pd.DataFrame(fold_rows)
|
| 322 |
-
fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False)
|
| 323 |
-
|
| 324 |
-
cv_metrics = eval_regression(y_used, oof_pred)
|
| 325 |
-
with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f:
|
| 326 |
-
json.dump(cv_metrics, f, indent=2)
|
| 327 |
-
|
| 328 |
-
oof_df = pd.DataFrame({
|
| 329 |
-
"sequence": seqs,
|
| 330 |
-
"y_true_used": y_used.astype(float),
|
| 331 |
-
"y_pred_oof": oof_pred.astype(float),
|
| 332 |
-
"residual": (y_used - oof_pred).astype(float),
|
| 333 |
-
})
|
| 334 |
-
oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False)
|
| 335 |
-
|
| 336 |
-
mean_best_iter = int(round(float(np.mean(best_iters))))
|
| 337 |
-
final_rounds = max(mean_best_iter + 1, 10)
|
| 338 |
-
|
| 339 |
-
# 2) train final model on ALL data (no early stopping here; use final_rounds)
|
| 340 |
-
dtrain_all = xgb.DMatrix(X, label=y_used)
|
| 341 |
-
|
| 342 |
-
xgb_model = None
|
| 343 |
-
if base_model_json is not None:
|
| 344 |
-
booster0 = xgb.Booster()
|
| 345 |
-
booster0.load_model(base_model_json)
|
| 346 |
-
xgb_model = booster0
|
| 347 |
-
|
| 348 |
-
final_params = best_params.copy()
|
| 349 |
-
final_params.pop("early_stopping_rounds", None)
|
| 350 |
-
final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
|
| 351 |
-
|
| 352 |
-
booster = xgb.train(
|
| 353 |
-
params=final_params,
|
| 354 |
-
dtrain=dtrain_all,
|
| 355 |
-
num_boost_round=int(final_params.pop("num_boost_round", final_rounds)),
|
| 356 |
-
evals=[],
|
| 357 |
-
verbose_eval=False,
|
| 358 |
-
xgb_model=xgb_model,
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
model_path = os.path.join(out_dir, "best_model_finetuned.json")
|
| 362 |
-
booster.save_model(model_path)
|
| 363 |
-
|
| 364 |
-
with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f:
|
| 365 |
-
json.dump(
|
| 366 |
-
{
|
| 367 |
-
"target_transform": target_transform,
|
| 368 |
-
"final_rounds_used": int(final_rounds),
|
| 369 |
-
"cv_oof_metrics_on_used_target": cv_metrics,
|
| 370 |
-
"model_path": model_path,
|
| 371 |
-
},
|
| 372 |
-
f,
|
| 373 |
-
indent=2,
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
print("=" * 72)
|
| 377 |
-
print("[Final] CV OOF metrics (on transformed target if enabled):")
|
| 378 |
-
print(json.dumps(cv_metrics, indent=2))
|
| 379 |
-
print(f"[Final] Saved finetuned model -> {model_path}")
|
| 380 |
-
print("=" * 72)
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
def main():
|
| 384 |
-
import argparse
|
| 385 |
-
|
| 386 |
-
parser = argparse.ArgumentParser()
|
| 387 |
-
parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv")
|
| 388 |
-
parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb")
|
| 389 |
-
|
| 390 |
-
# If provided, we will "finetune" by continuing boosting from this model
|
| 391 |
-
parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from")
|
| 392 |
-
|
| 393 |
-
# ESM embedding config
|
| 394 |
-
parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D")
|
| 395 |
-
parser.add_argument("--esm_batch_size", type=int, default=8)
|
| 396 |
-
parser.add_argument("--esm_max_length", type=int, default=1024)
|
| 397 |
-
parser.add_argument("--no_fp16", action="store_true")
|
| 398 |
-
|
| 399 |
-
# Training config
|
| 400 |
-
parser.add_argument("--n_trials", type=int, default=200)
|
| 401 |
-
parser.add_argument("--n_splits", type=int, default=5)
|
| 402 |
-
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
|
| 403 |
-
parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"])
|
| 404 |
-
|
| 405 |
-
args = parser.parse_args()
|
| 406 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 407 |
-
|
| 408 |
-
# Load data
|
| 409 |
-
df = pd.read_csv(args.csv_path)
|
| 410 |
-
if "sequence" not in df.columns or "half_life_hours" not in df.columns:
|
| 411 |
-
raise ValueError("CSV must contain columns: sequence, half_life_hours")
|
| 412 |
-
|
| 413 |
-
df = df.dropna(subset=["sequence", "half_life_hours"]).copy()
|
| 414 |
-
df["sequence"] = df["sequence"].astype(str).str.strip()
|
| 415 |
-
df = df[df["sequence"].str.len() > 0]
|
| 416 |
-
df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True)
|
| 417 |
-
|
| 418 |
-
print(f"[Data] N={len(df)} from {args.csv_path}")
|
| 419 |
-
|
| 420 |
-
# Embeddings (cached)
|
| 421 |
-
embed_cfg = ESMEmbedderConfig(
|
| 422 |
-
model_name=args.esm_model,
|
| 423 |
-
batch_size=args.esm_batch_size,
|
| 424 |
-
max_length=args.esm_max_length,
|
| 425 |
-
fp16=(not args.no_fp16),
|
| 426 |
-
)
|
| 427 |
-
X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device)
|
| 428 |
-
print(f"[Embeddings] X={X.shape} (float32)")
|
| 429 |
-
|
| 430 |
-
# Optuna study
|
| 431 |
-
sampler = optuna.samplers.TPESampler(seed=SEED)
|
| 432 |
-
study = optuna.create_study(
|
| 433 |
-
direction="maximize", # like your stability script :contentReference[oaicite:3]{index=3}
|
| 434 |
-
sampler=sampler,
|
| 435 |
-
pruner=optuna.pruners.MedianPruner(),
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
objective = make_cv_objective(
|
| 439 |
-
X=X,
|
| 440 |
-
y=y,
|
| 441 |
-
n_splits=args.n_splits,
|
| 442 |
-
device=args.device,
|
| 443 |
-
base_model_json=args.base_model_json,
|
| 444 |
-
target_transform=args.target_transform,
|
| 445 |
-
)
|
| 446 |
-
study.optimize(objective, n_trials=args.n_trials)
|
| 447 |
-
|
| 448 |
-
# Save trials
|
| 449 |
-
trials_df = study.trials_dataframe()
|
| 450 |
-
trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False)
|
| 451 |
-
|
| 452 |
-
best = study.best_trial
|
| 453 |
-
best_params = dict(best.params)
|
| 454 |
-
|
| 455 |
-
# Build full param dict for refit
|
| 456 |
-
best_xgb_params = {
|
| 457 |
-
"objective": "reg:squarederror",
|
| 458 |
-
"eval_metric": "rmse",
|
| 459 |
-
"lambda": best_params["lambda"],
|
| 460 |
-
"alpha": best_params["alpha"],
|
| 461 |
-
"gamma": best_params["gamma"],
|
| 462 |
-
"max_depth": best_params["max_depth"],
|
| 463 |
-
"min_child_weight": best_params["min_child_weight"],
|
| 464 |
-
"subsample": best_params["subsample"],
|
| 465 |
-
"colsample_bytree": best_params["colsample_bytree"],
|
| 466 |
-
"learning_rate": best_params["learning_rate"],
|
| 467 |
-
"tree_method": "hist",
|
| 468 |
-
"device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu",
|
| 469 |
-
"num_boost_round": best_params["num_boost_round"],
|
| 470 |
-
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 471 |
-
}
|
| 472 |
-
|
| 473 |
-
# Summary
|
| 474 |
-
summary = {
|
| 475 |
-
"best_trial_number": int(best.number),
|
| 476 |
-
"best_value_cv_spearman_rho": float(best.value),
|
| 477 |
-
"best_user_attrs": best.user_attrs,
|
| 478 |
-
"best_params": best_params,
|
| 479 |
-
"best_xgb_params_full": best_xgb_params,
|
| 480 |
-
"base_model_json": args.base_model_json,
|
| 481 |
-
"target_transform": args.target_transform,
|
| 482 |
-
"esm_model": args.esm_model,
|
| 483 |
-
"esm_max_length": args.esm_max_length,
|
| 484 |
-
}
|
| 485 |
-
with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f:
|
| 486 |
-
json.dump(summary, f, indent=2)
|
| 487 |
-
|
| 488 |
-
print("=" * 72)
|
| 489 |
-
print("[Optuna] Best CV Spearman rho:", float(best.value))
|
| 490 |
-
print("[Optuna] Best params:\n", json.dumps(best_params, indent=2))
|
| 491 |
-
print("=" * 72)
|
| 492 |
-
|
| 493 |
-
# Refit + save final finetuned model + OOF predictions
|
| 494 |
-
refit_and_save(
|
| 495 |
-
X=X,
|
| 496 |
-
y=y,
|
| 497 |
-
seqs=seqs,
|
| 498 |
-
out_dir=args.out_dir,
|
| 499 |
-
best_params=best_xgb_params,
|
| 500 |
-
n_splits=args.n_splits,
|
| 501 |
-
device=args.device,
|
| 502 |
-
base_model_json=args.base_model_json,
|
| 503 |
-
target_transform=args.target_transform,
|
| 504 |
-
)
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
if __name__ == "__main__":
|
| 508 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py
DELETED
|
@@ -1,309 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
# export_val_preds_csv.py
|
| 3 |
-
|
| 4 |
-
import argparse
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
from torch.utils.data import DataLoader
|
| 10 |
-
from datasets import load_from_disk, DatasetDict
|
| 11 |
-
|
| 12 |
-
# -----------------------------
|
| 13 |
-
# Repro / device
|
| 14 |
-
# -----------------------------
|
| 15 |
-
def seed_all(seed=1986):
|
| 16 |
-
import random
|
| 17 |
-
random.seed(seed)
|
| 18 |
-
np.random.seed(seed)
|
| 19 |
-
torch.manual_seed(seed)
|
| 20 |
-
torch.cuda.manual_seed_all(seed)
|
| 21 |
-
|
| 22 |
-
seed_all(1986)
|
| 23 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# -----------------------------
|
| 27 |
-
# Load paired DatasetDict
|
| 28 |
-
# -----------------------------
|
| 29 |
-
def load_split_paired(path: str):
|
| 30 |
-
dd = load_from_disk(path)
|
| 31 |
-
if not isinstance(dd, DatasetDict):
|
| 32 |
-
raise ValueError(f"Expected DatasetDict at {path}")
|
| 33 |
-
if "train" not in dd or "val" not in dd:
|
| 34 |
-
raise ValueError(f"DatasetDict missing train/val at {path}")
|
| 35 |
-
return dd["train"], dd["val"]
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# -----------------------------
|
| 39 |
-
# Collate fns (same as yours)
|
| 40 |
-
# -----------------------------
|
| 41 |
-
def collate_pair_pooled(batch):
|
| 42 |
-
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
|
| 43 |
-
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32)
|
| 44 |
-
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 45 |
-
return Pt, Pb, y
|
| 46 |
-
|
| 47 |
-
def collate_pair_unpooled(batch):
|
| 48 |
-
B = len(batch)
|
| 49 |
-
Ht = len(batch[0]["target_embedding"][0])
|
| 50 |
-
Hb = len(batch[0]["binder_embedding"][0])
|
| 51 |
-
Lt_max = max(int(x["target_length"]) for x in batch)
|
| 52 |
-
Lb_max = max(int(x["binder_length"]) for x in batch)
|
| 53 |
-
|
| 54 |
-
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
|
| 55 |
-
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
|
| 56 |
-
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
|
| 57 |
-
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
|
| 58 |
-
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 59 |
-
|
| 60 |
-
for i, x in enumerate(batch):
|
| 61 |
-
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
|
| 62 |
-
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
|
| 63 |
-
lt, lb = t.shape[0], b.shape[0]
|
| 64 |
-
Pt[i, :lt] = t
|
| 65 |
-
Pb[i, :lb] = b
|
| 66 |
-
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
|
| 67 |
-
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
|
| 68 |
-
|
| 69 |
-
return Pt, Mt, Pb, Mb, y
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# -----------------------------
|
| 73 |
-
# Models (same as yours)
|
| 74 |
-
# -----------------------------
|
| 75 |
-
class CrossAttnPooled(nn.Module):
|
| 76 |
-
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 77 |
-
super().__init__()
|
| 78 |
-
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 79 |
-
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 80 |
-
|
| 81 |
-
self.layers = nn.ModuleList([])
|
| 82 |
-
for _ in range(n_layers):
|
| 83 |
-
self.layers.append(nn.ModuleDict({
|
| 84 |
-
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 85 |
-
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
|
| 86 |
-
"n1t": nn.LayerNorm(hidden),
|
| 87 |
-
"n2t": nn.LayerNorm(hidden),
|
| 88 |
-
"n1b": nn.LayerNorm(hidden),
|
| 89 |
-
"n2b": nn.LayerNorm(hidden),
|
| 90 |
-
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 91 |
-
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 92 |
-
}))
|
| 93 |
-
|
| 94 |
-
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 95 |
-
self.reg = nn.Linear(hidden, 1)
|
| 96 |
-
self.cls = nn.Linear(hidden, 3)
|
| 97 |
-
|
| 98 |
-
def forward(self, t_vec, b_vec):
|
| 99 |
-
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
|
| 100 |
-
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
|
| 101 |
-
|
| 102 |
-
for L in self.layers:
|
| 103 |
-
t_attn, _ = L["attn_tb"](t, b, b)
|
| 104 |
-
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
|
| 105 |
-
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
|
| 106 |
-
|
| 107 |
-
b_attn, _ = L["attn_bt"](b, t, t)
|
| 108 |
-
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
|
| 109 |
-
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
|
| 110 |
-
|
| 111 |
-
z = torch.cat([t[0], b[0]], dim=-1)
|
| 112 |
-
h = self.shared(z)
|
| 113 |
-
return self.reg(h).squeeze(-1), self.cls(h)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class CrossAttnUnpooled(nn.Module):
|
| 117 |
-
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
|
| 118 |
-
super().__init__()
|
| 119 |
-
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
|
| 120 |
-
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
|
| 121 |
-
|
| 122 |
-
self.layers = nn.ModuleList([])
|
| 123 |
-
for _ in range(n_layers):
|
| 124 |
-
self.layers.append(nn.ModuleDict({
|
| 125 |
-
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 126 |
-
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
|
| 127 |
-
"n1t": nn.LayerNorm(hidden),
|
| 128 |
-
"n2t": nn.LayerNorm(hidden),
|
| 129 |
-
"n1b": nn.LayerNorm(hidden),
|
| 130 |
-
"n2b": nn.LayerNorm(hidden),
|
| 131 |
-
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 132 |
-
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
|
| 133 |
-
}))
|
| 134 |
-
|
| 135 |
-
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
|
| 136 |
-
self.reg = nn.Linear(hidden, 1)
|
| 137 |
-
self.cls = nn.Linear(hidden, 3)
|
| 138 |
-
|
| 139 |
-
def masked_mean(self, X, M):
|
| 140 |
-
Mf = M.unsqueeze(-1).float()
|
| 141 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 142 |
-
return (X * Mf).sum(dim=1) / denom
|
| 143 |
-
|
| 144 |
-
def forward(self, T, Mt, B, Mb):
|
| 145 |
-
T = self.t_proj(T)
|
| 146 |
-
Bx = self.b_proj(B)
|
| 147 |
-
|
| 148 |
-
kp_t = ~Mt
|
| 149 |
-
kp_b = ~Mb
|
| 150 |
-
|
| 151 |
-
for L in self.layers:
|
| 152 |
-
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
|
| 153 |
-
T = L["n1t"](T + T_attn)
|
| 154 |
-
T = L["n2t"](T + L["fft"](T))
|
| 155 |
-
|
| 156 |
-
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
|
| 157 |
-
Bx = L["n1b"](Bx + B_attn)
|
| 158 |
-
Bx = L["n2b"](Bx + L["ffb"](Bx))
|
| 159 |
-
|
| 160 |
-
t_pool = self.masked_mean(T, Mt)
|
| 161 |
-
b_pool = self.masked_mean(Bx, Mb)
|
| 162 |
-
z = torch.cat([t_pool, b_pool], dim=-1)
|
| 163 |
-
h = self.shared(z)
|
| 164 |
-
return self.reg(h).squeeze(-1), self.cls(h)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# -----------------------------
|
| 168 |
-
# Helpers
|
| 169 |
-
# -----------------------------
|
| 170 |
-
def softmax_np(logits: np.ndarray) -> np.ndarray:
|
| 171 |
-
x = logits - logits.max(axis=1, keepdims=True)
|
| 172 |
-
ex = np.exp(x)
|
| 173 |
-
return ex / ex.sum(axis=1, keepdims=True)
|
| 174 |
-
|
| 175 |
-
def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray:
|
| 176 |
-
centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3)
|
| 177 |
-
return (probs * centers).sum(axis=1)
|
| 178 |
-
|
| 179 |
-
def load_checkpoint(ckpt_path: str, mode: str, train_ds):
|
| 180 |
-
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 181 |
-
params = ckpt.get("best_params", {})
|
| 182 |
-
|
| 183 |
-
hidden = int(params.get("hidden_dim", 512))
|
| 184 |
-
n_heads = int(params.get("n_heads", 8))
|
| 185 |
-
n_layers = int(params.get("n_layers", 3))
|
| 186 |
-
dropout = float(params.get("dropout", 0.1))
|
| 187 |
-
|
| 188 |
-
if mode == "pooled":
|
| 189 |
-
Ht = len(train_ds[0]["target_embedding"])
|
| 190 |
-
Hb = len(train_ds[0]["binder_embedding"])
|
| 191 |
-
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
|
| 192 |
-
else:
|
| 193 |
-
Ht = len(train_ds[0]["target_embedding"][0])
|
| 194 |
-
Hb = len(train_ds[0]["binder_embedding"][0])
|
| 195 |
-
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
|
| 196 |
-
|
| 197 |
-
model.load_state_dict(ckpt["state_dict"], strict=True)
|
| 198 |
-
model.to(DEVICE).eval()
|
| 199 |
-
return model
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
@torch.no_grad()
|
| 203 |
-
def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str,
|
| 204 |
-
out_csv: str, batch_size: int, num_workers: int,
|
| 205 |
-
class_centers=(9.5, 8.0, 6.0)):
|
| 206 |
-
train_ds, val_ds = load_split_paired(dataset_path)
|
| 207 |
-
model = load_checkpoint(ckpt_path, mode, train_ds)
|
| 208 |
-
|
| 209 |
-
if mode == "pooled":
|
| 210 |
-
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
|
| 211 |
-
num_workers=num_workers, pin_memory=True,
|
| 212 |
-
collate_fn=collate_pair_pooled)
|
| 213 |
-
y_all, pred_reg_all, logits_all = [], [], []
|
| 214 |
-
for t, b, y in loader:
|
| 215 |
-
t = t.to(DEVICE, non_blocking=True)
|
| 216 |
-
b = b.to(DEVICE, non_blocking=True)
|
| 217 |
-
pred_reg, logits = model(t, b)
|
| 218 |
-
y_all.append(y.numpy())
|
| 219 |
-
pred_reg_all.append(pred_reg.detach().cpu().numpy())
|
| 220 |
-
logits_all.append(logits.detach().cpu().numpy())
|
| 221 |
-
|
| 222 |
-
else:
|
| 223 |
-
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
|
| 224 |
-
num_workers=num_workers, pin_memory=True,
|
| 225 |
-
collate_fn=collate_pair_unpooled)
|
| 226 |
-
y_all, pred_reg_all, logits_all = [], [], []
|
| 227 |
-
for T, Mt, B, Mb, y in loader:
|
| 228 |
-
T = T.to(DEVICE, non_blocking=True)
|
| 229 |
-
Mt = Mt.to(DEVICE, non_blocking=True)
|
| 230 |
-
B = B.to(DEVICE, non_blocking=True)
|
| 231 |
-
Mb = Mb.to(DEVICE, non_blocking=True)
|
| 232 |
-
pred_reg, logits = model(T, Mt, B, Mb)
|
| 233 |
-
y_all.append(y.numpy())
|
| 234 |
-
pred_reg_all.append(pred_reg.detach().cpu().numpy())
|
| 235 |
-
logits_all.append(logits.detach().cpu().numpy())
|
| 236 |
-
|
| 237 |
-
y_true = np.concatenate(y_all)
|
| 238 |
-
y_pred_reg = np.concatenate(pred_reg_all)
|
| 239 |
-
logits = np.concatenate(logits_all)
|
| 240 |
-
|
| 241 |
-
probs = softmax_np(logits) # (N,3)
|
| 242 |
-
y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers)
|
| 243 |
-
|
| 244 |
-
# Build CSV rows
|
| 245 |
-
out = Path(out_csv)
|
| 246 |
-
out.parent.mkdir(parents=True, exist_ok=True)
|
| 247 |
-
|
| 248 |
-
header = [
|
| 249 |
-
"split", "mode",
|
| 250 |
-
"y_true",
|
| 251 |
-
"y_pred_reg",
|
| 252 |
-
"p_high", "p_moderate", "p_low",
|
| 253 |
-
"y_pred_cls_score",
|
| 254 |
-
"center_high", "center_moderate", "center_low",
|
| 255 |
-
]
|
| 256 |
-
|
| 257 |
-
centers = list(class_centers)
|
| 258 |
-
rows = np.column_stack([
|
| 259 |
-
y_true,
|
| 260 |
-
y_pred_reg,
|
| 261 |
-
probs[:, 0], probs[:, 1], probs[:, 2],
|
| 262 |
-
y_pred_cls_score,
|
| 263 |
-
np.full_like(y_true, centers[0], dtype=np.float32),
|
| 264 |
-
np.full_like(y_true, centers[1], dtype=np.float32),
|
| 265 |
-
np.full_like(y_true, centers[2], dtype=np.float32),
|
| 266 |
-
])
|
| 267 |
-
|
| 268 |
-
with out.open("w") as f:
|
| 269 |
-
f.write(",".join(header) + "\n")
|
| 270 |
-
for i in range(rows.shape[0]):
|
| 271 |
-
f.write(
|
| 272 |
-
"val," + mode + "," +
|
| 273 |
-
",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) +
|
| 274 |
-
"\n"
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
print(f"[Data] Val N={len(y_true)} | mode={mode}")
|
| 278 |
-
print(f"[Saved] {out}")
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def main():
|
| 282 |
-
ap = argparse.ArgumentParser()
|
| 283 |
-
ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)")
|
| 284 |
-
ap.add_argument("--ckpt", required=True, help="Path to best_model.pt")
|
| 285 |
-
ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True)
|
| 286 |
-
ap.add_argument("--out_csv", required=True)
|
| 287 |
-
ap.add_argument("--batch_size", type=int, default=128)
|
| 288 |
-
ap.add_argument("--num_workers", type=int, default=4)
|
| 289 |
-
|
| 290 |
-
# Optional: choose class-centers for expected-score conversion
|
| 291 |
-
ap.add_argument("--center_high", type=float, default=9.5)
|
| 292 |
-
ap.add_argument("--center_moderate", type=float, default=8.0)
|
| 293 |
-
ap.add_argument("--center_low", type=float, default=6.0)
|
| 294 |
-
|
| 295 |
-
args = ap.parse_args()
|
| 296 |
-
|
| 297 |
-
export_val_preds_csv(
|
| 298 |
-
dataset_path=args.dataset_path,
|
| 299 |
-
ckpt_path=args.ckpt,
|
| 300 |
-
mode=args.mode,
|
| 301 |
-
out_csv=args.out_csv,
|
| 302 |
-
batch_size=args.batch_size,
|
| 303 |
-
num_workers=args.num_workers,
|
| 304 |
-
class_centers=(args.center_high, args.center_moderate, args.center_low),
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
if __name__ == "__main__":
|
| 309 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt
DELETED
|
@@ -1,234 +0,0 @@
|
|
| 1 |
-
./hemolysis/cnn_smiles/optimization_summary.txt
|
| 2 |
-
./hemolysis/cnn_smiles/pr_curve.png
|
| 3 |
-
./hemolysis/cnn_smiles/roc_curve.png
|
| 4 |
-
./hemolysis/cnn_smiles/study_trials.csv
|
| 5 |
-
./hemolysis/cnn_smiles/train_predictions.csv
|
| 6 |
-
./hemolysis/cnn_smiles/val_predictions.csv
|
| 7 |
-
./hemolysis/cnn_wt/optimization_summary.txt
|
| 8 |
-
./hemolysis/cnn_wt/pr_curve.png
|
| 9 |
-
./hemolysis/cnn_wt/roc_curve.png
|
| 10 |
-
./hemolysis/cnn_wt/study_trials.csv
|
| 11 |
-
./hemolysis/cnn_wt/train_predictions.csv
|
| 12 |
-
./hemolysis/cnn_wt/val_predictions.csv
|
| 13 |
-
./hemolysis/enet_gpu/optimization_summary.txt
|
| 14 |
-
./hemolysis/enet_gpu/pr_curve.png
|
| 15 |
-
./hemolysis/enet_gpu/roc_curve.png
|
| 16 |
-
./hemolysis/enet_gpu/study_trials.csv
|
| 17 |
-
./hemolysis/enet_gpu/train_predictions.csv
|
| 18 |
-
./hemolysis/enet_gpu/val_predictions.csv
|
| 19 |
-
./hemolysis/enet_gpu_smiles/optimization_summary.txt
|
| 20 |
-
./hemolysis/enet_gpu_smiles/pr_curve.png
|
| 21 |
-
./hemolysis/enet_gpu_smiles/roc_curve.png
|
| 22 |
-
./hemolysis/enet_gpu_smiles/study_trials.csv
|
| 23 |
-
./hemolysis/enet_gpu_smiles/train_predictions.csv
|
| 24 |
-
./hemolysis/enet_gpu_smiles/val_predictions.csv
|
| 25 |
-
./hemolysis/enet_gpu_wt/optimization_summary.txt
|
| 26 |
-
./hemolysis/enet_gpu_wt/pr_curve.png
|
| 27 |
-
./hemolysis/enet_gpu_wt/roc_curve.png
|
| 28 |
-
./hemolysis/enet_gpu_wt/study_trials.csv
|
| 29 |
-
./hemolysis/enet_gpu_wt/train_predictions.csv
|
| 30 |
-
./hemolysis/enet_gpu_wt/val_predictions.csv
|
| 31 |
-
./hemolysis/mlp_smiles/optimization_summary.txt
|
| 32 |
-
./hemolysis/mlp_smiles/pr_curve.png
|
| 33 |
-
./hemolysis/mlp_smiles/roc_curve.png
|
| 34 |
-
./hemolysis/mlp_smiles/study_trials.csv
|
| 35 |
-
./hemolysis/mlp_smiles/train_predictions.csv
|
| 36 |
-
./hemolysis/mlp_smiles/val_predictions.csv
|
| 37 |
-
./hemolysis/mlp_wt/optimization_summary.txt
|
| 38 |
-
./hemolysis/mlp_wt/pr_curve.png
|
| 39 |
-
./hemolysis/mlp_wt/roc_curve.png
|
| 40 |
-
./hemolysis/mlp_wt/study_trials.csv
|
| 41 |
-
./hemolysis/mlp_wt/train_predictions.csv
|
| 42 |
-
./hemolysis/mlp_wt/val_predictions.csv
|
| 43 |
-
./hemolysis/svm_gpu_wt/optimization_summary.txt
|
| 44 |
-
./hemolysis/svm_gpu_wt/pr_curve.png
|
| 45 |
-
./hemolysis/svm_gpu_wt/roc_curve.png
|
| 46 |
-
./hemolysis/svm_gpu_wt/study_trials.csv
|
| 47 |
-
./hemolysis/svm_gpu_wt/train_predictions.csv
|
| 48 |
-
./hemolysis/svm_gpu_wt/val_predictions.csv
|
| 49 |
-
./hemolysis/transformer_smiles/optimization_summary.txt
|
| 50 |
-
./hemolysis/transformer_smiles/pr_curve.png
|
| 51 |
-
./hemolysis/transformer_smiles/roc_curve.png
|
| 52 |
-
./hemolysis/transformer_smiles/study_trials.csv
|
| 53 |
-
./hemolysis/transformer_smiles/train_predictions.csv
|
| 54 |
-
./hemolysis/transformer_smiles/val_predictions.csv
|
| 55 |
-
./hemolysis/transformer_wt/optimization_summary.txt
|
| 56 |
-
./hemolysis/transformer_wt/pr_curve.png
|
| 57 |
-
./hemolysis/transformer_wt/roc_curve.png
|
| 58 |
-
./hemolysis/transformer_wt/study_trials.csv
|
| 59 |
-
./hemolysis/transformer_wt/train_predictions.csv
|
| 60 |
-
./hemolysis/transformer_wt/val_predictions.csv
|
| 61 |
-
./hemolysis/xgb/optimization_summary.txt
|
| 62 |
-
./hemolysis/xgb/pr_curve.png
|
| 63 |
-
./hemolysis/xgb/roc_curve.png
|
| 64 |
-
./hemolysis/xgb/study_trials.csv
|
| 65 |
-
./hemolysis/xgb/train_predictions.csv
|
| 66 |
-
./hemolysis/xgb/val_predictions.csv
|
| 67 |
-
./hemolysis/xgb_smiles/optimization_summary.txt
|
| 68 |
-
./hemolysis/xgb_smiles/pr_curve.png
|
| 69 |
-
./hemolysis/xgb_smiles/roc_curve.png
|
| 70 |
-
./hemolysis/xgb_smiles/study_trials.csv
|
| 71 |
-
./hemolysis/xgb_smiles/train_predictions.csv
|
| 72 |
-
./hemolysis/xgb_smiles/val_predictions.csv
|
| 73 |
-
./hemolysis/xgb_wt/optimization_summary.txt
|
| 74 |
-
./hemolysis/xgb_wt/pr_curve.png
|
| 75 |
-
./hemolysis/xgb_wt/roc_curve.png
|
| 76 |
-
./hemolysis/xgb_wt/study_trials.csv
|
| 77 |
-
./hemolysis/xgb_wt/train_predictions.csv
|
| 78 |
-
./hemolysis/xgb_wt/val_predictions.csv
|
| 79 |
-
./nf/cnn/optimization_summary.txt
|
| 80 |
-
./nf/cnn/pr_curve.png
|
| 81 |
-
./nf/cnn/roc_curve.png
|
| 82 |
-
./nf/cnn/study_trials.csv
|
| 83 |
-
./nf/cnn/train_predictions.csv
|
| 84 |
-
./nf/cnn/val_predictions.csv
|
| 85 |
-
./nf/cnn_wt/optimization_summary.txt
|
| 86 |
-
./nf/cnn_wt/pr_curve.png
|
| 87 |
-
./nf/cnn_wt/roc_curve.png
|
| 88 |
-
./nf/cnn_wt/study_trials.csv
|
| 89 |
-
./nf/cnn_wt/train_predictions.csv
|
| 90 |
-
./nf/cnn_wt/val_predictions.csv
|
| 91 |
-
./nf/enet_gpu/optimization_summary.txt
|
| 92 |
-
./nf/enet_gpu/pr_curve.png
|
| 93 |
-
./nf/enet_gpu/roc_curve.png
|
| 94 |
-
./nf/enet_gpu/study_trials.csv
|
| 95 |
-
./nf/enet_gpu/train_predictions.csv
|
| 96 |
-
./nf/enet_gpu/val_predictions.csv
|
| 97 |
-
./nf/enet_gpu_smiles/optimization_summary.txt
|
| 98 |
-
./nf/enet_gpu_smiles/pr_curve.png
|
| 99 |
-
./nf/enet_gpu_smiles/roc_curve.png
|
| 100 |
-
./nf/enet_gpu_smiles/study_trials.csv
|
| 101 |
-
./nf/enet_gpu_smiles/train_predictions.csv
|
| 102 |
-
./nf/enet_gpu_smiles/val_predictions.csv
|
| 103 |
-
./nf/enet_gpu_wt/optimization_summary.txt
|
| 104 |
-
./nf/enet_gpu_wt/pr_curve.png
|
| 105 |
-
./nf/enet_gpu_wt/roc_curve.png
|
| 106 |
-
./nf/enet_gpu_wt/study_trials.csv
|
| 107 |
-
./nf/enet_gpu_wt/train_predictions.csv
|
| 108 |
-
./nf/enet_gpu_wt/val_predictions.csv
|
| 109 |
-
./nf/mlp/optimization_summary.txt
|
| 110 |
-
./nf/mlp/pr_curve.png
|
| 111 |
-
./nf/mlp/roc_curve.png
|
| 112 |
-
./nf/mlp/study_trials.csv
|
| 113 |
-
./nf/mlp/train_predictions.csv
|
| 114 |
-
./nf/mlp/val_predictions.csv
|
| 115 |
-
./nf/mlp_wt/optimization_summary.txt
|
| 116 |
-
./nf/mlp_wt/pr_curve.png
|
| 117 |
-
./nf/mlp_wt/roc_curve.png
|
| 118 |
-
./nf/mlp_wt/study_trials.csv
|
| 119 |
-
./nf/mlp_wt/train_predictions.csv
|
| 120 |
-
./nf/mlp_wt/val_predictions.csv
|
| 121 |
-
./nf/svm_gpu/optimization_summary.txt
|
| 122 |
-
./nf/svm_gpu/pr_curve.png
|
| 123 |
-
./nf/svm_gpu/roc_curve.png
|
| 124 |
-
./nf/svm_gpu/study_trials.csv
|
| 125 |
-
./nf/svm_gpu/train_predictions.csv
|
| 126 |
-
./nf/svm_gpu/val_predictions.csv
|
| 127 |
-
./nf/svm_gpu_wt/optimization_summary.txt
|
| 128 |
-
./nf/svm_gpu_wt/pr_curve.png
|
| 129 |
-
./nf/svm_gpu_wt/roc_curve.png
|
| 130 |
-
./nf/svm_gpu_wt/study_trials.csv
|
| 131 |
-
./nf/svm_gpu_wt/train_predictions.csv
|
| 132 |
-
./nf/svm_gpu_wt/val_predictions.csv
|
| 133 |
-
./nf/transformer/optimization_summary.txt
|
| 134 |
-
./nf/transformer/pr_curve.png
|
| 135 |
-
./nf/transformer/roc_curve.png
|
| 136 |
-
./nf/transformer/study_trials.csv
|
| 137 |
-
./nf/transformer/train_predictions.csv
|
| 138 |
-
./nf/transformer/val_predictions.csv
|
| 139 |
-
./nf/transformer_wt/optimization_summary.txt
|
| 140 |
-
./nf/transformer_wt/pr_curve.png
|
| 141 |
-
./nf/transformer_wt/roc_curve.png
|
| 142 |
-
./nf/transformer_wt/study_trials.csv
|
| 143 |
-
./nf/transformer_wt/train_predictions.csv
|
| 144 |
-
./nf/transformer_wt/val_predictions.csv
|
| 145 |
-
./nf/xgb_wt/optimization_summary.txt
|
| 146 |
-
./nf/xgb_wt/pr_curve.png
|
| 147 |
-
./nf/xgb_wt/roc_curve.png
|
| 148 |
-
./nf/xgb_wt/study_trials.csv
|
| 149 |
-
./nf/xgb_wt/train_predictions.csv
|
| 150 |
-
./nf/xgb_wt/val_predictions.csv
|
| 151 |
-
./permeability_caco2/cnn_smiles/optimization_summary.txt
|
| 152 |
-
./permeability_caco2/cnn_smiles/study_trials.csv
|
| 153 |
-
./permeability_caco2/cnn_smiles/train_predictions.csv
|
| 154 |
-
./permeability_caco2/cnn_smiles/val_predictions.csv
|
| 155 |
-
./permeability_caco2/enet_gpu_smiles/optimization_summary.txt
|
| 156 |
-
./permeability_caco2/enet_gpu_smiles/study_trials.csv
|
| 157 |
-
./permeability_caco2/enet_gpu_smiles/train_predictions.csv
|
| 158 |
-
./permeability_caco2/enet_gpu_smiles/val_predictions.csv
|
| 159 |
-
./permeability_caco2/mlp_smiles/optimization_summary.txt
|
| 160 |
-
./permeability_caco2/mlp_smiles/study_trials.csv
|
| 161 |
-
./permeability_caco2/mlp_smiles/train_predictions.csv
|
| 162 |
-
./permeability_caco2/mlp_smiles/val_predictions.csv
|
| 163 |
-
./permeability_caco2/svr_smiles/optimization_summary.txt
|
| 164 |
-
./permeability_caco2/svr_smiles/study_trials.csv
|
| 165 |
-
./permeability_caco2/svr_smiles/train_predictions.csv
|
| 166 |
-
./permeability_caco2/svr_smiles/val_predictions.csv
|
| 167 |
-
./permeability_caco2/transformer_smiles/optimization_summary.txt
|
| 168 |
-
./permeability_caco2/transformer_smiles/study_trials.csv
|
| 169 |
-
./permeability_caco2/transformer_smiles/train_predictions.csv
|
| 170 |
-
./permeability_caco2/transformer_smiles/val_predictions.csv
|
| 171 |
-
./permeability_caco2/xgb_reg_smiles/optimization_summary.txt
|
| 172 |
-
./permeability_caco2/xgb_reg_smiles/study_trials.csv
|
| 173 |
-
./permeability_caco2/xgb_reg_smiles/train_predictions.csv
|
| 174 |
-
./permeability_caco2/xgb_reg_smiles/val_predictions.csv
|
| 175 |
-
./permeability_pampa/cnn_smiles/optimization_summary.txt
|
| 176 |
-
./permeability_pampa/cnn_smiles/study_trials.csv
|
| 177 |
-
./permeability_pampa/cnn_smiles/train_predictions.csv
|
| 178 |
-
./permeability_pampa/cnn_smiles/val_predictions.csv
|
| 179 |
-
./permeability_pampa/enet_gpu_smiles/optimization_summary.txt
|
| 180 |
-
./permeability_pampa/enet_gpu_smiles/study_trials.csv
|
| 181 |
-
./permeability_pampa/enet_gpu_smiles/train_predictions.csv
|
| 182 |
-
./permeability_pampa/enet_gpu_smiles/val_predictions.csv
|
| 183 |
-
./permeability_pampa/mlp_smiles/optimization_summary.txt
|
| 184 |
-
./permeability_pampa/mlp_smiles/study_trials.csv
|
| 185 |
-
./permeability_pampa/mlp_smiles/train_predictions.csv
|
| 186 |
-
./permeability_pampa/mlp_smiles/val_predictions.csv
|
| 187 |
-
./permeability_pampa/transformer_smiles/optimization_summary.txt
|
| 188 |
-
./permeability_pampa/transformer_smiles/study_trials.csv
|
| 189 |
-
./permeability_pampa/transformer_smiles/train_predictions.csv
|
| 190 |
-
./permeability_pampa/transformer_smiles/val_predictions.csv
|
| 191 |
-
./permeability_pampa/xgb_reg_smiles/optimization_summary.txt
|
| 192 |
-
./permeability_pampa/xgb_reg_smiles/study_trials.csv
|
| 193 |
-
./permeability_pampa/xgb_reg_smiles/train_predictions.csv
|
| 194 |
-
./permeability_pampa/xgb_reg_smiles/val_predictions.csv
|
| 195 |
-
./solubility/cnn_wt/optimization_summary.txt
|
| 196 |
-
./solubility/cnn_wt/pr_curve.png
|
| 197 |
-
./solubility/cnn_wt/roc_curve.png
|
| 198 |
-
./solubility/cnn_wt/study_trials.csv
|
| 199 |
-
./solubility/cnn_wt/train_predictions.csv
|
| 200 |
-
./solubility/cnn_wt/val_predictions.csv
|
| 201 |
-
./solubility/enet_gpu/optimization_summary.txt
|
| 202 |
-
./solubility/enet_gpu/pr_curve.png
|
| 203 |
-
./solubility/enet_gpu/roc_curve.png
|
| 204 |
-
./solubility/enet_gpu/study_trials.csv
|
| 205 |
-
./solubility/enet_gpu/train_predictions.csv
|
| 206 |
-
./solubility/enet_gpu/val_predictions.csv
|
| 207 |
-
./solubility/mlp_wt/optimization_summary.txt
|
| 208 |
-
./solubility/mlp_wt/pr_curve.png
|
| 209 |
-
./solubility/mlp_wt/roc_curve.png
|
| 210 |
-
./solubility/mlp_wt/study_trials.csv
|
| 211 |
-
./solubility/mlp_wt/train_predictions.csv
|
| 212 |
-
./solubility/mlp_wt/val_predictions.csv
|
| 213 |
-
./solubility/svm_gpu/optimization_summary.txt
|
| 214 |
-
./solubility/svm_gpu/pr_curve.png
|
| 215 |
-
./solubility/svm_gpu/roc_curve.png
|
| 216 |
-
./solubility/svm_gpu/study_trials.csv
|
| 217 |
-
./solubility/svm_gpu/train_predictions.csv
|
| 218 |
-
./solubility/svm_gpu/val_predictions.csv
|
| 219 |
-
./solubility/transformer_wt/optimization_summary.txt
|
| 220 |
-
./solubility/transformer_wt/pr_curve.png
|
| 221 |
-
./solubility/transformer_wt/roc_curve.png
|
| 222 |
-
./solubility/transformer_wt/study_trials.csv
|
| 223 |
-
./solubility/transformer_wt/train_predictions.csv
|
| 224 |
-
./solubility/transformer_wt/val_predictions.csv
|
| 225 |
-
./solubility/xgb/optimization_summary.txt
|
| 226 |
-
./solubility/xgb/pr_curve.png
|
| 227 |
-
./solubility/xgb/roc_curve.png
|
| 228 |
-
./solubility/xgb/study_trials.csv
|
| 229 |
-
./solubility/xgb/train_predictions.csv
|
| 230 |
-
./solubility/xgb/val_predictions.csv
|
| 231 |
-
./binding_affinity/wt_wt_pooled/optuna_trials.csv
|
| 232 |
-
./binding_affinity/wt_smiles_pooled/optuna_trials.csv
|
| 233 |
-
./binding_affinity/wt_smiles_unpooled/optuna_trials.csv
|
| 234 |
-
./binding_affinity/wt_wt_unpooled/optuna_trials.csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py
DELETED
|
@@ -1,417 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
import joblib
|
| 4 |
-
import optuna
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
|
| 9 |
-
from dataclasses import dataclass
|
| 10 |
-
from typing import Dict, Any, Tuple, Optional
|
| 11 |
-
|
| 12 |
-
from datasets import load_from_disk, DatasetDict
|
| 13 |
-
from sklearn.metrics import (
|
| 14 |
-
f1_score, roc_auc_score, average_precision_score,
|
| 15 |
-
precision_recall_curve, roc_curve
|
| 16 |
-
)
|
| 17 |
-
from sklearn.linear_model import LogisticRegression
|
| 18 |
-
from sklearn.ensemble import AdaBoostClassifier
|
| 19 |
-
from sklearn.tree import DecisionTreeClassifier
|
| 20 |
-
from linearboost import LinearBoostClassifier
|
| 21 |
-
|
| 22 |
-
import xgboost as xgb
|
| 23 |
-
from lightning.pytorch import seed_everything
|
| 24 |
-
|
| 25 |
-
seed_everything(1986)
|
| 26 |
-
|
| 27 |
-
# -----------------------------
|
| 28 |
-
# Data loading
|
| 29 |
-
# -----------------------------
|
| 30 |
-
@dataclass
|
| 31 |
-
class SplitData:
|
| 32 |
-
X_train: np.ndarray
|
| 33 |
-
y_train: np.ndarray
|
| 34 |
-
seq_train: Optional[np.ndarray]
|
| 35 |
-
X_val: np.ndarray
|
| 36 |
-
y_val: np.ndarray
|
| 37 |
-
seq_val: Optional[np.ndarray]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _stack_embeddings(col) -> np.ndarray:
|
| 41 |
-
# HF datasets often store embeddings as list-of-floats per row
|
| 42 |
-
arr = np.asarray(col, dtype=np.float32)
|
| 43 |
-
if arr.ndim != 2:
|
| 44 |
-
arr = np.stack(col).astype(np.float32)
|
| 45 |
-
return arr
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def load_split_data(dataset_path: str) -> SplitData:
|
| 49 |
-
ds = load_from_disk(dataset_path)
|
| 50 |
-
|
| 51 |
-
# Case A: DatasetDict with train/val
|
| 52 |
-
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 53 |
-
train_ds, val_ds = ds["train"], ds["val"]
|
| 54 |
-
else:
|
| 55 |
-
# Case B: Single dataset with "split" column
|
| 56 |
-
if "split" not in ds.column_names:
|
| 57 |
-
raise ValueError(
|
| 58 |
-
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
| 59 |
-
)
|
| 60 |
-
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 61 |
-
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 62 |
-
|
| 63 |
-
for required in ["embedding", "label"]:
|
| 64 |
-
if required not in train_ds.column_names:
|
| 65 |
-
raise ValueError(f"Missing column '{required}' in train split.")
|
| 66 |
-
if required not in val_ds.column_names:
|
| 67 |
-
raise ValueError(f"Missing column '{required}' in val split.")
|
| 68 |
-
|
| 69 |
-
X_train = _stack_embeddings(train_ds["embedding"])
|
| 70 |
-
y_train = np.asarray(train_ds["label"], dtype=np.int64)
|
| 71 |
-
|
| 72 |
-
X_val = _stack_embeddings(val_ds["embedding"])
|
| 73 |
-
y_val = np.asarray(val_ds["label"], dtype=np.int64)
|
| 74 |
-
|
| 75 |
-
seq_train = None
|
| 76 |
-
seq_val = None
|
| 77 |
-
if "sequence" in train_ds.column_names:
|
| 78 |
-
seq_train = np.asarray(train_ds["sequence"])
|
| 79 |
-
if "sequence" in val_ds.column_names:
|
| 80 |
-
seq_val = np.asarray(val_ds["sequence"])
|
| 81 |
-
|
| 82 |
-
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# -----------------------------
|
| 86 |
-
# Metrics + thresholding
|
| 87 |
-
# -----------------------------
|
| 88 |
-
def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
|
| 89 |
-
"""
|
| 90 |
-
Find threshold maximizing F1 on the given set.
|
| 91 |
-
Returns (best_threshold, best_f1).
|
| 92 |
-
"""
|
| 93 |
-
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
|
| 94 |
-
# precision_recall_curve returns thresholds of length n-1
|
| 95 |
-
# compute F1 for those thresholds
|
| 96 |
-
f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
|
| 97 |
-
best_idx = int(np.nanargmax(f1s))
|
| 98 |
-
return float(thresholds[best_idx]), float(f1s[best_idx])
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
|
| 102 |
-
y_pred = (y_prob >= threshold).astype(int)
|
| 103 |
-
return {
|
| 104 |
-
"f1": float(f1_score(y_true, y_pred)),
|
| 105 |
-
"auc": float(roc_auc_score(y_true, y_prob)),
|
| 106 |
-
"ap": float(average_precision_score(y_true, y_prob)),
|
| 107 |
-
"threshold": float(threshold),
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# -----------------------------
|
| 112 |
-
# Model factories
|
| 113 |
-
# -----------------------------
|
| 114 |
-
def train_xgb(
|
| 115 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 116 |
-
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 117 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 118 |
-
dval = xgb.DMatrix(X_val, label=y_val)
|
| 119 |
-
|
| 120 |
-
num_boost_round = int(params.pop("num_boost_round"))
|
| 121 |
-
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 122 |
-
|
| 123 |
-
booster = xgb.train(
|
| 124 |
-
params=params,
|
| 125 |
-
dtrain=dtrain,
|
| 126 |
-
num_boost_round=num_boost_round,
|
| 127 |
-
evals=[(dval, "val")],
|
| 128 |
-
early_stopping_rounds=early_stopping_rounds,
|
| 129 |
-
verbose_eval=False,
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
p_train = booster.predict(dtrain)
|
| 133 |
-
p_val = booster.predict(dval)
|
| 134 |
-
return booster, p_train, p_val
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def train_adaboost(
|
| 138 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 139 |
-
) -> Tuple[AdaBoostClassifier, np.ndarray, np.ndarray]:
|
| 140 |
-
base_depth = int(params.pop("base_depth"))
|
| 141 |
-
clf = AdaBoostClassifier(
|
| 142 |
-
estimator=DecisionTreeClassifier(max_depth=base_depth),
|
| 143 |
-
n_estimators=int(params["n_estimators"]),
|
| 144 |
-
learning_rate=float(params["learning_rate"]),
|
| 145 |
-
algorithm="SAMME",
|
| 146 |
-
)
|
| 147 |
-
clf.fit(X_train, y_train)
|
| 148 |
-
p_train = clf.predict_proba(X_train)[:, 1]
|
| 149 |
-
p_val = clf.predict_proba(X_val)[:, 1]
|
| 150 |
-
return clf, p_train, p_val
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def train_linearboost(X_train, y_train, X_val, y_val, params):
|
| 154 |
-
clf = LinearBoostClassifier(**params)
|
| 155 |
-
clf.fit(X_train, y_train)
|
| 156 |
-
p_train = clf.predict_proba(X_train)[:, 1]
|
| 157 |
-
p_val = clf.predict_proba(X_val)[:, 1]
|
| 158 |
-
return clf, p_train, p_val
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
def suggest_linearboost_params(trial):
|
| 162 |
-
# Core boosting params
|
| 163 |
-
params = {
|
| 164 |
-
"n_estimators": trial.suggest_int("n_estimators", 50, 800),
|
| 165 |
-
"learning_rate": trial.suggest_float("learning_rate", 0.01, 1.0, log=True),
|
| 166 |
-
"algorithm": trial.suggest_categorical("algorithm", ["SAMME.R", "SAMME"]),
|
| 167 |
-
# Scaling choices from docs (you can expand this list if you want)
|
| 168 |
-
"scaler": trial.suggest_categorical(
|
| 169 |
-
"scaler",
|
| 170 |
-
["minmax", "standard", "robust", "quantile-uniform", "quantile-normal", "power"]
|
| 171 |
-
),
|
| 172 |
-
# useful for imbalanced splits
|
| 173 |
-
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 174 |
-
# kernel trick
|
| 175 |
-
"kernel": trial.suggest_categorical("kernel", ["linear", "rbf", "poly", "sigmoid"]),
|
| 176 |
-
}
|
| 177 |
-
|
| 178 |
-
# Kernel-specific params (only when relevant)
|
| 179 |
-
if params["kernel"] in ["rbf", "poly"]:
|
| 180 |
-
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 181 |
-
else:
|
| 182 |
-
params["gamma"] = None # docs: default treated as 1/n_features for rbf/poly :contentReference[oaicite:5]{index=5}
|
| 183 |
-
|
| 184 |
-
if params["kernel"] == "poly":
|
| 185 |
-
params["degree"] = trial.suggest_int("degree", 2, 6) # docs default=3 :contentReference[oaicite:6]{index=6}
|
| 186 |
-
params["coef0"] = trial.suggest_float("coef0", 0.0, 5.0) # docs default=1 :contentReference[oaicite:7]{index=7}
|
| 187 |
-
else:
|
| 188 |
-
# safe defaults
|
| 189 |
-
params["degree"] = 3
|
| 190 |
-
params["coef0"] = 1.0
|
| 191 |
-
|
| 192 |
-
return params
|
| 193 |
-
# -----------------------------
|
| 194 |
-
# Saving artifacts
|
| 195 |
-
# -----------------------------
|
| 196 |
-
def save_predictions_csv(
|
| 197 |
-
out_dir: str,
|
| 198 |
-
split_name: str,
|
| 199 |
-
y_true: np.ndarray,
|
| 200 |
-
y_prob: np.ndarray,
|
| 201 |
-
threshold: float,
|
| 202 |
-
sequences: Optional[np.ndarray] = None,
|
| 203 |
-
):
|
| 204 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 205 |
-
df = pd.DataFrame({
|
| 206 |
-
"y_true": y_true.astype(int),
|
| 207 |
-
"y_prob": y_prob.astype(float),
|
| 208 |
-
"y_pred": (y_prob >= threshold).astype(int),
|
| 209 |
-
})
|
| 210 |
-
if sequences is not None:
|
| 211 |
-
df.insert(0, "sequence", sequences)
|
| 212 |
-
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 216 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 217 |
-
|
| 218 |
-
# PR
|
| 219 |
-
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 220 |
-
plt.figure()
|
| 221 |
-
plt.plot(recall, precision)
|
| 222 |
-
plt.xlabel("Recall")
|
| 223 |
-
plt.ylabel("Precision")
|
| 224 |
-
plt.title("Precision-Recall Curve")
|
| 225 |
-
plt.tight_layout()
|
| 226 |
-
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 227 |
-
plt.close()
|
| 228 |
-
|
| 229 |
-
# ROC
|
| 230 |
-
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 231 |
-
plt.figure()
|
| 232 |
-
plt.plot(fpr, tpr)
|
| 233 |
-
plt.xlabel("False Positive Rate")
|
| 234 |
-
plt.ylabel("True Positive Rate")
|
| 235 |
-
plt.title("ROC Curve")
|
| 236 |
-
plt.tight_layout()
|
| 237 |
-
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 238 |
-
plt.close()
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
# -----------------------------
|
| 242 |
-
# Optuna objectives
|
| 243 |
-
# -----------------------------
|
| 244 |
-
def make_objective(model_name: str, data: SplitData, out_dir: str):
|
| 245 |
-
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 246 |
-
|
| 247 |
-
def objective(trial: optuna.Trial) -> float:
|
| 248 |
-
if model_name == "xgb":
|
| 249 |
-
params = {
|
| 250 |
-
"objective": "binary:logistic",
|
| 251 |
-
"eval_metric": "logloss",
|
| 252 |
-
"lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
|
| 253 |
-
"alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
|
| 254 |
-
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 255 |
-
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 256 |
-
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 257 |
-
"max_depth": trial.suggest_int("max_depth", 2, 15),
|
| 258 |
-
"min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
|
| 259 |
-
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 260 |
-
"tree_method": "hist",
|
| 261 |
-
"device": "cuda",
|
| 262 |
-
}
|
| 263 |
-
|
| 264 |
-
# Optional GPU: set env CUDA_VISIBLE_DEVICES externally if you want.
|
| 265 |
-
# If you *know* you want GPU and your xgboost supports it:
|
| 266 |
-
# params["device"] = "cuda"
|
| 267 |
-
|
| 268 |
-
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
|
| 269 |
-
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 270 |
-
|
| 271 |
-
model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
|
| 272 |
-
|
| 273 |
-
elif model_name == "adaboost":
|
| 274 |
-
params = {
|
| 275 |
-
"n_estimators": trial.suggest_int("n_estimators", 50, 800),
|
| 276 |
-
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 2.0, log=True),
|
| 277 |
-
"base_depth": trial.suggest_int("base_depth", 1, 4),
|
| 278 |
-
}
|
| 279 |
-
model, p_tr, p_va = train_adaboost(Xtr, ytr, Xva, yva, params)
|
| 280 |
-
|
| 281 |
-
elif model_name == "linearboost":
|
| 282 |
-
params = suggest_linearboost_params(trial)
|
| 283 |
-
model, p_tr, p_va = train_linearboost(Xtr, ytr, Xva, yva, params)
|
| 284 |
-
else:
|
| 285 |
-
raise ValueError(f"Unknown model_name={model_name}")
|
| 286 |
-
|
| 287 |
-
# Threshold picked on val for fair comparison across models
|
| 288 |
-
thr, f1_at_thr = best_f1_threshold(yva, p_va)
|
| 289 |
-
metrics = eval_binary(yva, p_va, thr)
|
| 290 |
-
|
| 291 |
-
# Track best trial artifacts inside the study directory
|
| 292 |
-
trial.set_user_attr("threshold", thr)
|
| 293 |
-
trial.set_user_attr("auc", metrics["auc"])
|
| 294 |
-
trial.set_user_attr("ap", metrics["ap"])
|
| 295 |
-
|
| 296 |
-
return f1_at_thr
|
| 297 |
-
|
| 298 |
-
return objective
|
| 299 |
-
|
| 300 |
-
# -----------------------------
|
| 301 |
-
# Main runner
|
| 302 |
-
# -----------------------------
|
| 303 |
-
def run_optuna_and_refit(
|
| 304 |
-
dataset_path: str,
|
| 305 |
-
out_dir: str,
|
| 306 |
-
model_name: str,
|
| 307 |
-
n_trials: int = 200,
|
| 308 |
-
):
|
| 309 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 310 |
-
|
| 311 |
-
data = load_split_data(dataset_path)
|
| 312 |
-
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 313 |
-
|
| 314 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 315 |
-
study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
|
| 316 |
-
|
| 317 |
-
# Save trials table
|
| 318 |
-
trials_df = study.trials_dataframe()
|
| 319 |
-
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 320 |
-
|
| 321 |
-
best = study.best_trial
|
| 322 |
-
best_params = dict(best.params)
|
| 323 |
-
best_thr = float(best.user_attrs["threshold"])
|
| 324 |
-
best_auc = float(best.user_attrs["auc"])
|
| 325 |
-
best_ap = float(best.user_attrs["ap"])
|
| 326 |
-
best_f1 = float(best.value)
|
| 327 |
-
|
| 328 |
-
# Refit best model on train (same protocol as objective)
|
| 329 |
-
if model_name == "xgb":
|
| 330 |
-
# Reconstruct full param dict
|
| 331 |
-
params = {
|
| 332 |
-
"objective": "binary:logistic",
|
| 333 |
-
"eval_metric": "logloss",
|
| 334 |
-
"lambda": best_params["lambda"],
|
| 335 |
-
"alpha": best_params["alpha"],
|
| 336 |
-
"colsample_bytree": best_params["colsample_bytree"],
|
| 337 |
-
"subsample": best_params["subsample"],
|
| 338 |
-
"learning_rate": best_params["learning_rate"],
|
| 339 |
-
"max_depth": best_params["max_depth"],
|
| 340 |
-
"min_child_weight": best_params["min_child_weight"],
|
| 341 |
-
"gamma": best_params["gamma"],
|
| 342 |
-
"tree_method": "hist",
|
| 343 |
-
"num_boost_round": best_params["num_boost_round"],
|
| 344 |
-
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 345 |
-
}
|
| 346 |
-
model, p_tr, p_va = train_xgb(
|
| 347 |
-
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 348 |
-
)
|
| 349 |
-
model_path = os.path.join(out_dir, "best_model.json")
|
| 350 |
-
model.save_model(model_path)
|
| 351 |
-
|
| 352 |
-
elif model_name == "adaboost":
|
| 353 |
-
params = best_params
|
| 354 |
-
model, p_tr, p_va = train_adaboost(
|
| 355 |
-
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 356 |
-
)
|
| 357 |
-
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 358 |
-
joblib.dump(model, model_path)
|
| 359 |
-
|
| 360 |
-
elif model_name == "linearboost":
|
| 361 |
-
params = best_params
|
| 362 |
-
|
| 363 |
-
model, p_tr, p_va = train_linearboost(
|
| 364 |
-
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 368 |
-
joblib.dump(model, model_path)
|
| 369 |
-
else:
|
| 370 |
-
raise ValueError(model_name)
|
| 371 |
-
|
| 372 |
-
# Save predictions CSVs
|
| 373 |
-
save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
|
| 374 |
-
save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
|
| 375 |
-
|
| 376 |
-
# Plots on val
|
| 377 |
-
plot_curves(out_dir, data.y_val, p_va)
|
| 378 |
-
|
| 379 |
-
# Summary
|
| 380 |
-
summary = [
|
| 381 |
-
"=" * 72,
|
| 382 |
-
f"MODEL: {model_name}",
|
| 383 |
-
f"Best trial: {best.number}",
|
| 384 |
-
f"Best F1 (val @ best-threshold): {best_f1:.4f}",
|
| 385 |
-
f"Val AUC: {best_auc:.4f}",
|
| 386 |
-
f"Val AP: {best_ap:.4f}",
|
| 387 |
-
f"Best threshold (picked on val): {best_thr:.4f}",
|
| 388 |
-
f"Model saved to: {model_path}",
|
| 389 |
-
"Best params:",
|
| 390 |
-
json.dumps(best_params, indent=2),
|
| 391 |
-
"=" * 72,
|
| 392 |
-
]
|
| 393 |
-
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 394 |
-
f.write("\n".join(summary))
|
| 395 |
-
print("\n".join(summary))
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
if __name__ == "__main__":
|
| 399 |
-
# Example usage:
|
| 400 |
-
# dataset_path = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/data/solubility"
|
| 401 |
-
# out_dir = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/src/solubility/xgb"
|
| 402 |
-
# run_optuna_and_refit(dataset_path, out_dir, model_name="xgb", n_trials=200)
|
| 403 |
-
|
| 404 |
-
import argparse
|
| 405 |
-
parser = argparse.ArgumentParser()
|
| 406 |
-
parser.add_argument("--dataset_path", type=str, required=True)
|
| 407 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 408 |
-
parser.add_argument("--model", type=str, choices=["xgb", "adaboost", "linearboost"], required=True)
|
| 409 |
-
parser.add_argument("--n_trials", type=int, default=200)
|
| 410 |
-
args = parser.parse_args()
|
| 411 |
-
|
| 412 |
-
run_optuna_and_refit(
|
| 413 |
-
dataset_path=args.dataset_path,
|
| 414 |
-
out_dir=args.out_dir,
|
| 415 |
-
model_name=args.model,
|
| 416 |
-
n_trials=args.n_trials,
|
| 417 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py
DELETED
|
@@ -1,468 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
import joblib
|
| 4 |
-
import optuna
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from typing import Dict, Any, Tuple, Optional
|
| 10 |
-
from datasets import load_from_disk, DatasetDict
|
| 11 |
-
from sklearn.metrics import (
|
| 12 |
-
f1_score, roc_auc_score, average_precision_score,
|
| 13 |
-
precision_recall_curve, roc_curve
|
| 14 |
-
)
|
| 15 |
-
from sklearn.linear_model import LogisticRegression
|
| 16 |
-
from sklearn.svm import SVC, LinearSVC
|
| 17 |
-
from sklearn.calibration import CalibratedClassifierCV
|
| 18 |
-
import torch
|
| 19 |
-
import time
|
| 20 |
-
import xgboost as xgb
|
| 21 |
-
from lightning.pytorch import seed_everything
|
| 22 |
-
import cupy as cp
|
| 23 |
-
from cuml.svm import SVC as cuSVC
|
| 24 |
-
from cuml.linear_model import LogisticRegression as cuLogReg
|
| 25 |
-
seed_everything(1986)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def to_gpu(X: np.ndarray):
|
| 29 |
-
if isinstance(X, cp.ndarray):
|
| 30 |
-
return X
|
| 31 |
-
return cp.asarray(X, dtype=cp.float32)
|
| 32 |
-
|
| 33 |
-
def to_cpu(x):
|
| 34 |
-
if isinstance(x, cp.ndarray):
|
| 35 |
-
return cp.asnumpy(x)
|
| 36 |
-
return np.asarray(x)
|
| 37 |
-
|
| 38 |
-
@dataclass
|
| 39 |
-
class SplitData:
|
| 40 |
-
X_train: np.ndarray
|
| 41 |
-
y_train: np.ndarray
|
| 42 |
-
seq_train: Optional[np.ndarray]
|
| 43 |
-
X_val: np.ndarray
|
| 44 |
-
y_val: np.ndarray
|
| 45 |
-
seq_val: Optional[np.ndarray]
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def _stack_embeddings(col) -> np.ndarray:
|
| 49 |
-
arr = np.asarray(col, dtype=np.float32)
|
| 50 |
-
if arr.ndim != 2:
|
| 51 |
-
arr = np.stack(col).astype(np.float32)
|
| 52 |
-
return arr
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def load_split_data(dataset_path: str) -> SplitData:
|
| 56 |
-
ds = load_from_disk(dataset_path)
|
| 57 |
-
|
| 58 |
-
# Case A: DatasetDict with train/val
|
| 59 |
-
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 60 |
-
train_ds, val_ds = ds["train"], ds["val"]
|
| 61 |
-
else:
|
| 62 |
-
# Case B: Single dataset with "split" column
|
| 63 |
-
if "split" not in ds.column_names:
|
| 64 |
-
raise ValueError(
|
| 65 |
-
"Dataset must be a DatasetDict(train/val) or have a 'split' column."
|
| 66 |
-
)
|
| 67 |
-
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 68 |
-
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 69 |
-
|
| 70 |
-
for required in ["embedding", "label"]:
|
| 71 |
-
if required not in train_ds.column_names:
|
| 72 |
-
raise ValueError(f"Missing column '{required}' in train split.")
|
| 73 |
-
if required not in val_ds.column_names:
|
| 74 |
-
raise ValueError(f"Missing column '{required}' in val split.")
|
| 75 |
-
|
| 76 |
-
X_train = _stack_embeddings(train_ds["embedding"])
|
| 77 |
-
y_train = np.asarray(train_ds["label"], dtype=np.int64)
|
| 78 |
-
|
| 79 |
-
X_val = _stack_embeddings(val_ds["embedding"])
|
| 80 |
-
y_val = np.asarray(val_ds["label"], dtype=np.int64)
|
| 81 |
-
|
| 82 |
-
seq_train = None
|
| 83 |
-
seq_val = None
|
| 84 |
-
if "sequence" in train_ds.column_names:
|
| 85 |
-
seq_train = np.asarray(train_ds["sequence"])
|
| 86 |
-
if "sequence" in val_ds.column_names:
|
| 87 |
-
seq_val = np.asarray(val_ds["sequence"])
|
| 88 |
-
|
| 89 |
-
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
|
| 93 |
-
"""
|
| 94 |
-
Find threshold maximizing F1 on the given set.
|
| 95 |
-
Returns (best_threshold, best_f1).
|
| 96 |
-
"""
|
| 97 |
-
precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
|
| 98 |
-
f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
|
| 99 |
-
best_idx = int(np.nanargmax(f1s))
|
| 100 |
-
return float(thresholds[best_idx]), float(f1s[best_idx])
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
|
| 104 |
-
y_pred = (y_prob >= threshold).astype(int)
|
| 105 |
-
return {
|
| 106 |
-
"f1": float(f1_score(y_true, y_pred)),
|
| 107 |
-
"auc": float(roc_auc_score(y_true, y_prob)),
|
| 108 |
-
"ap": float(average_precision_score(y_true, y_prob)),
|
| 109 |
-
"threshold": float(threshold),
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
# -----------------------------
|
| 114 |
-
# Model
|
| 115 |
-
# -----------------------------
|
| 116 |
-
def train_xgb(
|
| 117 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 118 |
-
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 119 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 120 |
-
dval = xgb.DMatrix(X_val, label=y_val)
|
| 121 |
-
|
| 122 |
-
num_boost_round = int(params.pop("num_boost_round"))
|
| 123 |
-
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 124 |
-
|
| 125 |
-
booster = xgb.train(
|
| 126 |
-
params=params,
|
| 127 |
-
dtrain=dtrain,
|
| 128 |
-
num_boost_round=num_boost_round,
|
| 129 |
-
evals=[(dval, "val")],
|
| 130 |
-
early_stopping_rounds=early_stopping_rounds,
|
| 131 |
-
verbose_eval=False,
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
p_train = booster.predict(dtrain)
|
| 135 |
-
p_val = booster.predict(dval)
|
| 136 |
-
return booster, p_train, p_val
|
| 137 |
-
|
| 138 |
-
def train_cuml_svc(X_train, y_train, X_val, y_val, params):
|
| 139 |
-
Xtr = to_gpu(X_train)
|
| 140 |
-
Xva = to_gpu(X_val)
|
| 141 |
-
ytr = to_gpu(y_train).astype(cp.int32)
|
| 142 |
-
|
| 143 |
-
clf = cuSVC(
|
| 144 |
-
C=float(params["C"]),
|
| 145 |
-
kernel=params["kernel"],
|
| 146 |
-
gamma=params.get("gamma", "scale"),
|
| 147 |
-
class_weight=params.get("class_weight", None),
|
| 148 |
-
probability=bool(params.get("probability", True)),
|
| 149 |
-
random_state=1986,
|
| 150 |
-
max_iter=int(params.get("max_iter", 1000)),
|
| 151 |
-
tol=float(params.get("tol", 1e-4)),
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
clf.fit(Xtr, ytr)
|
| 155 |
-
|
| 156 |
-
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
|
| 157 |
-
p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
|
| 158 |
-
return clf, p_train, p_val
|
| 159 |
-
|
| 160 |
-
def train_cuml_elastic_net(X_train, y_train, X_val, y_val, params):
|
| 161 |
-
Xtr = to_gpu(X_train)
|
| 162 |
-
Xva = to_gpu(X_val)
|
| 163 |
-
ytr = to_gpu(y_train).astype(cp.int32)
|
| 164 |
-
|
| 165 |
-
clf = cuLogReg(
|
| 166 |
-
penalty="elasticnet",
|
| 167 |
-
C=float(params["C"]),
|
| 168 |
-
l1_ratio=float(params["l1_ratio"]),
|
| 169 |
-
class_weight=params.get("class_weight", None),
|
| 170 |
-
max_iter=int(params.get("max_iter", 1000)),
|
| 171 |
-
tol=float(params.get("tol", 1e-4)),
|
| 172 |
-
solver="qn",
|
| 173 |
-
fit_intercept=True,
|
| 174 |
-
)
|
| 175 |
-
clf.fit(Xtr, ytr)
|
| 176 |
-
|
| 177 |
-
p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
|
| 178 |
-
p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
|
| 179 |
-
return clf, p_train, p_val
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def train_svm(X_train, y_train, X_val, y_val, params):
|
| 183 |
-
"""
|
| 184 |
-
Kernel SVM via SVC. CPU only in sklearn.
|
| 185 |
-
probability=True enables predict_proba but is slower.
|
| 186 |
-
"""
|
| 187 |
-
clf = SVC(
|
| 188 |
-
C=float(params["C"]),
|
| 189 |
-
kernel=params["kernel"],
|
| 190 |
-
gamma=params.get("gamma", "scale"),
|
| 191 |
-
class_weight=params.get("class_weight", None),
|
| 192 |
-
probability=True,
|
| 193 |
-
random_state=1986,
|
| 194 |
-
)
|
| 195 |
-
clf.fit(X_train, y_train)
|
| 196 |
-
p_train = clf.predict_proba(X_train)[:, 1]
|
| 197 |
-
p_val = clf.predict_proba(X_val)[:, 1]
|
| 198 |
-
return clf, p_train, p_val
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
|
| 202 |
-
"""
|
| 203 |
-
Fast linear SVM (LinearSVC) + probability calibration.
|
| 204 |
-
Usually much faster than SVC on large datasets.
|
| 205 |
-
"""
|
| 206 |
-
base = LinearSVC(
|
| 207 |
-
C=float(params["C"]),
|
| 208 |
-
class_weight=params.get("class_weight", None),
|
| 209 |
-
max_iter=int(params.get("max_iter", 5000)),
|
| 210 |
-
random_state=1986,
|
| 211 |
-
)
|
| 212 |
-
# calibration to get probabilities for PR/ROC + thresholding
|
| 213 |
-
clf = CalibratedClassifierCV(base, method="sigmoid", cv=3)
|
| 214 |
-
clf.fit(X_train, y_train)
|
| 215 |
-
p_train = clf.predict_proba(X_train)[:, 1]
|
| 216 |
-
p_val = clf.predict_proba(X_val)[:, 1]
|
| 217 |
-
return clf, p_train, p_val
|
| 218 |
-
|
| 219 |
-
# -----------------------------
|
| 220 |
-
# Saving artifacts
|
| 221 |
-
# -----------------------------
|
| 222 |
-
def save_predictions_csv(
|
| 223 |
-
out_dir: str,
|
| 224 |
-
split_name: str,
|
| 225 |
-
y_true: np.ndarray,
|
| 226 |
-
y_prob: np.ndarray,
|
| 227 |
-
threshold: float,
|
| 228 |
-
sequences: Optional[np.ndarray] = None,
|
| 229 |
-
):
|
| 230 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 231 |
-
df = pd.DataFrame({
|
| 232 |
-
"y_true": y_true.astype(int),
|
| 233 |
-
"y_prob": y_prob.astype(float),
|
| 234 |
-
"y_pred": (y_prob >= threshold).astype(int),
|
| 235 |
-
})
|
| 236 |
-
if sequences is not None:
|
| 237 |
-
df.insert(0, "sequence", sequences)
|
| 238 |
-
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 242 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 243 |
-
|
| 244 |
-
# PR
|
| 245 |
-
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 246 |
-
plt.figure()
|
| 247 |
-
plt.plot(recall, precision)
|
| 248 |
-
plt.xlabel("Recall")
|
| 249 |
-
plt.ylabel("Precision")
|
| 250 |
-
plt.title("Precision-Recall Curve")
|
| 251 |
-
plt.tight_layout()
|
| 252 |
-
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 253 |
-
plt.close()
|
| 254 |
-
|
| 255 |
-
# ROC
|
| 256 |
-
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 257 |
-
plt.figure()
|
| 258 |
-
plt.plot(fpr, tpr)
|
| 259 |
-
plt.xlabel("False Positive Rate")
|
| 260 |
-
plt.ylabel("True Positive Rate")
|
| 261 |
-
plt.title("ROC Curve")
|
| 262 |
-
plt.tight_layout()
|
| 263 |
-
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 264 |
-
plt.close()
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
# -----------------------------
|
| 268 |
-
# Optuna objectives
|
| 269 |
-
# -----------------------------
|
| 270 |
-
def make_objective(model_name: str, data: SplitData, out_dir: str):
|
| 271 |
-
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 272 |
-
|
| 273 |
-
def objective(trial: optuna.Trial) -> float:
|
| 274 |
-
if model_name == "xgb":
|
| 275 |
-
params = {
|
| 276 |
-
"objective": "binary:logistic",
|
| 277 |
-
"eval_metric": "logloss",
|
| 278 |
-
"lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
|
| 279 |
-
"alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
|
| 280 |
-
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 281 |
-
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 282 |
-
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 283 |
-
"max_depth": trial.suggest_int("max_depth", 2, 15),
|
| 284 |
-
"min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
|
| 285 |
-
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 286 |
-
"tree_method": "hist",
|
| 287 |
-
"device": "cuda",
|
| 288 |
-
}
|
| 289 |
-
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
|
| 290 |
-
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 291 |
-
|
| 292 |
-
model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
|
| 293 |
-
|
| 294 |
-
elif model_name == "svm":
|
| 295 |
-
svm_kind = trial.suggest_categorical("svm_kind", ["svc", "linear_calibrated"])
|
| 296 |
-
|
| 297 |
-
if svm_kind == "svc":
|
| 298 |
-
params = {
|
| 299 |
-
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 300 |
-
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 301 |
-
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 302 |
-
}
|
| 303 |
-
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 304 |
-
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 305 |
-
else:
|
| 306 |
-
params["gamma"] = "scale"
|
| 307 |
-
|
| 308 |
-
model, p_tr, p_va = train_svm(Xtr, ytr, Xva, yva, params)
|
| 309 |
-
|
| 310 |
-
else:
|
| 311 |
-
params = {
|
| 312 |
-
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 313 |
-
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 314 |
-
"max_iter": trial.suggest_int("max_iter", 2000, 20000),
|
| 315 |
-
}
|
| 316 |
-
model, p_tr, p_va = train_linearsvm_calibrated(Xtr, ytr, Xva, yva, params)
|
| 317 |
-
elif model_name == "svm_gpu":
|
| 318 |
-
params = {
|
| 319 |
-
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 320 |
-
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 321 |
-
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 322 |
-
"probability": True,
|
| 323 |
-
"max_iter": trial.suggest_int("max_iter", 200, 5000),
|
| 324 |
-
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 325 |
-
}
|
| 326 |
-
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 327 |
-
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 328 |
-
else:
|
| 329 |
-
params["gamma"] = "scale"
|
| 330 |
-
|
| 331 |
-
model, p_tr, p_va = train_cuml_svc(Xtr, ytr, Xva, yva, params)
|
| 332 |
-
|
| 333 |
-
elif model_name == "enet_gpu":
|
| 334 |
-
params = {
|
| 335 |
-
"C": trial.suggest_float("C", 1e-4, 1e3, log=True),
|
| 336 |
-
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
|
| 337 |
-
"class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
|
| 338 |
-
"max_iter": trial.suggest_int("max_iter", 200, 5000),
|
| 339 |
-
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 340 |
-
}
|
| 341 |
-
model, p_tr, p_va = train_cuml_elastic_net(Xtr, ytr, Xva, yva, params)
|
| 342 |
-
else:
|
| 343 |
-
raise ValueError(f"Unknown model_name={model_name}")
|
| 344 |
-
|
| 345 |
-
thr, f1_at_thr = best_f1_threshold(yva, p_va)
|
| 346 |
-
metrics = eval_binary(yva, p_va, thr)
|
| 347 |
-
trial.set_user_attr("threshold", thr)
|
| 348 |
-
trial.set_user_attr("auc", metrics["auc"])
|
| 349 |
-
trial.set_user_attr("ap", metrics["ap"])
|
| 350 |
-
return f1_at_thr
|
| 351 |
-
|
| 352 |
-
return objective
|
| 353 |
-
|
| 354 |
-
# -----------------------------
|
| 355 |
-
# Main
|
| 356 |
-
# -----------------------------
|
| 357 |
-
def run_optuna_and_refit(
|
| 358 |
-
dataset_path: str,
|
| 359 |
-
out_dir: str,
|
| 360 |
-
model_name: str,
|
| 361 |
-
n_trials: int = 200,
|
| 362 |
-
):
|
| 363 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 364 |
-
|
| 365 |
-
data = load_split_data(dataset_path)
|
| 366 |
-
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 367 |
-
|
| 368 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 369 |
-
study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
|
| 370 |
-
|
| 371 |
-
trials_df = study.trials_dataframe()
|
| 372 |
-
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 373 |
-
|
| 374 |
-
best = study.best_trial
|
| 375 |
-
best_params = dict(best.params)
|
| 376 |
-
best_thr = float(best.user_attrs["threshold"])
|
| 377 |
-
best_auc = float(best.user_attrs["auc"])
|
| 378 |
-
best_ap = float(best.user_attrs["ap"])
|
| 379 |
-
best_f1 = float(best.value)
|
| 380 |
-
|
| 381 |
-
# Refit best model on train
|
| 382 |
-
if model_name == "xgb":
|
| 383 |
-
params = {
|
| 384 |
-
"objective": "binary:logistic",
|
| 385 |
-
"eval_metric": "logloss",
|
| 386 |
-
"lambda": best_params["lambda"],
|
| 387 |
-
"alpha": best_params["alpha"],
|
| 388 |
-
"colsample_bytree": best_params["colsample_bytree"],
|
| 389 |
-
"subsample": best_params["subsample"],
|
| 390 |
-
"learning_rate": best_params["learning_rate"],
|
| 391 |
-
"max_depth": best_params["max_depth"],
|
| 392 |
-
"min_child_weight": best_params["min_child_weight"],
|
| 393 |
-
"gamma": best_params["gamma"],
|
| 394 |
-
"tree_method": "hist",
|
| 395 |
-
"num_boost_round": best_params["num_boost_round"],
|
| 396 |
-
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 397 |
-
}
|
| 398 |
-
model, p_tr, p_va = train_xgb(
|
| 399 |
-
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 400 |
-
)
|
| 401 |
-
model_path = os.path.join(out_dir, "best_model.json")
|
| 402 |
-
model.save_model(model_path)
|
| 403 |
-
|
| 404 |
-
elif model_name == "svm":
|
| 405 |
-
svm_kind = best_params["svm_kind"]
|
| 406 |
-
if svm_kind == "svc":
|
| 407 |
-
model, p_tr, p_va = train_svm(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
|
| 408 |
-
else:
|
| 409 |
-
model, p_tr, p_va = train_linearsvm_calibrated(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
|
| 410 |
-
|
| 411 |
-
model_path = os.path.join(out_dir, "best_model.joblib")
|
| 412 |
-
joblib.dump(model, model_path)
|
| 413 |
-
elif model_name == "svm_gpu":
|
| 414 |
-
model, p_tr, p_va = train_cuml_svc(
|
| 415 |
-
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 416 |
-
)
|
| 417 |
-
model_path = os.path.join(out_dir, "best_model_cuml_svc.joblib")
|
| 418 |
-
joblib.dump(model, model_path)
|
| 419 |
-
|
| 420 |
-
elif model_name == "enet_gpu":
|
| 421 |
-
model, p_tr, p_va = train_cuml_elastic_net(
|
| 422 |
-
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 423 |
-
)
|
| 424 |
-
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
|
| 425 |
-
joblib.dump(model, model_path)
|
| 426 |
-
else:
|
| 427 |
-
raise ValueError(model_name)
|
| 428 |
-
|
| 429 |
-
# Save predictions CSVs
|
| 430 |
-
save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
|
| 431 |
-
save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
|
| 432 |
-
|
| 433 |
-
# Plots on val
|
| 434 |
-
plot_curves(out_dir, data.y_val, p_va)
|
| 435 |
-
|
| 436 |
-
summary = [
|
| 437 |
-
"=" * 72,
|
| 438 |
-
f"MODEL: {model_name}",
|
| 439 |
-
f"Best trial: {best.number}",
|
| 440 |
-
f"Best F1 (val @ best-threshold): {best_f1:.4f}",
|
| 441 |
-
f"Val AUC: {best_auc:.4f}",
|
| 442 |
-
f"Val AP: {best_ap:.4f}",
|
| 443 |
-
f"Best threshold (picked on val): {best_thr:.4f}",
|
| 444 |
-
f"Model saved to: {model_path}",
|
| 445 |
-
"Best params:",
|
| 446 |
-
json.dumps(best_params, indent=2),
|
| 447 |
-
"=" * 72,
|
| 448 |
-
]
|
| 449 |
-
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 450 |
-
f.write("\n".join(summary))
|
| 451 |
-
print("\n".join(summary))
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
if __name__ == "__main__":
|
| 455 |
-
import argparse
|
| 456 |
-
parser = argparse.ArgumentParser()
|
| 457 |
-
parser.add_argument("--dataset_path", type=str, required=True)
|
| 458 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 459 |
-
parser.add_argument("--model", type=str, choices=["xgb", "svm_gpu", "enet_gpu"], required=True)
|
| 460 |
-
parser.add_argument("--n_trials", type=int, default=200)
|
| 461 |
-
args = parser.parse_args()
|
| 462 |
-
|
| 463 |
-
run_optuna_and_refit(
|
| 464 |
-
dataset_path=args.dataset_path,
|
| 465 |
-
out_dir=args.out_dir,
|
| 466 |
-
model_name=args.model,
|
| 467 |
-
n_trials=args.n_trials,
|
| 468 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py
DELETED
|
@@ -1,410 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
import joblib
|
| 4 |
-
import optuna
|
| 5 |
-
import numpy as np
|
| 6 |
-
import pandas as pd
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
-
from typing import Dict, Any, Tuple, Optional
|
| 10 |
-
from datasets import load_from_disk, DatasetDict
|
| 11 |
-
from sklearn.preprocessing import StandardScaler
|
| 12 |
-
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 13 |
-
from sklearn.svm import SVR
|
| 14 |
-
import xgboost as xgb
|
| 15 |
-
from lightning.pytorch import seed_everything
|
| 16 |
-
import cupy as cp
|
| 17 |
-
from cuml.linear_model import ElasticNet as cuElasticNet
|
| 18 |
-
from scipy.stats import spearmanr
|
| 19 |
-
seed_everything(1986)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# -----------------------------
|
| 23 |
-
# GPU/CPU helpers
|
| 24 |
-
# -----------------------------
|
| 25 |
-
def to_gpu(X: np.ndarray):
|
| 26 |
-
if isinstance(X, cp.ndarray):
|
| 27 |
-
return X
|
| 28 |
-
return cp.asarray(X, dtype=cp.float32)
|
| 29 |
-
|
| 30 |
-
def to_cpu(x):
|
| 31 |
-
if isinstance(x, cp.ndarray):
|
| 32 |
-
return cp.asnumpy(x)
|
| 33 |
-
return np.asarray(x)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# -----------------------------
|
| 37 |
-
# Data loading
|
| 38 |
-
# -----------------------------
|
| 39 |
-
@dataclass
|
| 40 |
-
class SplitData:
|
| 41 |
-
X_train: np.ndarray
|
| 42 |
-
y_train: np.ndarray
|
| 43 |
-
seq_train: Optional[np.ndarray]
|
| 44 |
-
X_val: np.ndarray
|
| 45 |
-
y_val: np.ndarray
|
| 46 |
-
seq_val: Optional[np.ndarray]
|
| 47 |
-
|
| 48 |
-
def _stack_embeddings(col) -> np.ndarray:
|
| 49 |
-
arr = np.asarray(col, dtype=np.float32)
|
| 50 |
-
if arr.ndim != 2:
|
| 51 |
-
arr = np.stack(col).astype(np.float32)
|
| 52 |
-
return arr
|
| 53 |
-
|
| 54 |
-
def load_split_data(dataset_path: str) -> SplitData:
|
| 55 |
-
ds = load_from_disk(dataset_path)
|
| 56 |
-
|
| 57 |
-
if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
|
| 58 |
-
train_ds, val_ds = ds["train"], ds["val"]
|
| 59 |
-
else:
|
| 60 |
-
if "split" not in ds.column_names:
|
| 61 |
-
raise ValueError("Dataset must be a DatasetDict(train/val) or have a 'split' column.")
|
| 62 |
-
train_ds = ds.filter(lambda x: x["split"] == "train")
|
| 63 |
-
val_ds = ds.filter(lambda x: x["split"] == "val")
|
| 64 |
-
|
| 65 |
-
for required in ["embedding", "label"]:
|
| 66 |
-
if required not in train_ds.column_names:
|
| 67 |
-
raise ValueError(f"Missing column '{required}' in train split.")
|
| 68 |
-
if required not in val_ds.column_names:
|
| 69 |
-
raise ValueError(f"Missing column '{required}' in val split.")
|
| 70 |
-
|
| 71 |
-
X_train = _stack_embeddings(train_ds["embedding"]).astype(np.float32)
|
| 72 |
-
X_val = _stack_embeddings(val_ds["embedding"]).astype(np.float32)
|
| 73 |
-
|
| 74 |
-
y_train = np.asarray(train_ds["label"], dtype=np.float32)
|
| 75 |
-
y_val = np.asarray(val_ds["label"], dtype=np.float32)
|
| 76 |
-
|
| 77 |
-
seq_train = None
|
| 78 |
-
seq_val = None
|
| 79 |
-
if "sequence" in train_ds.column_names:
|
| 80 |
-
seq_train = np.asarray(train_ds["sequence"])
|
| 81 |
-
if "sequence" in val_ds.column_names:
|
| 82 |
-
seq_val = np.asarray(val_ds["sequence"])
|
| 83 |
-
|
| 84 |
-
return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# -----------------------------
|
| 88 |
-
# Metrics
|
| 89 |
-
# -----------------------------
|
| 90 |
-
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 91 |
-
rho = spearmanr(y_true, y_pred).correlation
|
| 92 |
-
if rho is None or np.isnan(rho):
|
| 93 |
-
return 0.0
|
| 94 |
-
return float(rho)
|
| 95 |
-
|
| 96 |
-
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 97 |
-
# RMSE
|
| 98 |
-
try:
|
| 99 |
-
from sklearn.metrics import root_mean_squared_error
|
| 100 |
-
rmse = root_mean_squared_error(y_true, y_pred)
|
| 101 |
-
except Exception:
|
| 102 |
-
rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
|
| 103 |
-
|
| 104 |
-
mae = float(mean_absolute_error(y_true, y_pred))
|
| 105 |
-
r2 = float(r2_score(y_true, y_pred))
|
| 106 |
-
rho = float(safe_spearmanr(y_true, y_pred))
|
| 107 |
-
return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
# -----------------------------
|
| 111 |
-
# Model
|
| 112 |
-
# -----------------------------
|
| 113 |
-
def train_xgb_reg(
|
| 114 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 115 |
-
) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
|
| 116 |
-
dtrain = xgb.DMatrix(X_train, label=y_train)
|
| 117 |
-
dval = xgb.DMatrix(X_val, label=y_val)
|
| 118 |
-
|
| 119 |
-
num_boost_round = int(params.pop("num_boost_round"))
|
| 120 |
-
early_stopping_rounds = int(params.pop("early_stopping_rounds"))
|
| 121 |
-
|
| 122 |
-
booster = xgb.train(
|
| 123 |
-
params=params,
|
| 124 |
-
dtrain=dtrain,
|
| 125 |
-
num_boost_round=num_boost_round,
|
| 126 |
-
evals=[(dval, "val")],
|
| 127 |
-
early_stopping_rounds=early_stopping_rounds,
|
| 128 |
-
verbose_eval=False,
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
p_train = booster.predict(dtrain)
|
| 132 |
-
p_val = booster.predict(dval)
|
| 133 |
-
return booster, p_train, p_val
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def train_cuml_elasticnet_reg(
|
| 137 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 138 |
-
):
|
| 139 |
-
Xtr = to_gpu(X_train)
|
| 140 |
-
Xva = to_gpu(X_val)
|
| 141 |
-
ytr = to_gpu(y_train).astype(cp.float32)
|
| 142 |
-
|
| 143 |
-
model = cuElasticNet(
|
| 144 |
-
alpha=float(params["alpha"]),
|
| 145 |
-
l1_ratio=float(params["l1_ratio"]),
|
| 146 |
-
fit_intercept=True,
|
| 147 |
-
max_iter=int(params.get("max_iter", 5000)),
|
| 148 |
-
tol=float(params.get("tol", 1e-4)),
|
| 149 |
-
selection=params.get("selection", "cyclic"),
|
| 150 |
-
)
|
| 151 |
-
model.fit(Xtr, ytr)
|
| 152 |
-
|
| 153 |
-
p_train = to_cpu(model.predict(Xtr))
|
| 154 |
-
p_val = to_cpu(model.predict(Xva))
|
| 155 |
-
return model, p_train, p_val
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
def train_svr_reg(
|
| 159 |
-
X_train, y_train, X_val, y_val, params: Dict[str, Any]
|
| 160 |
-
):
|
| 161 |
-
model = SVR(
|
| 162 |
-
C=float(params["C"]),
|
| 163 |
-
epsilon=float(params["epsilon"]),
|
| 164 |
-
kernel=params["kernel"],
|
| 165 |
-
gamma=params.get("gamma", "scale"),
|
| 166 |
-
)
|
| 167 |
-
model.fit(X_train, y_train)
|
| 168 |
-
p_train = model.predict(X_train)
|
| 169 |
-
p_val = model.predict(X_val)
|
| 170 |
-
return model, p_train, p_val
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# -----------------------------
|
| 174 |
-
# Saving + plots
|
| 175 |
-
# -----------------------------
|
| 176 |
-
def save_predictions_csv(
|
| 177 |
-
out_dir: str,
|
| 178 |
-
split_name: str,
|
| 179 |
-
y_true: np.ndarray,
|
| 180 |
-
y_pred: np.ndarray,
|
| 181 |
-
sequences: Optional[np.ndarray] = None,
|
| 182 |
-
):
|
| 183 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 184 |
-
df = pd.DataFrame({
|
| 185 |
-
"y_true": y_true.astype(float),
|
| 186 |
-
"y_pred": y_pred.astype(float),
|
| 187 |
-
"residual": (y_true - y_pred).astype(float),
|
| 188 |
-
})
|
| 189 |
-
if sequences is not None:
|
| 190 |
-
df.insert(0, "sequence", sequences)
|
| 191 |
-
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 192 |
-
|
| 193 |
-
def plot_regression_diagnostics(out_dir: str, y_true: np.ndarray, y_pred: np.ndarray):
|
| 194 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 195 |
-
|
| 196 |
-
plt.figure()
|
| 197 |
-
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
|
| 198 |
-
plt.xlabel("y_true")
|
| 199 |
-
plt.ylabel("y_pred")
|
| 200 |
-
plt.title("Predicted vs True")
|
| 201 |
-
plt.tight_layout()
|
| 202 |
-
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
|
| 203 |
-
plt.close()
|
| 204 |
-
|
| 205 |
-
resid = y_true - y_pred
|
| 206 |
-
plt.figure()
|
| 207 |
-
plt.hist(resid, bins=50)
|
| 208 |
-
plt.xlabel("residual (y_true - y_pred)")
|
| 209 |
-
plt.ylabel("count")
|
| 210 |
-
plt.title("Residual Histogram")
|
| 211 |
-
plt.tight_layout()
|
| 212 |
-
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
|
| 213 |
-
plt.close()
|
| 214 |
-
|
| 215 |
-
plt.figure()
|
| 216 |
-
plt.scatter(y_pred, resid, s=8, alpha=0.5)
|
| 217 |
-
plt.xlabel("y_pred")
|
| 218 |
-
plt.ylabel("residual")
|
| 219 |
-
plt.title("Residuals vs Prediction")
|
| 220 |
-
plt.tight_layout()
|
| 221 |
-
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
|
| 222 |
-
plt.close()
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
# -----------------------------
|
| 226 |
-
# Optuna objective (OPTIMIZE SPEARMAN RHO)
|
| 227 |
-
# -----------------------------
|
| 228 |
-
def make_objective(model_name: str, data: SplitData):
|
| 229 |
-
Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
|
| 230 |
-
|
| 231 |
-
def objective(trial: optuna.Trial) -> float:
|
| 232 |
-
if model_name == "xgb_reg":
|
| 233 |
-
params = {
|
| 234 |
-
"objective": "reg:squarederror",
|
| 235 |
-
"eval_metric": "rmse",
|
| 236 |
-
"lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
|
| 237 |
-
"alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
|
| 238 |
-
"gamma": trial.suggest_float("gamma", 0.0, 10.0),
|
| 239 |
-
"max_depth": trial.suggest_int("max_depth", 2, 16),
|
| 240 |
-
"min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 500.0, log=True),
|
| 241 |
-
"subsample": trial.suggest_float("subsample", 0.5, 1.0),
|
| 242 |
-
"colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
|
| 243 |
-
"learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
|
| 244 |
-
"tree_method": "hist",
|
| 245 |
-
"device": "cuda",
|
| 246 |
-
}
|
| 247 |
-
params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 2000)
|
| 248 |
-
params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
|
| 249 |
-
|
| 250 |
-
model, p_tr, p_va = train_xgb_reg(Xtr, ytr, Xva, yva, params.copy())
|
| 251 |
-
|
| 252 |
-
elif model_name == "enet_gpu":
|
| 253 |
-
params = {
|
| 254 |
-
"alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
|
| 255 |
-
"l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
|
| 256 |
-
"max_iter": trial.suggest_int("max_iter", 1000, 20000),
|
| 257 |
-
"tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
|
| 258 |
-
"selection": trial.suggest_categorical("selection", ["cyclic", "random"]),
|
| 259 |
-
}
|
| 260 |
-
model, p_tr, p_va = train_cuml_elasticnet_reg(Xtr, ytr, Xva, yva, params)
|
| 261 |
-
|
| 262 |
-
elif model_name == "svr":
|
| 263 |
-
params = {
|
| 264 |
-
"kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
|
| 265 |
-
"C": trial.suggest_float("C", 1e-3, 1e3, log=True),
|
| 266 |
-
"epsilon": trial.suggest_float("epsilon", 1e-4, 1.0, log=True),
|
| 267 |
-
}
|
| 268 |
-
if params["kernel"] in ["rbf", "poly", "sigmoid"]:
|
| 269 |
-
params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
|
| 270 |
-
else:
|
| 271 |
-
params["gamma"] = "scale"
|
| 272 |
-
|
| 273 |
-
model, p_tr, p_va = train_svr_reg(Xtr, ytr, Xva, yva, params)
|
| 274 |
-
|
| 275 |
-
else:
|
| 276 |
-
raise ValueError(f"Unknown model_name={model_name}")
|
| 277 |
-
|
| 278 |
-
metrics = eval_regression(yva, p_va)
|
| 279 |
-
trial.set_user_attr("spearman_rho", metrics["spearman_rho"])
|
| 280 |
-
trial.set_user_attr("rmse", metrics["rmse"])
|
| 281 |
-
trial.set_user_attr("mae", metrics["mae"])
|
| 282 |
-
trial.set_user_attr("r2", metrics["r2"])
|
| 283 |
-
|
| 284 |
-
# OPTUNA OBJECTIVE = maximize Spearman rho
|
| 285 |
-
return metrics["spearman_rho"]
|
| 286 |
-
|
| 287 |
-
return objective
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
# -----------------------------
|
| 291 |
-
# Main
|
| 292 |
-
# -----------------------------
|
| 293 |
-
def run_optuna_and_refit(
|
| 294 |
-
dataset_path: str,
|
| 295 |
-
out_dir: str,
|
| 296 |
-
model_name: str,
|
| 297 |
-
n_trials: int = 200,
|
| 298 |
-
standardize_X: bool = True,
|
| 299 |
-
):
|
| 300 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 301 |
-
|
| 302 |
-
data = load_split_data(dataset_path)
|
| 303 |
-
print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
|
| 304 |
-
|
| 305 |
-
# Standardize features (SVR + ElasticNet)
|
| 306 |
-
if standardize_X:
|
| 307 |
-
scaler = StandardScaler()
|
| 308 |
-
data.X_train = scaler.fit_transform(data.X_train).astype(np.float32)
|
| 309 |
-
data.X_val = scaler.transform(data.X_val).astype(np.float32)
|
| 310 |
-
joblib.dump(scaler, os.path.join(out_dir, "scaler.joblib"))
|
| 311 |
-
print("[Preprocess] Saved StandardScaler -> scaler.joblib")
|
| 312 |
-
|
| 313 |
-
study = optuna.create_study(
|
| 314 |
-
direction="maximize",
|
| 315 |
-
pruner=optuna.pruners.MedianPruner()
|
| 316 |
-
)
|
| 317 |
-
study.optimize(make_objective(model_name, data), n_trials=n_trials)
|
| 318 |
-
|
| 319 |
-
trials_df = study.trials_dataframe()
|
| 320 |
-
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 321 |
-
|
| 322 |
-
best = study.best_trial
|
| 323 |
-
best_params = dict(best.params)
|
| 324 |
-
|
| 325 |
-
best_rho = float(best.user_attrs.get("spearman_rho", best.value))
|
| 326 |
-
best_rmse = float(best.user_attrs.get("rmse", np.nan))
|
| 327 |
-
best_mae = float(best.user_attrs.get("mae", np.nan))
|
| 328 |
-
best_r2 = float(best.user_attrs.get("r2", np.nan))
|
| 329 |
-
|
| 330 |
-
# Refit best model on train
|
| 331 |
-
if model_name == "xgb_reg":
|
| 332 |
-
params = {
|
| 333 |
-
"objective": "reg:squarederror",
|
| 334 |
-
"eval_metric": "rmse",
|
| 335 |
-
"lambda": best_params["lambda"],
|
| 336 |
-
"alpha": best_params["alpha"],
|
| 337 |
-
"gamma": best_params["gamma"],
|
| 338 |
-
"max_depth": best_params["max_depth"],
|
| 339 |
-
"min_child_weight": best_params["min_child_weight"],
|
| 340 |
-
"subsample": best_params["subsample"],
|
| 341 |
-
"colsample_bytree": best_params["colsample_bytree"],
|
| 342 |
-
"learning_rate": best_params["learning_rate"],
|
| 343 |
-
"tree_method": "hist",
|
| 344 |
-
"device": "cuda",
|
| 345 |
-
"num_boost_round": best_params["num_boost_round"],
|
| 346 |
-
"early_stopping_rounds": best_params["early_stopping_rounds"],
|
| 347 |
-
}
|
| 348 |
-
model, p_tr, p_va = train_xgb_reg(
|
| 349 |
-
data.X_train, data.y_train, data.X_val, data.y_val, params
|
| 350 |
-
)
|
| 351 |
-
model_path = os.path.join(out_dir, "best_model.json")
|
| 352 |
-
model.save_model(model_path)
|
| 353 |
-
|
| 354 |
-
elif model_name == "enet_gpu":
|
| 355 |
-
model, p_tr, p_va = train_cuml_elasticnet_reg(
|
| 356 |
-
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 357 |
-
)
|
| 358 |
-
model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
|
| 359 |
-
joblib.dump(model, model_path)
|
| 360 |
-
|
| 361 |
-
elif model_name == "svr":
|
| 362 |
-
model, p_tr, p_va = train_svr_reg(
|
| 363 |
-
data.X_train, data.y_train, data.X_val, data.y_val, best_params
|
| 364 |
-
)
|
| 365 |
-
model_path = os.path.join(out_dir, "best_model_svr.joblib")
|
| 366 |
-
joblib.dump(model, model_path)
|
| 367 |
-
|
| 368 |
-
else:
|
| 369 |
-
raise ValueError(model_name)
|
| 370 |
-
|
| 371 |
-
save_predictions_csv(out_dir, "train", data.y_train, p_tr, data.seq_train)
|
| 372 |
-
save_predictions_csv(out_dir, "val", data.y_val, p_va, data.seq_val)
|
| 373 |
-
|
| 374 |
-
plot_regression_diagnostics(out_dir, data.y_val, p_va)
|
| 375 |
-
|
| 376 |
-
summary = [
|
| 377 |
-
"=" * 72,
|
| 378 |
-
f"MODEL: {model_name}",
|
| 379 |
-
f"Best trial: {best.number}",
|
| 380 |
-
f"Val Spearman rho (objective): {best_rho:.6f}",
|
| 381 |
-
f"Val RMSE: {best_rmse:.6f}",
|
| 382 |
-
f"Val MAE: {best_mae:.6f}",
|
| 383 |
-
f"Val R2: {best_r2:.6f}",
|
| 384 |
-
f"Model saved to: {model_path}",
|
| 385 |
-
"Best params:",
|
| 386 |
-
json.dumps(best_params, indent=2),
|
| 387 |
-
"=" * 72,
|
| 388 |
-
]
|
| 389 |
-
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 390 |
-
f.write("\n".join(summary))
|
| 391 |
-
print("\n".join(summary))
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
if __name__ == "__main__":
|
| 395 |
-
import argparse
|
| 396 |
-
parser = argparse.ArgumentParser()
|
| 397 |
-
parser.add_argument("--dataset_path", type=str, required=True)
|
| 398 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 399 |
-
parser.add_argument("--model", type=str, choices=["xgb_reg", "enet_gpu", "svr"], required=True)
|
| 400 |
-
parser.add_argument("--n_trials", type=int, default=200)
|
| 401 |
-
parser.add_argument("--no_standardize", action="store_true", help="Disable StandardScaler on X")
|
| 402 |
-
args = parser.parse_args()
|
| 403 |
-
|
| 404 |
-
run_optuna_and_refit(
|
| 405 |
-
dataset_path=args.dataset_path,
|
| 406 |
-
out_dir=args.out_dir,
|
| 407 |
-
model_name=args.model,
|
| 408 |
-
n_trials=args.n_trials,
|
| 409 |
-
standardize_X=(not args.no_standardize),
|
| 410 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py
DELETED
|
@@ -1,426 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
from torch.utils.data import DataLoader
|
| 4 |
-
from datasets import load_from_disk, DatasetDict
|
| 5 |
-
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import optuna
|
| 8 |
-
import os
|
| 9 |
-
from typing import Dict, Any, Tuple, Optional
|
| 10 |
-
import matplotlib.pyplot as plt
|
| 11 |
-
from sklearn.metrics import (
|
| 12 |
-
f1_score, roc_auc_score, average_precision_score,
|
| 13 |
-
precision_recall_curve, roc_curve
|
| 14 |
-
)
|
| 15 |
-
import json
|
| 16 |
-
import joblib
|
| 17 |
-
import pandas as pd
|
| 18 |
-
import time
|
| 19 |
-
|
| 20 |
-
def infer_in_dim_from_unpooled_ds(ds) -> int:
|
| 21 |
-
ex = ds[0]
|
| 22 |
-
# ex["embedding"] is (L, H) list/array
|
| 23 |
-
return int(len(ex["embedding"][0]))
|
| 24 |
-
|
| 25 |
-
def load_split(dataset_path):
|
| 26 |
-
ds = load_from_disk(dataset_path)
|
| 27 |
-
|
| 28 |
-
if isinstance(ds, DatasetDict):
|
| 29 |
-
return ds["train"], ds["val"]
|
| 30 |
-
|
| 31 |
-
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 32 |
-
|
| 33 |
-
def collate_unpooled(batch):
|
| 34 |
-
# batch: list of dicts
|
| 35 |
-
lengths = [int(x["length"]) for x in batch]
|
| 36 |
-
Lmax = max(lengths)
|
| 37 |
-
H = len(batch[0]["embedding"][0]) # 1280
|
| 38 |
-
|
| 39 |
-
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 40 |
-
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 41 |
-
y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
|
| 42 |
-
|
| 43 |
-
for i, x in enumerate(batch):
|
| 44 |
-
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H)
|
| 45 |
-
L = emb.shape[0]
|
| 46 |
-
X[i, :L] = emb
|
| 47 |
-
if "attention_mask" in x:
|
| 48 |
-
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 49 |
-
M[i, :L] = m[:L]
|
| 50 |
-
else:
|
| 51 |
-
M[i, :L] = True
|
| 52 |
-
|
| 53 |
-
return X, M, y
|
| 54 |
-
|
| 55 |
-
# ======================== Helper functions =========================================
|
| 56 |
-
def save_predictions_csv(
|
| 57 |
-
out_dir: str,
|
| 58 |
-
split_name: str,
|
| 59 |
-
y_true: np.ndarray,
|
| 60 |
-
y_prob: np.ndarray,
|
| 61 |
-
threshold: float,
|
| 62 |
-
sequences: Optional[np.ndarray] = None,
|
| 63 |
-
):
|
| 64 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 65 |
-
df = pd.DataFrame({
|
| 66 |
-
"y_true": y_true.astype(int),
|
| 67 |
-
"y_prob": y_prob.astype(float),
|
| 68 |
-
"y_pred": (y_prob >= threshold).astype(int),
|
| 69 |
-
})
|
| 70 |
-
if sequences is not None:
|
| 71 |
-
df.insert(0, "sequence", sequences)
|
| 72 |
-
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
|
| 76 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 77 |
-
|
| 78 |
-
# PR
|
| 79 |
-
precision, recall, _ = precision_recall_curve(y_true, y_prob)
|
| 80 |
-
plt.figure()
|
| 81 |
-
plt.plot(recall, precision)
|
| 82 |
-
plt.xlabel("Recall")
|
| 83 |
-
plt.ylabel("Precision")
|
| 84 |
-
plt.title("Precision-Recall Curve")
|
| 85 |
-
plt.tight_layout()
|
| 86 |
-
plt.savefig(os.path.join(out_dir, "pr_curve.png"))
|
| 87 |
-
plt.close()
|
| 88 |
-
|
| 89 |
-
# ROC
|
| 90 |
-
fpr, tpr, _ = roc_curve(y_true, y_prob)
|
| 91 |
-
plt.figure()
|
| 92 |
-
plt.plot(fpr, tpr)
|
| 93 |
-
plt.xlabel("False Positive Rate")
|
| 94 |
-
plt.ylabel("True Positive Rate")
|
| 95 |
-
plt.title("ROC Curve")
|
| 96 |
-
plt.tight_layout()
|
| 97 |
-
plt.savefig(os.path.join(out_dir, "roc_curve.png"))
|
| 98 |
-
plt.close()
|
| 99 |
-
|
| 100 |
-
# ======================== Shared OPTUNA training scheme =========================================
|
| 101 |
-
def best_f1_threshold(y_true, y_prob):
|
| 102 |
-
p, r, thr = precision_recall_curve(y_true, y_prob)
|
| 103 |
-
f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12)
|
| 104 |
-
i = int(np.nanargmax(f1s))
|
| 105 |
-
return float(thr[i]), float(f1s[i])
|
| 106 |
-
|
| 107 |
-
@torch.no_grad()
|
| 108 |
-
def eval_probs(model, loader, device):
|
| 109 |
-
model.eval()
|
| 110 |
-
ys, ps = [], []
|
| 111 |
-
for X, M, y in loader:
|
| 112 |
-
X, M = X.to(device), M.to(device)
|
| 113 |
-
logits = model(X, M)
|
| 114 |
-
prob = torch.sigmoid(logits).detach().cpu().numpy()
|
| 115 |
-
ys.append(y.numpy())
|
| 116 |
-
ps.append(prob)
|
| 117 |
-
return np.concatenate(ys), np.concatenate(ps)
|
| 118 |
-
|
| 119 |
-
def train_one_epoch(model, loader, optim, criterion, device):
|
| 120 |
-
model.train()
|
| 121 |
-
for X, M, y in loader:
|
| 122 |
-
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 123 |
-
optim.zero_grad(set_to_none=True)
|
| 124 |
-
logits = model(X, M)
|
| 125 |
-
loss = criterion(logits, y)
|
| 126 |
-
loss.backward()
|
| 127 |
-
optim.step()
|
| 128 |
-
|
| 129 |
-
# ======================== MLP =========================================
|
| 130 |
-
# Still need mean pooling along lengths
|
| 131 |
-
class MaskedMeanPool(nn.Module):
|
| 132 |
-
def forward(self, X, M): # X: (B,L,H), M: (B,L)
|
| 133 |
-
Mf = M.unsqueeze(-1).float()
|
| 134 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 135 |
-
return (X * Mf).sum(dim=1) / denom # (B,H)
|
| 136 |
-
|
| 137 |
-
class MLPClassifier(nn.Module):
|
| 138 |
-
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 139 |
-
super().__init__()
|
| 140 |
-
self.pool = MaskedMeanPool()
|
| 141 |
-
self.net = nn.Sequential(
|
| 142 |
-
nn.Linear(in_dim, hidden),
|
| 143 |
-
nn.GELU(),
|
| 144 |
-
nn.Dropout(dropout),
|
| 145 |
-
nn.Linear(hidden, 1),
|
| 146 |
-
)
|
| 147 |
-
def forward(self, X, M):
|
| 148 |
-
z = self.pool(X, M)
|
| 149 |
-
return self.net(z).squeeze(-1) # logits
|
| 150 |
-
|
| 151 |
-
# ======================== CNN =========================================
|
| 152 |
-
# Treat 1280 dimensions as channels
|
| 153 |
-
class CNNClassifier(nn.Module):
|
| 154 |
-
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 155 |
-
super().__init__()
|
| 156 |
-
blocks = []
|
| 157 |
-
ch = in_ch
|
| 158 |
-
for _ in range(layers):
|
| 159 |
-
blocks += [
|
| 160 |
-
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
|
| 161 |
-
nn.GELU(),
|
| 162 |
-
nn.Dropout(dropout),
|
| 163 |
-
]
|
| 164 |
-
ch = c
|
| 165 |
-
self.conv = nn.Sequential(*blocks)
|
| 166 |
-
self.head = nn.Linear(c, 1)
|
| 167 |
-
|
| 168 |
-
def forward(self, X, M):
|
| 169 |
-
# X: (B,L,H) -> (B,H,L)
|
| 170 |
-
Xc = X.transpose(1, 2)
|
| 171 |
-
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
|
| 172 |
-
|
| 173 |
-
# masked mean pool over L
|
| 174 |
-
Mf = M.unsqueeze(-1).float()
|
| 175 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 176 |
-
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
|
| 177 |
-
return self.head(pooled).squeeze(-1)
|
| 178 |
-
|
| 179 |
-
# ========================== Transformer ====================================
|
| 180 |
-
class TransformerClassifier(nn.Module):
|
| 181 |
-
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 182 |
-
super().__init__()
|
| 183 |
-
self.proj = nn.Linear(in_dim, d_model)
|
| 184 |
-
enc_layer = nn.TransformerEncoderLayer(
|
| 185 |
-
d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 186 |
-
dropout=dropout, batch_first=True, activation="gelu"
|
| 187 |
-
)
|
| 188 |
-
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 189 |
-
self.head = nn.Linear(d_model, 1)
|
| 190 |
-
|
| 191 |
-
def forward(self, X, M):
|
| 192 |
-
# src_key_padding_mask: True = pad positions
|
| 193 |
-
pad_mask = ~M
|
| 194 |
-
Z = self.proj(X) # (B,L,d)
|
| 195 |
-
Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d)
|
| 196 |
-
|
| 197 |
-
Mf = M.unsqueeze(-1).float()
|
| 198 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 199 |
-
pooled = (Z * Mf).sum(dim=1) / denom
|
| 200 |
-
return self.head(pooled).squeeze(-1)
|
| 201 |
-
|
| 202 |
-
# ========================== OPTUNA ====================================
|
| 203 |
-
|
| 204 |
-
def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"):
|
| 205 |
-
# hyperparams shared
|
| 206 |
-
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 207 |
-
wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True)
|
| 208 |
-
dropout = trial.suggest_float("dropout", 0.0, 0.5)
|
| 209 |
-
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 210 |
-
|
| 211 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 212 |
-
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 213 |
-
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 214 |
-
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 215 |
-
|
| 216 |
-
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
|
| 217 |
-
|
| 218 |
-
if model_name == "mlp":
|
| 219 |
-
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
|
| 220 |
-
model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout)
|
| 221 |
-
elif model_name == "cnn":
|
| 222 |
-
c = trial.suggest_categorical("channels", [128, 256, 512])
|
| 223 |
-
k = trial.suggest_categorical("kernel", [3, 5, 7])
|
| 224 |
-
layers = trial.suggest_int("layers", 1, 4)
|
| 225 |
-
model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
|
| 226 |
-
elif model_name == "transformer":
|
| 227 |
-
d = trial.suggest_categorical("d_model", [128, 256, 384])
|
| 228 |
-
nhead = trial.suggest_categorical("nhead", [4, 8])
|
| 229 |
-
layers = trial.suggest_int("layers", 1, 4)
|
| 230 |
-
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
|
| 231 |
-
model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
|
| 232 |
-
else:
|
| 233 |
-
raise ValueError(model_name)
|
| 234 |
-
|
| 235 |
-
model = model.to(device)
|
| 236 |
-
|
| 237 |
-
# class imbalance handling
|
| 238 |
-
ytr = np.asarray(train_ds["label"], dtype=np.int64)
|
| 239 |
-
pos = ytr.sum()
|
| 240 |
-
neg = len(ytr) - pos
|
| 241 |
-
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
|
| 242 |
-
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 243 |
-
|
| 244 |
-
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 245 |
-
|
| 246 |
-
best_f1 = -1.0
|
| 247 |
-
patience = 8
|
| 248 |
-
bad = 0
|
| 249 |
-
|
| 250 |
-
for epoch in range(1, 51):
|
| 251 |
-
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 252 |
-
|
| 253 |
-
y_true, y_prob = eval_probs(model, val_loader, device)
|
| 254 |
-
auc = roc_auc_score(y_true, y_prob)
|
| 255 |
-
|
| 256 |
-
thr, f1 = best_f1_threshold(y_true, y_prob)
|
| 257 |
-
|
| 258 |
-
trial.set_user_attr("val_auc", float(auc))
|
| 259 |
-
trial.set_user_attr("val_f1", float(f1))
|
| 260 |
-
trial.set_user_attr("val_thr", float(thr))
|
| 261 |
-
|
| 262 |
-
# prune
|
| 263 |
-
trial.report(f1, epoch)
|
| 264 |
-
if trial.should_prune():
|
| 265 |
-
raise optuna.TrialPruned()
|
| 266 |
-
|
| 267 |
-
if f1 > best_f1 + 1e-4:
|
| 268 |
-
best_f1 = f1
|
| 269 |
-
bad = 0
|
| 270 |
-
else:
|
| 271 |
-
bad += 1
|
| 272 |
-
if bad >= patience:
|
| 273 |
-
break
|
| 274 |
-
|
| 275 |
-
return best_f1
|
| 276 |
-
|
| 277 |
-
def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"):
|
| 278 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 279 |
-
|
| 280 |
-
train_ds, val_ds = load_split(dataset_path)
|
| 281 |
-
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 282 |
-
|
| 283 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 284 |
-
study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials)
|
| 285 |
-
|
| 286 |
-
trials_df = study.trials_dataframe()
|
| 287 |
-
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 288 |
-
|
| 289 |
-
best = study.best_trial
|
| 290 |
-
best_params = dict(best.params)
|
| 291 |
-
best_f1_optuna = float(best.value)
|
| 292 |
-
best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan))
|
| 293 |
-
best_thr = float(best.user_attrs.get("val_thr", 0.5))
|
| 294 |
-
|
| 295 |
-
in_dim = infer_in_dim_from_unpooled_ds(train_ds)
|
| 296 |
-
|
| 297 |
-
# --- Refit best model ---
|
| 298 |
-
batch_size = int(best_params.get("batch_size", 32))
|
| 299 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 300 |
-
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 301 |
-
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 302 |
-
collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
|
| 303 |
-
|
| 304 |
-
# Rebuild
|
| 305 |
-
dropout = float(best_params.get("dropout", 0.1))
|
| 306 |
-
if model_name == "mlp":
|
| 307 |
-
model = MLPClassifier(
|
| 308 |
-
in_dim=in_dim,
|
| 309 |
-
hidden=int(best_params["hidden"]),
|
| 310 |
-
dropout=dropout,
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
elif model_name == "cnn":
|
| 314 |
-
model = CNNClassifier(
|
| 315 |
-
in_ch=in_dim,
|
| 316 |
-
c=int(best_params["channels"]),
|
| 317 |
-
k=int(best_params["kernel"]),
|
| 318 |
-
layers=int(best_params["layers"]),
|
| 319 |
-
dropout=dropout,
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
elif model_name == "transformer":
|
| 323 |
-
model = TransformerClassifier(
|
| 324 |
-
in_dim=in_dim,
|
| 325 |
-
d_model=int(best_params["d_model"]),
|
| 326 |
-
nhead=int(best_params["nhead"]),
|
| 327 |
-
layers=int(best_params["layers"]),
|
| 328 |
-
ff=int(best_params["ff"]),
|
| 329 |
-
dropout=dropout,
|
| 330 |
-
)
|
| 331 |
-
else:
|
| 332 |
-
raise ValueError(model_name)
|
| 333 |
-
|
| 334 |
-
model = model.to(device)
|
| 335 |
-
|
| 336 |
-
# loss + optimizer
|
| 337 |
-
ytr = np.asarray(train_ds["label"], dtype=np.int64)
|
| 338 |
-
pos = ytr.sum()
|
| 339 |
-
neg = len(ytr) - pos
|
| 340 |
-
pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
|
| 341 |
-
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 342 |
-
|
| 343 |
-
lr = float(best_params["lr"])
|
| 344 |
-
wd = float(best_params["weight_decay"])
|
| 345 |
-
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 346 |
-
|
| 347 |
-
# train longer with early stopping on AUC
|
| 348 |
-
best_f1_seen, bad, patience = -1.0, 0, 12
|
| 349 |
-
best_state = None
|
| 350 |
-
best_thr_seen = 0.5
|
| 351 |
-
best_auc_seen = -1.0
|
| 352 |
-
|
| 353 |
-
for epoch in range(1, 151):
|
| 354 |
-
train_one_epoch(model, train_loader, optim, criterion, device)
|
| 355 |
-
|
| 356 |
-
y_true, y_prob = eval_probs(model, val_loader, device)
|
| 357 |
-
auc = roc_auc_score(y_true, y_prob)
|
| 358 |
-
thr, f1 = best_f1_threshold(y_true, y_prob)
|
| 359 |
-
|
| 360 |
-
if f1 > best_f1_seen + 1e-4:
|
| 361 |
-
best_f1_seen = f1
|
| 362 |
-
best_thr_seen = thr
|
| 363 |
-
best_auc_seen = auc
|
| 364 |
-
bad = 0
|
| 365 |
-
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 366 |
-
else:
|
| 367 |
-
bad += 1
|
| 368 |
-
if bad >= patience:
|
| 369 |
-
break
|
| 370 |
-
|
| 371 |
-
if best_state is not None:
|
| 372 |
-
model.load_state_dict(best_state)
|
| 373 |
-
|
| 374 |
-
# final preds + threshold picked on val
|
| 375 |
-
y_true_val, y_prob_val = eval_probs(model, val_loader, device)
|
| 376 |
-
best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
|
| 377 |
-
|
| 378 |
-
# save model
|
| 379 |
-
model_path = os.path.join(out_dir, "best_model.pt")
|
| 380 |
-
torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path)
|
| 381 |
-
|
| 382 |
-
# train preds
|
| 383 |
-
y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False,
|
| 384 |
-
collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device)
|
| 385 |
-
|
| 386 |
-
save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final,
|
| 387 |
-
sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None)
|
| 388 |
-
save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final,
|
| 389 |
-
sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None)
|
| 390 |
-
|
| 391 |
-
plot_curves(out_dir, y_true_val, y_prob_val)
|
| 392 |
-
|
| 393 |
-
summary = [
|
| 394 |
-
"=" * 72,
|
| 395 |
-
f"MODEL: {model_name}",
|
| 396 |
-
|
| 397 |
-
# Optuna results (objective = F1)
|
| 398 |
-
f"Best Optuna F1 (objective): {best_f1_optuna:.4f}",
|
| 399 |
-
f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}",
|
| 400 |
-
f"Best Optuna threshold (val): {best_thr:.4f}",
|
| 401 |
-
|
| 402 |
-
# Refit results
|
| 403 |
-
f"Refit best AUC (val): {best_auc_seen:.4f}",
|
| 404 |
-
f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}",
|
| 405 |
-
|
| 406 |
-
"Best params:",
|
| 407 |
-
json.dumps(best_params, indent=2),
|
| 408 |
-
f"Saved model: {model_path}",
|
| 409 |
-
"=" * 72,
|
| 410 |
-
]
|
| 411 |
-
|
| 412 |
-
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 413 |
-
f.write("\n".join(summary))
|
| 414 |
-
print("\n".join(summary))
|
| 415 |
-
|
| 416 |
-
if __name__ == "__main__":
|
| 417 |
-
import argparse
|
| 418 |
-
parser = argparse.ArgumentParser()
|
| 419 |
-
parser.add_argument("--dataset_path", type=str, required=True)
|
| 420 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 421 |
-
parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
|
| 422 |
-
parser.add_argument("--n_trials", type=int, default=50)
|
| 423 |
-
args = parser.parse_args()
|
| 424 |
-
|
| 425 |
-
if args.model in ["mlp", "cnn", "transformer"]:
|
| 426 |
-
run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py
DELETED
|
@@ -1,420 +0,0 @@
|
|
| 1 |
-
import os, json, time
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
from torch.utils.data import DataLoader
|
| 9 |
-
from datasets import load_from_disk, DatasetDict
|
| 10 |
-
import optuna
|
| 11 |
-
from dataclasses import dataclass
|
| 12 |
-
from typing import Dict, Any, Tuple, Optional
|
| 13 |
-
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 14 |
-
from scipy.stats import spearmanr
|
| 15 |
-
from torch.cuda.amp import autocast
|
| 16 |
-
from torch.cuda.amp import autocast, GradScaler
|
| 17 |
-
scaler = GradScaler(enabled=torch.cuda.is_available())
|
| 18 |
-
from lightning.pytorch import seed_everything
|
| 19 |
-
seed_everything(1986)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def load_split(dataset_path):
|
| 23 |
-
ds = load_from_disk(dataset_path)
|
| 24 |
-
if isinstance(ds, DatasetDict):
|
| 25 |
-
return ds["train"], ds["val"]
|
| 26 |
-
raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
|
| 27 |
-
|
| 28 |
-
def collate_unpooled_reg(batch):
|
| 29 |
-
lengths = [int(x["length"]) for x in batch]
|
| 30 |
-
Lmax = max(lengths)
|
| 31 |
-
H = len(batch[0]["embedding"][0])
|
| 32 |
-
|
| 33 |
-
X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
|
| 34 |
-
M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
|
| 35 |
-
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
|
| 36 |
-
|
| 37 |
-
for i, x in enumerate(batch):
|
| 38 |
-
emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H)
|
| 39 |
-
L = emb.shape[0]
|
| 40 |
-
X[i, :L] = emb
|
| 41 |
-
if "attention_mask" in x:
|
| 42 |
-
m = torch.tensor(x["attention_mask"], dtype=torch.bool)
|
| 43 |
-
M[i, :L] = m[:L]
|
| 44 |
-
else:
|
| 45 |
-
M[i, :L] = True
|
| 46 |
-
return X, M, y
|
| 47 |
-
|
| 48 |
-
def infer_in_dim(ds) -> int:
|
| 49 |
-
ex = ds[0]
|
| 50 |
-
return int(len(ex["embedding"][0]))
|
| 51 |
-
|
| 52 |
-
# ============================
|
| 53 |
-
# Metrics
|
| 54 |
-
# ============================
|
| 55 |
-
def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 56 |
-
rho = spearmanr(y_true, y_pred).correlation
|
| 57 |
-
if rho is None or np.isnan(rho):
|
| 58 |
-
return 0.0
|
| 59 |
-
return float(rho)
|
| 60 |
-
|
| 61 |
-
def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
|
| 62 |
-
# ---- RMSE ----
|
| 63 |
-
try:
|
| 64 |
-
from sklearn.metrics import root_mean_squared_error
|
| 65 |
-
rmse = root_mean_squared_error(y_true, y_pred)
|
| 66 |
-
except Exception:
|
| 67 |
-
mse = mean_squared_error(y_true, y_pred)
|
| 68 |
-
rmse = float(np.sqrt(mse))
|
| 69 |
-
|
| 70 |
-
mae = float(mean_absolute_error(y_true, y_pred))
|
| 71 |
-
r2 = float(r2_score(y_true, y_pred))
|
| 72 |
-
rho = float(safe_spearmanr(y_true, y_pred))
|
| 73 |
-
return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho}
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
# ============================
|
| 77 |
-
# Models
|
| 78 |
-
# ============================
|
| 79 |
-
class MaskedMeanPool(nn.Module):
|
| 80 |
-
def forward(self, X, M):
|
| 81 |
-
Mf = M.unsqueeze(-1).float()
|
| 82 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 83 |
-
return (X * Mf).sum(dim=1) / denom
|
| 84 |
-
|
| 85 |
-
class MLPRegressor(nn.Module):
|
| 86 |
-
def __init__(self, in_dim, hidden=512, dropout=0.1):
|
| 87 |
-
super().__init__()
|
| 88 |
-
self.pool = MaskedMeanPool()
|
| 89 |
-
self.net = nn.Sequential(
|
| 90 |
-
nn.Linear(in_dim, hidden),
|
| 91 |
-
nn.GELU(),
|
| 92 |
-
nn.Dropout(dropout),
|
| 93 |
-
nn.Linear(hidden, 1),
|
| 94 |
-
)
|
| 95 |
-
def forward(self, X, M):
|
| 96 |
-
z = self.pool(X, M)
|
| 97 |
-
return self.net(z).squeeze(-1) # y_pred
|
| 98 |
-
|
| 99 |
-
class CNNRegressor(nn.Module):
|
| 100 |
-
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
|
| 101 |
-
super().__init__()
|
| 102 |
-
blocks = []
|
| 103 |
-
ch = in_ch
|
| 104 |
-
for _ in range(layers):
|
| 105 |
-
blocks += [
|
| 106 |
-
nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
|
| 107 |
-
nn.GELU(),
|
| 108 |
-
nn.Dropout(dropout),
|
| 109 |
-
]
|
| 110 |
-
ch = c
|
| 111 |
-
self.conv = nn.Sequential(*blocks)
|
| 112 |
-
self.head = nn.Linear(c, 1)
|
| 113 |
-
|
| 114 |
-
def forward(self, X, M):
|
| 115 |
-
Xc = X.transpose(1, 2) # (B,H,L)
|
| 116 |
-
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
|
| 117 |
-
Mf = M.unsqueeze(-1).float()
|
| 118 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 119 |
-
pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
|
| 120 |
-
return self.head(pooled).squeeze(-1)
|
| 121 |
-
|
| 122 |
-
class TransformerRegressor(nn.Module):
|
| 123 |
-
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
|
| 124 |
-
super().__init__()
|
| 125 |
-
self.proj = nn.Linear(in_dim, d_model)
|
| 126 |
-
enc_layer = nn.TransformerEncoderLayer(
|
| 127 |
-
d_model=d_model, nhead=nhead, dim_feedforward=ff,
|
| 128 |
-
dropout=dropout, batch_first=True, activation="gelu"
|
| 129 |
-
)
|
| 130 |
-
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
|
| 131 |
-
self.head = nn.Linear(d_model, 1)
|
| 132 |
-
|
| 133 |
-
def forward(self, X, M):
|
| 134 |
-
pad_mask = ~M
|
| 135 |
-
Z = self.proj(X)
|
| 136 |
-
Z = self.enc(Z, src_key_padding_mask=pad_mask)
|
| 137 |
-
Mf = M.unsqueeze(-1).float()
|
| 138 |
-
denom = Mf.sum(dim=1).clamp(min=1.0)
|
| 139 |
-
pooled = (Z * Mf).sum(dim=1) / denom
|
| 140 |
-
return self.head(pooled).squeeze(-1)
|
| 141 |
-
|
| 142 |
-
# ============================
|
| 143 |
-
# Train / eval
|
| 144 |
-
# ============================
|
| 145 |
-
@torch.no_grad()
|
| 146 |
-
def eval_preds(model, loader, device):
|
| 147 |
-
model.eval()
|
| 148 |
-
ys, ps = [], []
|
| 149 |
-
for X, M, y in loader:
|
| 150 |
-
X, M = X.to(device), M.to(device)
|
| 151 |
-
pred = model(X, M).detach().cpu().numpy()
|
| 152 |
-
ys.append(y.numpy())
|
| 153 |
-
ps.append(pred)
|
| 154 |
-
return np.concatenate(ys), np.concatenate(ps)
|
| 155 |
-
|
| 156 |
-
def train_one_epoch_reg(model, loader, optim, criterion, device):
|
| 157 |
-
model.train()
|
| 158 |
-
for X, M, y in loader:
|
| 159 |
-
X, M, y = X.to(device), M.to(device), y.to(device)
|
| 160 |
-
optim.zero_grad(set_to_none=True)
|
| 161 |
-
with autocast(enabled=torch.cuda.is_available()):
|
| 162 |
-
pred = model(X, M)
|
| 163 |
-
loss = criterion(pred, y)
|
| 164 |
-
scaler.scale(loss).backward()
|
| 165 |
-
scaler.step(optim)
|
| 166 |
-
scaler.update()
|
| 167 |
-
|
| 168 |
-
# ============================
|
| 169 |
-
# Saving + plots
|
| 170 |
-
# ============================
|
| 171 |
-
def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None):
|
| 172 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 173 |
-
df = pd.DataFrame({
|
| 174 |
-
"y_true": y_true.astype(float),
|
| 175 |
-
"y_pred": y_pred.astype(float),
|
| 176 |
-
"residual": (y_true - y_pred).astype(float),
|
| 177 |
-
})
|
| 178 |
-
if sequences is not None:
|
| 179 |
-
df.insert(0, "sequence", sequences)
|
| 180 |
-
df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
|
| 181 |
-
|
| 182 |
-
def plot_regression_diagnostics(out_dir, y_true, y_pred):
|
| 183 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 184 |
-
|
| 185 |
-
plt.figure()
|
| 186 |
-
plt.scatter(y_true, y_pred, s=8, alpha=0.5)
|
| 187 |
-
plt.xlabel("y_true"); plt.ylabel("y_pred")
|
| 188 |
-
plt.title("Predicted vs True")
|
| 189 |
-
plt.tight_layout()
|
| 190 |
-
plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
|
| 191 |
-
plt.close()
|
| 192 |
-
|
| 193 |
-
resid = y_true - y_pred
|
| 194 |
-
plt.figure()
|
| 195 |
-
plt.hist(resid, bins=50)
|
| 196 |
-
plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count")
|
| 197 |
-
plt.title("Residual Histogram")
|
| 198 |
-
plt.tight_layout()
|
| 199 |
-
plt.savefig(os.path.join(out_dir, "residual_hist.png"))
|
| 200 |
-
plt.close()
|
| 201 |
-
|
| 202 |
-
plt.figure()
|
| 203 |
-
plt.scatter(y_pred, resid, s=8, alpha=0.5)
|
| 204 |
-
plt.xlabel("y_pred"); plt.ylabel("residual")
|
| 205 |
-
plt.title("Residuals vs Prediction")
|
| 206 |
-
plt.tight_layout()
|
| 207 |
-
plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
|
| 208 |
-
plt.close()
|
| 209 |
-
|
| 210 |
-
# ============================
|
| 211 |
-
# Optuna objective
|
| 212 |
-
# ============================
|
| 213 |
-
def score_from_metrics(metrics: Dict[str, float], objective: str) -> float:
|
| 214 |
-
if objective == "spearman":
|
| 215 |
-
return metrics["spearman_rho"]
|
| 216 |
-
if objective == "r2":
|
| 217 |
-
return metrics["r2"]
|
| 218 |
-
if objective == "neg_rmse":
|
| 219 |
-
return -metrics["rmse"]
|
| 220 |
-
raise ValueError(f"Unknown objective={objective}")
|
| 221 |
-
|
| 222 |
-
def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"):
|
| 223 |
-
lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
|
| 224 |
-
wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
|
| 225 |
-
dropout = trial.suggest_float("dropout", 0.0, 0.5)
|
| 226 |
-
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 227 |
-
|
| 228 |
-
in_dim = infer_in_dim(train_ds)
|
| 229 |
-
|
| 230 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 231 |
-
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 232 |
-
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 233 |
-
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 234 |
-
|
| 235 |
-
if model_name == "mlp":
|
| 236 |
-
hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
|
| 237 |
-
model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout)
|
| 238 |
-
elif model_name == "cnn":
|
| 239 |
-
c = trial.suggest_categorical("channels", [128, 256, 512])
|
| 240 |
-
k = trial.suggest_categorical("kernel", [3, 5, 7])
|
| 241 |
-
layers = trial.suggest_int("layers", 1, 4)
|
| 242 |
-
model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
|
| 243 |
-
elif model_name == "transformer":
|
| 244 |
-
d = trial.suggest_categorical("d_model", [128, 256, 384])
|
| 245 |
-
nhead = trial.suggest_categorical("nhead", [4, 8])
|
| 246 |
-
layers = trial.suggest_int("layers", 1, 4)
|
| 247 |
-
ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
|
| 248 |
-
model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
|
| 249 |
-
else:
|
| 250 |
-
raise ValueError(model_name)
|
| 251 |
-
|
| 252 |
-
model = model.to(device)
|
| 253 |
-
|
| 254 |
-
loss_name = trial.suggest_categorical("loss", ["mse", "huber"])
|
| 255 |
-
if loss_name == "mse":
|
| 256 |
-
criterion = nn.MSELoss()
|
| 257 |
-
else:
|
| 258 |
-
delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True)
|
| 259 |
-
criterion = nn.HuberLoss(delta=delta)
|
| 260 |
-
|
| 261 |
-
optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
|
| 262 |
-
|
| 263 |
-
best_score = -1e18
|
| 264 |
-
patience = 10
|
| 265 |
-
bad = 0
|
| 266 |
-
|
| 267 |
-
for epoch in range(1, 61):
|
| 268 |
-
train_one_epoch_reg(model, train_loader, optim, criterion, device)
|
| 269 |
-
|
| 270 |
-
y_true, y_pred = eval_preds(model, val_loader, device)
|
| 271 |
-
metrics = eval_regression(y_true, y_pred)
|
| 272 |
-
score = score_from_metrics(metrics, objective)
|
| 273 |
-
|
| 274 |
-
# log attrs
|
| 275 |
-
for k, v in metrics.items():
|
| 276 |
-
trial.set_user_attr(f"val_{k}", float(v))
|
| 277 |
-
|
| 278 |
-
trial.report(score, epoch)
|
| 279 |
-
if trial.should_prune():
|
| 280 |
-
raise optuna.TrialPruned()
|
| 281 |
-
|
| 282 |
-
if score > best_score + 1e-6:
|
| 283 |
-
best_score = score
|
| 284 |
-
bad = 0
|
| 285 |
-
else:
|
| 286 |
-
bad += 1
|
| 287 |
-
if bad >= patience:
|
| 288 |
-
break
|
| 289 |
-
|
| 290 |
-
return float(best_score)
|
| 291 |
-
|
| 292 |
-
# ============================
|
| 293 |
-
# Main runner
|
| 294 |
-
# ============================
|
| 295 |
-
def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0",
|
| 296 |
-
objective="spearman"):
|
| 297 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 298 |
-
|
| 299 |
-
train_ds, val_ds = load_split(dataset_path)
|
| 300 |
-
print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
|
| 301 |
-
|
| 302 |
-
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
|
| 303 |
-
study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective),
|
| 304 |
-
n_trials=n_trials)
|
| 305 |
-
|
| 306 |
-
trials_df = study.trials_dataframe()
|
| 307 |
-
trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
|
| 308 |
-
|
| 309 |
-
best = study.best_trial
|
| 310 |
-
best_params = dict(best.params)
|
| 311 |
-
|
| 312 |
-
# rebuild model from best params
|
| 313 |
-
in_dim = infer_in_dim(train_ds)
|
| 314 |
-
dropout = float(best_params.get("dropout", 0.1))
|
| 315 |
-
if model_name == "mlp":
|
| 316 |
-
model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout)
|
| 317 |
-
elif model_name == "cnn":
|
| 318 |
-
model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]),
|
| 319 |
-
k=int(best_params["kernel"]), layers=int(best_params["layers"]),
|
| 320 |
-
dropout=dropout)
|
| 321 |
-
elif model_name == "transformer":
|
| 322 |
-
model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]),
|
| 323 |
-
nhead=int(best_params["nhead"]), layers=int(best_params["layers"]),
|
| 324 |
-
ff=int(best_params["ff"]), dropout=dropout)
|
| 325 |
-
else:
|
| 326 |
-
raise ValueError(model_name)
|
| 327 |
-
|
| 328 |
-
model = model.to(device)
|
| 329 |
-
|
| 330 |
-
batch_size = int(best_params.get("batch_size", 32))
|
| 331 |
-
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 332 |
-
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 333 |
-
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
|
| 334 |
-
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
|
| 335 |
-
|
| 336 |
-
# loss
|
| 337 |
-
if best_params.get("loss", "mse") == "mse":
|
| 338 |
-
criterion = nn.MSELoss()
|
| 339 |
-
else:
|
| 340 |
-
criterion = nn.HuberLoss(delta=float(best_params["huber_delta"]))
|
| 341 |
-
|
| 342 |
-
optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]),
|
| 343 |
-
weight_decay=float(best_params["weight_decay"]))
|
| 344 |
-
|
| 345 |
-
# refit longer with early stopping on the SAME objective
|
| 346 |
-
best_score, bad, patience = -1e18, 0, 15
|
| 347 |
-
best_state = None
|
| 348 |
-
|
| 349 |
-
for epoch in range(1, 201):
|
| 350 |
-
train_one_epoch_reg(model, train_loader, optim, criterion, device)
|
| 351 |
-
|
| 352 |
-
y_true, y_pred = eval_preds(model, val_loader, device)
|
| 353 |
-
metrics = eval_regression(y_true, y_pred)
|
| 354 |
-
score = score_from_metrics(metrics, objective)
|
| 355 |
-
|
| 356 |
-
if score > best_score + 1e-6:
|
| 357 |
-
best_score = score
|
| 358 |
-
bad = 0
|
| 359 |
-
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
|
| 360 |
-
best_metrics = metrics
|
| 361 |
-
else:
|
| 362 |
-
bad += 1
|
| 363 |
-
if bad >= patience:
|
| 364 |
-
break
|
| 365 |
-
|
| 366 |
-
if best_state is not None:
|
| 367 |
-
model.load_state_dict(best_state)
|
| 368 |
-
|
| 369 |
-
# preds
|
| 370 |
-
y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False,
|
| 371 |
-
collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device)
|
| 372 |
-
y_true_va, y_pred_va = eval_preds(model, val_loader, device)
|
| 373 |
-
|
| 374 |
-
seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None
|
| 375 |
-
seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None
|
| 376 |
-
save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train)
|
| 377 |
-
save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val)
|
| 378 |
-
plot_regression_diagnostics(out_dir, y_true_va, y_pred_va)
|
| 379 |
-
|
| 380 |
-
# save model
|
| 381 |
-
model_path = os.path.join(out_dir, "best_model.pt")
|
| 382 |
-
torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path)
|
| 383 |
-
|
| 384 |
-
summary = [
|
| 385 |
-
"=" * 72,
|
| 386 |
-
f"MODEL: {model_name}",
|
| 387 |
-
f"OPTUNA objective: {objective} (direction=maximize)",
|
| 388 |
-
f"Best trial: {best.number}",
|
| 389 |
-
"Best val metrics:",
|
| 390 |
-
json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2),
|
| 391 |
-
f"Saved model: {model_path}",
|
| 392 |
-
"Best params:",
|
| 393 |
-
json.dumps(best_params, indent=2),
|
| 394 |
-
"=" * 72,
|
| 395 |
-
]
|
| 396 |
-
with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
|
| 397 |
-
f.write("\n".join(summary))
|
| 398 |
-
print("\n".join(summary))
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
if __name__ == "__main__":
|
| 402 |
-
import argparse
|
| 403 |
-
parser = argparse.ArgumentParser()
|
| 404 |
-
parser.add_argument("--dataset_path", type=str, required=True)
|
| 405 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 406 |
-
parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True)
|
| 407 |
-
parser.add_argument("--n_trials", type=int, default=80)
|
| 408 |
-
parser.add_argument("--objective", type=str, default="spearman",
|
| 409 |
-
choices=["spearman","neg_rmse","r2"])
|
| 410 |
-
parser.add_argument("--device", type=str, default="cuda:0")
|
| 411 |
-
args = parser.parse_args()
|
| 412 |
-
|
| 413 |
-
run_optuna_and_refit_nn_reg(
|
| 414 |
-
dataset_path=args.dataset_path,
|
| 415 |
-
out_dir=args.out_dir,
|
| 416 |
-
model_name=args.model,
|
| 417 |
-
n_trials=args.n_trials,
|
| 418 |
-
device=args.device,
|
| 419 |
-
objective=args.objective,
|
| 420 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/val_smiles_pooled.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5410a45a7b65def6cfb94c167b07537abd33b5aac4ecdffe162b7ce4e9bc3d19
|
| 3 |
-
size 36525
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/val_smiles_unpooled.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:cdf71fbb3e7b3b8e8dbfe4ed45b32a2da0049df851f09ee32564825f626cb86c
|
| 3 |
-
size 37187
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/val_wt_pooled.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b194e7b2b97258320323021b3ffe6143133070212a0215ade22fa91b87a3a861
|
| 3 |
-
size 33224
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/val_wt_unpooled.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:051325790047e749fbf1daf7bf25a08178297b0c37acaf9439816d09f2b6c1e3
|
| 3 |
-
size 33826
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:12f956a7bf04ed602c11fd275377afa73f3f0af1982dbe06c607d8ada304b01c
|
| 3 |
-
size 21617397
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3d7ae3d2190b034352a65bda1bce86aa5a96ce3daf74cf10a166f8d9e9af51f0
|
| 3 |
-
size 181183221
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
|
| 3 |
-
size 40700
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/cnn_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0a8a57d44cac3fcd701b550a4eaf9e29910540bfb7580a9b8ee997a7227375d2
|
| 3 |
-
size 13748
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/cnn_unpooled_peptideclm/best_model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9eaaafffe02663f7cfe67fde25cdebd7d4315af69b393b433c4291b700bc5063
|
| 3 |
-
size 16525563
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/cnn_unpooled_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:879d8c6f47b02c1ddd86fbe3982d8b0134167521f9f71d2450957dc3bbbb6bd1
|
| 3 |
-
size 13705
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/enet_gpu_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:fd5c84e9788e2db949c6be785f8539178e72fc6fa6bc703daf9574ad0622e0f1
|
| 3 |
-
size 13649
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/enet_peptideclm/smiles_halflife_best_enet.joblib
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0eb93bcb27436e80bce2a6433cbd7502b90de4962731250972eef08a5d96ce69
|
| 3 |
-
size 22698
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/mlp_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:14b1010a2c0b5d065fa9b82636085806ec7f6f6091c7c2355c6c4717d07fa79b
|
| 3 |
-
size 13724
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/mlp_unpooled_peptideclm/best_model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ef78a5e5c555768f91dc646652a39e367287e851a14e2cf85e4006c9355a8368
|
| 3 |
-
size 2368455
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/mlp_unpooled_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2a037cf36528fc8e04a375c0443577830733636fca83fa9ce44e457e28e4f771
|
| 3 |
-
size 13745
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/svr_gpu_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a0dd42537c1a5589b78451de8645bfc089b8f7f5839808222bb1e9e033d78c66
|
| 3 |
-
size 13746
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/svr_peptideclm/smiles_halflife_best_svr.joblib
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5579f1407fc8dfd1e42b4ea2a6b619dea8b0eff4ce9a4c0869890cbd1b321851
|
| 3 |
-
size 1530479
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:6c6dcaff0d542c3a6bbaf499aba56e5f440c50aa18b55271ee85feb43851fe92
|
| 3 |
-
size 13694
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_unpooled_peptideclm/best_model.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9a19240f8067e68a6e2eaff139f90b6d2f37ab5431197c5496894937c01918f7
|
| 3 |
-
size 931353
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_unpooled_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7b1111c0b57288092ab97b940598b2d3b44c2ff5299fe55a50a2312d8c2e45af
|
| 3 |
-
size 13683
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_log/oof_pred_vs_true.png
DELETED
|
Binary file (16.9 kB)
|
|
|
training_classifiers/half_life/transformer_wt_log/oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8ec7b8dee908ef43ba7633a887a988834e24f952711b906472e1b41b833de714
|
| 3 |
-
size 14100
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_log/oof_residual_hist.png
DELETED
|
Binary file (15.3 kB)
|
|
|
training_classifiers/half_life/transformer_wt_log/oof_residual_vs_pred.png
DELETED
|
Binary file (19.6 kB)
|
|
|
training_classifiers/half_life/transformer_wt_log/optimization_summary.txt
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
========================================================================
|
| 2 |
-
MODEL: transformer
|
| 3 |
-
Dataset: /scratch/pranamlab/tong/data/halflife/halflife_embedding_unpooled
|
| 4 |
-
Target column: log_label
|
| 5 |
-
CV folds: 5
|
| 6 |
-
Optuna objective: spearman (direction=maximize)
|
| 7 |
-
Best trial: 45
|
| 8 |
-
OOF metrics:
|
| 9 |
-
{
|
| 10 |
-
"rmse": 1.0389505624771118,
|
| 11 |
-
"mae": 0.722099244594574,
|
| 12 |
-
"r2": 0.30950748920440674,
|
| 13 |
-
"spearman_rho": 0.5818272477094295
|
| 14 |
-
}
|
| 15 |
-
OOF score (spearman): 0.581827
|
| 16 |
-
Best params:
|
| 17 |
-
{
|
| 18 |
-
"lr": 0.0003603824115240561,
|
| 19 |
-
"weight_decay": 2.9442493502916885e-09,
|
| 20 |
-
"dropout": 0.3851371373367485,
|
| 21 |
-
"batch_size": 16
|
| 22 |
-
}
|
| 23 |
-
Final refit epochs (all data): 15
|
| 24 |
-
Saved final model: /scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer/final_model.pt
|
| 25 |
-
Benchmark (final model on full data):
|
| 26 |
-
{
|
| 27 |
-
"n_samples": 130,
|
| 28 |
-
"wall_time_s": 1.9577592574059963,
|
| 29 |
-
"throughput_samples_per_s": 66.40244427818372,
|
| 30 |
-
"gpu_ms_per_sample": 0.28296443315652703,
|
| 31 |
-
"gpu_peak_mem_MB": 77.5693359375
|
| 32 |
-
}
|
| 33 |
-
========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_log/study_trials.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5726f6b8541c7ca85eda9f0e526db0cb10156eadb2c440fd7a66f7a7d1209175
|
| 3 |
-
size 10154
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_raw/oof_pred_vs_true.png
DELETED
|
Binary file (17.4 kB)
|
|
|
training_classifiers/half_life/transformer_wt_raw/oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c3df70b094757f34fa380a28727877694fcb1ec367bbbef28c63b257ecec74e6
|
| 3 |
-
size 13516
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_raw/oof_residual_hist.png
DELETED
|
Binary file (14.6 kB)
|
|
|
training_classifiers/half_life/transformer_wt_raw/oof_residual_vs_pred.png
DELETED
|
Binary file (18.9 kB)
|
|
|
training_classifiers/half_life/transformer_wt_raw/optimization_summary.txt
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
========================================================================
|
| 2 |
-
MODEL: transformer
|
| 3 |
-
Dataset: /scratch/pranamlab/tong/data/halflife/halflife_embedding_unpooled
|
| 4 |
-
Target column: label
|
| 5 |
-
CV folds: 5
|
| 6 |
-
Optuna objective: spearman (direction=maximize)
|
| 7 |
-
Best trial: 22
|
| 8 |
-
OOF metrics:
|
| 9 |
-
{
|
| 10 |
-
"rmse": 45.00321578979492,
|
| 11 |
-
"mae": 11.352466583251953,
|
| 12 |
-
"r2": 0.02070075273513794,
|
| 13 |
-
"spearman_rho": 0.3759734508605516
|
| 14 |
-
}
|
| 15 |
-
OOF score (spearman): 0.375973
|
| 16 |
-
Best params:
|
| 17 |
-
{
|
| 18 |
-
"lr": 0.00019977882554167927,
|
| 19 |
-
"weight_decay": 1.102955470301081e-07,
|
| 20 |
-
"dropout": 1.2707176359392082e-05,
|
| 21 |
-
"batch_size": 16
|
| 22 |
-
}
|
| 23 |
-
Final refit epochs (all data): 14
|
| 24 |
-
Saved final model: /scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer_raw/final_model.pt
|
| 25 |
-
Benchmark (final model on full data):
|
| 26 |
-
{
|
| 27 |
-
"n_samples": 130,
|
| 28 |
-
"wall_time_s": 1.6299039730802178,
|
| 29 |
-
"throughput_samples_per_s": 79.7593000244818,
|
| 30 |
-
"gpu_ms_per_sample": 0.23774326214423547,
|
| 31 |
-
"gpu_peak_mem_MB": 77.5693359375
|
| 32 |
-
}
|
| 33 |
-
========================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/transformer_wt_raw/study_trials.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7c38352854ad4142c02a4bcb33caee9fe8fa22fca86dcb8c17c05c24f3fa5bca
|
| 3 |
-
size 10895
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/xgb_smiles/cv_oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e2df43f71aad2cf791b49daa0b3353f524d5a3f3e132fecf1251e96242639ca5
|
| 3 |
-
size 13675
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/xgb_wt_log/oof_pred_vs_true.png
DELETED
|
Binary file (16.5 kB)
|
|
|
training_classifiers/half_life/xgb_wt_log/oof_predictions.csv
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2293b5752ef6bdc7b1ec8ae2f56e11ccbf32aee024d777c86c1e63f390fa89cf
|
| 3 |
-
size 14805
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/half_life/xgb_wt_log/oof_residual_hist.png
DELETED
|
Binary file (15.1 kB)
|
|
|
training_classifiers/half_life/xgb_wt_log/oof_residual_vs_pred.png
DELETED
|
Binary file (19.1 kB)
|
|
|
training_classifiers/half_life/xgb_wt_log/optimization_summary.txt
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"model": "xgb_reg",
|
| 3 |
-
"dataset_path": "/scratch/pranamlab/tong/data/halflife/halflife_embedding",
|
| 4 |
-
"target_col": "log_label",
|
| 5 |
-
"n_folds": 5,
|
| 6 |
-
"best_trial_number": 20,
|
| 7 |
-
"best_objective_cv_spearman": 0.5879000126060311,
|
| 8 |
-
"oof_metrics": {
|
| 9 |
-
"rmse": 1.0810768604278564,
|
| 10 |
-
"mae": 0.7866008281707764,
|
| 11 |
-
"r2": 0.2524225115776062,
|
| 12 |
-
"spearman_rho": 0.557870619380726
|
| 13 |
-
},
|
| 14 |
-
"model_path": "/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb_log/best_model.json",
|
| 15 |
-
"best_params": {
|
| 16 |
-
"lambda": 0.0006291983667746282,
|
| 17 |
-
"alpha": 0.0820082035401697,
|
| 18 |
-
"gamma": 1.2243543209914751,
|
| 19 |
-
"max_depth": 3,
|
| 20 |
-
"min_child_weight": 1.7773959178614585,
|
| 21 |
-
"subsample": 0.568291807635477,
|
| 22 |
-
"colsample_bytree": 0.8597778117881122,
|
| 23 |
-
"learning_rate": 0.0512590763008084,
|
| 24 |
-
"num_boost_round": 1728,
|
| 25 |
-
"early_stopping_rounds": 121
|
| 26 |
-
}
|
| 27 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|