Joblib
ynuozhang commited on
Commit
3e669de
·
1 Parent(s): b90bb8d

clean up legacy _smiles folders, stray diagnostic files, and half_life non-best models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py +0 -132
  2. training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py +0 -847
  3. training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py +0 -414
  4. training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash +0 -31
  5. training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py +0 -508
  6. training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py +0 -309
  7. training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt +0 -234
  8. training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py +0 -417
  9. training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py +0 -468
  10. training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py +0 -410
  11. training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py +0 -426
  12. training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py +0 -420
  13. training_classifiers/binding_affinity/val_smiles_pooled.csv +0 -3
  14. training_classifiers/binding_affinity/val_smiles_unpooled.csv +0 -3
  15. training_classifiers/binding_affinity/val_wt_pooled.csv +0 -3
  16. training_classifiers/binding_affinity/val_wt_unpooled.csv +0 -3
  17. training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt +0 -3
  18. training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt +0 -3
  19. training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv +0 -3
  20. training_classifiers/half_life/cnn_smiles/cv_oof_predictions.csv +0 -3
  21. training_classifiers/half_life/cnn_unpooled_peptideclm/best_model.pt +0 -3
  22. training_classifiers/half_life/cnn_unpooled_smiles/cv_oof_predictions.csv +0 -3
  23. training_classifiers/half_life/enet_gpu_smiles/cv_oof_predictions.csv +0 -3
  24. training_classifiers/half_life/enet_peptideclm/smiles_halflife_best_enet.joblib +0 -3
  25. training_classifiers/half_life/mlp_smiles/cv_oof_predictions.csv +0 -3
  26. training_classifiers/half_life/mlp_unpooled_peptideclm/best_model.pt +0 -3
  27. training_classifiers/half_life/mlp_unpooled_smiles/cv_oof_predictions.csv +0 -3
  28. training_classifiers/half_life/svr_gpu_smiles/cv_oof_predictions.csv +0 -3
  29. training_classifiers/half_life/svr_peptideclm/smiles_halflife_best_svr.joblib +0 -3
  30. training_classifiers/half_life/transformer_smiles/cv_oof_predictions.csv +0 -3
  31. training_classifiers/half_life/transformer_unpooled_peptideclm/best_model.pt +0 -3
  32. training_classifiers/half_life/transformer_unpooled_smiles/cv_oof_predictions.csv +0 -3
  33. training_classifiers/half_life/transformer_wt_log/oof_pred_vs_true.png +0 -0
  34. training_classifiers/half_life/transformer_wt_log/oof_predictions.csv +0 -3
  35. training_classifiers/half_life/transformer_wt_log/oof_residual_hist.png +0 -0
  36. training_classifiers/half_life/transformer_wt_log/oof_residual_vs_pred.png +0 -0
  37. training_classifiers/half_life/transformer_wt_log/optimization_summary.txt +0 -33
  38. training_classifiers/half_life/transformer_wt_log/study_trials.csv +0 -3
  39. training_classifiers/half_life/transformer_wt_raw/oof_pred_vs_true.png +0 -0
  40. training_classifiers/half_life/transformer_wt_raw/oof_predictions.csv +0 -3
  41. training_classifiers/half_life/transformer_wt_raw/oof_residual_hist.png +0 -0
  42. training_classifiers/half_life/transformer_wt_raw/oof_residual_vs_pred.png +0 -0
  43. training_classifiers/half_life/transformer_wt_raw/optimization_summary.txt +0 -33
  44. training_classifiers/half_life/transformer_wt_raw/study_trials.csv +0 -3
  45. training_classifiers/half_life/xgb_smiles/cv_oof_predictions.csv +0 -3
  46. training_classifiers/half_life/xgb_wt_log/oof_pred_vs_true.png +0 -0
  47. training_classifiers/half_life/xgb_wt_log/oof_predictions.csv +0 -3
  48. training_classifiers/half_life/xgb_wt_log/oof_residual_hist.png +0 -0
  49. training_classifiers/half_life/xgb_wt_log/oof_residual_vs_pred.png +0 -0
  50. training_classifiers/half_life/xgb_wt_log/optimization_summary.txt +0 -27
training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py DELETED
@@ -1,132 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- extract_iptm_affinity_csv_all.py
4
-
5
- Writes:
6
- - out_dir/wt_iptm_affinity_all.csv
7
- - out_dir/smiles_iptm_affinity_all.csv
8
-
9
- Also prints:
10
- - N
11
- - Spearman rho (affinity vs iptm)
12
- - Pearson r (affinity vs iptm)
13
- """
14
-
15
- from pathlib import Path
16
- import numpy as np
17
- import pandas as pd
18
-
19
-
20
- def corr_stats(df: pd.DataFrame, x: str, y: str):
21
- # pandas handles NaNs if we already dropped them; still be safe
22
- xx = pd.to_numeric(df[x], errors="coerce")
23
- yy = pd.to_numeric(df[y], errors="coerce")
24
- m = xx.notna() & yy.notna()
25
- xx = xx[m]
26
- yy = yy[m]
27
- n = int(m.sum())
28
-
29
- # Pearson r
30
- pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
31
- # Spearman rho
32
- spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
33
-
34
- return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
35
-
36
-
37
- def clean_one(
38
- in_csv: Path,
39
- out_csv: Path,
40
- iptm_col: str,
41
- affinity_col: str = "affinity",
42
- keep_cols=(),
43
- ):
44
- df = pd.read_csv(in_csv)
45
-
46
- # affinity + iptm must exist
47
- need = [affinity_col, iptm_col]
48
- missing = [c for c in need if c not in df.columns]
49
- if missing:
50
- raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
51
-
52
- # coerce numeric
53
- df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
54
- df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
55
-
56
- # drop NaNs in either
57
- df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
58
-
59
- # output cols (standardize names)
60
- out = pd.DataFrame({
61
- "affinity": df[affinity_col].astype(float),
62
- "iptm": df[iptm_col].astype(float),
63
- })
64
-
65
- # keep split if present (handy for coloring later, but not used for corr)
66
- if "split" in df.columns:
67
- out.insert(0, "split", df["split"].astype(str))
68
-
69
- # optional extras for labeling/debug
70
- for c in keep_cols:
71
- if c in df.columns:
72
- out[c] = df[c]
73
-
74
- out_csv.parent.mkdir(parents=True, exist_ok=True)
75
- out.to_csv(out_csv, index=False)
76
-
77
- stats = corr_stats(out, "iptm", "affinity")
78
- print(f"[write] {out_csv}")
79
- print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
80
-
81
- # also save stats json next to csv
82
- stats_path = out_csv.with_suffix(".stats.json")
83
- with open(stats_path, "w") as f:
84
- import json
85
- json.dump(
86
- {
87
- "input_csv": str(in_csv),
88
- "output_csv": str(out_csv),
89
- "iptm_col": iptm_col,
90
- "affinity_col": affinity_col,
91
- **stats,
92
- },
93
- f,
94
- indent=2,
95
- )
96
-
97
-
98
- def main():
99
- import argparse
100
- ap = argparse.ArgumentParser()
101
- ap.add_argument("--wt_meta_csv", type=str, required=True)
102
- ap.add_argument("--smiles_meta_csv", type=str, required=True)
103
- ap.add_argument("--out_dir", type=str, required=True)
104
-
105
- ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
106
- ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
107
- ap.add_argument("--affinity_col", type=str, default="affinity")
108
- args = ap.parse_args()
109
-
110
- out_dir = Path(args.out_dir)
111
-
112
- clean_one(
113
- Path(args.wt_meta_csv),
114
- out_dir / "wt_iptm_affinity_all.csv",
115
- iptm_col=args.wt_iptm_col,
116
- affinity_col=args.affinity_col,
117
- keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
118
- )
119
-
120
- clean_one(
121
- Path(args.smiles_meta_csv),
122
- out_dir / "smiles_iptm_affinity_all.csv",
123
- iptm_col=args.smiles_iptm_col,
124
- affinity_col=args.affinity_col,
125
- keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
126
- )
127
-
128
- print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
129
-
130
-
131
- if __name__ == "__main__":
132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py DELETED
@@ -1,847 +0,0 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import math
4
- from pathlib import Path
5
- import sys
6
- from contextlib import contextmanager
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import torch
11
-
12
- # tqdm is optional; we’ll disable it by default in notebooks
13
- from tqdm import tqdm
14
-
15
- sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
16
- from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
-
18
- from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
19
- from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
20
-
21
- # -------------------------
22
- # Config
23
- # -------------------------
24
- CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
25
-
26
- OUT_ROOT = Path(
27
- "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
28
- )
29
-
30
- # WT (seq) embedding model
31
- WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
32
- WT_MAX_LEN = 1022
33
- WT_BATCH = 32
34
-
35
- # SMILES embedding model + tokenizer
36
- SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
37
- TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
38
- TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
39
- SMI_MAX_LEN = 768
40
- SMI_BATCH = 128
41
-
42
- # Split config
43
- TRAIN_FRAC = 0.80
44
- RANDOM_SEED = 1986
45
- AFFINITY_Q_BINS = 30
46
-
47
- # Columns expected in CSV
48
- COL_SEQ1 = "seq1"
49
- COL_SEQ2 = "seq2"
50
- COL_AFF = "affinity"
51
- COL_F2S = "Fasta2SMILES"
52
- COL_REACT = "REACT_SMILES"
53
- COL_WT_IPTM = "wt_iptm_score"
54
- COL_SMI_IPTM = "smiles_iptm_score"
55
-
56
- # Device
57
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
-
59
- # -------------------------
60
- # Quiet / notebook-safe output controls
61
- # -------------------------
62
- QUIET = True # suppress most prints
63
- USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
64
- LOG_FILE = None # optionally: OUT_ROOT / "build.log"
65
-
66
- def log(msg: str):
67
- if LOG_FILE is not None:
68
- Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
69
- with open(LOG_FILE, "a") as f:
70
- f.write(msg.rstrip() + "\n")
71
- if not QUIET:
72
- print(msg)
73
-
74
- def pbar(it, **kwargs):
75
- return tqdm(it, **kwargs) if USE_TQDM else it
76
-
77
- @contextmanager
78
- def section(title: str):
79
- log(f"\n=== {title} ===")
80
- yield
81
- log(f"=== done: {title} ===")
82
-
83
-
84
- # -------------------------
85
- # Helpers
86
- # -------------------------
87
- def has_uaa(seq: str) -> bool:
88
- return "X" in str(seq).upper()
89
-
90
- def affinity_to_class(a: float) -> str:
91
- # High: >= 9 ; Moderate: [7, 9) ; Low: < 7
92
- if a >= 9.0:
93
- return "High"
94
- elif a >= 7.0:
95
- return "Moderate"
96
- else:
97
- return "Low"
98
-
99
- def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
100
- df = df.copy()
101
-
102
- df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
103
- df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
104
-
105
- df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
106
-
107
- try:
108
- df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
109
- strat_col = "aff_bin"
110
- except Exception:
111
- df["aff_bin"] = df["affinity_class"]
112
- strat_col = "aff_bin"
113
-
114
- rng = np.random.RandomState(RANDOM_SEED)
115
-
116
- df["split"] = None
117
- for _, g in df.groupby(strat_col, observed=True):
118
- idx = g.index.to_numpy()
119
- rng.shuffle(idx)
120
- n_train = int(math.floor(len(idx) * TRAIN_FRAC))
121
- df.loc[idx[:n_train], "split"] = "train"
122
- df.loc[idx[n_train:], "split"] = "val"
123
-
124
- df["split"] = df["split"].fillna("train")
125
- return df
126
-
127
- def _summ(x):
128
- x = np.asarray(x, dtype=float)
129
- x = x[~np.isnan(x)]
130
- if len(x) == 0:
131
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
132
- return {
133
- "n": int(len(x)),
134
- "mean": float(np.mean(x)),
135
- "std": float(np.std(x)),
136
- "p50": float(np.quantile(x, 0.50)),
137
- "p95": float(np.quantile(x, 0.95)),
138
- }
139
-
140
- def _len_stats(seqs):
141
- lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
142
- if len(lens) == 0:
143
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
144
- return {
145
- "n": int(len(lens)),
146
- "mean": float(lens.mean()),
147
- "std": float(lens.std()),
148
- "p50": float(np.quantile(lens, 0.50)),
149
- "p95": float(np.quantile(lens, 0.95)),
150
- }
151
-
152
- def verify_split_before_embedding(
153
- df2: pd.DataFrame,
154
- affinity_col: str,
155
- split_col: str,
156
- seq_col: str,
157
- iptm_col: str,
158
- aff_class_col: str = "affinity_class",
159
- aff_bins: int = 30,
160
- save_report_prefix: str | None = None,
161
- verbose: bool = False,
162
- ):
163
- """
164
- Notebook-safe: by default prints only ONE line via `log()`.
165
- Optionally writes CSV reports (stats + class proportions).
166
- """
167
- df2 = df2.copy()
168
- df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
169
- df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
170
-
171
- assert split_col in df2.columns, f"Missing split col: {split_col}"
172
- assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
173
- assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
174
-
175
- try:
176
- df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
177
- except Exception:
178
- df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
179
-
180
- tr = df2[df2[split_col] == "train"].reset_index(drop=True)
181
- va = df2[df2[split_col] == "val"].reset_index(drop=True)
182
-
183
- tr_aff = _summ(tr[affinity_col].to_numpy())
184
- va_aff = _summ(va[affinity_col].to_numpy())
185
- tr_len = _len_stats(tr[seq_col].tolist())
186
- va_len = _len_stats(va[seq_col].tolist())
187
-
188
- # bin drift
189
- bin_ct = (
190
- df2.groupby([split_col, "_aff_bin_dbg"])
191
- .size()
192
- .groupby(level=0)
193
- .apply(lambda s: s / s.sum())
194
- )
195
- tr_bins = bin_ct.loc["train"]
196
- va_bins = bin_ct.loc["val"]
197
- all_bins = tr_bins.index.union(va_bins.index)
198
- tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
199
- va_bins = va_bins.reindex(all_bins, fill_value=0.0)
200
- max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
201
-
202
- msg = (
203
- f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
204
- f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
205
- f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
206
- f"max_bin_diff={max_bin_diff:.4f}"
207
- )
208
- log(msg)
209
-
210
- if verbose and (not QUIET):
211
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
212
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
213
- print("\n[verbose] affinity_class counts:\n", class_ct)
214
- print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
215
-
216
- if save_report_prefix is not None:
217
- out = Path(save_report_prefix)
218
- out.parent.mkdir(parents=True, exist_ok=True)
219
-
220
- stats_df = pd.DataFrame([
221
- {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
222
- {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
223
- ])
224
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
225
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
226
-
227
- stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
228
- class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
229
-
230
-
231
- # -------------------------
232
- # WT pooled (ESM2)
233
- # -------------------------
234
- @torch.no_grad()
235
- def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
236
- embs = []
237
- for i in pbar(range(0, len(seqs), batch_size)):
238
- batch = seqs[i:i + batch_size]
239
- inputs = tokenizer(
240
- batch,
241
- padding=True,
242
- truncation=True,
243
- max_length=max_length,
244
- return_tensors="pt",
245
- )
246
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
247
- out = model(**inputs)
248
- h = out.last_hidden_state # (B, L, H)
249
-
250
- attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
251
- summed = (h * attn).sum(dim=1) # (B, H)
252
- denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
253
- pooled = (summed / denom).detach().cpu().numpy()
254
- embs.append(pooled)
255
-
256
- return np.vstack(embs)
257
-
258
-
259
- # -------------------------
260
- # WT unpooled (ESM2)
261
- # -------------------------
262
- @torch.no_grad()
263
- def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
264
- tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
265
- tok = {k: v.to(DEVICE) for k, v in tok.items()}
266
- out = model(**tok)
267
- h = out.last_hidden_state[0] # (L, H)
268
- attn = tok["attention_mask"][0].bool() # (L,)
269
- ids = tok["input_ids"][0]
270
-
271
- keep = attn.clone()
272
- if cls_id is not None:
273
- keep &= (ids != cls_id)
274
- if eos_id is not None:
275
- keep &= (ids != eos_id)
276
-
277
- return h[keep].detach().cpu().to(torch.float16).numpy()
278
-
279
- def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
280
- """
281
- Expects df_split to have:
282
- - target_sequence (seq1)
283
- - sequence (binder seq2; WT binder)
284
- - label, affinity_class, COL_AFF, COL_WT_IPTM
285
- Saves a dataset where each row contains BOTH:
286
- - target_embedding (Lt,H), target_attention_mask, target_length
287
- - binder_embedding (Lb,H), binder_attention_mask, binder_length
288
- """
289
- cls_id = tokenizer.cls_token_id
290
- eos_id = tokenizer.eos_token_id
291
- H = model.config.hidden_size
292
-
293
- features = Features({
294
- "target_sequence": Value("string"),
295
- "sequence": Value("string"),
296
- "label": Value("float32"),
297
- "affinity": Value("float32"),
298
- "affinity_class": Value("string"),
299
-
300
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
301
- "target_attention_mask": HFSequence(Value("int8")),
302
- "target_length": Value("int64"),
303
-
304
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
305
- "binder_attention_mask": HFSequence(Value("int8")),
306
- "binder_length": Value("int64"),
307
-
308
- COL_WT_IPTM: Value("float32"),
309
- COL_AFF: Value("float32"),
310
- })
311
-
312
- def gen_rows(df: pd.DataFrame):
313
- for r in pbar(df.itertuples(index=False), total=len(df)):
314
- tgt = str(getattr(r, "target_sequence")).strip()
315
- bnd = str(getattr(r, "sequence")).strip()
316
-
317
- y = float(getattr(r, "label"))
318
- aff = float(getattr(r, COL_AFF))
319
- acls = str(getattr(r, "affinity_class"))
320
-
321
- iptm = getattr(r, COL_WT_IPTM)
322
- iptm = float(iptm) if pd.notna(iptm) else np.nan
323
-
324
- # token embeddings for target + binder (both ESM)
325
- t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
326
- b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
327
-
328
- t_list = t_emb.tolist()
329
- b_list = b_emb.tolist()
330
- Lt = len(t_list)
331
- Lb = len(b_list)
332
-
333
- yield {
334
- "target_sequence": tgt,
335
- "sequence": bnd,
336
- "label": np.float32(y),
337
- "affinity": np.float32(aff),
338
- "affinity_class": acls,
339
-
340
- "target_embedding": t_list,
341
- "target_attention_mask": [1] * Lt,
342
- "target_length": int(Lt),
343
-
344
- "binder_embedding": b_list,
345
- "binder_attention_mask": [1] * Lb,
346
- "binder_length": int(Lb),
347
-
348
- COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
349
- COL_AFF: np.float32(aff),
350
- }
351
-
352
- out_dir.mkdir(parents=True, exist_ok=True)
353
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
354
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
355
- return ds
356
-
357
- def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
358
- smi_tok, smi_roformer):
359
- """
360
- df_split must have:
361
- - target_sequence (seq1)
362
- - sequence (binder smiles string)
363
- - label, affinity_class, COL_AFF, COL_SMI_IPTM
364
- Saves rows with:
365
- target_embedding (Lt,Ht) from ESM
366
- binder_embedding (Lb,Hb) from PeptideCLM
367
- """
368
- cls_id = wt_tokenizer.cls_token_id
369
- eos_id = wt_tokenizer.eos_token_id
370
- Ht = wt_model_unpooled.config.hidden_size
371
-
372
- # Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
373
- # Here: we’ll infer from model config if available.
374
- Hb = getattr(smi_roformer.config, "hidden_size", None)
375
- if Hb is None:
376
- Hb = getattr(smi_roformer.config, "dim", None)
377
- if Hb is None:
378
- raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
379
-
380
- features = Features({
381
- "target_sequence": Value("string"),
382
- "sequence": Value("string"),
383
- "label": Value("float32"),
384
- "affinity": Value("float32"),
385
- "affinity_class": Value("string"),
386
-
387
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
388
- "target_attention_mask": HFSequence(Value("int8")),
389
- "target_length": Value("int64"),
390
-
391
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
392
- "binder_attention_mask": HFSequence(Value("int8")),
393
- "binder_length": Value("int64"),
394
-
395
- COL_SMI_IPTM: Value("float32"),
396
- COL_AFF: Value("float32"),
397
- })
398
-
399
- def gen_rows(df: pd.DataFrame):
400
- for r in pbar(df.itertuples(index=False), total=len(df)):
401
- tgt = str(getattr(r, "target_sequence")).strip()
402
- bnd = str(getattr(r, "sequence")).strip()
403
-
404
- y = float(getattr(r, "label"))
405
- aff = float(getattr(r, COL_AFF))
406
- acls = str(getattr(r, "affinity_class"))
407
-
408
- iptm = getattr(r, COL_SMI_IPTM)
409
- iptm = float(iptm) if pd.notna(iptm) else np.nan
410
-
411
- # target token embeddings (ESM)
412
- t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
413
- t_list = t_emb.tolist()
414
- Lt = len(t_list)
415
-
416
- # binder token embeddings (PeptideCLM) — single-item batch
417
- _, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
418
- [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
419
- )
420
- b_emb = tok_list[0] # np.float16 (Lb, Hb)
421
- b_list = b_emb.tolist()
422
- Lb = int(lengths[0])
423
- b_mask = mask_list[0].astype(np.int8).tolist()
424
-
425
- yield {
426
- "target_sequence": tgt,
427
- "sequence": bnd,
428
- "label": np.float32(y),
429
- "affinity": np.float32(aff),
430
- "affinity_class": acls,
431
-
432
- "target_embedding": t_list,
433
- "target_attention_mask": [1] * Lt,
434
- "target_length": int(Lt),
435
-
436
- "binder_embedding": b_list,
437
- "binder_attention_mask": [int(x) for x in b_mask],
438
- "binder_length": int(Lb),
439
-
440
- COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
441
- COL_AFF: np.float32(aff),
442
- }
443
-
444
- out_dir.mkdir(parents=True, exist_ok=True)
445
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
446
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
447
- return ds
448
-
449
-
450
- # -------------------------
451
- # SMILES pooled + unpooled (PeptideCLM)
452
- # -------------------------
453
- def get_special_ids(tokenizer_obj):
454
- cand = [
455
- getattr(tokenizer_obj, "pad_token_id", None),
456
- getattr(tokenizer_obj, "cls_token_id", None),
457
- getattr(tokenizer_obj, "sep_token_id", None),
458
- getattr(tokenizer_obj, "bos_token_id", None),
459
- getattr(tokenizer_obj, "eos_token_id", None),
460
- getattr(tokenizer_obj, "mask_token_id", None),
461
- ]
462
- return sorted({x for x in cand if x is not None})
463
-
464
- @torch.no_grad()
465
- def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
466
- tok = tokenizer_obj(
467
- batch_sequences,
468
- return_tensors="pt",
469
- padding=True,
470
- truncation=True,
471
- max_length=max_length,
472
- )
473
- input_ids = tok["input_ids"].to(DEVICE)
474
- attention_mask = tok["attention_mask"].to(DEVICE)
475
-
476
- outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
477
- last_hidden = outputs.last_hidden_state # (B, L, H)
478
-
479
- special_ids = get_special_ids(tokenizer_obj)
480
- valid = attention_mask.bool()
481
- if len(special_ids) > 0:
482
- sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
483
- if hasattr(torch, "isin"):
484
- valid = valid & (~torch.isin(input_ids, sid))
485
- else:
486
- m = torch.zeros_like(valid)
487
- for s in special_ids:
488
- m |= (input_ids == s)
489
- valid = valid & (~m)
490
-
491
- valid_f = valid.unsqueeze(-1).float()
492
- summed = torch.sum(last_hidden * valid_f, dim=1)
493
- denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
494
- pooled = (summed / denom).detach().cpu().numpy()
495
-
496
- token_emb_list, mask_list, lengths = [], [], []
497
- for b in range(last_hidden.shape[0]):
498
- emb = last_hidden[b, valid[b]] # (Li, H)
499
- token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
500
- li = emb.shape[0]
501
- lengths.append(int(li))
502
- mask_list.append(np.ones((li,), dtype=np.int8))
503
-
504
- return pooled, token_emb_list, mask_list, lengths
505
-
506
- def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
507
- pooled_all = []
508
- token_emb_all = []
509
- mask_all = []
510
- lengths_all = []
511
-
512
- for i in pbar(range(0, len(seqs), batch_size)):
513
- batch = seqs[i:i + batch_size]
514
- pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
515
- batch, tokenizer_obj, model_roformer, max_length
516
- )
517
- pooled_all.append(pooled)
518
- token_emb_all.extend(tok_list)
519
- mask_all.extend(m_list)
520
- lengths_all.extend(lens)
521
-
522
- return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
523
-
524
- # -------------------------
525
- # Target embedding cache (NO extra ESM runs)
526
- # We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
527
- # -------------------------
528
- def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
529
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
530
- wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
531
-
532
- # compute target pooled embeddings once
533
- tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
534
- tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
535
-
536
- wt_train_tgt_emb = wt_pooled_embeddings(
537
- tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
538
- )
539
- wt_val_tgt_emb = wt_pooled_embeddings(
540
- tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
541
- )
542
-
543
- # build dict: target_sequence -> embedding (float32 array)
544
- # if duplicates exist, last wins; you can add checks if needed
545
- train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
546
- val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
547
- return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
548
- # -------------------------
549
- # Main
550
- # -------------------------
551
- def main():
552
- log(f"[INFO] DEVICE: {DEVICE}")
553
- OUT_ROOT.mkdir(parents=True, exist_ok=True)
554
-
555
- # 1) Load
556
- with section("load csv + dedup"):
557
- df = pd.read_csv(CSV_PATH)
558
- for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
559
- if c in df.columns:
560
- df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
561
-
562
- # Dedup on the full identity tuple you want
563
- DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
564
- df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
565
-
566
- print("Rows after dedup on", DEDUP_COLS, ":", len(df))
567
-
568
- need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
569
- missing = [c for c in need if c not in df.columns]
570
- if missing:
571
- raise ValueError(f"Missing required columns: {missing}")
572
-
573
- # numeric affinity for both branches
574
- df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
575
-
576
- # 2) Build WT subset + SMILES subset separately (NO global dropping)
577
- with section("prepare wt/smiles subsets"):
578
- # WT: requires a canonical peptide sequence (no X) + affinity
579
- df_wt = df.copy()
580
- df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
581
- df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
582
- df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
583
- df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
584
-
585
- # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
586
- df_smi = df.copy()
587
- df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
588
- df_smi = df_smi[
589
- pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
590
- ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
591
-
592
- is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
593
- df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
594
- df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
595
- df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
596
- df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
597
-
598
- log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
599
-
600
- # 3) Split separately (different sizes and memberships are expected)
601
- with section("split wt and smiles separately"):
602
- df_wt2 = make_distribution_matched_split(df_wt)
603
- df_smi2 = make_distribution_matched_split(df_smi)
604
-
605
- # save split tables
606
- wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
607
- smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
608
- df_wt2.to_csv(wt_split_csv, index=False)
609
- df_smi2.to_csv(smi_split_csv, index=False)
610
- log(f"Saved WT split meta: {wt_split_csv}")
611
- log(f"Saved SMILES split meta: {smi_split_csv}")
612
-
613
- # lightweight double-check (one-line)
614
- verify_split_before_embedding(
615
- df2=df_wt2,
616
- affinity_col=COL_AFF,
617
- split_col="split",
618
- seq_col="wt_sequence",
619
- iptm_col=COL_WT_IPTM,
620
- aff_class_col="affinity_class",
621
- aff_bins=AFFINITY_Q_BINS,
622
- save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
623
- verbose=False,
624
- )
625
- verify_split_before_embedding(
626
- df2=df_smi2,
627
- affinity_col=COL_AFF,
628
- split_col="split",
629
- seq_col="smiles_sequence",
630
- iptm_col=COL_SMI_IPTM,
631
- aff_class_col="affinity_class",
632
- aff_bins=AFFINITY_Q_BINS,
633
- save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
634
- verbose=False,
635
- )
636
-
637
- # Prepare split views
638
- def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
639
- out = df_in.copy()
640
- out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
641
- out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
642
- out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
643
- out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
644
- out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
645
- out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
646
- return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
647
-
648
- wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
649
- smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
650
-
651
- # -------------------------
652
- # Split views
653
- # -------------------------
654
- wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
655
- wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
656
- smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
657
- smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
658
-
659
-
660
- # =========================
661
- # TARGET pooled embeddings (ESM) — SEPARATE per branch
662
- # =========================
663
- with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
664
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
665
- wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
666
-
667
- # ---- WT targets ----
668
- wt_train_tgt_emb = wt_pooled_embeddings(
669
- wt_train["target_sequence"].astype(str).str.strip().tolist(),
670
- wt_tok, wt_esm,
671
- batch_size=WT_BATCH,
672
- max_length=WT_MAX_LEN,
673
- ).astype(np.float32)
674
-
675
- wt_val_tgt_emb = wt_pooled_embeddings(
676
- wt_val["target_sequence"].astype(str).str.strip().tolist(),
677
- wt_tok, wt_esm,
678
- batch_size=WT_BATCH,
679
- max_length=WT_MAX_LEN,
680
- ).astype(np.float32)
681
-
682
- # ---- SMILES targets (independent; may include UAA-only targets) ----
683
- smi_train_tgt_emb = wt_pooled_embeddings(
684
- smi_train["target_sequence"].astype(str).str.strip().tolist(),
685
- wt_tok, wt_esm,
686
- batch_size=WT_BATCH,
687
- max_length=WT_MAX_LEN,
688
- ).astype(np.float32)
689
-
690
- smi_val_tgt_emb = wt_pooled_embeddings(
691
- smi_val["target_sequence"].astype(str).str.strip().tolist(),
692
- wt_tok, wt_esm,
693
- batch_size=WT_BATCH,
694
- max_length=WT_MAX_LEN,
695
- ).astype(np.float32)
696
-
697
-
698
- # =========================
699
- # WT pooled binder embeddings (binder = WT peptide)
700
- # =========================
701
- with section("WT pooled binder embeddings + save"):
702
- wt_train_emb = wt_pooled_embeddings(
703
- wt_train["sequence"].astype(str).str.strip().tolist(),
704
- wt_tok, wt_esm,
705
- batch_size=WT_BATCH,
706
- max_length=WT_MAX_LEN,
707
- ).astype(np.float32)
708
-
709
- wt_val_emb = wt_pooled_embeddings(
710
- wt_val["sequence"].astype(str).str.strip().tolist(),
711
- wt_tok, wt_esm,
712
- batch_size=WT_BATCH,
713
- max_length=WT_MAX_LEN,
714
- ).astype(np.float32)
715
-
716
- wt_train_ds = Dataset.from_dict({
717
- "target_sequence": wt_train["target_sequence"].tolist(),
718
- "sequence": wt_train["sequence"].tolist(),
719
- "label": wt_train["label"].astype(float).tolist(),
720
- "target_embedding": wt_train_tgt_emb,
721
- "embedding": wt_train_emb,
722
- COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
723
- COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
724
- "affinity_class": wt_train["affinity_class"].tolist(),
725
- })
726
-
727
- wt_val_ds = Dataset.from_dict({
728
- "target_sequence": wt_val["target_sequence"].tolist(),
729
- "sequence": wt_val["sequence"].tolist(),
730
- "label": wt_val["label"].astype(float).tolist(),
731
- "target_embedding": wt_val_tgt_emb,
732
- "embedding": wt_val_emb,
733
- COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
734
- COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
735
- "affinity_class": wt_val["affinity_class"].tolist(),
736
- })
737
-
738
- wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
739
- wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
740
- wt_pooled_dd.save_to_disk(str(wt_pooled_out))
741
- log(f"Saved WT pooled -> {wt_pooled_out}")
742
-
743
-
744
- # =========================
745
- # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
746
- # =========================
747
- with section("SMILES pooled binder embeddings + save"):
748
- smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
749
- smi_roformer = (
750
- AutoModelForMaskedLM
751
- .from_pretrained(SMI_MODEL_NAME)
752
- .roformer
753
- .to(DEVICE)
754
- .eval()
755
- )
756
-
757
- smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
758
- smi_train["sequence"].astype(str).str.strip().tolist(),
759
- smi_tok, smi_roformer,
760
- batch_size=SMI_BATCH,
761
- max_length=SMI_MAX_LEN,
762
- )
763
-
764
- smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
765
- smi_val["sequence"].astype(str).str.strip().tolist(),
766
- smi_tok, smi_roformer,
767
- batch_size=SMI_BATCH,
768
- max_length=SMI_MAX_LEN,
769
- )
770
-
771
- smi_train_ds = Dataset.from_dict({
772
- "target_sequence": smi_train["target_sequence"].tolist(),
773
- "sequence": smi_train["sequence"].tolist(),
774
- "label": smi_train["label"].astype(float).tolist(),
775
- "target_embedding": smi_train_tgt_emb,
776
- "embedding": smi_train_pooled.astype(np.float32),
777
- COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
778
- COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
779
- "affinity_class": smi_train["affinity_class"].tolist(),
780
- })
781
-
782
- smi_val_ds = Dataset.from_dict({
783
- "target_sequence": smi_val["target_sequence"].tolist(),
784
- "sequence": smi_val["sequence"].tolist(),
785
- "label": smi_val["label"].astype(float).tolist(),
786
- "target_embedding": smi_val_tgt_emb,
787
- "embedding": smi_val_pooled.astype(np.float32),
788
- COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
789
- COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
790
- "affinity_class": smi_val["affinity_class"].tolist(),
791
- })
792
-
793
- smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
794
- smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
795
- smi_pooled_dd.save_to_disk(str(smi_pooled_out))
796
- log(f"Saved SMILES pooled -> {smi_pooled_out}")
797
-
798
-
799
- # =========================
800
- # WT unpooled paired (ESM target + ESM binder) + save
801
- # =========================
802
- with section("WT unpooled paired embeddings + save"):
803
- wt_tok_unpooled = wt_tok # reuse tokenizer
804
- wt_esm_unpooled = wt_esm # reuse model
805
-
806
- wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
807
- wt_unpooled_dd = DatasetDict({
808
- "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
809
- wt_tok_unpooled, wt_esm_unpooled),
810
- "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
811
- wt_tok_unpooled, wt_esm_unpooled),
812
- })
813
- # (Optional) also save as DatasetDict root if you want a single load_from_disk path:
814
- wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
815
- log(f"Saved WT unpooled -> {wt_unpooled_out}")
816
-
817
-
818
- # =========================
819
- # SMILES unpooled paired (ESM target + PeptideCLM binder) + save
820
- # =========================
821
- with section("SMILES unpooled paired embeddings + save"):
822
- # reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
823
- # otherwise re-init here:
824
- # smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
825
- # smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
826
-
827
- smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
828
- smi_unpooled_dd = DatasetDict({
829
- "train": build_smiles_unpooled_paired_dataset(
830
- smi_train, smi_unpooled_out / "train",
831
- wt_tok, wt_esm,
832
- smi_tok, smi_roformer
833
- ),
834
- "val": build_smiles_unpooled_paired_dataset(
835
- smi_val, smi_unpooled_out / "val",
836
- wt_tok, wt_esm,
837
- smi_tok, smi_roformer
838
- ),
839
- })
840
- smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
841
- log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
842
-
843
- log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
844
-
845
-
846
- if __name__ == "__main__":
847
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py DELETED
@@ -1,414 +0,0 @@
1
- import os, json
2
- from pathlib import Path
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- from torch.utils.data import DataLoader
7
- import optuna
8
- from datasets import load_from_disk, DatasetDict
9
- from scipy.stats import spearmanr
10
- from lightning.pytorch import seed_everything
11
- seed_everything(1986)
12
-
13
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
-
15
-
16
- def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
17
- rho = spearmanr(y_true, y_pred).correlation
18
- if rho is None or np.isnan(rho):
19
- return 0.0
20
- return float(rho)
21
-
22
-
23
- # -----------------------------
24
- # Affinity class thresholds (final spec)
25
- # High >= 9 ; Moderate 7-9 ; Low < 7
26
- # 0=High, 1=Moderate, 2=Low
27
- # -----------------------------
28
- def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
29
- high = y >= 9.0
30
- low = y < 7.0
31
- mid = ~(high | low)
32
- cls = torch.zeros_like(y, dtype=torch.long)
33
- cls[mid] = 1
34
- cls[low] = 2
35
- return cls
36
-
37
-
38
- # -----------------------------
39
- # Load paired DatasetDict
40
- # -----------------------------
41
- def load_split_paired(path: str):
42
- dd = load_from_disk(path)
43
- if not isinstance(dd, DatasetDict):
44
- raise ValueError(f"Expected DatasetDict at {path}")
45
- if "train" not in dd or "val" not in dd:
46
- raise ValueError(f"DatasetDict missing train/val at {path}")
47
- return dd["train"], dd["val"]
48
-
49
-
50
- # -----------------------------
51
- # Collate: pooled paired
52
- # -----------------------------
53
- def collate_pair_pooled(batch):
54
- Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
55
- Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
56
- y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
57
- return Pt, Pb, y
58
-
59
-
60
- # -----------------------------
61
- # Collate: unpooled paired
62
- # -----------------------------
63
- def collate_pair_unpooled(batch):
64
- B = len(batch)
65
- Ht = len(batch[0]["target_embedding"][0])
66
- Hb = len(batch[0]["binder_embedding"][0])
67
- Lt_max = max(int(x["target_length"]) for x in batch)
68
- Lb_max = max(int(x["binder_length"]) for x in batch)
69
-
70
- Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
71
- Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
72
- Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
73
- Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
74
- y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
75
-
76
- for i, x in enumerate(batch):
77
- t = torch.tensor(x["target_embedding"], dtype=torch.float32)
78
- b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
79
- lt, lb = t.shape[0], b.shape[0]
80
- Pt[i, :lt] = t
81
- Pb[i, :lb] = b
82
- Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
83
- Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
84
-
85
- return Pt, Mt, Pb, Mb, y
86
-
87
-
88
- # -----------------------------
89
- # Cross-attention models
90
- # -----------------------------
91
- class CrossAttnPooled(nn.Module):
92
- """
93
- pooled vectors -> treat as single-token sequences for cross attention
94
- """
95
- def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
96
- super().__init__()
97
- self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
98
- self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
99
-
100
- self.layers = nn.ModuleList([])
101
- for _ in range(n_layers):
102
- self.layers.append(nn.ModuleDict({
103
- "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
104
- "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
105
- "n1t": nn.LayerNorm(hidden),
106
- "n2t": nn.LayerNorm(hidden),
107
- "n1b": nn.LayerNorm(hidden),
108
- "n2b": nn.LayerNorm(hidden),
109
- "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
110
- "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
111
- }))
112
-
113
- self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
114
- self.reg = nn.Linear(hidden, 1)
115
- self.cls = nn.Linear(hidden, 3)
116
-
117
- def forward(self, t_vec, b_vec):
118
- # (B,Ht),(B,Hb)
119
- t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
120
- b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
121
-
122
- for L in self.layers:
123
- t_attn, _ = L["attn_tb"](t, b, b)
124
- t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
125
- t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
126
-
127
- b_attn, _ = L["attn_bt"](b, t, t)
128
- b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
129
- b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
130
-
131
- t0 = t[0]
132
- b0 = b[0]
133
- z = torch.cat([t0, b0], dim=-1)
134
- h = self.shared(z)
135
- return self.reg(h).squeeze(-1), self.cls(h)
136
-
137
-
138
- class CrossAttnUnpooled(nn.Module):
139
- """
140
- token sequences with masks; alternating cross attention.
141
- """
142
- def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
143
- super().__init__()
144
- self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
145
- self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
146
-
147
- self.layers = nn.ModuleList([])
148
- for _ in range(n_layers):
149
- self.layers.append(nn.ModuleDict({
150
- "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
151
- "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
152
- "n1t": nn.LayerNorm(hidden),
153
- "n2t": nn.LayerNorm(hidden),
154
- "n1b": nn.LayerNorm(hidden),
155
- "n2b": nn.LayerNorm(hidden),
156
- "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
157
- "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
158
- }))
159
-
160
- self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
161
- self.reg = nn.Linear(hidden, 1)
162
- self.cls = nn.Linear(hidden, 3)
163
-
164
- def masked_mean(self, X, M):
165
- Mf = M.unsqueeze(-1).float()
166
- denom = Mf.sum(dim=1).clamp(min=1.0)
167
- return (X * Mf).sum(dim=1) / denom
168
-
169
- def forward(self, T, Mt, B, Mb):
170
- # T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
171
- T = self.t_proj(T)
172
- Bx = self.b_proj(B)
173
-
174
- kp_t = ~Mt # key_padding_mask True = pad
175
- kp_b = ~Mb
176
-
177
- for L in self.layers:
178
- # T attends to B
179
- T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
180
- T = L["n1t"](T + T_attn)
181
- T = L["n2t"](T + L["fft"](T))
182
-
183
- # B attends to T
184
- B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
185
- Bx = L["n1b"](Bx + B_attn)
186
- Bx = L["n2b"](Bx + L["ffb"](Bx))
187
-
188
- t_pool = self.masked_mean(T, Mt)
189
- b_pool = self.masked_mean(Bx, Mb)
190
- z = torch.cat([t_pool, b_pool], dim=-1)
191
- h = self.shared(z)
192
- return self.reg(h).squeeze(-1), self.cls(h)
193
-
194
-
195
- # -----------------------------
196
- # Train/eval
197
- # -----------------------------
198
- @torch.no_grad()
199
- def eval_spearman_pooled(model, loader):
200
- model.eval()
201
- ys, ps = [], []
202
- for t, b, y in loader:
203
- t = t.to(DEVICE, non_blocking=True)
204
- b = b.to(DEVICE, non_blocking=True)
205
- pred, _ = model(t, b)
206
- ys.append(y.numpy())
207
- ps.append(pred.detach().cpu().numpy())
208
- return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
209
-
210
- @torch.no_grad()
211
- def eval_spearman_unpooled(model, loader):
212
- model.eval()
213
- ys, ps = [], []
214
- for T, Mt, B, Mb, y in loader:
215
- T = T.to(DEVICE, non_blocking=True)
216
- Mt = Mt.to(DEVICE, non_blocking=True)
217
- B = B.to(DEVICE, non_blocking=True)
218
- Mb = Mb.to(DEVICE, non_blocking=True)
219
- pred, _ = model(T, Mt, B, Mb)
220
- ys.append(y.numpy())
221
- ps.append(pred.detach().cpu().numpy())
222
- return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
223
-
224
- def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
225
- model.train()
226
- for t, b, y in loader:
227
- t = t.to(DEVICE, non_blocking=True)
228
- b = b.to(DEVICE, non_blocking=True)
229
- y = y.to(DEVICE, non_blocking=True)
230
- y_cls = affinity_to_class_tensor(y)
231
-
232
- opt.zero_grad(set_to_none=True)
233
- pred, logits = model(t, b)
234
- L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
235
- L.backward()
236
- if clip is not None:
237
- torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
238
- opt.step()
239
-
240
- def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
241
- model.train()
242
- for T, Mt, B, Mb, y in loader:
243
- T = T.to(DEVICE, non_blocking=True)
244
- Mt = Mt.to(DEVICE, non_blocking=True)
245
- B = B.to(DEVICE, non_blocking=True)
246
- Mb = Mb.to(DEVICE, non_blocking=True)
247
- y = y.to(DEVICE, non_blocking=True)
248
- y_cls = affinity_to_class_tensor(y)
249
-
250
- opt.zero_grad(set_to_none=True)
251
- pred, logits = model(T, Mt, B, Mb)
252
- L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
253
- L.backward()
254
- if clip is not None:
255
- torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
256
- opt.step()
257
-
258
-
259
- # -----------------------------
260
- # Optuna objective
261
- # -----------------------------
262
- def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
263
- lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
264
- wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
265
- dropout = trial.suggest_float("dropout", 0.0, 0.4)
266
- hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
267
- n_heads = trial.suggest_categorical("n_heads", [4, 8])
268
- n_layers = trial.suggest_int("n_layers", 1, 4)
269
- cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
270
- batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
271
-
272
- # infer dims from first row
273
- if mode == "pooled":
274
- Ht = len(train_ds[0]["target_embedding"])
275
- Hb = len(train_ds[0]["binder_embedding"])
276
- collate = collate_pair_pooled
277
- model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
278
- train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
279
- val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
280
- eval_fn = eval_spearman_pooled
281
- train_fn = train_one_epoch_pooled
282
-
283
- else:
284
- Ht = len(train_ds[0]["target_embedding"][0])
285
- Hb = len(train_ds[0]["binder_embedding"][0])
286
- collate = collate_pair_unpooled
287
- model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
288
- train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
289
- val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
290
- eval_fn = eval_spearman_unpooled
291
- train_fn = train_one_epoch_unpooled
292
-
293
- opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
294
- loss_reg = nn.MSELoss()
295
- loss_cls = nn.CrossEntropyLoss()
296
-
297
- best = -1e9
298
- bad = 0
299
- patience = 10
300
-
301
- for ep in range(1, 61):
302
- train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
303
- rho = eval_fn(model, val_loader)
304
-
305
- trial.report(rho, ep)
306
- if trial.should_prune():
307
- raise optuna.TrialPruned()
308
-
309
- if rho > best + 1e-6:
310
- best = rho
311
- bad = 0
312
- else:
313
- bad += 1
314
- if bad >= patience:
315
- break
316
-
317
- return float(best)
318
-
319
-
320
- # -----------------------------
321
- # Run: optuna + refit best
322
- # -----------------------------
323
- def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
324
- out_dir = Path(out_dir)
325
- out_dir.mkdir(parents=True, exist_ok=True)
326
-
327
- train_ds, val_ds = load_split_paired(dataset_path)
328
- print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
329
-
330
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
331
- study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
332
-
333
- study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
334
- best = study.best_trial
335
- best_params = dict(best.params)
336
-
337
- # refit longer
338
- lr = float(best_params["lr"])
339
- wd = float(best_params["weight_decay"])
340
- dropout = float(best_params["dropout"])
341
- hidden = int(best_params["hidden_dim"])
342
- n_heads = int(best_params["n_heads"])
343
- n_layers = int(best_params["n_layers"])
344
- cls_w = float(best_params["cls_weight"])
345
- batch = int(best_params["batch_size"])
346
-
347
- loss_reg = nn.MSELoss()
348
- loss_cls = nn.CrossEntropyLoss()
349
-
350
- if mode == "pooled":
351
- Ht = len(train_ds[0]["target_embedding"])
352
- Hb = len(train_ds[0]["binder_embedding"])
353
- model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
354
- collate = collate_pair_pooled
355
- train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
356
- val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
357
- eval_fn = eval_spearman_pooled
358
- train_fn = train_one_epoch_pooled
359
- else:
360
- Ht = len(train_ds[0]["target_embedding"][0])
361
- Hb = len(train_ds[0]["binder_embedding"][0])
362
- model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
363
- collate = collate_pair_unpooled
364
- train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
365
- val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
366
- eval_fn = eval_spearman_unpooled
367
- train_fn = train_one_epoch_unpooled
368
-
369
- opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
370
-
371
- best_rho = -1e9
372
- bad = 0
373
- patience = 20
374
- best_state = None
375
-
376
- for ep in range(1, 201):
377
- train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
378
- rho = eval_fn(model, val_loader)
379
-
380
- if rho > best_rho + 1e-6:
381
- best_rho = rho
382
- bad = 0
383
- best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
384
- else:
385
- bad += 1
386
- if bad >= patience:
387
- break
388
-
389
- if best_state is not None:
390
- model.load_state_dict(best_state)
391
-
392
- # save
393
- torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
394
- with open(out_dir / "best_params.json", "w") as f:
395
- json.dump(best_params, f, indent=2)
396
-
397
- print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
398
-
399
-
400
- if __name__ == "__main__":
401
- import argparse
402
- ap = argparse.ArgumentParser()
403
- ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
404
- ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
405
- ap.add_argument("--out_dir", type=str, required=True)
406
- ap.add_argument("--n_trials", type=int, default=50)
407
- args = ap.parse_args()
408
-
409
- run(
410
- dataset_path=args.dataset_path,
411
- out_dir=args.out_dir,
412
- mode=args.mode,
413
- n_trials=args.n_trials,
414
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash DELETED
@@ -1,31 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --job-name=b-data
3
- #SBATCH --partition=dgx-b200
4
- #SBATCH --gpus=1
5
- #SBATCH --cpus-per-task=10
6
- #SBATCH --mem=100G
7
- #SBATCH --time=48:00:00
8
- #SBATCH --output=%x_%j.out
9
-
10
- HOME_LOC=/vast/projects/pranam/lab/yz927
11
- SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
12
- DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
13
- OBJECTIVE='binding_affinity'
14
- WT='smiles' #wt/smiles
15
- STATUS='pooled' #pooled/unpooled
16
- DATA_FILE="pair_wt_${WT}_${STATUS}"
17
- LOG_LOC=$SCRIPT_LOC
18
- DATE=$(date +%m_%d)
19
- SPECIAL_PREFIX="binding_affinity_data_generation"
20
-
21
- # Create log directory if it doesn't exist
22
- mkdir -p $LOG_LOC
23
-
24
- cd $SCRIPT_LOC
25
- source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
26
- conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
27
-
28
- python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
29
-
30
- echo "Script completed at $(date)"
31
- conda deactivate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py DELETED
@@ -1,508 +0,0 @@
1
- #!/usr/bin/env python3
2
- # finetune_xgb_halflife_cv_optuna.py
3
-
4
- import os
5
- import json
6
- import math
7
- import hashlib
8
- from dataclasses import dataclass
9
- from typing import Dict, Any, Optional, Tuple, List
10
-
11
- import numpy as np
12
- import pandas as pd
13
- import optuna
14
-
15
- from sklearn.model_selection import KFold
16
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
17
- from scipy.stats import spearmanr
18
-
19
- import torch
20
- from transformers import AutoTokenizer, AutoModel
21
-
22
- import xgboost as xgb
23
-
24
-
25
- # -----------------------------
26
- # Repro
27
- # -----------------------------
28
- SEED = 1986
29
- np.random.seed(SEED)
30
- torch.manual_seed(SEED)
31
-
32
-
33
- # -----------------------------
34
- # Metrics (mirrors your stability script style)
35
- # -----------------------------
36
- def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
37
- rho = spearmanr(y_true, y_pred).correlation
38
- if rho is None or np.isnan(rho):
39
- return 0.0
40
- return float(rho)
41
-
42
- def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
43
- rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
44
- mae = float(mean_absolute_error(y_true, y_pred))
45
- r2 = float(r2_score(y_true, y_pred))
46
- rho = float(safe_spearmanr(y_true, y_pred))
47
- return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
48
-
49
-
50
- # -----------------------------
51
- # ESM-2 embeddings (cached)
52
- # -----------------------------
53
- @dataclass
54
- class ESMEmbedderConfig:
55
- model_name: str = "facebook/esm2_t33_650M_UR50D"
56
- batch_size: int = 8
57
- max_length: int = 1024 # truncate very long proteins
58
- fp16: bool = True
59
-
60
- class ESM2Embedder:
61
- """
62
- Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence.
63
- """
64
- def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"):
65
- self.cfg = cfg
66
- self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu"
67
- self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False)
68
- self.model = AutoModel.from_pretrained(cfg.model_name)
69
- self.model.eval()
70
- self.model.to(self.device)
71
-
72
- # Turn off gradients
73
- for p in self.model.parameters():
74
- p.requires_grad = False
75
-
76
- @torch.inference_mode()
77
- def embed(self, seqs: List[str]) -> np.ndarray:
78
- out = []
79
- bs = self.cfg.batch_size
80
-
81
- use_amp = (self.cfg.fp16 and self.device == "cuda")
82
- autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast # safe fallback
83
-
84
- for i in range(0, len(seqs), bs):
85
- batch = [s.strip().upper() for s in seqs[i:i+bs]]
86
- toks = self.tokenizer(
87
- batch,
88
- return_tensors="pt",
89
- padding=True,
90
- truncation=True,
91
- max_length=self.cfg.max_length,
92
- add_special_tokens=True,
93
- )
94
- toks = {k: v.to(self.device) for k, v in toks.items()}
95
- attn = toks["attention_mask"] # (B, L)
96
-
97
- with autocast(enabled=use_amp):
98
- h = self.model(**toks).last_hidden_state # (B, L, H)
99
-
100
- # mask out special tokens: first token is <cls>; last non-pad token is usually <eos>
101
- mask = attn.clone()
102
- mask[:, 0] = 0
103
- lengths = attn.sum(dim=1) # includes special tokens
104
- # zero out last real token position per sequence
105
- eos_pos = (lengths - 1).clamp(min=0)
106
- mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0
107
-
108
- denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B,1)
109
- pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom # (B,H)
110
- out.append(pooled.float().detach().cpu().numpy())
111
-
112
- return np.concatenate(out, axis=0).astype(np.float32)
113
-
114
-
115
- def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str:
116
- h = hashlib.sha256()
117
- for s in seqs:
118
- h.update(s.encode("utf-8"))
119
- h.update(b"\n")
120
- h.update(np.asarray(y, dtype=np.float32).tobytes())
121
- h.update(extra.encode("utf-8"))
122
- return h.hexdigest()[:16]
123
-
124
-
125
- def load_or_compute_embeddings(
126
- df: pd.DataFrame,
127
- out_dir: str,
128
- embed_cfg: ESMEmbedderConfig,
129
- device: str,
130
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
131
- os.makedirs(out_dir, exist_ok=True)
132
-
133
- seqs = df["sequence"].astype(str).tolist()
134
- y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32)
135
-
136
- fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}")
137
- emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy")
138
- meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json")
139
-
140
- if os.path.exists(emb_path) and os.path.exists(meta_path):
141
- X = np.load(emb_path).astype(np.float32)
142
- return X, y, np.asarray(seqs)
143
-
144
- embedder = ESM2Embedder(embed_cfg, device=device)
145
- X = embedder.embed(seqs) # (N,H)
146
-
147
- np.save(emb_path, X)
148
- with open(meta_path, "w") as f:
149
- json.dump(
150
- {
151
- "fingerprint": fp,
152
- "model_name": embed_cfg.model_name,
153
- "max_length": embed_cfg.max_length,
154
- "n": len(seqs),
155
- "dim": int(X.shape[1]),
156
- },
157
- f,
158
- indent=2,
159
- )
160
- return X, y, np.asarray(seqs)
161
-
162
-
163
- # -----------------------------
164
- # XGBoost training (supports "finetune" via xgb_model)
165
- # -----------------------------
166
- def train_xgb_reg(
167
- X_train: np.ndarray,
168
- y_train: np.ndarray,
169
- X_val: np.ndarray,
170
- y_val: np.ndarray,
171
- params: Dict[str, Any],
172
- base_model_json: Optional[str] = None,
173
- ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]:
174
- dtrain = xgb.DMatrix(X_train, label=y_train)
175
- dval = xgb.DMatrix(X_val, label=y_val)
176
-
177
- num_boost_round = int(params.pop("num_boost_round"))
178
- early_stopping_rounds = int(params.pop("early_stopping_rounds"))
179
-
180
- # Important: load a fresh base model each fold (avoid leakage)
181
- xgb_model = None
182
- if base_model_json is not None:
183
- booster0 = xgb.Booster()
184
- booster0.load_model(base_model_json)
185
- xgb_model = booster0
186
-
187
- booster = xgb.train(
188
- params=params,
189
- dtrain=dtrain,
190
- num_boost_round=num_boost_round,
191
- evals=[(dval, "val")],
192
- early_stopping_rounds=early_stopping_rounds,
193
- verbose_eval=False,
194
- xgb_model=xgb_model, # <-- "finetune": continue boosting from base model
195
- )
196
-
197
- p_train = booster.predict(dtrain)
198
- p_val = booster.predict(dval)
199
- best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1))
200
- return booster, p_train, p_val, best_iter
201
-
202
-
203
- # -----------------------------
204
- # Optuna objective: 5-fold mean Spearman rho
205
- # -----------------------------
206
- def make_cv_objective(
207
- X: np.ndarray,
208
- y: np.ndarray,
209
- n_splits: int,
210
- device: str,
211
- base_model_json: Optional[str],
212
- target_transform: str,
213
- ):
214
- kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
215
-
216
- # Optional target transform (sometimes helps with heavy-tailed half-life)
217
- if target_transform == "log1p":
218
- y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
219
- elif target_transform == "none":
220
- y_used = y.astype(np.float32)
221
- else:
222
- raise ValueError(f"Unknown target_transform: {target_transform}")
223
-
224
- def objective(trial: optuna.Trial) -> float:
225
- # Hyperparam ranges patterned after your stability script :contentReference[oaicite:1]{index=1}
226
- params = {
227
- "objective": "reg:squarederror",
228
- "eval_metric": "rmse",
229
-
230
- "lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
231
- "alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
232
- "gamma": trial.suggest_float("gamma", 0.0, 10.0),
233
-
234
- "max_depth": trial.suggest_int("max_depth", 2, 12),
235
- "min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True),
236
- "subsample": trial.suggest_float("subsample", 0.5, 1.0),
237
- "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
238
-
239
- "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
240
-
241
- "tree_method": "hist",
242
- "device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu",
243
- }
244
- params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500)
245
- params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150)
246
-
247
- fold_metrics = []
248
- fold_best_iters = []
249
-
250
- for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
251
- Xtr, ytr = X[tr_idx], y_used[tr_idx]
252
- Xva, yva = X[va_idx], y_used[va_idx]
253
-
254
- _, _, p_va, best_iter = train_xgb_reg(
255
- Xtr, ytr, Xva, yva, params.copy(),
256
- base_model_json=base_model_json,
257
- )
258
-
259
- m = eval_regression(yva, p_va)
260
- fold_metrics.append(m)
261
- fold_best_iters.append(best_iter)
262
-
263
- mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics]))
264
- mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics]))
265
- mean_mae = float(np.mean([m["mae"] for m in fold_metrics]))
266
- mean_r2 = float(np.mean([m["r2"] for m in fold_metrics]))
267
- mean_best_iter = float(np.mean(fold_best_iters))
268
-
269
- trial.set_user_attr("cv_spearman_rho", mean_rho)
270
- trial.set_user_attr("cv_rmse", mean_rmse)
271
- trial.set_user_attr("cv_mae", mean_mae)
272
- trial.set_user_attr("cv_r2", mean_r2)
273
- trial.set_user_attr("cv_mean_best_iter", mean_best_iter)
274
-
275
- # maximize Spearman rho (same as your stability workflow :contentReference[oaicite:2]{index=2})
276
- return mean_rho
277
-
278
- return objective
279
-
280
-
281
- def refit_and_save(
282
- X: np.ndarray,
283
- y: np.ndarray,
284
- seqs: np.ndarray,
285
- out_dir: str,
286
- best_params: Dict[str, Any],
287
- n_splits: int,
288
- device: str,
289
- base_model_json: Optional[str],
290
- target_transform: str,
291
- ):
292
- os.makedirs(out_dir, exist_ok=True)
293
-
294
- # Transform target consistently
295
- if target_transform == "log1p":
296
- y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
297
- else:
298
- y_used = y.astype(np.float32)
299
-
300
- kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
301
-
302
- # 1) get OOF preds + average best_iteration
303
- oof_pred = np.zeros_like(y_used, dtype=np.float32)
304
- best_iters = []
305
- fold_rows = []
306
-
307
- for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
308
- Xtr, ytr = X[tr_idx], y_used[tr_idx]
309
- Xva, yva = X[va_idx], y_used[va_idx]
310
-
311
- _, _, p_va, best_iter = train_xgb_reg(
312
- Xtr, ytr, Xva, yva, best_params.copy(),
313
- base_model_json=base_model_json,
314
- )
315
- oof_pred[va_idx] = p_va.astype(np.float32)
316
- best_iters.append(best_iter)
317
-
318
- m = eval_regression(yva, p_va)
319
- fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)})
320
-
321
- fold_df = pd.DataFrame(fold_rows)
322
- fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False)
323
-
324
- cv_metrics = eval_regression(y_used, oof_pred)
325
- with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f:
326
- json.dump(cv_metrics, f, indent=2)
327
-
328
- oof_df = pd.DataFrame({
329
- "sequence": seqs,
330
- "y_true_used": y_used.astype(float),
331
- "y_pred_oof": oof_pred.astype(float),
332
- "residual": (y_used - oof_pred).astype(float),
333
- })
334
- oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False)
335
-
336
- mean_best_iter = int(round(float(np.mean(best_iters))))
337
- final_rounds = max(mean_best_iter + 1, 10)
338
-
339
- # 2) train final model on ALL data (no early stopping here; use final_rounds)
340
- dtrain_all = xgb.DMatrix(X, label=y_used)
341
-
342
- xgb_model = None
343
- if base_model_json is not None:
344
- booster0 = xgb.Booster()
345
- booster0.load_model(base_model_json)
346
- xgb_model = booster0
347
-
348
- final_params = best_params.copy()
349
- final_params.pop("early_stopping_rounds", None)
350
- final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
351
-
352
- booster = xgb.train(
353
- params=final_params,
354
- dtrain=dtrain_all,
355
- num_boost_round=int(final_params.pop("num_boost_round", final_rounds)),
356
- evals=[],
357
- verbose_eval=False,
358
- xgb_model=xgb_model,
359
- )
360
-
361
- model_path = os.path.join(out_dir, "best_model_finetuned.json")
362
- booster.save_model(model_path)
363
-
364
- with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f:
365
- json.dump(
366
- {
367
- "target_transform": target_transform,
368
- "final_rounds_used": int(final_rounds),
369
- "cv_oof_metrics_on_used_target": cv_metrics,
370
- "model_path": model_path,
371
- },
372
- f,
373
- indent=2,
374
- )
375
-
376
- print("=" * 72)
377
- print("[Final] CV OOF metrics (on transformed target if enabled):")
378
- print(json.dumps(cv_metrics, indent=2))
379
- print(f"[Final] Saved finetuned model -> {model_path}")
380
- print("=" * 72)
381
-
382
-
383
- def main():
384
- import argparse
385
-
386
- parser = argparse.ArgumentParser()
387
- parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv")
388
- parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb")
389
-
390
- # If provided, we will "finetune" by continuing boosting from this model
391
- parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from")
392
-
393
- # ESM embedding config
394
- parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D")
395
- parser.add_argument("--esm_batch_size", type=int, default=8)
396
- parser.add_argument("--esm_max_length", type=int, default=1024)
397
- parser.add_argument("--no_fp16", action="store_true")
398
-
399
- # Training config
400
- parser.add_argument("--n_trials", type=int, default=200)
401
- parser.add_argument("--n_splits", type=int, default=5)
402
- parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
403
- parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"])
404
-
405
- args = parser.parse_args()
406
- os.makedirs(args.out_dir, exist_ok=True)
407
-
408
- # Load data
409
- df = pd.read_csv(args.csv_path)
410
- if "sequence" not in df.columns or "half_life_hours" not in df.columns:
411
- raise ValueError("CSV must contain columns: sequence, half_life_hours")
412
-
413
- df = df.dropna(subset=["sequence", "half_life_hours"]).copy()
414
- df["sequence"] = df["sequence"].astype(str).str.strip()
415
- df = df[df["sequence"].str.len() > 0]
416
- df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True)
417
-
418
- print(f"[Data] N={len(df)} from {args.csv_path}")
419
-
420
- # Embeddings (cached)
421
- embed_cfg = ESMEmbedderConfig(
422
- model_name=args.esm_model,
423
- batch_size=args.esm_batch_size,
424
- max_length=args.esm_max_length,
425
- fp16=(not args.no_fp16),
426
- )
427
- X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device)
428
- print(f"[Embeddings] X={X.shape} (float32)")
429
-
430
- # Optuna study
431
- sampler = optuna.samplers.TPESampler(seed=SEED)
432
- study = optuna.create_study(
433
- direction="maximize", # like your stability script :contentReference[oaicite:3]{index=3}
434
- sampler=sampler,
435
- pruner=optuna.pruners.MedianPruner(),
436
- )
437
-
438
- objective = make_cv_objective(
439
- X=X,
440
- y=y,
441
- n_splits=args.n_splits,
442
- device=args.device,
443
- base_model_json=args.base_model_json,
444
- target_transform=args.target_transform,
445
- )
446
- study.optimize(objective, n_trials=args.n_trials)
447
-
448
- # Save trials
449
- trials_df = study.trials_dataframe()
450
- trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False)
451
-
452
- best = study.best_trial
453
- best_params = dict(best.params)
454
-
455
- # Build full param dict for refit
456
- best_xgb_params = {
457
- "objective": "reg:squarederror",
458
- "eval_metric": "rmse",
459
- "lambda": best_params["lambda"],
460
- "alpha": best_params["alpha"],
461
- "gamma": best_params["gamma"],
462
- "max_depth": best_params["max_depth"],
463
- "min_child_weight": best_params["min_child_weight"],
464
- "subsample": best_params["subsample"],
465
- "colsample_bytree": best_params["colsample_bytree"],
466
- "learning_rate": best_params["learning_rate"],
467
- "tree_method": "hist",
468
- "device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu",
469
- "num_boost_round": best_params["num_boost_round"],
470
- "early_stopping_rounds": best_params["early_stopping_rounds"],
471
- }
472
-
473
- # Summary
474
- summary = {
475
- "best_trial_number": int(best.number),
476
- "best_value_cv_spearman_rho": float(best.value),
477
- "best_user_attrs": best.user_attrs,
478
- "best_params": best_params,
479
- "best_xgb_params_full": best_xgb_params,
480
- "base_model_json": args.base_model_json,
481
- "target_transform": args.target_transform,
482
- "esm_model": args.esm_model,
483
- "esm_max_length": args.esm_max_length,
484
- }
485
- with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f:
486
- json.dump(summary, f, indent=2)
487
-
488
- print("=" * 72)
489
- print("[Optuna] Best CV Spearman rho:", float(best.value))
490
- print("[Optuna] Best params:\n", json.dumps(best_params, indent=2))
491
- print("=" * 72)
492
-
493
- # Refit + save final finetuned model + OOF predictions
494
- refit_and_save(
495
- X=X,
496
- y=y,
497
- seqs=seqs,
498
- out_dir=args.out_dir,
499
- best_params=best_xgb_params,
500
- n_splits=args.n_splits,
501
- device=args.device,
502
- base_model_json=args.base_model_json,
503
- target_transform=args.target_transform,
504
- )
505
-
506
-
507
- if __name__ == "__main__":
508
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py DELETED
@@ -1,309 +0,0 @@
1
- #!/usr/bin/env python3
2
- # export_val_preds_csv.py
3
-
4
- import argparse
5
- from pathlib import Path
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from torch.utils.data import DataLoader
10
- from datasets import load_from_disk, DatasetDict
11
-
12
- # -----------------------------
13
- # Repro / device
14
- # -----------------------------
15
- def seed_all(seed=1986):
16
- import random
17
- random.seed(seed)
18
- np.random.seed(seed)
19
- torch.manual_seed(seed)
20
- torch.cuda.manual_seed_all(seed)
21
-
22
- seed_all(1986)
23
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
-
25
-
26
- # -----------------------------
27
- # Load paired DatasetDict
28
- # -----------------------------
29
- def load_split_paired(path: str):
30
- dd = load_from_disk(path)
31
- if not isinstance(dd, DatasetDict):
32
- raise ValueError(f"Expected DatasetDict at {path}")
33
- if "train" not in dd or "val" not in dd:
34
- raise ValueError(f"DatasetDict missing train/val at {path}")
35
- return dd["train"], dd["val"]
36
-
37
-
38
- # -----------------------------
39
- # Collate fns (same as yours)
40
- # -----------------------------
41
- def collate_pair_pooled(batch):
42
- Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
43
- Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32)
44
- y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
45
- return Pt, Pb, y
46
-
47
- def collate_pair_unpooled(batch):
48
- B = len(batch)
49
- Ht = len(batch[0]["target_embedding"][0])
50
- Hb = len(batch[0]["binder_embedding"][0])
51
- Lt_max = max(int(x["target_length"]) for x in batch)
52
- Lb_max = max(int(x["binder_length"]) for x in batch)
53
-
54
- Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
55
- Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
56
- Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
57
- Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
58
- y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
59
-
60
- for i, x in enumerate(batch):
61
- t = torch.tensor(x["target_embedding"], dtype=torch.float32)
62
- b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
63
- lt, lb = t.shape[0], b.shape[0]
64
- Pt[i, :lt] = t
65
- Pb[i, :lb] = b
66
- Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
67
- Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
68
-
69
- return Pt, Mt, Pb, Mb, y
70
-
71
-
72
- # -----------------------------
73
- # Models (same as yours)
74
- # -----------------------------
75
- class CrossAttnPooled(nn.Module):
76
- def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
77
- super().__init__()
78
- self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
79
- self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
80
-
81
- self.layers = nn.ModuleList([])
82
- for _ in range(n_layers):
83
- self.layers.append(nn.ModuleDict({
84
- "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
85
- "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
86
- "n1t": nn.LayerNorm(hidden),
87
- "n2t": nn.LayerNorm(hidden),
88
- "n1b": nn.LayerNorm(hidden),
89
- "n2b": nn.LayerNorm(hidden),
90
- "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
91
- "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
92
- }))
93
-
94
- self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
95
- self.reg = nn.Linear(hidden, 1)
96
- self.cls = nn.Linear(hidden, 3)
97
-
98
- def forward(self, t_vec, b_vec):
99
- t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
100
- b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
101
-
102
- for L in self.layers:
103
- t_attn, _ = L["attn_tb"](t, b, b)
104
- t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
105
- t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
106
-
107
- b_attn, _ = L["attn_bt"](b, t, t)
108
- b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
109
- b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
110
-
111
- z = torch.cat([t[0], b[0]], dim=-1)
112
- h = self.shared(z)
113
- return self.reg(h).squeeze(-1), self.cls(h)
114
-
115
-
116
- class CrossAttnUnpooled(nn.Module):
117
- def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
118
- super().__init__()
119
- self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
120
- self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
121
-
122
- self.layers = nn.ModuleList([])
123
- for _ in range(n_layers):
124
- self.layers.append(nn.ModuleDict({
125
- "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
126
- "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
127
- "n1t": nn.LayerNorm(hidden),
128
- "n2t": nn.LayerNorm(hidden),
129
- "n1b": nn.LayerNorm(hidden),
130
- "n2b": nn.LayerNorm(hidden),
131
- "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
132
- "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
133
- }))
134
-
135
- self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
136
- self.reg = nn.Linear(hidden, 1)
137
- self.cls = nn.Linear(hidden, 3)
138
-
139
- def masked_mean(self, X, M):
140
- Mf = M.unsqueeze(-1).float()
141
- denom = Mf.sum(dim=1).clamp(min=1.0)
142
- return (X * Mf).sum(dim=1) / denom
143
-
144
- def forward(self, T, Mt, B, Mb):
145
- T = self.t_proj(T)
146
- Bx = self.b_proj(B)
147
-
148
- kp_t = ~Mt
149
- kp_b = ~Mb
150
-
151
- for L in self.layers:
152
- T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
153
- T = L["n1t"](T + T_attn)
154
- T = L["n2t"](T + L["fft"](T))
155
-
156
- B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
157
- Bx = L["n1b"](Bx + B_attn)
158
- Bx = L["n2b"](Bx + L["ffb"](Bx))
159
-
160
- t_pool = self.masked_mean(T, Mt)
161
- b_pool = self.masked_mean(Bx, Mb)
162
- z = torch.cat([t_pool, b_pool], dim=-1)
163
- h = self.shared(z)
164
- return self.reg(h).squeeze(-1), self.cls(h)
165
-
166
-
167
- # -----------------------------
168
- # Helpers
169
- # -----------------------------
170
- def softmax_np(logits: np.ndarray) -> np.ndarray:
171
- x = logits - logits.max(axis=1, keepdims=True)
172
- ex = np.exp(x)
173
- return ex / ex.sum(axis=1, keepdims=True)
174
-
175
- def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray:
176
- centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3)
177
- return (probs * centers).sum(axis=1)
178
-
179
- def load_checkpoint(ckpt_path: str, mode: str, train_ds):
180
- ckpt = torch.load(ckpt_path, map_location="cpu")
181
- params = ckpt.get("best_params", {})
182
-
183
- hidden = int(params.get("hidden_dim", 512))
184
- n_heads = int(params.get("n_heads", 8))
185
- n_layers = int(params.get("n_layers", 3))
186
- dropout = float(params.get("dropout", 0.1))
187
-
188
- if mode == "pooled":
189
- Ht = len(train_ds[0]["target_embedding"])
190
- Hb = len(train_ds[0]["binder_embedding"])
191
- model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
192
- else:
193
- Ht = len(train_ds[0]["target_embedding"][0])
194
- Hb = len(train_ds[0]["binder_embedding"][0])
195
- model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
196
-
197
- model.load_state_dict(ckpt["state_dict"], strict=True)
198
- model.to(DEVICE).eval()
199
- return model
200
-
201
-
202
- @torch.no_grad()
203
- def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str,
204
- out_csv: str, batch_size: int, num_workers: int,
205
- class_centers=(9.5, 8.0, 6.0)):
206
- train_ds, val_ds = load_split_paired(dataset_path)
207
- model = load_checkpoint(ckpt_path, mode, train_ds)
208
-
209
- if mode == "pooled":
210
- loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
211
- num_workers=num_workers, pin_memory=True,
212
- collate_fn=collate_pair_pooled)
213
- y_all, pred_reg_all, logits_all = [], [], []
214
- for t, b, y in loader:
215
- t = t.to(DEVICE, non_blocking=True)
216
- b = b.to(DEVICE, non_blocking=True)
217
- pred_reg, logits = model(t, b)
218
- y_all.append(y.numpy())
219
- pred_reg_all.append(pred_reg.detach().cpu().numpy())
220
- logits_all.append(logits.detach().cpu().numpy())
221
-
222
- else:
223
- loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
224
- num_workers=num_workers, pin_memory=True,
225
- collate_fn=collate_pair_unpooled)
226
- y_all, pred_reg_all, logits_all = [], [], []
227
- for T, Mt, B, Mb, y in loader:
228
- T = T.to(DEVICE, non_blocking=True)
229
- Mt = Mt.to(DEVICE, non_blocking=True)
230
- B = B.to(DEVICE, non_blocking=True)
231
- Mb = Mb.to(DEVICE, non_blocking=True)
232
- pred_reg, logits = model(T, Mt, B, Mb)
233
- y_all.append(y.numpy())
234
- pred_reg_all.append(pred_reg.detach().cpu().numpy())
235
- logits_all.append(logits.detach().cpu().numpy())
236
-
237
- y_true = np.concatenate(y_all)
238
- y_pred_reg = np.concatenate(pred_reg_all)
239
- logits = np.concatenate(logits_all)
240
-
241
- probs = softmax_np(logits) # (N,3)
242
- y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers)
243
-
244
- # Build CSV rows
245
- out = Path(out_csv)
246
- out.parent.mkdir(parents=True, exist_ok=True)
247
-
248
- header = [
249
- "split", "mode",
250
- "y_true",
251
- "y_pred_reg",
252
- "p_high", "p_moderate", "p_low",
253
- "y_pred_cls_score",
254
- "center_high", "center_moderate", "center_low",
255
- ]
256
-
257
- centers = list(class_centers)
258
- rows = np.column_stack([
259
- y_true,
260
- y_pred_reg,
261
- probs[:, 0], probs[:, 1], probs[:, 2],
262
- y_pred_cls_score,
263
- np.full_like(y_true, centers[0], dtype=np.float32),
264
- np.full_like(y_true, centers[1], dtype=np.float32),
265
- np.full_like(y_true, centers[2], dtype=np.float32),
266
- ])
267
-
268
- with out.open("w") as f:
269
- f.write(",".join(header) + "\n")
270
- for i in range(rows.shape[0]):
271
- f.write(
272
- "val," + mode + "," +
273
- ",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) +
274
- "\n"
275
- )
276
-
277
- print(f"[Data] Val N={len(y_true)} | mode={mode}")
278
- print(f"[Saved] {out}")
279
-
280
-
281
- def main():
282
- ap = argparse.ArgumentParser()
283
- ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)")
284
- ap.add_argument("--ckpt", required=True, help="Path to best_model.pt")
285
- ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True)
286
- ap.add_argument("--out_csv", required=True)
287
- ap.add_argument("--batch_size", type=int, default=128)
288
- ap.add_argument("--num_workers", type=int, default=4)
289
-
290
- # Optional: choose class-centers for expected-score conversion
291
- ap.add_argument("--center_high", type=float, default=9.5)
292
- ap.add_argument("--center_moderate", type=float, default=8.0)
293
- ap.add_argument("--center_low", type=float, default=6.0)
294
-
295
- args = ap.parse_args()
296
-
297
- export_val_preds_csv(
298
- dataset_path=args.dataset_path,
299
- ckpt_path=args.ckpt,
300
- mode=args.mode,
301
- out_csv=args.out_csv,
302
- batch_size=args.batch_size,
303
- num_workers=args.num_workers,
304
- class_centers=(args.center_high, args.center_moderate, args.center_low),
305
- )
306
-
307
-
308
- if __name__ == "__main__":
309
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt DELETED
@@ -1,234 +0,0 @@
1
- ./hemolysis/cnn_smiles/optimization_summary.txt
2
- ./hemolysis/cnn_smiles/pr_curve.png
3
- ./hemolysis/cnn_smiles/roc_curve.png
4
- ./hemolysis/cnn_smiles/study_trials.csv
5
- ./hemolysis/cnn_smiles/train_predictions.csv
6
- ./hemolysis/cnn_smiles/val_predictions.csv
7
- ./hemolysis/cnn_wt/optimization_summary.txt
8
- ./hemolysis/cnn_wt/pr_curve.png
9
- ./hemolysis/cnn_wt/roc_curve.png
10
- ./hemolysis/cnn_wt/study_trials.csv
11
- ./hemolysis/cnn_wt/train_predictions.csv
12
- ./hemolysis/cnn_wt/val_predictions.csv
13
- ./hemolysis/enet_gpu/optimization_summary.txt
14
- ./hemolysis/enet_gpu/pr_curve.png
15
- ./hemolysis/enet_gpu/roc_curve.png
16
- ./hemolysis/enet_gpu/study_trials.csv
17
- ./hemolysis/enet_gpu/train_predictions.csv
18
- ./hemolysis/enet_gpu/val_predictions.csv
19
- ./hemolysis/enet_gpu_smiles/optimization_summary.txt
20
- ./hemolysis/enet_gpu_smiles/pr_curve.png
21
- ./hemolysis/enet_gpu_smiles/roc_curve.png
22
- ./hemolysis/enet_gpu_smiles/study_trials.csv
23
- ./hemolysis/enet_gpu_smiles/train_predictions.csv
24
- ./hemolysis/enet_gpu_smiles/val_predictions.csv
25
- ./hemolysis/enet_gpu_wt/optimization_summary.txt
26
- ./hemolysis/enet_gpu_wt/pr_curve.png
27
- ./hemolysis/enet_gpu_wt/roc_curve.png
28
- ./hemolysis/enet_gpu_wt/study_trials.csv
29
- ./hemolysis/enet_gpu_wt/train_predictions.csv
30
- ./hemolysis/enet_gpu_wt/val_predictions.csv
31
- ./hemolysis/mlp_smiles/optimization_summary.txt
32
- ./hemolysis/mlp_smiles/pr_curve.png
33
- ./hemolysis/mlp_smiles/roc_curve.png
34
- ./hemolysis/mlp_smiles/study_trials.csv
35
- ./hemolysis/mlp_smiles/train_predictions.csv
36
- ./hemolysis/mlp_smiles/val_predictions.csv
37
- ./hemolysis/mlp_wt/optimization_summary.txt
38
- ./hemolysis/mlp_wt/pr_curve.png
39
- ./hemolysis/mlp_wt/roc_curve.png
40
- ./hemolysis/mlp_wt/study_trials.csv
41
- ./hemolysis/mlp_wt/train_predictions.csv
42
- ./hemolysis/mlp_wt/val_predictions.csv
43
- ./hemolysis/svm_gpu_wt/optimization_summary.txt
44
- ./hemolysis/svm_gpu_wt/pr_curve.png
45
- ./hemolysis/svm_gpu_wt/roc_curve.png
46
- ./hemolysis/svm_gpu_wt/study_trials.csv
47
- ./hemolysis/svm_gpu_wt/train_predictions.csv
48
- ./hemolysis/svm_gpu_wt/val_predictions.csv
49
- ./hemolysis/transformer_smiles/optimization_summary.txt
50
- ./hemolysis/transformer_smiles/pr_curve.png
51
- ./hemolysis/transformer_smiles/roc_curve.png
52
- ./hemolysis/transformer_smiles/study_trials.csv
53
- ./hemolysis/transformer_smiles/train_predictions.csv
54
- ./hemolysis/transformer_smiles/val_predictions.csv
55
- ./hemolysis/transformer_wt/optimization_summary.txt
56
- ./hemolysis/transformer_wt/pr_curve.png
57
- ./hemolysis/transformer_wt/roc_curve.png
58
- ./hemolysis/transformer_wt/study_trials.csv
59
- ./hemolysis/transformer_wt/train_predictions.csv
60
- ./hemolysis/transformer_wt/val_predictions.csv
61
- ./hemolysis/xgb/optimization_summary.txt
62
- ./hemolysis/xgb/pr_curve.png
63
- ./hemolysis/xgb/roc_curve.png
64
- ./hemolysis/xgb/study_trials.csv
65
- ./hemolysis/xgb/train_predictions.csv
66
- ./hemolysis/xgb/val_predictions.csv
67
- ./hemolysis/xgb_smiles/optimization_summary.txt
68
- ./hemolysis/xgb_smiles/pr_curve.png
69
- ./hemolysis/xgb_smiles/roc_curve.png
70
- ./hemolysis/xgb_smiles/study_trials.csv
71
- ./hemolysis/xgb_smiles/train_predictions.csv
72
- ./hemolysis/xgb_smiles/val_predictions.csv
73
- ./hemolysis/xgb_wt/optimization_summary.txt
74
- ./hemolysis/xgb_wt/pr_curve.png
75
- ./hemolysis/xgb_wt/roc_curve.png
76
- ./hemolysis/xgb_wt/study_trials.csv
77
- ./hemolysis/xgb_wt/train_predictions.csv
78
- ./hemolysis/xgb_wt/val_predictions.csv
79
- ./nf/cnn/optimization_summary.txt
80
- ./nf/cnn/pr_curve.png
81
- ./nf/cnn/roc_curve.png
82
- ./nf/cnn/study_trials.csv
83
- ./nf/cnn/train_predictions.csv
84
- ./nf/cnn/val_predictions.csv
85
- ./nf/cnn_wt/optimization_summary.txt
86
- ./nf/cnn_wt/pr_curve.png
87
- ./nf/cnn_wt/roc_curve.png
88
- ./nf/cnn_wt/study_trials.csv
89
- ./nf/cnn_wt/train_predictions.csv
90
- ./nf/cnn_wt/val_predictions.csv
91
- ./nf/enet_gpu/optimization_summary.txt
92
- ./nf/enet_gpu/pr_curve.png
93
- ./nf/enet_gpu/roc_curve.png
94
- ./nf/enet_gpu/study_trials.csv
95
- ./nf/enet_gpu/train_predictions.csv
96
- ./nf/enet_gpu/val_predictions.csv
97
- ./nf/enet_gpu_smiles/optimization_summary.txt
98
- ./nf/enet_gpu_smiles/pr_curve.png
99
- ./nf/enet_gpu_smiles/roc_curve.png
100
- ./nf/enet_gpu_smiles/study_trials.csv
101
- ./nf/enet_gpu_smiles/train_predictions.csv
102
- ./nf/enet_gpu_smiles/val_predictions.csv
103
- ./nf/enet_gpu_wt/optimization_summary.txt
104
- ./nf/enet_gpu_wt/pr_curve.png
105
- ./nf/enet_gpu_wt/roc_curve.png
106
- ./nf/enet_gpu_wt/study_trials.csv
107
- ./nf/enet_gpu_wt/train_predictions.csv
108
- ./nf/enet_gpu_wt/val_predictions.csv
109
- ./nf/mlp/optimization_summary.txt
110
- ./nf/mlp/pr_curve.png
111
- ./nf/mlp/roc_curve.png
112
- ./nf/mlp/study_trials.csv
113
- ./nf/mlp/train_predictions.csv
114
- ./nf/mlp/val_predictions.csv
115
- ./nf/mlp_wt/optimization_summary.txt
116
- ./nf/mlp_wt/pr_curve.png
117
- ./nf/mlp_wt/roc_curve.png
118
- ./nf/mlp_wt/study_trials.csv
119
- ./nf/mlp_wt/train_predictions.csv
120
- ./nf/mlp_wt/val_predictions.csv
121
- ./nf/svm_gpu/optimization_summary.txt
122
- ./nf/svm_gpu/pr_curve.png
123
- ./nf/svm_gpu/roc_curve.png
124
- ./nf/svm_gpu/study_trials.csv
125
- ./nf/svm_gpu/train_predictions.csv
126
- ./nf/svm_gpu/val_predictions.csv
127
- ./nf/svm_gpu_wt/optimization_summary.txt
128
- ./nf/svm_gpu_wt/pr_curve.png
129
- ./nf/svm_gpu_wt/roc_curve.png
130
- ./nf/svm_gpu_wt/study_trials.csv
131
- ./nf/svm_gpu_wt/train_predictions.csv
132
- ./nf/svm_gpu_wt/val_predictions.csv
133
- ./nf/transformer/optimization_summary.txt
134
- ./nf/transformer/pr_curve.png
135
- ./nf/transformer/roc_curve.png
136
- ./nf/transformer/study_trials.csv
137
- ./nf/transformer/train_predictions.csv
138
- ./nf/transformer/val_predictions.csv
139
- ./nf/transformer_wt/optimization_summary.txt
140
- ./nf/transformer_wt/pr_curve.png
141
- ./nf/transformer_wt/roc_curve.png
142
- ./nf/transformer_wt/study_trials.csv
143
- ./nf/transformer_wt/train_predictions.csv
144
- ./nf/transformer_wt/val_predictions.csv
145
- ./nf/xgb_wt/optimization_summary.txt
146
- ./nf/xgb_wt/pr_curve.png
147
- ./nf/xgb_wt/roc_curve.png
148
- ./nf/xgb_wt/study_trials.csv
149
- ./nf/xgb_wt/train_predictions.csv
150
- ./nf/xgb_wt/val_predictions.csv
151
- ./permeability_caco2/cnn_smiles/optimization_summary.txt
152
- ./permeability_caco2/cnn_smiles/study_trials.csv
153
- ./permeability_caco2/cnn_smiles/train_predictions.csv
154
- ./permeability_caco2/cnn_smiles/val_predictions.csv
155
- ./permeability_caco2/enet_gpu_smiles/optimization_summary.txt
156
- ./permeability_caco2/enet_gpu_smiles/study_trials.csv
157
- ./permeability_caco2/enet_gpu_smiles/train_predictions.csv
158
- ./permeability_caco2/enet_gpu_smiles/val_predictions.csv
159
- ./permeability_caco2/mlp_smiles/optimization_summary.txt
160
- ./permeability_caco2/mlp_smiles/study_trials.csv
161
- ./permeability_caco2/mlp_smiles/train_predictions.csv
162
- ./permeability_caco2/mlp_smiles/val_predictions.csv
163
- ./permeability_caco2/svr_smiles/optimization_summary.txt
164
- ./permeability_caco2/svr_smiles/study_trials.csv
165
- ./permeability_caco2/svr_smiles/train_predictions.csv
166
- ./permeability_caco2/svr_smiles/val_predictions.csv
167
- ./permeability_caco2/transformer_smiles/optimization_summary.txt
168
- ./permeability_caco2/transformer_smiles/study_trials.csv
169
- ./permeability_caco2/transformer_smiles/train_predictions.csv
170
- ./permeability_caco2/transformer_smiles/val_predictions.csv
171
- ./permeability_caco2/xgb_reg_smiles/optimization_summary.txt
172
- ./permeability_caco2/xgb_reg_smiles/study_trials.csv
173
- ./permeability_caco2/xgb_reg_smiles/train_predictions.csv
174
- ./permeability_caco2/xgb_reg_smiles/val_predictions.csv
175
- ./permeability_pampa/cnn_smiles/optimization_summary.txt
176
- ./permeability_pampa/cnn_smiles/study_trials.csv
177
- ./permeability_pampa/cnn_smiles/train_predictions.csv
178
- ./permeability_pampa/cnn_smiles/val_predictions.csv
179
- ./permeability_pampa/enet_gpu_smiles/optimization_summary.txt
180
- ./permeability_pampa/enet_gpu_smiles/study_trials.csv
181
- ./permeability_pampa/enet_gpu_smiles/train_predictions.csv
182
- ./permeability_pampa/enet_gpu_smiles/val_predictions.csv
183
- ./permeability_pampa/mlp_smiles/optimization_summary.txt
184
- ./permeability_pampa/mlp_smiles/study_trials.csv
185
- ./permeability_pampa/mlp_smiles/train_predictions.csv
186
- ./permeability_pampa/mlp_smiles/val_predictions.csv
187
- ./permeability_pampa/transformer_smiles/optimization_summary.txt
188
- ./permeability_pampa/transformer_smiles/study_trials.csv
189
- ./permeability_pampa/transformer_smiles/train_predictions.csv
190
- ./permeability_pampa/transformer_smiles/val_predictions.csv
191
- ./permeability_pampa/xgb_reg_smiles/optimization_summary.txt
192
- ./permeability_pampa/xgb_reg_smiles/study_trials.csv
193
- ./permeability_pampa/xgb_reg_smiles/train_predictions.csv
194
- ./permeability_pampa/xgb_reg_smiles/val_predictions.csv
195
- ./solubility/cnn_wt/optimization_summary.txt
196
- ./solubility/cnn_wt/pr_curve.png
197
- ./solubility/cnn_wt/roc_curve.png
198
- ./solubility/cnn_wt/study_trials.csv
199
- ./solubility/cnn_wt/train_predictions.csv
200
- ./solubility/cnn_wt/val_predictions.csv
201
- ./solubility/enet_gpu/optimization_summary.txt
202
- ./solubility/enet_gpu/pr_curve.png
203
- ./solubility/enet_gpu/roc_curve.png
204
- ./solubility/enet_gpu/study_trials.csv
205
- ./solubility/enet_gpu/train_predictions.csv
206
- ./solubility/enet_gpu/val_predictions.csv
207
- ./solubility/mlp_wt/optimization_summary.txt
208
- ./solubility/mlp_wt/pr_curve.png
209
- ./solubility/mlp_wt/roc_curve.png
210
- ./solubility/mlp_wt/study_trials.csv
211
- ./solubility/mlp_wt/train_predictions.csv
212
- ./solubility/mlp_wt/val_predictions.csv
213
- ./solubility/svm_gpu/optimization_summary.txt
214
- ./solubility/svm_gpu/pr_curve.png
215
- ./solubility/svm_gpu/roc_curve.png
216
- ./solubility/svm_gpu/study_trials.csv
217
- ./solubility/svm_gpu/train_predictions.csv
218
- ./solubility/svm_gpu/val_predictions.csv
219
- ./solubility/transformer_wt/optimization_summary.txt
220
- ./solubility/transformer_wt/pr_curve.png
221
- ./solubility/transformer_wt/roc_curve.png
222
- ./solubility/transformer_wt/study_trials.csv
223
- ./solubility/transformer_wt/train_predictions.csv
224
- ./solubility/transformer_wt/val_predictions.csv
225
- ./solubility/xgb/optimization_summary.txt
226
- ./solubility/xgb/pr_curve.png
227
- ./solubility/xgb/roc_curve.png
228
- ./solubility/xgb/study_trials.csv
229
- ./solubility/xgb/train_predictions.csv
230
- ./solubility/xgb/val_predictions.csv
231
- ./binding_affinity/wt_wt_pooled/optuna_trials.csv
232
- ./binding_affinity/wt_smiles_pooled/optuna_trials.csv
233
- ./binding_affinity/wt_smiles_unpooled/optuna_trials.csv
234
- ./binding_affinity/wt_wt_unpooled/optuna_trials.csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py DELETED
@@ -1,417 +0,0 @@
1
- import os
2
- import json
3
- import joblib
4
- import optuna
5
- import numpy as np
6
- import pandas as pd
7
- import matplotlib.pyplot as plt
8
-
9
- from dataclasses import dataclass
10
- from typing import Dict, Any, Tuple, Optional
11
-
12
- from datasets import load_from_disk, DatasetDict
13
- from sklearn.metrics import (
14
- f1_score, roc_auc_score, average_precision_score,
15
- precision_recall_curve, roc_curve
16
- )
17
- from sklearn.linear_model import LogisticRegression
18
- from sklearn.ensemble import AdaBoostClassifier
19
- from sklearn.tree import DecisionTreeClassifier
20
- from linearboost import LinearBoostClassifier
21
-
22
- import xgboost as xgb
23
- from lightning.pytorch import seed_everything
24
-
25
- seed_everything(1986)
26
-
27
- # -----------------------------
28
- # Data loading
29
- # -----------------------------
30
- @dataclass
31
- class SplitData:
32
- X_train: np.ndarray
33
- y_train: np.ndarray
34
- seq_train: Optional[np.ndarray]
35
- X_val: np.ndarray
36
- y_val: np.ndarray
37
- seq_val: Optional[np.ndarray]
38
-
39
-
40
- def _stack_embeddings(col) -> np.ndarray:
41
- # HF datasets often store embeddings as list-of-floats per row
42
- arr = np.asarray(col, dtype=np.float32)
43
- if arr.ndim != 2:
44
- arr = np.stack(col).astype(np.float32)
45
- return arr
46
-
47
-
48
- def load_split_data(dataset_path: str) -> SplitData:
49
- ds = load_from_disk(dataset_path)
50
-
51
- # Case A: DatasetDict with train/val
52
- if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
53
- train_ds, val_ds = ds["train"], ds["val"]
54
- else:
55
- # Case B: Single dataset with "split" column
56
- if "split" not in ds.column_names:
57
- raise ValueError(
58
- "Dataset must be a DatasetDict(train/val) or have a 'split' column."
59
- )
60
- train_ds = ds.filter(lambda x: x["split"] == "train")
61
- val_ds = ds.filter(lambda x: x["split"] == "val")
62
-
63
- for required in ["embedding", "label"]:
64
- if required not in train_ds.column_names:
65
- raise ValueError(f"Missing column '{required}' in train split.")
66
- if required not in val_ds.column_names:
67
- raise ValueError(f"Missing column '{required}' in val split.")
68
-
69
- X_train = _stack_embeddings(train_ds["embedding"])
70
- y_train = np.asarray(train_ds["label"], dtype=np.int64)
71
-
72
- X_val = _stack_embeddings(val_ds["embedding"])
73
- y_val = np.asarray(val_ds["label"], dtype=np.int64)
74
-
75
- seq_train = None
76
- seq_val = None
77
- if "sequence" in train_ds.column_names:
78
- seq_train = np.asarray(train_ds["sequence"])
79
- if "sequence" in val_ds.column_names:
80
- seq_val = np.asarray(val_ds["sequence"])
81
-
82
- return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
83
-
84
-
85
- # -----------------------------
86
- # Metrics + thresholding
87
- # -----------------------------
88
- def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
89
- """
90
- Find threshold maximizing F1 on the given set.
91
- Returns (best_threshold, best_f1).
92
- """
93
- precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
94
- # precision_recall_curve returns thresholds of length n-1
95
- # compute F1 for those thresholds
96
- f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
97
- best_idx = int(np.nanargmax(f1s))
98
- return float(thresholds[best_idx]), float(f1s[best_idx])
99
-
100
-
101
- def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
102
- y_pred = (y_prob >= threshold).astype(int)
103
- return {
104
- "f1": float(f1_score(y_true, y_pred)),
105
- "auc": float(roc_auc_score(y_true, y_prob)),
106
- "ap": float(average_precision_score(y_true, y_prob)),
107
- "threshold": float(threshold),
108
- }
109
-
110
-
111
- # -----------------------------
112
- # Model factories
113
- # -----------------------------
114
- def train_xgb(
115
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
116
- ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
117
- dtrain = xgb.DMatrix(X_train, label=y_train)
118
- dval = xgb.DMatrix(X_val, label=y_val)
119
-
120
- num_boost_round = int(params.pop("num_boost_round"))
121
- early_stopping_rounds = int(params.pop("early_stopping_rounds"))
122
-
123
- booster = xgb.train(
124
- params=params,
125
- dtrain=dtrain,
126
- num_boost_round=num_boost_round,
127
- evals=[(dval, "val")],
128
- early_stopping_rounds=early_stopping_rounds,
129
- verbose_eval=False,
130
- )
131
-
132
- p_train = booster.predict(dtrain)
133
- p_val = booster.predict(dval)
134
- return booster, p_train, p_val
135
-
136
-
137
- def train_adaboost(
138
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
139
- ) -> Tuple[AdaBoostClassifier, np.ndarray, np.ndarray]:
140
- base_depth = int(params.pop("base_depth"))
141
- clf = AdaBoostClassifier(
142
- estimator=DecisionTreeClassifier(max_depth=base_depth),
143
- n_estimators=int(params["n_estimators"]),
144
- learning_rate=float(params["learning_rate"]),
145
- algorithm="SAMME",
146
- )
147
- clf.fit(X_train, y_train)
148
- p_train = clf.predict_proba(X_train)[:, 1]
149
- p_val = clf.predict_proba(X_val)[:, 1]
150
- return clf, p_train, p_val
151
-
152
-
153
- def train_linearboost(X_train, y_train, X_val, y_val, params):
154
- clf = LinearBoostClassifier(**params)
155
- clf.fit(X_train, y_train)
156
- p_train = clf.predict_proba(X_train)[:, 1]
157
- p_val = clf.predict_proba(X_val)[:, 1]
158
- return clf, p_train, p_val
159
-
160
-
161
- def suggest_linearboost_params(trial):
162
- # Core boosting params
163
- params = {
164
- "n_estimators": trial.suggest_int("n_estimators", 50, 800),
165
- "learning_rate": trial.suggest_float("learning_rate", 0.01, 1.0, log=True),
166
- "algorithm": trial.suggest_categorical("algorithm", ["SAMME.R", "SAMME"]),
167
- # Scaling choices from docs (you can expand this list if you want)
168
- "scaler": trial.suggest_categorical(
169
- "scaler",
170
- ["minmax", "standard", "robust", "quantile-uniform", "quantile-normal", "power"]
171
- ),
172
- # useful for imbalanced splits
173
- "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
174
- # kernel trick
175
- "kernel": trial.suggest_categorical("kernel", ["linear", "rbf", "poly", "sigmoid"]),
176
- }
177
-
178
- # Kernel-specific params (only when relevant)
179
- if params["kernel"] in ["rbf", "poly"]:
180
- params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
181
- else:
182
- params["gamma"] = None # docs: default treated as 1/n_features for rbf/poly :contentReference[oaicite:5]{index=5}
183
-
184
- if params["kernel"] == "poly":
185
- params["degree"] = trial.suggest_int("degree", 2, 6) # docs default=3 :contentReference[oaicite:6]{index=6}
186
- params["coef0"] = trial.suggest_float("coef0", 0.0, 5.0) # docs default=1 :contentReference[oaicite:7]{index=7}
187
- else:
188
- # safe defaults
189
- params["degree"] = 3
190
- params["coef0"] = 1.0
191
-
192
- return params
193
- # -----------------------------
194
- # Saving artifacts
195
- # -----------------------------
196
- def save_predictions_csv(
197
- out_dir: str,
198
- split_name: str,
199
- y_true: np.ndarray,
200
- y_prob: np.ndarray,
201
- threshold: float,
202
- sequences: Optional[np.ndarray] = None,
203
- ):
204
- os.makedirs(out_dir, exist_ok=True)
205
- df = pd.DataFrame({
206
- "y_true": y_true.astype(int),
207
- "y_prob": y_prob.astype(float),
208
- "y_pred": (y_prob >= threshold).astype(int),
209
- })
210
- if sequences is not None:
211
- df.insert(0, "sequence", sequences)
212
- df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
213
-
214
-
215
- def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
216
- os.makedirs(out_dir, exist_ok=True)
217
-
218
- # PR
219
- precision, recall, _ = precision_recall_curve(y_true, y_prob)
220
- plt.figure()
221
- plt.plot(recall, precision)
222
- plt.xlabel("Recall")
223
- plt.ylabel("Precision")
224
- plt.title("Precision-Recall Curve")
225
- plt.tight_layout()
226
- plt.savefig(os.path.join(out_dir, "pr_curve.png"))
227
- plt.close()
228
-
229
- # ROC
230
- fpr, tpr, _ = roc_curve(y_true, y_prob)
231
- plt.figure()
232
- plt.plot(fpr, tpr)
233
- plt.xlabel("False Positive Rate")
234
- plt.ylabel("True Positive Rate")
235
- plt.title("ROC Curve")
236
- plt.tight_layout()
237
- plt.savefig(os.path.join(out_dir, "roc_curve.png"))
238
- plt.close()
239
-
240
-
241
- # -----------------------------
242
- # Optuna objectives
243
- # -----------------------------
244
- def make_objective(model_name: str, data: SplitData, out_dir: str):
245
- Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
246
-
247
- def objective(trial: optuna.Trial) -> float:
248
- if model_name == "xgb":
249
- params = {
250
- "objective": "binary:logistic",
251
- "eval_metric": "logloss",
252
- "lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
253
- "alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
254
- "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
255
- "subsample": trial.suggest_float("subsample", 0.5, 1.0),
256
- "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
257
- "max_depth": trial.suggest_int("max_depth", 2, 15),
258
- "min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
259
- "gamma": trial.suggest_float("gamma", 0.0, 10.0),
260
- "tree_method": "hist",
261
- "device": "cuda",
262
- }
263
-
264
- # Optional GPU: set env CUDA_VISIBLE_DEVICES externally if you want.
265
- # If you *know* you want GPU and your xgboost supports it:
266
- # params["device"] = "cuda"
267
-
268
- params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
269
- params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
270
-
271
- model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
272
-
273
- elif model_name == "adaboost":
274
- params = {
275
- "n_estimators": trial.suggest_int("n_estimators", 50, 800),
276
- "learning_rate": trial.suggest_float("learning_rate", 1e-3, 2.0, log=True),
277
- "base_depth": trial.suggest_int("base_depth", 1, 4),
278
- }
279
- model, p_tr, p_va = train_adaboost(Xtr, ytr, Xva, yva, params)
280
-
281
- elif model_name == "linearboost":
282
- params = suggest_linearboost_params(trial)
283
- model, p_tr, p_va = train_linearboost(Xtr, ytr, Xva, yva, params)
284
- else:
285
- raise ValueError(f"Unknown model_name={model_name}")
286
-
287
- # Threshold picked on val for fair comparison across models
288
- thr, f1_at_thr = best_f1_threshold(yva, p_va)
289
- metrics = eval_binary(yva, p_va, thr)
290
-
291
- # Track best trial artifacts inside the study directory
292
- trial.set_user_attr("threshold", thr)
293
- trial.set_user_attr("auc", metrics["auc"])
294
- trial.set_user_attr("ap", metrics["ap"])
295
-
296
- return f1_at_thr
297
-
298
- return objective
299
-
300
- # -----------------------------
301
- # Main runner
302
- # -----------------------------
303
- def run_optuna_and_refit(
304
- dataset_path: str,
305
- out_dir: str,
306
- model_name: str,
307
- n_trials: int = 200,
308
- ):
309
- os.makedirs(out_dir, exist_ok=True)
310
-
311
- data = load_split_data(dataset_path)
312
- print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
313
-
314
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
315
- study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
316
-
317
- # Save trials table
318
- trials_df = study.trials_dataframe()
319
- trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
320
-
321
- best = study.best_trial
322
- best_params = dict(best.params)
323
- best_thr = float(best.user_attrs["threshold"])
324
- best_auc = float(best.user_attrs["auc"])
325
- best_ap = float(best.user_attrs["ap"])
326
- best_f1 = float(best.value)
327
-
328
- # Refit best model on train (same protocol as objective)
329
- if model_name == "xgb":
330
- # Reconstruct full param dict
331
- params = {
332
- "objective": "binary:logistic",
333
- "eval_metric": "logloss",
334
- "lambda": best_params["lambda"],
335
- "alpha": best_params["alpha"],
336
- "colsample_bytree": best_params["colsample_bytree"],
337
- "subsample": best_params["subsample"],
338
- "learning_rate": best_params["learning_rate"],
339
- "max_depth": best_params["max_depth"],
340
- "min_child_weight": best_params["min_child_weight"],
341
- "gamma": best_params["gamma"],
342
- "tree_method": "hist",
343
- "num_boost_round": best_params["num_boost_round"],
344
- "early_stopping_rounds": best_params["early_stopping_rounds"],
345
- }
346
- model, p_tr, p_va = train_xgb(
347
- data.X_train, data.y_train, data.X_val, data.y_val, params
348
- )
349
- model_path = os.path.join(out_dir, "best_model.json")
350
- model.save_model(model_path)
351
-
352
- elif model_name == "adaboost":
353
- params = best_params
354
- model, p_tr, p_va = train_adaboost(
355
- data.X_train, data.y_train, data.X_val, data.y_val, params
356
- )
357
- model_path = os.path.join(out_dir, "best_model.joblib")
358
- joblib.dump(model, model_path)
359
-
360
- elif model_name == "linearboost":
361
- params = best_params
362
-
363
- model, p_tr, p_va = train_linearboost(
364
- data.X_train, data.y_train, data.X_val, data.y_val, params
365
- )
366
-
367
- model_path = os.path.join(out_dir, "best_model.joblib")
368
- joblib.dump(model, model_path)
369
- else:
370
- raise ValueError(model_name)
371
-
372
- # Save predictions CSVs
373
- save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
374
- save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
375
-
376
- # Plots on val
377
- plot_curves(out_dir, data.y_val, p_va)
378
-
379
- # Summary
380
- summary = [
381
- "=" * 72,
382
- f"MODEL: {model_name}",
383
- f"Best trial: {best.number}",
384
- f"Best F1 (val @ best-threshold): {best_f1:.4f}",
385
- f"Val AUC: {best_auc:.4f}",
386
- f"Val AP: {best_ap:.4f}",
387
- f"Best threshold (picked on val): {best_thr:.4f}",
388
- f"Model saved to: {model_path}",
389
- "Best params:",
390
- json.dumps(best_params, indent=2),
391
- "=" * 72,
392
- ]
393
- with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
394
- f.write("\n".join(summary))
395
- print("\n".join(summary))
396
-
397
-
398
- if __name__ == "__main__":
399
- # Example usage:
400
- # dataset_path = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/data/solubility"
401
- # out_dir = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/src/solubility/xgb"
402
- # run_optuna_and_refit(dataset_path, out_dir, model_name="xgb", n_trials=200)
403
-
404
- import argparse
405
- parser = argparse.ArgumentParser()
406
- parser.add_argument("--dataset_path", type=str, required=True)
407
- parser.add_argument("--out_dir", type=str, required=True)
408
- parser.add_argument("--model", type=str, choices=["xgb", "adaboost", "linearboost"], required=True)
409
- parser.add_argument("--n_trials", type=int, default=200)
410
- args = parser.parse_args()
411
-
412
- run_optuna_and_refit(
413
- dataset_path=args.dataset_path,
414
- out_dir=args.out_dir,
415
- model_name=args.model,
416
- n_trials=args.n_trials,
417
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py DELETED
@@ -1,468 +0,0 @@
1
- import os
2
- import json
3
- import joblib
4
- import optuna
5
- import numpy as np
6
- import pandas as pd
7
- import matplotlib.pyplot as plt
8
- from dataclasses import dataclass
9
- from typing import Dict, Any, Tuple, Optional
10
- from datasets import load_from_disk, DatasetDict
11
- from sklearn.metrics import (
12
- f1_score, roc_auc_score, average_precision_score,
13
- precision_recall_curve, roc_curve
14
- )
15
- from sklearn.linear_model import LogisticRegression
16
- from sklearn.svm import SVC, LinearSVC
17
- from sklearn.calibration import CalibratedClassifierCV
18
- import torch
19
- import time
20
- import xgboost as xgb
21
- from lightning.pytorch import seed_everything
22
- import cupy as cp
23
- from cuml.svm import SVC as cuSVC
24
- from cuml.linear_model import LogisticRegression as cuLogReg
25
- seed_everything(1986)
26
-
27
-
28
- def to_gpu(X: np.ndarray):
29
- if isinstance(X, cp.ndarray):
30
- return X
31
- return cp.asarray(X, dtype=cp.float32)
32
-
33
- def to_cpu(x):
34
- if isinstance(x, cp.ndarray):
35
- return cp.asnumpy(x)
36
- return np.asarray(x)
37
-
38
- @dataclass
39
- class SplitData:
40
- X_train: np.ndarray
41
- y_train: np.ndarray
42
- seq_train: Optional[np.ndarray]
43
- X_val: np.ndarray
44
- y_val: np.ndarray
45
- seq_val: Optional[np.ndarray]
46
-
47
-
48
- def _stack_embeddings(col) -> np.ndarray:
49
- arr = np.asarray(col, dtype=np.float32)
50
- if arr.ndim != 2:
51
- arr = np.stack(col).astype(np.float32)
52
- return arr
53
-
54
-
55
- def load_split_data(dataset_path: str) -> SplitData:
56
- ds = load_from_disk(dataset_path)
57
-
58
- # Case A: DatasetDict with train/val
59
- if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
60
- train_ds, val_ds = ds["train"], ds["val"]
61
- else:
62
- # Case B: Single dataset with "split" column
63
- if "split" not in ds.column_names:
64
- raise ValueError(
65
- "Dataset must be a DatasetDict(train/val) or have a 'split' column."
66
- )
67
- train_ds = ds.filter(lambda x: x["split"] == "train")
68
- val_ds = ds.filter(lambda x: x["split"] == "val")
69
-
70
- for required in ["embedding", "label"]:
71
- if required not in train_ds.column_names:
72
- raise ValueError(f"Missing column '{required}' in train split.")
73
- if required not in val_ds.column_names:
74
- raise ValueError(f"Missing column '{required}' in val split.")
75
-
76
- X_train = _stack_embeddings(train_ds["embedding"])
77
- y_train = np.asarray(train_ds["label"], dtype=np.int64)
78
-
79
- X_val = _stack_embeddings(val_ds["embedding"])
80
- y_val = np.asarray(val_ds["label"], dtype=np.int64)
81
-
82
- seq_train = None
83
- seq_val = None
84
- if "sequence" in train_ds.column_names:
85
- seq_train = np.asarray(train_ds["sequence"])
86
- if "sequence" in val_ds.column_names:
87
- seq_val = np.asarray(val_ds["sequence"])
88
-
89
- return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
90
-
91
-
92
- def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
93
- """
94
- Find threshold maximizing F1 on the given set.
95
- Returns (best_threshold, best_f1).
96
- """
97
- precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
98
- f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
99
- best_idx = int(np.nanargmax(f1s))
100
- return float(thresholds[best_idx]), float(f1s[best_idx])
101
-
102
-
103
- def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
104
- y_pred = (y_prob >= threshold).astype(int)
105
- return {
106
- "f1": float(f1_score(y_true, y_pred)),
107
- "auc": float(roc_auc_score(y_true, y_prob)),
108
- "ap": float(average_precision_score(y_true, y_prob)),
109
- "threshold": float(threshold),
110
- }
111
-
112
-
113
- # -----------------------------
114
- # Model
115
- # -----------------------------
116
- def train_xgb(
117
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
118
- ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
119
- dtrain = xgb.DMatrix(X_train, label=y_train)
120
- dval = xgb.DMatrix(X_val, label=y_val)
121
-
122
- num_boost_round = int(params.pop("num_boost_round"))
123
- early_stopping_rounds = int(params.pop("early_stopping_rounds"))
124
-
125
- booster = xgb.train(
126
- params=params,
127
- dtrain=dtrain,
128
- num_boost_round=num_boost_round,
129
- evals=[(dval, "val")],
130
- early_stopping_rounds=early_stopping_rounds,
131
- verbose_eval=False,
132
- )
133
-
134
- p_train = booster.predict(dtrain)
135
- p_val = booster.predict(dval)
136
- return booster, p_train, p_val
137
-
138
- def train_cuml_svc(X_train, y_train, X_val, y_val, params):
139
- Xtr = to_gpu(X_train)
140
- Xva = to_gpu(X_val)
141
- ytr = to_gpu(y_train).astype(cp.int32)
142
-
143
- clf = cuSVC(
144
- C=float(params["C"]),
145
- kernel=params["kernel"],
146
- gamma=params.get("gamma", "scale"),
147
- class_weight=params.get("class_weight", None),
148
- probability=bool(params.get("probability", True)),
149
- random_state=1986,
150
- max_iter=int(params.get("max_iter", 1000)),
151
- tol=float(params.get("tol", 1e-4)),
152
- )
153
-
154
- clf.fit(Xtr, ytr)
155
-
156
- p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
157
- p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
158
- return clf, p_train, p_val
159
-
160
- def train_cuml_elastic_net(X_train, y_train, X_val, y_val, params):
161
- Xtr = to_gpu(X_train)
162
- Xva = to_gpu(X_val)
163
- ytr = to_gpu(y_train).astype(cp.int32)
164
-
165
- clf = cuLogReg(
166
- penalty="elasticnet",
167
- C=float(params["C"]),
168
- l1_ratio=float(params["l1_ratio"]),
169
- class_weight=params.get("class_weight", None),
170
- max_iter=int(params.get("max_iter", 1000)),
171
- tol=float(params.get("tol", 1e-4)),
172
- solver="qn",
173
- fit_intercept=True,
174
- )
175
- clf.fit(Xtr, ytr)
176
-
177
- p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
178
- p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
179
- return clf, p_train, p_val
180
-
181
-
182
- def train_svm(X_train, y_train, X_val, y_val, params):
183
- """
184
- Kernel SVM via SVC. CPU only in sklearn.
185
- probability=True enables predict_proba but is slower.
186
- """
187
- clf = SVC(
188
- C=float(params["C"]),
189
- kernel=params["kernel"],
190
- gamma=params.get("gamma", "scale"),
191
- class_weight=params.get("class_weight", None),
192
- probability=True,
193
- random_state=1986,
194
- )
195
- clf.fit(X_train, y_train)
196
- p_train = clf.predict_proba(X_train)[:, 1]
197
- p_val = clf.predict_proba(X_val)[:, 1]
198
- return clf, p_train, p_val
199
-
200
-
201
- def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
202
- """
203
- Fast linear SVM (LinearSVC) + probability calibration.
204
- Usually much faster than SVC on large datasets.
205
- """
206
- base = LinearSVC(
207
- C=float(params["C"]),
208
- class_weight=params.get("class_weight", None),
209
- max_iter=int(params.get("max_iter", 5000)),
210
- random_state=1986,
211
- )
212
- # calibration to get probabilities for PR/ROC + thresholding
213
- clf = CalibratedClassifierCV(base, method="sigmoid", cv=3)
214
- clf.fit(X_train, y_train)
215
- p_train = clf.predict_proba(X_train)[:, 1]
216
- p_val = clf.predict_proba(X_val)[:, 1]
217
- return clf, p_train, p_val
218
-
219
- # -----------------------------
220
- # Saving artifacts
221
- # -----------------------------
222
- def save_predictions_csv(
223
- out_dir: str,
224
- split_name: str,
225
- y_true: np.ndarray,
226
- y_prob: np.ndarray,
227
- threshold: float,
228
- sequences: Optional[np.ndarray] = None,
229
- ):
230
- os.makedirs(out_dir, exist_ok=True)
231
- df = pd.DataFrame({
232
- "y_true": y_true.astype(int),
233
- "y_prob": y_prob.astype(float),
234
- "y_pred": (y_prob >= threshold).astype(int),
235
- })
236
- if sequences is not None:
237
- df.insert(0, "sequence", sequences)
238
- df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
239
-
240
-
241
- def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
242
- os.makedirs(out_dir, exist_ok=True)
243
-
244
- # PR
245
- precision, recall, _ = precision_recall_curve(y_true, y_prob)
246
- plt.figure()
247
- plt.plot(recall, precision)
248
- plt.xlabel("Recall")
249
- plt.ylabel("Precision")
250
- plt.title("Precision-Recall Curve")
251
- plt.tight_layout()
252
- plt.savefig(os.path.join(out_dir, "pr_curve.png"))
253
- plt.close()
254
-
255
- # ROC
256
- fpr, tpr, _ = roc_curve(y_true, y_prob)
257
- plt.figure()
258
- plt.plot(fpr, tpr)
259
- plt.xlabel("False Positive Rate")
260
- plt.ylabel("True Positive Rate")
261
- plt.title("ROC Curve")
262
- plt.tight_layout()
263
- plt.savefig(os.path.join(out_dir, "roc_curve.png"))
264
- plt.close()
265
-
266
-
267
- # -----------------------------
268
- # Optuna objectives
269
- # -----------------------------
270
- def make_objective(model_name: str, data: SplitData, out_dir: str):
271
- Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
272
-
273
- def objective(trial: optuna.Trial) -> float:
274
- if model_name == "xgb":
275
- params = {
276
- "objective": "binary:logistic",
277
- "eval_metric": "logloss",
278
- "lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
279
- "alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
280
- "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
281
- "subsample": trial.suggest_float("subsample", 0.5, 1.0),
282
- "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
283
- "max_depth": trial.suggest_int("max_depth", 2, 15),
284
- "min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
285
- "gamma": trial.suggest_float("gamma", 0.0, 10.0),
286
- "tree_method": "hist",
287
- "device": "cuda",
288
- }
289
- params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
290
- params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
291
-
292
- model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
293
-
294
- elif model_name == "svm":
295
- svm_kind = trial.suggest_categorical("svm_kind", ["svc", "linear_calibrated"])
296
-
297
- if svm_kind == "svc":
298
- params = {
299
- "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
300
- "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
301
- "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
302
- }
303
- if params["kernel"] in ["rbf", "poly", "sigmoid"]:
304
- params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
305
- else:
306
- params["gamma"] = "scale"
307
-
308
- model, p_tr, p_va = train_svm(Xtr, ytr, Xva, yva, params)
309
-
310
- else:
311
- params = {
312
- "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
313
- "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
314
- "max_iter": trial.suggest_int("max_iter", 2000, 20000),
315
- }
316
- model, p_tr, p_va = train_linearsvm_calibrated(Xtr, ytr, Xva, yva, params)
317
- elif model_name == "svm_gpu":
318
- params = {
319
- "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
320
- "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
321
- "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
322
- "probability": True,
323
- "max_iter": trial.suggest_int("max_iter", 200, 5000),
324
- "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
325
- }
326
- if params["kernel"] in ["rbf", "poly", "sigmoid"]:
327
- params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
328
- else:
329
- params["gamma"] = "scale"
330
-
331
- model, p_tr, p_va = train_cuml_svc(Xtr, ytr, Xva, yva, params)
332
-
333
- elif model_name == "enet_gpu":
334
- params = {
335
- "C": trial.suggest_float("C", 1e-4, 1e3, log=True),
336
- "l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
337
- "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
338
- "max_iter": trial.suggest_int("max_iter", 200, 5000),
339
- "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
340
- }
341
- model, p_tr, p_va = train_cuml_elastic_net(Xtr, ytr, Xva, yva, params)
342
- else:
343
- raise ValueError(f"Unknown model_name={model_name}")
344
-
345
- thr, f1_at_thr = best_f1_threshold(yva, p_va)
346
- metrics = eval_binary(yva, p_va, thr)
347
- trial.set_user_attr("threshold", thr)
348
- trial.set_user_attr("auc", metrics["auc"])
349
- trial.set_user_attr("ap", metrics["ap"])
350
- return f1_at_thr
351
-
352
- return objective
353
-
354
- # -----------------------------
355
- # Main
356
- # -----------------------------
357
- def run_optuna_and_refit(
358
- dataset_path: str,
359
- out_dir: str,
360
- model_name: str,
361
- n_trials: int = 200,
362
- ):
363
- os.makedirs(out_dir, exist_ok=True)
364
-
365
- data = load_split_data(dataset_path)
366
- print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
367
-
368
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
369
- study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
370
-
371
- trials_df = study.trials_dataframe()
372
- trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
373
-
374
- best = study.best_trial
375
- best_params = dict(best.params)
376
- best_thr = float(best.user_attrs["threshold"])
377
- best_auc = float(best.user_attrs["auc"])
378
- best_ap = float(best.user_attrs["ap"])
379
- best_f1 = float(best.value)
380
-
381
- # Refit best model on train
382
- if model_name == "xgb":
383
- params = {
384
- "objective": "binary:logistic",
385
- "eval_metric": "logloss",
386
- "lambda": best_params["lambda"],
387
- "alpha": best_params["alpha"],
388
- "colsample_bytree": best_params["colsample_bytree"],
389
- "subsample": best_params["subsample"],
390
- "learning_rate": best_params["learning_rate"],
391
- "max_depth": best_params["max_depth"],
392
- "min_child_weight": best_params["min_child_weight"],
393
- "gamma": best_params["gamma"],
394
- "tree_method": "hist",
395
- "num_boost_round": best_params["num_boost_round"],
396
- "early_stopping_rounds": best_params["early_stopping_rounds"],
397
- }
398
- model, p_tr, p_va = train_xgb(
399
- data.X_train, data.y_train, data.X_val, data.y_val, params
400
- )
401
- model_path = os.path.join(out_dir, "best_model.json")
402
- model.save_model(model_path)
403
-
404
- elif model_name == "svm":
405
- svm_kind = best_params["svm_kind"]
406
- if svm_kind == "svc":
407
- model, p_tr, p_va = train_svm(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
408
- else:
409
- model, p_tr, p_va = train_linearsvm_calibrated(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
410
-
411
- model_path = os.path.join(out_dir, "best_model.joblib")
412
- joblib.dump(model, model_path)
413
- elif model_name == "svm_gpu":
414
- model, p_tr, p_va = train_cuml_svc(
415
- data.X_train, data.y_train, data.X_val, data.y_val, best_params
416
- )
417
- model_path = os.path.join(out_dir, "best_model_cuml_svc.joblib")
418
- joblib.dump(model, model_path)
419
-
420
- elif model_name == "enet_gpu":
421
- model, p_tr, p_va = train_cuml_elastic_net(
422
- data.X_train, data.y_train, data.X_val, data.y_val, best_params
423
- )
424
- model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
425
- joblib.dump(model, model_path)
426
- else:
427
- raise ValueError(model_name)
428
-
429
- # Save predictions CSVs
430
- save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
431
- save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
432
-
433
- # Plots on val
434
- plot_curves(out_dir, data.y_val, p_va)
435
-
436
- summary = [
437
- "=" * 72,
438
- f"MODEL: {model_name}",
439
- f"Best trial: {best.number}",
440
- f"Best F1 (val @ best-threshold): {best_f1:.4f}",
441
- f"Val AUC: {best_auc:.4f}",
442
- f"Val AP: {best_ap:.4f}",
443
- f"Best threshold (picked on val): {best_thr:.4f}",
444
- f"Model saved to: {model_path}",
445
- "Best params:",
446
- json.dumps(best_params, indent=2),
447
- "=" * 72,
448
- ]
449
- with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
450
- f.write("\n".join(summary))
451
- print("\n".join(summary))
452
-
453
-
454
- if __name__ == "__main__":
455
- import argparse
456
- parser = argparse.ArgumentParser()
457
- parser.add_argument("--dataset_path", type=str, required=True)
458
- parser.add_argument("--out_dir", type=str, required=True)
459
- parser.add_argument("--model", type=str, choices=["xgb", "svm_gpu", "enet_gpu"], required=True)
460
- parser.add_argument("--n_trials", type=int, default=200)
461
- args = parser.parse_args()
462
-
463
- run_optuna_and_refit(
464
- dataset_path=args.dataset_path,
465
- out_dir=args.out_dir,
466
- model_name=args.model,
467
- n_trials=args.n_trials,
468
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py DELETED
@@ -1,410 +0,0 @@
1
- import os
2
- import json
3
- import joblib
4
- import optuna
5
- import numpy as np
6
- import pandas as pd
7
- import matplotlib.pyplot as plt
8
- from dataclasses import dataclass
9
- from typing import Dict, Any, Tuple, Optional
10
- from datasets import load_from_disk, DatasetDict
11
- from sklearn.preprocessing import StandardScaler
12
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
13
- from sklearn.svm import SVR
14
- import xgboost as xgb
15
- from lightning.pytorch import seed_everything
16
- import cupy as cp
17
- from cuml.linear_model import ElasticNet as cuElasticNet
18
- from scipy.stats import spearmanr
19
- seed_everything(1986)
20
-
21
-
22
- # -----------------------------
23
- # GPU/CPU helpers
24
- # -----------------------------
25
- def to_gpu(X: np.ndarray):
26
- if isinstance(X, cp.ndarray):
27
- return X
28
- return cp.asarray(X, dtype=cp.float32)
29
-
30
- def to_cpu(x):
31
- if isinstance(x, cp.ndarray):
32
- return cp.asnumpy(x)
33
- return np.asarray(x)
34
-
35
-
36
- # -----------------------------
37
- # Data loading
38
- # -----------------------------
39
- @dataclass
40
- class SplitData:
41
- X_train: np.ndarray
42
- y_train: np.ndarray
43
- seq_train: Optional[np.ndarray]
44
- X_val: np.ndarray
45
- y_val: np.ndarray
46
- seq_val: Optional[np.ndarray]
47
-
48
- def _stack_embeddings(col) -> np.ndarray:
49
- arr = np.asarray(col, dtype=np.float32)
50
- if arr.ndim != 2:
51
- arr = np.stack(col).astype(np.float32)
52
- return arr
53
-
54
- def load_split_data(dataset_path: str) -> SplitData:
55
- ds = load_from_disk(dataset_path)
56
-
57
- if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
58
- train_ds, val_ds = ds["train"], ds["val"]
59
- else:
60
- if "split" not in ds.column_names:
61
- raise ValueError("Dataset must be a DatasetDict(train/val) or have a 'split' column.")
62
- train_ds = ds.filter(lambda x: x["split"] == "train")
63
- val_ds = ds.filter(lambda x: x["split"] == "val")
64
-
65
- for required in ["embedding", "label"]:
66
- if required not in train_ds.column_names:
67
- raise ValueError(f"Missing column '{required}' in train split.")
68
- if required not in val_ds.column_names:
69
- raise ValueError(f"Missing column '{required}' in val split.")
70
-
71
- X_train = _stack_embeddings(train_ds["embedding"]).astype(np.float32)
72
- X_val = _stack_embeddings(val_ds["embedding"]).astype(np.float32)
73
-
74
- y_train = np.asarray(train_ds["label"], dtype=np.float32)
75
- y_val = np.asarray(val_ds["label"], dtype=np.float32)
76
-
77
- seq_train = None
78
- seq_val = None
79
- if "sequence" in train_ds.column_names:
80
- seq_train = np.asarray(train_ds["sequence"])
81
- if "sequence" in val_ds.column_names:
82
- seq_val = np.asarray(val_ds["sequence"])
83
-
84
- return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
85
-
86
-
87
- # -----------------------------
88
- # Metrics
89
- # -----------------------------
90
- def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
91
- rho = spearmanr(y_true, y_pred).correlation
92
- if rho is None or np.isnan(rho):
93
- return 0.0
94
- return float(rho)
95
-
96
- def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
97
- # RMSE
98
- try:
99
- from sklearn.metrics import root_mean_squared_error
100
- rmse = root_mean_squared_error(y_true, y_pred)
101
- except Exception:
102
- rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
103
-
104
- mae = float(mean_absolute_error(y_true, y_pred))
105
- r2 = float(r2_score(y_true, y_pred))
106
- rho = float(safe_spearmanr(y_true, y_pred))
107
- return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
108
-
109
-
110
- # -----------------------------
111
- # Model
112
- # -----------------------------
113
- def train_xgb_reg(
114
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
115
- ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
116
- dtrain = xgb.DMatrix(X_train, label=y_train)
117
- dval = xgb.DMatrix(X_val, label=y_val)
118
-
119
- num_boost_round = int(params.pop("num_boost_round"))
120
- early_stopping_rounds = int(params.pop("early_stopping_rounds"))
121
-
122
- booster = xgb.train(
123
- params=params,
124
- dtrain=dtrain,
125
- num_boost_round=num_boost_round,
126
- evals=[(dval, "val")],
127
- early_stopping_rounds=early_stopping_rounds,
128
- verbose_eval=False,
129
- )
130
-
131
- p_train = booster.predict(dtrain)
132
- p_val = booster.predict(dval)
133
- return booster, p_train, p_val
134
-
135
-
136
- def train_cuml_elasticnet_reg(
137
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
138
- ):
139
- Xtr = to_gpu(X_train)
140
- Xva = to_gpu(X_val)
141
- ytr = to_gpu(y_train).astype(cp.float32)
142
-
143
- model = cuElasticNet(
144
- alpha=float(params["alpha"]),
145
- l1_ratio=float(params["l1_ratio"]),
146
- fit_intercept=True,
147
- max_iter=int(params.get("max_iter", 5000)),
148
- tol=float(params.get("tol", 1e-4)),
149
- selection=params.get("selection", "cyclic"),
150
- )
151
- model.fit(Xtr, ytr)
152
-
153
- p_train = to_cpu(model.predict(Xtr))
154
- p_val = to_cpu(model.predict(Xva))
155
- return model, p_train, p_val
156
-
157
-
158
- def train_svr_reg(
159
- X_train, y_train, X_val, y_val, params: Dict[str, Any]
160
- ):
161
- model = SVR(
162
- C=float(params["C"]),
163
- epsilon=float(params["epsilon"]),
164
- kernel=params["kernel"],
165
- gamma=params.get("gamma", "scale"),
166
- )
167
- model.fit(X_train, y_train)
168
- p_train = model.predict(X_train)
169
- p_val = model.predict(X_val)
170
- return model, p_train, p_val
171
-
172
-
173
- # -----------------------------
174
- # Saving + plots
175
- # -----------------------------
176
- def save_predictions_csv(
177
- out_dir: str,
178
- split_name: str,
179
- y_true: np.ndarray,
180
- y_pred: np.ndarray,
181
- sequences: Optional[np.ndarray] = None,
182
- ):
183
- os.makedirs(out_dir, exist_ok=True)
184
- df = pd.DataFrame({
185
- "y_true": y_true.astype(float),
186
- "y_pred": y_pred.astype(float),
187
- "residual": (y_true - y_pred).astype(float),
188
- })
189
- if sequences is not None:
190
- df.insert(0, "sequence", sequences)
191
- df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
192
-
193
- def plot_regression_diagnostics(out_dir: str, y_true: np.ndarray, y_pred: np.ndarray):
194
- os.makedirs(out_dir, exist_ok=True)
195
-
196
- plt.figure()
197
- plt.scatter(y_true, y_pred, s=8, alpha=0.5)
198
- plt.xlabel("y_true")
199
- plt.ylabel("y_pred")
200
- plt.title("Predicted vs True")
201
- plt.tight_layout()
202
- plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
203
- plt.close()
204
-
205
- resid = y_true - y_pred
206
- plt.figure()
207
- plt.hist(resid, bins=50)
208
- plt.xlabel("residual (y_true - y_pred)")
209
- plt.ylabel("count")
210
- plt.title("Residual Histogram")
211
- plt.tight_layout()
212
- plt.savefig(os.path.join(out_dir, "residual_hist.png"))
213
- plt.close()
214
-
215
- plt.figure()
216
- plt.scatter(y_pred, resid, s=8, alpha=0.5)
217
- plt.xlabel("y_pred")
218
- plt.ylabel("residual")
219
- plt.title("Residuals vs Prediction")
220
- plt.tight_layout()
221
- plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
222
- plt.close()
223
-
224
-
225
- # -----------------------------
226
- # Optuna objective (OPTIMIZE SPEARMAN RHO)
227
- # -----------------------------
228
- def make_objective(model_name: str, data: SplitData):
229
- Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
230
-
231
- def objective(trial: optuna.Trial) -> float:
232
- if model_name == "xgb_reg":
233
- params = {
234
- "objective": "reg:squarederror",
235
- "eval_metric": "rmse",
236
- "lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
237
- "alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
238
- "gamma": trial.suggest_float("gamma", 0.0, 10.0),
239
- "max_depth": trial.suggest_int("max_depth", 2, 16),
240
- "min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 500.0, log=True),
241
- "subsample": trial.suggest_float("subsample", 0.5, 1.0),
242
- "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
243
- "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
244
- "tree_method": "hist",
245
- "device": "cuda",
246
- }
247
- params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 2000)
248
- params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
249
-
250
- model, p_tr, p_va = train_xgb_reg(Xtr, ytr, Xva, yva, params.copy())
251
-
252
- elif model_name == "enet_gpu":
253
- params = {
254
- "alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
255
- "l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
256
- "max_iter": trial.suggest_int("max_iter", 1000, 20000),
257
- "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
258
- "selection": trial.suggest_categorical("selection", ["cyclic", "random"]),
259
- }
260
- model, p_tr, p_va = train_cuml_elasticnet_reg(Xtr, ytr, Xva, yva, params)
261
-
262
- elif model_name == "svr":
263
- params = {
264
- "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
265
- "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
266
- "epsilon": trial.suggest_float("epsilon", 1e-4, 1.0, log=True),
267
- }
268
- if params["kernel"] in ["rbf", "poly", "sigmoid"]:
269
- params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
270
- else:
271
- params["gamma"] = "scale"
272
-
273
- model, p_tr, p_va = train_svr_reg(Xtr, ytr, Xva, yva, params)
274
-
275
- else:
276
- raise ValueError(f"Unknown model_name={model_name}")
277
-
278
- metrics = eval_regression(yva, p_va)
279
- trial.set_user_attr("spearman_rho", metrics["spearman_rho"])
280
- trial.set_user_attr("rmse", metrics["rmse"])
281
- trial.set_user_attr("mae", metrics["mae"])
282
- trial.set_user_attr("r2", metrics["r2"])
283
-
284
- # OPTUNA OBJECTIVE = maximize Spearman rho
285
- return metrics["spearman_rho"]
286
-
287
- return objective
288
-
289
-
290
- # -----------------------------
291
- # Main
292
- # -----------------------------
293
- def run_optuna_and_refit(
294
- dataset_path: str,
295
- out_dir: str,
296
- model_name: str,
297
- n_trials: int = 200,
298
- standardize_X: bool = True,
299
- ):
300
- os.makedirs(out_dir, exist_ok=True)
301
-
302
- data = load_split_data(dataset_path)
303
- print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
304
-
305
- # Standardize features (SVR + ElasticNet)
306
- if standardize_X:
307
- scaler = StandardScaler()
308
- data.X_train = scaler.fit_transform(data.X_train).astype(np.float32)
309
- data.X_val = scaler.transform(data.X_val).astype(np.float32)
310
- joblib.dump(scaler, os.path.join(out_dir, "scaler.joblib"))
311
- print("[Preprocess] Saved StandardScaler -> scaler.joblib")
312
-
313
- study = optuna.create_study(
314
- direction="maximize",
315
- pruner=optuna.pruners.MedianPruner()
316
- )
317
- study.optimize(make_objective(model_name, data), n_trials=n_trials)
318
-
319
- trials_df = study.trials_dataframe()
320
- trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
321
-
322
- best = study.best_trial
323
- best_params = dict(best.params)
324
-
325
- best_rho = float(best.user_attrs.get("spearman_rho", best.value))
326
- best_rmse = float(best.user_attrs.get("rmse", np.nan))
327
- best_mae = float(best.user_attrs.get("mae", np.nan))
328
- best_r2 = float(best.user_attrs.get("r2", np.nan))
329
-
330
- # Refit best model on train
331
- if model_name == "xgb_reg":
332
- params = {
333
- "objective": "reg:squarederror",
334
- "eval_metric": "rmse",
335
- "lambda": best_params["lambda"],
336
- "alpha": best_params["alpha"],
337
- "gamma": best_params["gamma"],
338
- "max_depth": best_params["max_depth"],
339
- "min_child_weight": best_params["min_child_weight"],
340
- "subsample": best_params["subsample"],
341
- "colsample_bytree": best_params["colsample_bytree"],
342
- "learning_rate": best_params["learning_rate"],
343
- "tree_method": "hist",
344
- "device": "cuda",
345
- "num_boost_round": best_params["num_boost_round"],
346
- "early_stopping_rounds": best_params["early_stopping_rounds"],
347
- }
348
- model, p_tr, p_va = train_xgb_reg(
349
- data.X_train, data.y_train, data.X_val, data.y_val, params
350
- )
351
- model_path = os.path.join(out_dir, "best_model.json")
352
- model.save_model(model_path)
353
-
354
- elif model_name == "enet_gpu":
355
- model, p_tr, p_va = train_cuml_elasticnet_reg(
356
- data.X_train, data.y_train, data.X_val, data.y_val, best_params
357
- )
358
- model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
359
- joblib.dump(model, model_path)
360
-
361
- elif model_name == "svr":
362
- model, p_tr, p_va = train_svr_reg(
363
- data.X_train, data.y_train, data.X_val, data.y_val, best_params
364
- )
365
- model_path = os.path.join(out_dir, "best_model_svr.joblib")
366
- joblib.dump(model, model_path)
367
-
368
- else:
369
- raise ValueError(model_name)
370
-
371
- save_predictions_csv(out_dir, "train", data.y_train, p_tr, data.seq_train)
372
- save_predictions_csv(out_dir, "val", data.y_val, p_va, data.seq_val)
373
-
374
- plot_regression_diagnostics(out_dir, data.y_val, p_va)
375
-
376
- summary = [
377
- "=" * 72,
378
- f"MODEL: {model_name}",
379
- f"Best trial: {best.number}",
380
- f"Val Spearman rho (objective): {best_rho:.6f}",
381
- f"Val RMSE: {best_rmse:.6f}",
382
- f"Val MAE: {best_mae:.6f}",
383
- f"Val R2: {best_r2:.6f}",
384
- f"Model saved to: {model_path}",
385
- "Best params:",
386
- json.dumps(best_params, indent=2),
387
- "=" * 72,
388
- ]
389
- with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
390
- f.write("\n".join(summary))
391
- print("\n".join(summary))
392
-
393
-
394
- if __name__ == "__main__":
395
- import argparse
396
- parser = argparse.ArgumentParser()
397
- parser.add_argument("--dataset_path", type=str, required=True)
398
- parser.add_argument("--out_dir", type=str, required=True)
399
- parser.add_argument("--model", type=str, choices=["xgb_reg", "enet_gpu", "svr"], required=True)
400
- parser.add_argument("--n_trials", type=int, default=200)
401
- parser.add_argument("--no_standardize", action="store_true", help="Disable StandardScaler on X")
402
- args = parser.parse_args()
403
-
404
- run_optuna_and_refit(
405
- dataset_path=args.dataset_path,
406
- out_dir=args.out_dir,
407
- model_name=args.model,
408
- n_trials=args.n_trials,
409
- standardize_X=(not args.no_standardize),
410
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py DELETED
@@ -1,426 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch.utils.data import DataLoader
4
- from datasets import load_from_disk, DatasetDict
5
- from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
6
- import torch.nn as nn
7
- import optuna
8
- import os
9
- from typing import Dict, Any, Tuple, Optional
10
- import matplotlib.pyplot as plt
11
- from sklearn.metrics import (
12
- f1_score, roc_auc_score, average_precision_score,
13
- precision_recall_curve, roc_curve
14
- )
15
- import json
16
- import joblib
17
- import pandas as pd
18
- import time
19
-
20
- def infer_in_dim_from_unpooled_ds(ds) -> int:
21
- ex = ds[0]
22
- # ex["embedding"] is (L, H) list/array
23
- return int(len(ex["embedding"][0]))
24
-
25
- def load_split(dataset_path):
26
- ds = load_from_disk(dataset_path)
27
-
28
- if isinstance(ds, DatasetDict):
29
- return ds["train"], ds["val"]
30
-
31
- raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
32
-
33
- def collate_unpooled(batch):
34
- # batch: list of dicts
35
- lengths = [int(x["length"]) for x in batch]
36
- Lmax = max(lengths)
37
- H = len(batch[0]["embedding"][0]) # 1280
38
-
39
- X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
40
- M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
41
- y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
42
-
43
- for i, x in enumerate(batch):
44
- emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H)
45
- L = emb.shape[0]
46
- X[i, :L] = emb
47
- if "attention_mask" in x:
48
- m = torch.tensor(x["attention_mask"], dtype=torch.bool)
49
- M[i, :L] = m[:L]
50
- else:
51
- M[i, :L] = True
52
-
53
- return X, M, y
54
-
55
- # ======================== Helper functions =========================================
56
- def save_predictions_csv(
57
- out_dir: str,
58
- split_name: str,
59
- y_true: np.ndarray,
60
- y_prob: np.ndarray,
61
- threshold: float,
62
- sequences: Optional[np.ndarray] = None,
63
- ):
64
- os.makedirs(out_dir, exist_ok=True)
65
- df = pd.DataFrame({
66
- "y_true": y_true.astype(int),
67
- "y_prob": y_prob.astype(float),
68
- "y_pred": (y_prob >= threshold).astype(int),
69
- })
70
- if sequences is not None:
71
- df.insert(0, "sequence", sequences)
72
- df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
73
-
74
-
75
- def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
76
- os.makedirs(out_dir, exist_ok=True)
77
-
78
- # PR
79
- precision, recall, _ = precision_recall_curve(y_true, y_prob)
80
- plt.figure()
81
- plt.plot(recall, precision)
82
- plt.xlabel("Recall")
83
- plt.ylabel("Precision")
84
- plt.title("Precision-Recall Curve")
85
- plt.tight_layout()
86
- plt.savefig(os.path.join(out_dir, "pr_curve.png"))
87
- plt.close()
88
-
89
- # ROC
90
- fpr, tpr, _ = roc_curve(y_true, y_prob)
91
- plt.figure()
92
- plt.plot(fpr, tpr)
93
- plt.xlabel("False Positive Rate")
94
- plt.ylabel("True Positive Rate")
95
- plt.title("ROC Curve")
96
- plt.tight_layout()
97
- plt.savefig(os.path.join(out_dir, "roc_curve.png"))
98
- plt.close()
99
-
100
- # ======================== Shared OPTUNA training scheme =========================================
101
- def best_f1_threshold(y_true, y_prob):
102
- p, r, thr = precision_recall_curve(y_true, y_prob)
103
- f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12)
104
- i = int(np.nanargmax(f1s))
105
- return float(thr[i]), float(f1s[i])
106
-
107
- @torch.no_grad()
108
- def eval_probs(model, loader, device):
109
- model.eval()
110
- ys, ps = [], []
111
- for X, M, y in loader:
112
- X, M = X.to(device), M.to(device)
113
- logits = model(X, M)
114
- prob = torch.sigmoid(logits).detach().cpu().numpy()
115
- ys.append(y.numpy())
116
- ps.append(prob)
117
- return np.concatenate(ys), np.concatenate(ps)
118
-
119
- def train_one_epoch(model, loader, optim, criterion, device):
120
- model.train()
121
- for X, M, y in loader:
122
- X, M, y = X.to(device), M.to(device), y.to(device)
123
- optim.zero_grad(set_to_none=True)
124
- logits = model(X, M)
125
- loss = criterion(logits, y)
126
- loss.backward()
127
- optim.step()
128
-
129
- # ======================== MLP =========================================
130
- # Still need mean pooling along lengths
131
- class MaskedMeanPool(nn.Module):
132
- def forward(self, X, M): # X: (B,L,H), M: (B,L)
133
- Mf = M.unsqueeze(-1).float()
134
- denom = Mf.sum(dim=1).clamp(min=1.0)
135
- return (X * Mf).sum(dim=1) / denom # (B,H)
136
-
137
- class MLPClassifier(nn.Module):
138
- def __init__(self, in_dim, hidden=512, dropout=0.1):
139
- super().__init__()
140
- self.pool = MaskedMeanPool()
141
- self.net = nn.Sequential(
142
- nn.Linear(in_dim, hidden),
143
- nn.GELU(),
144
- nn.Dropout(dropout),
145
- nn.Linear(hidden, 1),
146
- )
147
- def forward(self, X, M):
148
- z = self.pool(X, M)
149
- return self.net(z).squeeze(-1) # logits
150
-
151
- # ======================== CNN =========================================
152
- # Treat 1280 dimensions as channels
153
- class CNNClassifier(nn.Module):
154
- def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
155
- super().__init__()
156
- blocks = []
157
- ch = in_ch
158
- for _ in range(layers):
159
- blocks += [
160
- nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
161
- nn.GELU(),
162
- nn.Dropout(dropout),
163
- ]
164
- ch = c
165
- self.conv = nn.Sequential(*blocks)
166
- self.head = nn.Linear(c, 1)
167
-
168
- def forward(self, X, M):
169
- # X: (B,L,H) -> (B,H,L)
170
- Xc = X.transpose(1, 2)
171
- Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
172
-
173
- # masked mean pool over L
174
- Mf = M.unsqueeze(-1).float()
175
- denom = Mf.sum(dim=1).clamp(min=1.0)
176
- pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
177
- return self.head(pooled).squeeze(-1)
178
-
179
- # ========================== Transformer ====================================
180
- class TransformerClassifier(nn.Module):
181
- def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
182
- super().__init__()
183
- self.proj = nn.Linear(in_dim, d_model)
184
- enc_layer = nn.TransformerEncoderLayer(
185
- d_model=d_model, nhead=nhead, dim_feedforward=ff,
186
- dropout=dropout, batch_first=True, activation="gelu"
187
- )
188
- self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
189
- self.head = nn.Linear(d_model, 1)
190
-
191
- def forward(self, X, M):
192
- # src_key_padding_mask: True = pad positions
193
- pad_mask = ~M
194
- Z = self.proj(X) # (B,L,d)
195
- Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d)
196
-
197
- Mf = M.unsqueeze(-1).float()
198
- denom = Mf.sum(dim=1).clamp(min=1.0)
199
- pooled = (Z * Mf).sum(dim=1) / denom
200
- return self.head(pooled).squeeze(-1)
201
-
202
- # ========================== OPTUNA ====================================
203
-
204
- def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"):
205
- # hyperparams shared
206
- lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
207
- wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True)
208
- dropout = trial.suggest_float("dropout", 0.0, 0.5)
209
- batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
210
-
211
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
212
- collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
213
- val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
214
- collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
215
-
216
- in_dim = infer_in_dim_from_unpooled_ds(train_ds)
217
-
218
- if model_name == "mlp":
219
- hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
220
- model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout)
221
- elif model_name == "cnn":
222
- c = trial.suggest_categorical("channels", [128, 256, 512])
223
- k = trial.suggest_categorical("kernel", [3, 5, 7])
224
- layers = trial.suggest_int("layers", 1, 4)
225
- model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
226
- elif model_name == "transformer":
227
- d = trial.suggest_categorical("d_model", [128, 256, 384])
228
- nhead = trial.suggest_categorical("nhead", [4, 8])
229
- layers = trial.suggest_int("layers", 1, 4)
230
- ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
231
- model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
232
- else:
233
- raise ValueError(model_name)
234
-
235
- model = model.to(device)
236
-
237
- # class imbalance handling
238
- ytr = np.asarray(train_ds["label"], dtype=np.int64)
239
- pos = ytr.sum()
240
- neg = len(ytr) - pos
241
- pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
242
- criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
243
-
244
- optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
245
-
246
- best_f1 = -1.0
247
- patience = 8
248
- bad = 0
249
-
250
- for epoch in range(1, 51):
251
- train_one_epoch(model, train_loader, optim, criterion, device)
252
-
253
- y_true, y_prob = eval_probs(model, val_loader, device)
254
- auc = roc_auc_score(y_true, y_prob)
255
-
256
- thr, f1 = best_f1_threshold(y_true, y_prob)
257
-
258
- trial.set_user_attr("val_auc", float(auc))
259
- trial.set_user_attr("val_f1", float(f1))
260
- trial.set_user_attr("val_thr", float(thr))
261
-
262
- # prune
263
- trial.report(f1, epoch)
264
- if trial.should_prune():
265
- raise optuna.TrialPruned()
266
-
267
- if f1 > best_f1 + 1e-4:
268
- best_f1 = f1
269
- bad = 0
270
- else:
271
- bad += 1
272
- if bad >= patience:
273
- break
274
-
275
- return best_f1
276
-
277
- def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"):
278
- os.makedirs(out_dir, exist_ok=True)
279
-
280
- train_ds, val_ds = load_split(dataset_path)
281
- print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
282
-
283
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
284
- study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials)
285
-
286
- trials_df = study.trials_dataframe()
287
- trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
288
-
289
- best = study.best_trial
290
- best_params = dict(best.params)
291
- best_f1_optuna = float(best.value)
292
- best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan))
293
- best_thr = float(best.user_attrs.get("val_thr", 0.5))
294
-
295
- in_dim = infer_in_dim_from_unpooled_ds(train_ds)
296
-
297
- # --- Refit best model ---
298
- batch_size = int(best_params.get("batch_size", 32))
299
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
300
- collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
301
- val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
302
- collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
303
-
304
- # Rebuild
305
- dropout = float(best_params.get("dropout", 0.1))
306
- if model_name == "mlp":
307
- model = MLPClassifier(
308
- in_dim=in_dim,
309
- hidden=int(best_params["hidden"]),
310
- dropout=dropout,
311
- )
312
-
313
- elif model_name == "cnn":
314
- model = CNNClassifier(
315
- in_ch=in_dim,
316
- c=int(best_params["channels"]),
317
- k=int(best_params["kernel"]),
318
- layers=int(best_params["layers"]),
319
- dropout=dropout,
320
- )
321
-
322
- elif model_name == "transformer":
323
- model = TransformerClassifier(
324
- in_dim=in_dim,
325
- d_model=int(best_params["d_model"]),
326
- nhead=int(best_params["nhead"]),
327
- layers=int(best_params["layers"]),
328
- ff=int(best_params["ff"]),
329
- dropout=dropout,
330
- )
331
- else:
332
- raise ValueError(model_name)
333
-
334
- model = model.to(device)
335
-
336
- # loss + optimizer
337
- ytr = np.asarray(train_ds["label"], dtype=np.int64)
338
- pos = ytr.sum()
339
- neg = len(ytr) - pos
340
- pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
341
- criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
342
-
343
- lr = float(best_params["lr"])
344
- wd = float(best_params["weight_decay"])
345
- optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
346
-
347
- # train longer with early stopping on AUC
348
- best_f1_seen, bad, patience = -1.0, 0, 12
349
- best_state = None
350
- best_thr_seen = 0.5
351
- best_auc_seen = -1.0
352
-
353
- for epoch in range(1, 151):
354
- train_one_epoch(model, train_loader, optim, criterion, device)
355
-
356
- y_true, y_prob = eval_probs(model, val_loader, device)
357
- auc = roc_auc_score(y_true, y_prob)
358
- thr, f1 = best_f1_threshold(y_true, y_prob)
359
-
360
- if f1 > best_f1_seen + 1e-4:
361
- best_f1_seen = f1
362
- best_thr_seen = thr
363
- best_auc_seen = auc
364
- bad = 0
365
- best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
366
- else:
367
- bad += 1
368
- if bad >= patience:
369
- break
370
-
371
- if best_state is not None:
372
- model.load_state_dict(best_state)
373
-
374
- # final preds + threshold picked on val
375
- y_true_val, y_prob_val = eval_probs(model, val_loader, device)
376
- best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
377
-
378
- # save model
379
- model_path = os.path.join(out_dir, "best_model.pt")
380
- torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path)
381
-
382
- # train preds
383
- y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False,
384
- collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device)
385
-
386
- save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final,
387
- sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None)
388
- save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final,
389
- sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None)
390
-
391
- plot_curves(out_dir, y_true_val, y_prob_val)
392
-
393
- summary = [
394
- "=" * 72,
395
- f"MODEL: {model_name}",
396
-
397
- # Optuna results (objective = F1)
398
- f"Best Optuna F1 (objective): {best_f1_optuna:.4f}",
399
- f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}",
400
- f"Best Optuna threshold (val): {best_thr:.4f}",
401
-
402
- # Refit results
403
- f"Refit best AUC (val): {best_auc_seen:.4f}",
404
- f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}",
405
-
406
- "Best params:",
407
- json.dumps(best_params, indent=2),
408
- f"Saved model: {model_path}",
409
- "=" * 72,
410
- ]
411
-
412
- with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
413
- f.write("\n".join(summary))
414
- print("\n".join(summary))
415
-
416
- if __name__ == "__main__":
417
- import argparse
418
- parser = argparse.ArgumentParser()
419
- parser.add_argument("--dataset_path", type=str, required=True)
420
- parser.add_argument("--out_dir", type=str, required=True)
421
- parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
422
- parser.add_argument("--n_trials", type=int, default=50)
423
- args = parser.parse_args()
424
-
425
- if args.model in ["mlp", "cnn", "transformer"]:
426
- run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py DELETED
@@ -1,420 +0,0 @@
1
- import os, json, time
2
- import numpy as np
3
- import pandas as pd
4
- import matplotlib.pyplot as plt
5
-
6
- import torch
7
- import torch.nn as nn
8
- from torch.utils.data import DataLoader
9
- from datasets import load_from_disk, DatasetDict
10
- import optuna
11
- from dataclasses import dataclass
12
- from typing import Dict, Any, Tuple, Optional
13
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
14
- from scipy.stats import spearmanr
15
- from torch.cuda.amp import autocast
16
- from torch.cuda.amp import autocast, GradScaler
17
- scaler = GradScaler(enabled=torch.cuda.is_available())
18
- from lightning.pytorch import seed_everything
19
- seed_everything(1986)
20
-
21
-
22
- def load_split(dataset_path):
23
- ds = load_from_disk(dataset_path)
24
- if isinstance(ds, DatasetDict):
25
- return ds["train"], ds["val"]
26
- raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
27
-
28
- def collate_unpooled_reg(batch):
29
- lengths = [int(x["length"]) for x in batch]
30
- Lmax = max(lengths)
31
- H = len(batch[0]["embedding"][0])
32
-
33
- X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
34
- M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
35
- y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
36
-
37
- for i, x in enumerate(batch):
38
- emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H)
39
- L = emb.shape[0]
40
- X[i, :L] = emb
41
- if "attention_mask" in x:
42
- m = torch.tensor(x["attention_mask"], dtype=torch.bool)
43
- M[i, :L] = m[:L]
44
- else:
45
- M[i, :L] = True
46
- return X, M, y
47
-
48
- def infer_in_dim(ds) -> int:
49
- ex = ds[0]
50
- return int(len(ex["embedding"][0]))
51
-
52
- # ============================
53
- # Metrics
54
- # ============================
55
- def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
56
- rho = spearmanr(y_true, y_pred).correlation
57
- if rho is None or np.isnan(rho):
58
- return 0.0
59
- return float(rho)
60
-
61
- def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
62
- # ---- RMSE ----
63
- try:
64
- from sklearn.metrics import root_mean_squared_error
65
- rmse = root_mean_squared_error(y_true, y_pred)
66
- except Exception:
67
- mse = mean_squared_error(y_true, y_pred)
68
- rmse = float(np.sqrt(mse))
69
-
70
- mae = float(mean_absolute_error(y_true, y_pred))
71
- r2 = float(r2_score(y_true, y_pred))
72
- rho = float(safe_spearmanr(y_true, y_pred))
73
- return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho}
74
-
75
-
76
- # ============================
77
- # Models
78
- # ============================
79
- class MaskedMeanPool(nn.Module):
80
- def forward(self, X, M):
81
- Mf = M.unsqueeze(-1).float()
82
- denom = Mf.sum(dim=1).clamp(min=1.0)
83
- return (X * Mf).sum(dim=1) / denom
84
-
85
- class MLPRegressor(nn.Module):
86
- def __init__(self, in_dim, hidden=512, dropout=0.1):
87
- super().__init__()
88
- self.pool = MaskedMeanPool()
89
- self.net = nn.Sequential(
90
- nn.Linear(in_dim, hidden),
91
- nn.GELU(),
92
- nn.Dropout(dropout),
93
- nn.Linear(hidden, 1),
94
- )
95
- def forward(self, X, M):
96
- z = self.pool(X, M)
97
- return self.net(z).squeeze(-1) # y_pred
98
-
99
- class CNNRegressor(nn.Module):
100
- def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
101
- super().__init__()
102
- blocks = []
103
- ch = in_ch
104
- for _ in range(layers):
105
- blocks += [
106
- nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
107
- nn.GELU(),
108
- nn.Dropout(dropout),
109
- ]
110
- ch = c
111
- self.conv = nn.Sequential(*blocks)
112
- self.head = nn.Linear(c, 1)
113
-
114
- def forward(self, X, M):
115
- Xc = X.transpose(1, 2) # (B,H,L)
116
- Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
117
- Mf = M.unsqueeze(-1).float()
118
- denom = Mf.sum(dim=1).clamp(min=1.0)
119
- pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
120
- return self.head(pooled).squeeze(-1)
121
-
122
- class TransformerRegressor(nn.Module):
123
- def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
124
- super().__init__()
125
- self.proj = nn.Linear(in_dim, d_model)
126
- enc_layer = nn.TransformerEncoderLayer(
127
- d_model=d_model, nhead=nhead, dim_feedforward=ff,
128
- dropout=dropout, batch_first=True, activation="gelu"
129
- )
130
- self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
131
- self.head = nn.Linear(d_model, 1)
132
-
133
- def forward(self, X, M):
134
- pad_mask = ~M
135
- Z = self.proj(X)
136
- Z = self.enc(Z, src_key_padding_mask=pad_mask)
137
- Mf = M.unsqueeze(-1).float()
138
- denom = Mf.sum(dim=1).clamp(min=1.0)
139
- pooled = (Z * Mf).sum(dim=1) / denom
140
- return self.head(pooled).squeeze(-1)
141
-
142
- # ============================
143
- # Train / eval
144
- # ============================
145
- @torch.no_grad()
146
- def eval_preds(model, loader, device):
147
- model.eval()
148
- ys, ps = [], []
149
- for X, M, y in loader:
150
- X, M = X.to(device), M.to(device)
151
- pred = model(X, M).detach().cpu().numpy()
152
- ys.append(y.numpy())
153
- ps.append(pred)
154
- return np.concatenate(ys), np.concatenate(ps)
155
-
156
- def train_one_epoch_reg(model, loader, optim, criterion, device):
157
- model.train()
158
- for X, M, y in loader:
159
- X, M, y = X.to(device), M.to(device), y.to(device)
160
- optim.zero_grad(set_to_none=True)
161
- with autocast(enabled=torch.cuda.is_available()):
162
- pred = model(X, M)
163
- loss = criterion(pred, y)
164
- scaler.scale(loss).backward()
165
- scaler.step(optim)
166
- scaler.update()
167
-
168
- # ============================
169
- # Saving + plots
170
- # ============================
171
- def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None):
172
- os.makedirs(out_dir, exist_ok=True)
173
- df = pd.DataFrame({
174
- "y_true": y_true.astype(float),
175
- "y_pred": y_pred.astype(float),
176
- "residual": (y_true - y_pred).astype(float),
177
- })
178
- if sequences is not None:
179
- df.insert(0, "sequence", sequences)
180
- df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
181
-
182
- def plot_regression_diagnostics(out_dir, y_true, y_pred):
183
- os.makedirs(out_dir, exist_ok=True)
184
-
185
- plt.figure()
186
- plt.scatter(y_true, y_pred, s=8, alpha=0.5)
187
- plt.xlabel("y_true"); plt.ylabel("y_pred")
188
- plt.title("Predicted vs True")
189
- plt.tight_layout()
190
- plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
191
- plt.close()
192
-
193
- resid = y_true - y_pred
194
- plt.figure()
195
- plt.hist(resid, bins=50)
196
- plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count")
197
- plt.title("Residual Histogram")
198
- plt.tight_layout()
199
- plt.savefig(os.path.join(out_dir, "residual_hist.png"))
200
- plt.close()
201
-
202
- plt.figure()
203
- plt.scatter(y_pred, resid, s=8, alpha=0.5)
204
- plt.xlabel("y_pred"); plt.ylabel("residual")
205
- plt.title("Residuals vs Prediction")
206
- plt.tight_layout()
207
- plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
208
- plt.close()
209
-
210
- # ============================
211
- # Optuna objective
212
- # ============================
213
- def score_from_metrics(metrics: Dict[str, float], objective: str) -> float:
214
- if objective == "spearman":
215
- return metrics["spearman_rho"]
216
- if objective == "r2":
217
- return metrics["r2"]
218
- if objective == "neg_rmse":
219
- return -metrics["rmse"]
220
- raise ValueError(f"Unknown objective={objective}")
221
-
222
- def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"):
223
- lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
224
- wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
225
- dropout = trial.suggest_float("dropout", 0.0, 0.5)
226
- batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
227
-
228
- in_dim = infer_in_dim(train_ds)
229
-
230
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
231
- collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
232
- val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
233
- collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
234
-
235
- if model_name == "mlp":
236
- hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
237
- model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout)
238
- elif model_name == "cnn":
239
- c = trial.suggest_categorical("channels", [128, 256, 512])
240
- k = trial.suggest_categorical("kernel", [3, 5, 7])
241
- layers = trial.suggest_int("layers", 1, 4)
242
- model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
243
- elif model_name == "transformer":
244
- d = trial.suggest_categorical("d_model", [128, 256, 384])
245
- nhead = trial.suggest_categorical("nhead", [4, 8])
246
- layers = trial.suggest_int("layers", 1, 4)
247
- ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
248
- model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
249
- else:
250
- raise ValueError(model_name)
251
-
252
- model = model.to(device)
253
-
254
- loss_name = trial.suggest_categorical("loss", ["mse", "huber"])
255
- if loss_name == "mse":
256
- criterion = nn.MSELoss()
257
- else:
258
- delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True)
259
- criterion = nn.HuberLoss(delta=delta)
260
-
261
- optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
262
-
263
- best_score = -1e18
264
- patience = 10
265
- bad = 0
266
-
267
- for epoch in range(1, 61):
268
- train_one_epoch_reg(model, train_loader, optim, criterion, device)
269
-
270
- y_true, y_pred = eval_preds(model, val_loader, device)
271
- metrics = eval_regression(y_true, y_pred)
272
- score = score_from_metrics(metrics, objective)
273
-
274
- # log attrs
275
- for k, v in metrics.items():
276
- trial.set_user_attr(f"val_{k}", float(v))
277
-
278
- trial.report(score, epoch)
279
- if trial.should_prune():
280
- raise optuna.TrialPruned()
281
-
282
- if score > best_score + 1e-6:
283
- best_score = score
284
- bad = 0
285
- else:
286
- bad += 1
287
- if bad >= patience:
288
- break
289
-
290
- return float(best_score)
291
-
292
- # ============================
293
- # Main runner
294
- # ============================
295
- def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0",
296
- objective="spearman"):
297
- os.makedirs(out_dir, exist_ok=True)
298
-
299
- train_ds, val_ds = load_split(dataset_path)
300
- print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
301
-
302
- study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
303
- study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective),
304
- n_trials=n_trials)
305
-
306
- trials_df = study.trials_dataframe()
307
- trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
308
-
309
- best = study.best_trial
310
- best_params = dict(best.params)
311
-
312
- # rebuild model from best params
313
- in_dim = infer_in_dim(train_ds)
314
- dropout = float(best_params.get("dropout", 0.1))
315
- if model_name == "mlp":
316
- model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout)
317
- elif model_name == "cnn":
318
- model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]),
319
- k=int(best_params["kernel"]), layers=int(best_params["layers"]),
320
- dropout=dropout)
321
- elif model_name == "transformer":
322
- model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]),
323
- nhead=int(best_params["nhead"]), layers=int(best_params["layers"]),
324
- ff=int(best_params["ff"]), dropout=dropout)
325
- else:
326
- raise ValueError(model_name)
327
-
328
- model = model.to(device)
329
-
330
- batch_size = int(best_params.get("batch_size", 32))
331
- train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
332
- collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
333
- val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
334
- collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
335
-
336
- # loss
337
- if best_params.get("loss", "mse") == "mse":
338
- criterion = nn.MSELoss()
339
- else:
340
- criterion = nn.HuberLoss(delta=float(best_params["huber_delta"]))
341
-
342
- optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]),
343
- weight_decay=float(best_params["weight_decay"]))
344
-
345
- # refit longer with early stopping on the SAME objective
346
- best_score, bad, patience = -1e18, 0, 15
347
- best_state = None
348
-
349
- for epoch in range(1, 201):
350
- train_one_epoch_reg(model, train_loader, optim, criterion, device)
351
-
352
- y_true, y_pred = eval_preds(model, val_loader, device)
353
- metrics = eval_regression(y_true, y_pred)
354
- score = score_from_metrics(metrics, objective)
355
-
356
- if score > best_score + 1e-6:
357
- best_score = score
358
- bad = 0
359
- best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
360
- best_metrics = metrics
361
- else:
362
- bad += 1
363
- if bad >= patience:
364
- break
365
-
366
- if best_state is not None:
367
- model.load_state_dict(best_state)
368
-
369
- # preds
370
- y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False,
371
- collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device)
372
- y_true_va, y_pred_va = eval_preds(model, val_loader, device)
373
-
374
- seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None
375
- seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None
376
- save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train)
377
- save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val)
378
- plot_regression_diagnostics(out_dir, y_true_va, y_pred_va)
379
-
380
- # save model
381
- model_path = os.path.join(out_dir, "best_model.pt")
382
- torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path)
383
-
384
- summary = [
385
- "=" * 72,
386
- f"MODEL: {model_name}",
387
- f"OPTUNA objective: {objective} (direction=maximize)",
388
- f"Best trial: {best.number}",
389
- "Best val metrics:",
390
- json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2),
391
- f"Saved model: {model_path}",
392
- "Best params:",
393
- json.dumps(best_params, indent=2),
394
- "=" * 72,
395
- ]
396
- with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
397
- f.write("\n".join(summary))
398
- print("\n".join(summary))
399
-
400
-
401
- if __name__ == "__main__":
402
- import argparse
403
- parser = argparse.ArgumentParser()
404
- parser.add_argument("--dataset_path", type=str, required=True)
405
- parser.add_argument("--out_dir", type=str, required=True)
406
- parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True)
407
- parser.add_argument("--n_trials", type=int, default=80)
408
- parser.add_argument("--objective", type=str, default="spearman",
409
- choices=["spearman","neg_rmse","r2"])
410
- parser.add_argument("--device", type=str, default="cuda:0")
411
- args = parser.parse_args()
412
-
413
- run_optuna_and_refit_nn_reg(
414
- dataset_path=args.dataset_path,
415
- out_dir=args.out_dir,
416
- model_name=args.model,
417
- n_trials=args.n_trials,
418
- device=args.device,
419
- objective=args.objective,
420
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/binding_affinity/val_smiles_pooled.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5410a45a7b65def6cfb94c167b07537abd33b5aac4ecdffe162b7ce4e9bc3d19
3
- size 36525
 
 
 
 
training_classifiers/binding_affinity/val_smiles_unpooled.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cdf71fbb3e7b3b8e8dbfe4ed45b32a2da0049df851f09ee32564825f626cb86c
3
- size 37187
 
 
 
 
training_classifiers/binding_affinity/val_wt_pooled.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b194e7b2b97258320323021b3ffe6143133070212a0215ade22fa91b87a3a861
3
- size 33224
 
 
 
 
training_classifiers/binding_affinity/val_wt_unpooled.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:051325790047e749fbf1daf7bf25a08178297b0c37acaf9439816d09f2b6c1e3
3
- size 33826
 
 
 
 
training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12f956a7bf04ed602c11fd275377afa73f3f0af1982dbe06c607d8ada304b01c
3
- size 21617397
 
 
 
 
training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3d7ae3d2190b034352a65bda1bce86aa5a96ce3daf74cf10a166f8d9e9af51f0
3
- size 181183221
 
 
 
 
training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
3
- size 40700
 
 
 
 
training_classifiers/half_life/cnn_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0a8a57d44cac3fcd701b550a4eaf9e29910540bfb7580a9b8ee997a7227375d2
3
- size 13748
 
 
 
 
training_classifiers/half_life/cnn_unpooled_peptideclm/best_model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9eaaafffe02663f7cfe67fde25cdebd7d4315af69b393b433c4291b700bc5063
3
- size 16525563
 
 
 
 
training_classifiers/half_life/cnn_unpooled_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:879d8c6f47b02c1ddd86fbe3982d8b0134167521f9f71d2450957dc3bbbb6bd1
3
- size 13705
 
 
 
 
training_classifiers/half_life/enet_gpu_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd5c84e9788e2db949c6be785f8539178e72fc6fa6bc703daf9574ad0622e0f1
3
- size 13649
 
 
 
 
training_classifiers/half_life/enet_peptideclm/smiles_halflife_best_enet.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0eb93bcb27436e80bce2a6433cbd7502b90de4962731250972eef08a5d96ce69
3
- size 22698
 
 
 
 
training_classifiers/half_life/mlp_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:14b1010a2c0b5d065fa9b82636085806ec7f6f6091c7c2355c6c4717d07fa79b
3
- size 13724
 
 
 
 
training_classifiers/half_life/mlp_unpooled_peptideclm/best_model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ef78a5e5c555768f91dc646652a39e367287e851a14e2cf85e4006c9355a8368
3
- size 2368455
 
 
 
 
training_classifiers/half_life/mlp_unpooled_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2a037cf36528fc8e04a375c0443577830733636fca83fa9ce44e457e28e4f771
3
- size 13745
 
 
 
 
training_classifiers/half_life/svr_gpu_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0dd42537c1a5589b78451de8645bfc089b8f7f5839808222bb1e9e033d78c66
3
- size 13746
 
 
 
 
training_classifiers/half_life/svr_peptideclm/smiles_halflife_best_svr.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5579f1407fc8dfd1e42b4ea2a6b619dea8b0eff4ce9a4c0869890cbd1b321851
3
- size 1530479
 
 
 
 
training_classifiers/half_life/transformer_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6c6dcaff0d542c3a6bbaf499aba56e5f440c50aa18b55271ee85feb43851fe92
3
- size 13694
 
 
 
 
training_classifiers/half_life/transformer_unpooled_peptideclm/best_model.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a19240f8067e68a6e2eaff139f90b6d2f37ab5431197c5496894937c01918f7
3
- size 931353
 
 
 
 
training_classifiers/half_life/transformer_unpooled_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b1111c0b57288092ab97b940598b2d3b44c2ff5299fe55a50a2312d8c2e45af
3
- size 13683
 
 
 
 
training_classifiers/half_life/transformer_wt_log/oof_pred_vs_true.png DELETED
Binary file (16.9 kB)
 
training_classifiers/half_life/transformer_wt_log/oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8ec7b8dee908ef43ba7633a887a988834e24f952711b906472e1b41b833de714
3
- size 14100
 
 
 
 
training_classifiers/half_life/transformer_wt_log/oof_residual_hist.png DELETED
Binary file (15.3 kB)
 
training_classifiers/half_life/transformer_wt_log/oof_residual_vs_pred.png DELETED
Binary file (19.6 kB)
 
training_classifiers/half_life/transformer_wt_log/optimization_summary.txt DELETED
@@ -1,33 +0,0 @@
1
- ========================================================================
2
- MODEL: transformer
3
- Dataset: /scratch/pranamlab/tong/data/halflife/halflife_embedding_unpooled
4
- Target column: log_label
5
- CV folds: 5
6
- Optuna objective: spearman (direction=maximize)
7
- Best trial: 45
8
- OOF metrics:
9
- {
10
- "rmse": 1.0389505624771118,
11
- "mae": 0.722099244594574,
12
- "r2": 0.30950748920440674,
13
- "spearman_rho": 0.5818272477094295
14
- }
15
- OOF score (spearman): 0.581827
16
- Best params:
17
- {
18
- "lr": 0.0003603824115240561,
19
- "weight_decay": 2.9442493502916885e-09,
20
- "dropout": 0.3851371373367485,
21
- "batch_size": 16
22
- }
23
- Final refit epochs (all data): 15
24
- Saved final model: /scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer/final_model.pt
25
- Benchmark (final model on full data):
26
- {
27
- "n_samples": 130,
28
- "wall_time_s": 1.9577592574059963,
29
- "throughput_samples_per_s": 66.40244427818372,
30
- "gpu_ms_per_sample": 0.28296443315652703,
31
- "gpu_peak_mem_MB": 77.5693359375
32
- }
33
- ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/half_life/transformer_wt_log/study_trials.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5726f6b8541c7ca85eda9f0e526db0cb10156eadb2c440fd7a66f7a7d1209175
3
- size 10154
 
 
 
 
training_classifiers/half_life/transformer_wt_raw/oof_pred_vs_true.png DELETED
Binary file (17.4 kB)
 
training_classifiers/half_life/transformer_wt_raw/oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3df70b094757f34fa380a28727877694fcb1ec367bbbef28c63b257ecec74e6
3
- size 13516
 
 
 
 
training_classifiers/half_life/transformer_wt_raw/oof_residual_hist.png DELETED
Binary file (14.6 kB)
 
training_classifiers/half_life/transformer_wt_raw/oof_residual_vs_pred.png DELETED
Binary file (18.9 kB)
 
training_classifiers/half_life/transformer_wt_raw/optimization_summary.txt DELETED
@@ -1,33 +0,0 @@
1
- ========================================================================
2
- MODEL: transformer
3
- Dataset: /scratch/pranamlab/tong/data/halflife/halflife_embedding_unpooled
4
- Target column: label
5
- CV folds: 5
6
- Optuna objective: spearman (direction=maximize)
7
- Best trial: 22
8
- OOF metrics:
9
- {
10
- "rmse": 45.00321578979492,
11
- "mae": 11.352466583251953,
12
- "r2": 0.02070075273513794,
13
- "spearman_rho": 0.3759734508605516
14
- }
15
- OOF score (spearman): 0.375973
16
- Best params:
17
- {
18
- "lr": 0.00019977882554167927,
19
- "weight_decay": 1.102955470301081e-07,
20
- "dropout": 1.2707176359392082e-05,
21
- "batch_size": 16
22
- }
23
- Final refit epochs (all data): 14
24
- Saved final model: /scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_transformer_raw/final_model.pt
25
- Benchmark (final model on full data):
26
- {
27
- "n_samples": 130,
28
- "wall_time_s": 1.6299039730802178,
29
- "throughput_samples_per_s": 79.7593000244818,
30
- "gpu_ms_per_sample": 0.23774326214423547,
31
- "gpu_peak_mem_MB": 77.5693359375
32
- }
33
- ========================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/half_life/transformer_wt_raw/study_trials.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c38352854ad4142c02a4bcb33caee9fe8fa22fca86dcb8c17c05c24f3fa5bca
3
- size 10895
 
 
 
 
training_classifiers/half_life/xgb_smiles/cv_oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2df43f71aad2cf791b49daa0b3353f524d5a3f3e132fecf1251e96242639ca5
3
- size 13675
 
 
 
 
training_classifiers/half_life/xgb_wt_log/oof_pred_vs_true.png DELETED
Binary file (16.5 kB)
 
training_classifiers/half_life/xgb_wt_log/oof_predictions.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2293b5752ef6bdc7b1ec8ae2f56e11ccbf32aee024d777c86c1e63f390fa89cf
3
- size 14805
 
 
 
 
training_classifiers/half_life/xgb_wt_log/oof_residual_hist.png DELETED
Binary file (15.1 kB)
 
training_classifiers/half_life/xgb_wt_log/oof_residual_vs_pred.png DELETED
Binary file (19.1 kB)
 
training_classifiers/half_life/xgb_wt_log/optimization_summary.txt DELETED
@@ -1,27 +0,0 @@
1
- {
2
- "model": "xgb_reg",
3
- "dataset_path": "/scratch/pranamlab/tong/data/halflife/halflife_embedding",
4
- "target_col": "log_label",
5
- "n_folds": 5,
6
- "best_trial_number": 20,
7
- "best_objective_cv_spearman": 0.5879000126060311,
8
- "oof_metrics": {
9
- "rmse": 1.0810768604278564,
10
- "mae": 0.7866008281707764,
11
- "r2": 0.2524225115776062,
12
- "spearman_rho": 0.557870619380726
13
- },
14
- "model_path": "/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb_log/best_model.json",
15
- "best_params": {
16
- "lambda": 0.0006291983667746282,
17
- "alpha": 0.0820082035401697,
18
- "gamma": 1.2243543209914751,
19
- "max_depth": 3,
20
- "min_child_weight": 1.7773959178614585,
21
- "subsample": 0.568291807635477,
22
- "colsample_bytree": 0.8597778117881122,
23
- "learning_rate": 0.0512590763008084,
24
- "num_boost_round": 1728,
25
- "early_stopping_rounds": 121
26
- }
27
- }