ynuozhang
commited on
Commit
·
df85e24
1
Parent(s):
21ea966
env
Browse files
training_classifiers/binding_affinity_iptm.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
extract_iptm_affinity_csv_all.py
|
| 4 |
-
|
| 5 |
-
Writes:
|
| 6 |
-
- out_dir/wt_iptm_affinity_all.csv
|
| 7 |
-
- out_dir/smiles_iptm_affinity_all.csv
|
| 8 |
-
|
| 9 |
-
Also prints:
|
| 10 |
-
- N
|
| 11 |
-
- Spearman rho (affinity vs iptm)
|
| 12 |
-
- Pearson r (affinity vs iptm)
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
import numpy as np
|
| 17 |
-
import pandas as pd
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def corr_stats(df: pd.DataFrame, x: str, y: str):
|
| 21 |
-
# pandas handles NaNs if we already dropped them; still be safe
|
| 22 |
-
xx = pd.to_numeric(df[x], errors="coerce")
|
| 23 |
-
yy = pd.to_numeric(df[y], errors="coerce")
|
| 24 |
-
m = xx.notna() & yy.notna()
|
| 25 |
-
xx = xx[m]
|
| 26 |
-
yy = yy[m]
|
| 27 |
-
n = int(m.sum())
|
| 28 |
-
|
| 29 |
-
# Pearson r
|
| 30 |
-
pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
|
| 31 |
-
# Spearman rho
|
| 32 |
-
spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
|
| 33 |
-
|
| 34 |
-
return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def clean_one(
|
| 38 |
-
in_csv: Path,
|
| 39 |
-
out_csv: Path,
|
| 40 |
-
iptm_col: str,
|
| 41 |
-
affinity_col: str = "affinity",
|
| 42 |
-
keep_cols=(),
|
| 43 |
-
):
|
| 44 |
-
df = pd.read_csv(in_csv)
|
| 45 |
-
|
| 46 |
-
# affinity + iptm must exist
|
| 47 |
-
need = [affinity_col, iptm_col]
|
| 48 |
-
missing = [c for c in need if c not in df.columns]
|
| 49 |
-
if missing:
|
| 50 |
-
raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
|
| 51 |
-
|
| 52 |
-
# coerce numeric
|
| 53 |
-
df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
|
| 54 |
-
df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
|
| 55 |
-
|
| 56 |
-
# drop NaNs in either
|
| 57 |
-
df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
|
| 58 |
-
|
| 59 |
-
# output cols (standardize names)
|
| 60 |
-
out = pd.DataFrame({
|
| 61 |
-
"affinity": df[affinity_col].astype(float),
|
| 62 |
-
"iptm": df[iptm_col].astype(float),
|
| 63 |
-
})
|
| 64 |
-
|
| 65 |
-
# keep split if present (handy for coloring later, but not used for corr)
|
| 66 |
-
if "split" in df.columns:
|
| 67 |
-
out.insert(0, "split", df["split"].astype(str))
|
| 68 |
-
|
| 69 |
-
# optional extras for labeling/debug
|
| 70 |
-
for c in keep_cols:
|
| 71 |
-
if c in df.columns:
|
| 72 |
-
out[c] = df[c]
|
| 73 |
-
|
| 74 |
-
out_csv.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
-
out.to_csv(out_csv, index=False)
|
| 76 |
-
|
| 77 |
-
stats = corr_stats(out, "iptm", "affinity")
|
| 78 |
-
print(f"[write] {out_csv}")
|
| 79 |
-
print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
|
| 80 |
-
|
| 81 |
-
# also save stats json next to csv
|
| 82 |
-
stats_path = out_csv.with_suffix(".stats.json")
|
| 83 |
-
with open(stats_path, "w") as f:
|
| 84 |
-
import json
|
| 85 |
-
json.dump(
|
| 86 |
-
{
|
| 87 |
-
"input_csv": str(in_csv),
|
| 88 |
-
"output_csv": str(out_csv),
|
| 89 |
-
"iptm_col": iptm_col,
|
| 90 |
-
"affinity_col": affinity_col,
|
| 91 |
-
**stats,
|
| 92 |
-
},
|
| 93 |
-
f,
|
| 94 |
-
indent=2,
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def main():
|
| 99 |
-
import argparse
|
| 100 |
-
ap = argparse.ArgumentParser()
|
| 101 |
-
ap.add_argument("--wt_meta_csv", type=str, required=True)
|
| 102 |
-
ap.add_argument("--smiles_meta_csv", type=str, required=True)
|
| 103 |
-
ap.add_argument("--out_dir", type=str, required=True)
|
| 104 |
-
|
| 105 |
-
ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
|
| 106 |
-
ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
|
| 107 |
-
ap.add_argument("--affinity_col", type=str, default="affinity")
|
| 108 |
-
args = ap.parse_args()
|
| 109 |
-
|
| 110 |
-
out_dir = Path(args.out_dir)
|
| 111 |
-
|
| 112 |
-
clean_one(
|
| 113 |
-
Path(args.wt_meta_csv),
|
| 114 |
-
out_dir / "wt_iptm_affinity_all.csv",
|
| 115 |
-
iptm_col=args.wt_iptm_col,
|
| 116 |
-
affinity_col=args.affinity_col,
|
| 117 |
-
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
clean_one(
|
| 121 |
-
Path(args.smiles_meta_csv),
|
| 122 |
-
out_dir / "smiles_iptm_affinity_all.csv",
|
| 123 |
-
iptm_col=args.smiles_iptm_col,
|
| 124 |
-
affinity_col=args.affinity_col,
|
| 125 |
-
keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
if __name__ == "__main__":
|
| 132 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_affinity_split.py
DELETED
|
@@ -1,847 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import os
|
| 3 |
-
import math
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
import sys
|
| 6 |
-
from contextlib import contextmanager
|
| 7 |
-
|
| 8 |
-
import numpy as np
|
| 9 |
-
import pandas as pd
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
# tqdm is optional; we’ll disable it by default in notebooks
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
|
| 15 |
-
sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
|
| 16 |
-
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 17 |
-
|
| 18 |
-
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
|
| 19 |
-
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
|
| 20 |
-
|
| 21 |
-
# -------------------------
|
| 22 |
-
# Config
|
| 23 |
-
# -------------------------
|
| 24 |
-
CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
|
| 25 |
-
|
| 26 |
-
OUT_ROOT = Path(
|
| 27 |
-
"/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# WT (seq) embedding model
|
| 31 |
-
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
|
| 32 |
-
WT_MAX_LEN = 1022
|
| 33 |
-
WT_BATCH = 32
|
| 34 |
-
|
| 35 |
-
# SMILES embedding model + tokenizer
|
| 36 |
-
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
|
| 37 |
-
TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
|
| 38 |
-
TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
|
| 39 |
-
SMI_MAX_LEN = 768
|
| 40 |
-
SMI_BATCH = 128
|
| 41 |
-
|
| 42 |
-
# Split config
|
| 43 |
-
TRAIN_FRAC = 0.80
|
| 44 |
-
RANDOM_SEED = 1986
|
| 45 |
-
AFFINITY_Q_BINS = 30
|
| 46 |
-
|
| 47 |
-
# Columns expected in CSV
|
| 48 |
-
COL_SEQ1 = "seq1"
|
| 49 |
-
COL_SEQ2 = "seq2"
|
| 50 |
-
COL_AFF = "affinity"
|
| 51 |
-
COL_F2S = "Fasta2SMILES"
|
| 52 |
-
COL_REACT = "REACT_SMILES"
|
| 53 |
-
COL_WT_IPTM = "wt_iptm_score"
|
| 54 |
-
COL_SMI_IPTM = "smiles_iptm_score"
|
| 55 |
-
|
| 56 |
-
# Device
|
| 57 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 58 |
-
|
| 59 |
-
# -------------------------
|
| 60 |
-
# Quiet / notebook-safe output controls
|
| 61 |
-
# -------------------------
|
| 62 |
-
QUIET = True # suppress most prints
|
| 63 |
-
USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
|
| 64 |
-
LOG_FILE = None # optionally: OUT_ROOT / "build.log"
|
| 65 |
-
|
| 66 |
-
def log(msg: str):
|
| 67 |
-
if LOG_FILE is not None:
|
| 68 |
-
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
| 69 |
-
with open(LOG_FILE, "a") as f:
|
| 70 |
-
f.write(msg.rstrip() + "\n")
|
| 71 |
-
if not QUIET:
|
| 72 |
-
print(msg)
|
| 73 |
-
|
| 74 |
-
def pbar(it, **kwargs):
|
| 75 |
-
return tqdm(it, **kwargs) if USE_TQDM else it
|
| 76 |
-
|
| 77 |
-
@contextmanager
|
| 78 |
-
def section(title: str):
|
| 79 |
-
log(f"\n=== {title} ===")
|
| 80 |
-
yield
|
| 81 |
-
log(f"=== done: {title} ===")
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
# -------------------------
|
| 85 |
-
# Helpers
|
| 86 |
-
# -------------------------
|
| 87 |
-
def has_uaa(seq: str) -> bool:
|
| 88 |
-
return "X" in str(seq).upper()
|
| 89 |
-
|
| 90 |
-
def affinity_to_class(a: float) -> str:
|
| 91 |
-
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
|
| 92 |
-
if a >= 9.0:
|
| 93 |
-
return "High"
|
| 94 |
-
elif a >= 7.0:
|
| 95 |
-
return "Moderate"
|
| 96 |
-
else:
|
| 97 |
-
return "Low"
|
| 98 |
-
|
| 99 |
-
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
|
| 100 |
-
df = df.copy()
|
| 101 |
-
|
| 102 |
-
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 103 |
-
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 104 |
-
|
| 105 |
-
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
|
| 109 |
-
strat_col = "aff_bin"
|
| 110 |
-
except Exception:
|
| 111 |
-
df["aff_bin"] = df["affinity_class"]
|
| 112 |
-
strat_col = "aff_bin"
|
| 113 |
-
|
| 114 |
-
rng = np.random.RandomState(RANDOM_SEED)
|
| 115 |
-
|
| 116 |
-
df["split"] = None
|
| 117 |
-
for _, g in df.groupby(strat_col, observed=True):
|
| 118 |
-
idx = g.index.to_numpy()
|
| 119 |
-
rng.shuffle(idx)
|
| 120 |
-
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
|
| 121 |
-
df.loc[idx[:n_train], "split"] = "train"
|
| 122 |
-
df.loc[idx[n_train:], "split"] = "val"
|
| 123 |
-
|
| 124 |
-
df["split"] = df["split"].fillna("train")
|
| 125 |
-
return df
|
| 126 |
-
|
| 127 |
-
def _summ(x):
|
| 128 |
-
x = np.asarray(x, dtype=float)
|
| 129 |
-
x = x[~np.isnan(x)]
|
| 130 |
-
if len(x) == 0:
|
| 131 |
-
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 132 |
-
return {
|
| 133 |
-
"n": int(len(x)),
|
| 134 |
-
"mean": float(np.mean(x)),
|
| 135 |
-
"std": float(np.std(x)),
|
| 136 |
-
"p50": float(np.quantile(x, 0.50)),
|
| 137 |
-
"p95": float(np.quantile(x, 0.95)),
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
def _len_stats(seqs):
|
| 141 |
-
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
|
| 142 |
-
if len(lens) == 0:
|
| 143 |
-
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
|
| 144 |
-
return {
|
| 145 |
-
"n": int(len(lens)),
|
| 146 |
-
"mean": float(lens.mean()),
|
| 147 |
-
"std": float(lens.std()),
|
| 148 |
-
"p50": float(np.quantile(lens, 0.50)),
|
| 149 |
-
"p95": float(np.quantile(lens, 0.95)),
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
def verify_split_before_embedding(
|
| 153 |
-
df2: pd.DataFrame,
|
| 154 |
-
affinity_col: str,
|
| 155 |
-
split_col: str,
|
| 156 |
-
seq_col: str,
|
| 157 |
-
iptm_col: str,
|
| 158 |
-
aff_class_col: str = "affinity_class",
|
| 159 |
-
aff_bins: int = 30,
|
| 160 |
-
save_report_prefix: str | None = None,
|
| 161 |
-
verbose: bool = False,
|
| 162 |
-
):
|
| 163 |
-
"""
|
| 164 |
-
Notebook-safe: by default prints only ONE line via `log()`.
|
| 165 |
-
Optionally writes CSV reports (stats + class proportions).
|
| 166 |
-
"""
|
| 167 |
-
df2 = df2.copy()
|
| 168 |
-
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
|
| 169 |
-
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
|
| 170 |
-
|
| 171 |
-
assert split_col in df2.columns, f"Missing split col: {split_col}"
|
| 172 |
-
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
|
| 173 |
-
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
|
| 174 |
-
|
| 175 |
-
try:
|
| 176 |
-
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
|
| 177 |
-
except Exception:
|
| 178 |
-
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
|
| 179 |
-
|
| 180 |
-
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
|
| 181 |
-
va = df2[df2[split_col] == "val"].reset_index(drop=True)
|
| 182 |
-
|
| 183 |
-
tr_aff = _summ(tr[affinity_col].to_numpy())
|
| 184 |
-
va_aff = _summ(va[affinity_col].to_numpy())
|
| 185 |
-
tr_len = _len_stats(tr[seq_col].tolist())
|
| 186 |
-
va_len = _len_stats(va[seq_col].tolist())
|
| 187 |
-
|
| 188 |
-
# bin drift
|
| 189 |
-
bin_ct = (
|
| 190 |
-
df2.groupby([split_col, "_aff_bin_dbg"])
|
| 191 |
-
.size()
|
| 192 |
-
.groupby(level=0)
|
| 193 |
-
.apply(lambda s: s / s.sum())
|
| 194 |
-
)
|
| 195 |
-
tr_bins = bin_ct.loc["train"]
|
| 196 |
-
va_bins = bin_ct.loc["val"]
|
| 197 |
-
all_bins = tr_bins.index.union(va_bins.index)
|
| 198 |
-
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
|
| 199 |
-
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
|
| 200 |
-
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
|
| 201 |
-
|
| 202 |
-
msg = (
|
| 203 |
-
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
|
| 204 |
-
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
|
| 205 |
-
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
|
| 206 |
-
f"max_bin_diff={max_bin_diff:.4f}"
|
| 207 |
-
)
|
| 208 |
-
log(msg)
|
| 209 |
-
|
| 210 |
-
if verbose and (not QUIET):
|
| 211 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 212 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
|
| 213 |
-
print("\n[verbose] affinity_class counts:\n", class_ct)
|
| 214 |
-
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
|
| 215 |
-
|
| 216 |
-
if save_report_prefix is not None:
|
| 217 |
-
out = Path(save_report_prefix)
|
| 218 |
-
out.parent.mkdir(parents=True, exist_ok=True)
|
| 219 |
-
|
| 220 |
-
stats_df = pd.DataFrame([
|
| 221 |
-
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
|
| 222 |
-
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
|
| 223 |
-
])
|
| 224 |
-
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
|
| 225 |
-
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
|
| 226 |
-
|
| 227 |
-
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
|
| 228 |
-
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# -------------------------
|
| 232 |
-
# WT pooled (ESM2)
|
| 233 |
-
# -------------------------
|
| 234 |
-
@torch.no_grad()
|
| 235 |
-
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
|
| 236 |
-
embs = []
|
| 237 |
-
for i in pbar(range(0, len(seqs), batch_size)):
|
| 238 |
-
batch = seqs[i:i + batch_size]
|
| 239 |
-
inputs = tokenizer(
|
| 240 |
-
batch,
|
| 241 |
-
padding=True,
|
| 242 |
-
truncation=True,
|
| 243 |
-
max_length=max_length,
|
| 244 |
-
return_tensors="pt",
|
| 245 |
-
)
|
| 246 |
-
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 247 |
-
out = model(**inputs)
|
| 248 |
-
h = out.last_hidden_state # (B, L, H)
|
| 249 |
-
|
| 250 |
-
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
|
| 251 |
-
summed = (h * attn).sum(dim=1) # (B, H)
|
| 252 |
-
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
|
| 253 |
-
pooled = (summed / denom).detach().cpu().numpy()
|
| 254 |
-
embs.append(pooled)
|
| 255 |
-
|
| 256 |
-
return np.vstack(embs)
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
# -------------------------
|
| 260 |
-
# WT unpooled (ESM2)
|
| 261 |
-
# -------------------------
|
| 262 |
-
@torch.no_grad()
|
| 263 |
-
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
|
| 264 |
-
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
|
| 265 |
-
tok = {k: v.to(DEVICE) for k, v in tok.items()}
|
| 266 |
-
out = model(**tok)
|
| 267 |
-
h = out.last_hidden_state[0] # (L, H)
|
| 268 |
-
attn = tok["attention_mask"][0].bool() # (L,)
|
| 269 |
-
ids = tok["input_ids"][0]
|
| 270 |
-
|
| 271 |
-
keep = attn.clone()
|
| 272 |
-
if cls_id is not None:
|
| 273 |
-
keep &= (ids != cls_id)
|
| 274 |
-
if eos_id is not None:
|
| 275 |
-
keep &= (ids != eos_id)
|
| 276 |
-
|
| 277 |
-
return h[keep].detach().cpu().to(torch.float16).numpy()
|
| 278 |
-
|
| 279 |
-
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
|
| 280 |
-
"""
|
| 281 |
-
Expects df_split to have:
|
| 282 |
-
- target_sequence (seq1)
|
| 283 |
-
- sequence (binder seq2; WT binder)
|
| 284 |
-
- label, affinity_class, COL_AFF, COL_WT_IPTM
|
| 285 |
-
Saves a dataset where each row contains BOTH:
|
| 286 |
-
- target_embedding (Lt,H), target_attention_mask, target_length
|
| 287 |
-
- binder_embedding (Lb,H), binder_attention_mask, binder_length
|
| 288 |
-
"""
|
| 289 |
-
cls_id = tokenizer.cls_token_id
|
| 290 |
-
eos_id = tokenizer.eos_token_id
|
| 291 |
-
H = model.config.hidden_size
|
| 292 |
-
|
| 293 |
-
features = Features({
|
| 294 |
-
"target_sequence": Value("string"),
|
| 295 |
-
"sequence": Value("string"),
|
| 296 |
-
"label": Value("float32"),
|
| 297 |
-
"affinity": Value("float32"),
|
| 298 |
-
"affinity_class": Value("string"),
|
| 299 |
-
|
| 300 |
-
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 301 |
-
"target_attention_mask": HFSequence(Value("int8")),
|
| 302 |
-
"target_length": Value("int64"),
|
| 303 |
-
|
| 304 |
-
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
|
| 305 |
-
"binder_attention_mask": HFSequence(Value("int8")),
|
| 306 |
-
"binder_length": Value("int64"),
|
| 307 |
-
|
| 308 |
-
COL_WT_IPTM: Value("float32"),
|
| 309 |
-
COL_AFF: Value("float32"),
|
| 310 |
-
})
|
| 311 |
-
|
| 312 |
-
def gen_rows(df: pd.DataFrame):
|
| 313 |
-
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 314 |
-
tgt = str(getattr(r, "target_sequence")).strip()
|
| 315 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 316 |
-
|
| 317 |
-
y = float(getattr(r, "label"))
|
| 318 |
-
aff = float(getattr(r, COL_AFF))
|
| 319 |
-
acls = str(getattr(r, "affinity_class"))
|
| 320 |
-
|
| 321 |
-
iptm = getattr(r, COL_WT_IPTM)
|
| 322 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 323 |
-
|
| 324 |
-
# token embeddings for target + binder (both ESM)
|
| 325 |
-
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
|
| 326 |
-
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
|
| 327 |
-
|
| 328 |
-
t_list = t_emb.tolist()
|
| 329 |
-
b_list = b_emb.tolist()
|
| 330 |
-
Lt = len(t_list)
|
| 331 |
-
Lb = len(b_list)
|
| 332 |
-
|
| 333 |
-
yield {
|
| 334 |
-
"target_sequence": tgt,
|
| 335 |
-
"sequence": bnd,
|
| 336 |
-
"label": np.float32(y),
|
| 337 |
-
"affinity": np.float32(aff),
|
| 338 |
-
"affinity_class": acls,
|
| 339 |
-
|
| 340 |
-
"target_embedding": t_list,
|
| 341 |
-
"target_attention_mask": [1] * Lt,
|
| 342 |
-
"target_length": int(Lt),
|
| 343 |
-
|
| 344 |
-
"binder_embedding": b_list,
|
| 345 |
-
"binder_attention_mask": [1] * Lb,
|
| 346 |
-
"binder_length": int(Lb),
|
| 347 |
-
|
| 348 |
-
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 349 |
-
COL_AFF: np.float32(aff),
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 353 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 354 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 355 |
-
return ds
|
| 356 |
-
|
| 357 |
-
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
|
| 358 |
-
smi_tok, smi_roformer):
|
| 359 |
-
"""
|
| 360 |
-
df_split must have:
|
| 361 |
-
- target_sequence (seq1)
|
| 362 |
-
- sequence (binder smiles string)
|
| 363 |
-
- label, affinity_class, COL_AFF, COL_SMI_IPTM
|
| 364 |
-
Saves rows with:
|
| 365 |
-
target_embedding (Lt,Ht) from ESM
|
| 366 |
-
binder_embedding (Lb,Hb) from PeptideCLM
|
| 367 |
-
"""
|
| 368 |
-
cls_id = wt_tokenizer.cls_token_id
|
| 369 |
-
eos_id = wt_tokenizer.eos_token_id
|
| 370 |
-
Ht = wt_model_unpooled.config.hidden_size
|
| 371 |
-
|
| 372 |
-
# Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
|
| 373 |
-
# Here: we’ll infer from model config if available.
|
| 374 |
-
Hb = getattr(smi_roformer.config, "hidden_size", None)
|
| 375 |
-
if Hb is None:
|
| 376 |
-
Hb = getattr(smi_roformer.config, "dim", None)
|
| 377 |
-
if Hb is None:
|
| 378 |
-
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
|
| 379 |
-
|
| 380 |
-
features = Features({
|
| 381 |
-
"target_sequence": Value("string"),
|
| 382 |
-
"sequence": Value("string"),
|
| 383 |
-
"label": Value("float32"),
|
| 384 |
-
"affinity": Value("float32"),
|
| 385 |
-
"affinity_class": Value("string"),
|
| 386 |
-
|
| 387 |
-
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
|
| 388 |
-
"target_attention_mask": HFSequence(Value("int8")),
|
| 389 |
-
"target_length": Value("int64"),
|
| 390 |
-
|
| 391 |
-
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
|
| 392 |
-
"binder_attention_mask": HFSequence(Value("int8")),
|
| 393 |
-
"binder_length": Value("int64"),
|
| 394 |
-
|
| 395 |
-
COL_SMI_IPTM: Value("float32"),
|
| 396 |
-
COL_AFF: Value("float32"),
|
| 397 |
-
})
|
| 398 |
-
|
| 399 |
-
def gen_rows(df: pd.DataFrame):
|
| 400 |
-
for r in pbar(df.itertuples(index=False), total=len(df)):
|
| 401 |
-
tgt = str(getattr(r, "target_sequence")).strip()
|
| 402 |
-
bnd = str(getattr(r, "sequence")).strip()
|
| 403 |
-
|
| 404 |
-
y = float(getattr(r, "label"))
|
| 405 |
-
aff = float(getattr(r, COL_AFF))
|
| 406 |
-
acls = str(getattr(r, "affinity_class"))
|
| 407 |
-
|
| 408 |
-
iptm = getattr(r, COL_SMI_IPTM)
|
| 409 |
-
iptm = float(iptm) if pd.notna(iptm) else np.nan
|
| 410 |
-
|
| 411 |
-
# target token embeddings (ESM)
|
| 412 |
-
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
|
| 413 |
-
t_list = t_emb.tolist()
|
| 414 |
-
Lt = len(t_list)
|
| 415 |
-
|
| 416 |
-
# binder token embeddings (PeptideCLM) — single-item batch
|
| 417 |
-
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
|
| 418 |
-
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
|
| 419 |
-
)
|
| 420 |
-
b_emb = tok_list[0] # np.float16 (Lb, Hb)
|
| 421 |
-
b_list = b_emb.tolist()
|
| 422 |
-
Lb = int(lengths[0])
|
| 423 |
-
b_mask = mask_list[0].astype(np.int8).tolist()
|
| 424 |
-
|
| 425 |
-
yield {
|
| 426 |
-
"target_sequence": tgt,
|
| 427 |
-
"sequence": bnd,
|
| 428 |
-
"label": np.float32(y),
|
| 429 |
-
"affinity": np.float32(aff),
|
| 430 |
-
"affinity_class": acls,
|
| 431 |
-
|
| 432 |
-
"target_embedding": t_list,
|
| 433 |
-
"target_attention_mask": [1] * Lt,
|
| 434 |
-
"target_length": int(Lt),
|
| 435 |
-
|
| 436 |
-
"binder_embedding": b_list,
|
| 437 |
-
"binder_attention_mask": [int(x) for x in b_mask],
|
| 438 |
-
"binder_length": int(Lb),
|
| 439 |
-
|
| 440 |
-
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
|
| 441 |
-
COL_AFF: np.float32(aff),
|
| 442 |
-
}
|
| 443 |
-
|
| 444 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 445 |
-
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
|
| 446 |
-
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
|
| 447 |
-
return ds
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
# -------------------------
|
| 451 |
-
# SMILES pooled + unpooled (PeptideCLM)
|
| 452 |
-
# -------------------------
|
| 453 |
-
def get_special_ids(tokenizer_obj):
|
| 454 |
-
cand = [
|
| 455 |
-
getattr(tokenizer_obj, "pad_token_id", None),
|
| 456 |
-
getattr(tokenizer_obj, "cls_token_id", None),
|
| 457 |
-
getattr(tokenizer_obj, "sep_token_id", None),
|
| 458 |
-
getattr(tokenizer_obj, "bos_token_id", None),
|
| 459 |
-
getattr(tokenizer_obj, "eos_token_id", None),
|
| 460 |
-
getattr(tokenizer_obj, "mask_token_id", None),
|
| 461 |
-
]
|
| 462 |
-
return sorted({x for x in cand if x is not None})
|
| 463 |
-
|
| 464 |
-
@torch.no_grad()
|
| 465 |
-
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
|
| 466 |
-
tok = tokenizer_obj(
|
| 467 |
-
batch_sequences,
|
| 468 |
-
return_tensors="pt",
|
| 469 |
-
padding=True,
|
| 470 |
-
truncation=True,
|
| 471 |
-
max_length=max_length,
|
| 472 |
-
)
|
| 473 |
-
input_ids = tok["input_ids"].to(DEVICE)
|
| 474 |
-
attention_mask = tok["attention_mask"].to(DEVICE)
|
| 475 |
-
|
| 476 |
-
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
|
| 477 |
-
last_hidden = outputs.last_hidden_state # (B, L, H)
|
| 478 |
-
|
| 479 |
-
special_ids = get_special_ids(tokenizer_obj)
|
| 480 |
-
valid = attention_mask.bool()
|
| 481 |
-
if len(special_ids) > 0:
|
| 482 |
-
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
|
| 483 |
-
if hasattr(torch, "isin"):
|
| 484 |
-
valid = valid & (~torch.isin(input_ids, sid))
|
| 485 |
-
else:
|
| 486 |
-
m = torch.zeros_like(valid)
|
| 487 |
-
for s in special_ids:
|
| 488 |
-
m |= (input_ids == s)
|
| 489 |
-
valid = valid & (~m)
|
| 490 |
-
|
| 491 |
-
valid_f = valid.unsqueeze(-1).float()
|
| 492 |
-
summed = torch.sum(last_hidden * valid_f, dim=1)
|
| 493 |
-
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
|
| 494 |
-
pooled = (summed / denom).detach().cpu().numpy()
|
| 495 |
-
|
| 496 |
-
token_emb_list, mask_list, lengths = [], [], []
|
| 497 |
-
for b in range(last_hidden.shape[0]):
|
| 498 |
-
emb = last_hidden[b, valid[b]] # (Li, H)
|
| 499 |
-
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
|
| 500 |
-
li = emb.shape[0]
|
| 501 |
-
lengths.append(int(li))
|
| 502 |
-
mask_list.append(np.ones((li,), dtype=np.int8))
|
| 503 |
-
|
| 504 |
-
return pooled, token_emb_list, mask_list, lengths
|
| 505 |
-
|
| 506 |
-
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
|
| 507 |
-
pooled_all = []
|
| 508 |
-
token_emb_all = []
|
| 509 |
-
mask_all = []
|
| 510 |
-
lengths_all = []
|
| 511 |
-
|
| 512 |
-
for i in pbar(range(0, len(seqs), batch_size)):
|
| 513 |
-
batch = seqs[i:i + batch_size]
|
| 514 |
-
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
|
| 515 |
-
batch, tokenizer_obj, model_roformer, max_length
|
| 516 |
-
)
|
| 517 |
-
pooled_all.append(pooled)
|
| 518 |
-
token_emb_all.extend(tok_list)
|
| 519 |
-
mask_all.extend(m_list)
|
| 520 |
-
lengths_all.extend(lens)
|
| 521 |
-
|
| 522 |
-
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
|
| 523 |
-
|
| 524 |
-
# -------------------------
|
| 525 |
-
# Target embedding cache (NO extra ESM runs)
|
| 526 |
-
# We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
|
| 527 |
-
# -------------------------
|
| 528 |
-
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
|
| 529 |
-
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 530 |
-
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 531 |
-
|
| 532 |
-
# compute target pooled embeddings once
|
| 533 |
-
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
|
| 534 |
-
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
|
| 535 |
-
|
| 536 |
-
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 537 |
-
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 538 |
-
)
|
| 539 |
-
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 540 |
-
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
|
| 541 |
-
)
|
| 542 |
-
|
| 543 |
-
# build dict: target_sequence -> embedding (float32 array)
|
| 544 |
-
# if duplicates exist, last wins; you can add checks if needed
|
| 545 |
-
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
|
| 546 |
-
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
|
| 547 |
-
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
|
| 548 |
-
# -------------------------
|
| 549 |
-
# Main
|
| 550 |
-
# -------------------------
|
| 551 |
-
def main():
|
| 552 |
-
log(f"[INFO] DEVICE: {DEVICE}")
|
| 553 |
-
OUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 554 |
-
|
| 555 |
-
# 1) Load
|
| 556 |
-
with section("load csv + dedup"):
|
| 557 |
-
df = pd.read_csv(CSV_PATH)
|
| 558 |
-
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
|
| 559 |
-
if c in df.columns:
|
| 560 |
-
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
|
| 561 |
-
|
| 562 |
-
# Dedup on the full identity tuple you want
|
| 563 |
-
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
|
| 564 |
-
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
|
| 565 |
-
|
| 566 |
-
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
|
| 567 |
-
|
| 568 |
-
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
|
| 569 |
-
missing = [c for c in need if c not in df.columns]
|
| 570 |
-
if missing:
|
| 571 |
-
raise ValueError(f"Missing required columns: {missing}")
|
| 572 |
-
|
| 573 |
-
# numeric affinity for both branches
|
| 574 |
-
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
|
| 575 |
-
|
| 576 |
-
# 2) Build WT subset + SMILES subset separately (NO global dropping)
|
| 577 |
-
with section("prepare wt/smiles subsets"):
|
| 578 |
-
# WT: requires a canonical peptide sequence (no X) + affinity
|
| 579 |
-
df_wt = df.copy()
|
| 580 |
-
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
|
| 581 |
-
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 582 |
-
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
|
| 583 |
-
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
|
| 584 |
-
|
| 585 |
-
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
|
| 586 |
-
df_smi = df.copy()
|
| 587 |
-
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
|
| 588 |
-
df_smi = df_smi[
|
| 589 |
-
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
|
| 590 |
-
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
|
| 591 |
-
|
| 592 |
-
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
|
| 593 |
-
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
|
| 594 |
-
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
|
| 595 |
-
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
|
| 596 |
-
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
|
| 597 |
-
|
| 598 |
-
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
|
| 599 |
-
|
| 600 |
-
# 3) Split separately (different sizes and memberships are expected)
|
| 601 |
-
with section("split wt and smiles separately"):
|
| 602 |
-
df_wt2 = make_distribution_matched_split(df_wt)
|
| 603 |
-
df_smi2 = make_distribution_matched_split(df_smi)
|
| 604 |
-
|
| 605 |
-
# save split tables
|
| 606 |
-
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
|
| 607 |
-
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
|
| 608 |
-
df_wt2.to_csv(wt_split_csv, index=False)
|
| 609 |
-
df_smi2.to_csv(smi_split_csv, index=False)
|
| 610 |
-
log(f"Saved WT split meta: {wt_split_csv}")
|
| 611 |
-
log(f"Saved SMILES split meta: {smi_split_csv}")
|
| 612 |
-
|
| 613 |
-
# lightweight double-check (one-line)
|
| 614 |
-
verify_split_before_embedding(
|
| 615 |
-
df2=df_wt2,
|
| 616 |
-
affinity_col=COL_AFF,
|
| 617 |
-
split_col="split",
|
| 618 |
-
seq_col="wt_sequence",
|
| 619 |
-
iptm_col=COL_WT_IPTM,
|
| 620 |
-
aff_class_col="affinity_class",
|
| 621 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 622 |
-
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
|
| 623 |
-
verbose=False,
|
| 624 |
-
)
|
| 625 |
-
verify_split_before_embedding(
|
| 626 |
-
df2=df_smi2,
|
| 627 |
-
affinity_col=COL_AFF,
|
| 628 |
-
split_col="split",
|
| 629 |
-
seq_col="smiles_sequence",
|
| 630 |
-
iptm_col=COL_SMI_IPTM,
|
| 631 |
-
aff_class_col="affinity_class",
|
| 632 |
-
aff_bins=AFFINITY_Q_BINS,
|
| 633 |
-
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
|
| 634 |
-
verbose=False,
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
# Prepare split views
|
| 638 |
-
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
|
| 639 |
-
out = df_in.copy()
|
| 640 |
-
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
|
| 641 |
-
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
|
| 642 |
-
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 643 |
-
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
|
| 644 |
-
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
|
| 645 |
-
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
|
| 646 |
-
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
|
| 647 |
-
|
| 648 |
-
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
|
| 649 |
-
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
|
| 650 |
-
|
| 651 |
-
# -------------------------
|
| 652 |
-
# Split views
|
| 653 |
-
# -------------------------
|
| 654 |
-
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
|
| 655 |
-
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
|
| 656 |
-
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
|
| 657 |
-
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
# =========================
|
| 661 |
-
# TARGET pooled embeddings (ESM) — SEPARATE per branch
|
| 662 |
-
# =========================
|
| 663 |
-
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
|
| 664 |
-
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
|
| 665 |
-
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
|
| 666 |
-
|
| 667 |
-
# ---- WT targets ----
|
| 668 |
-
wt_train_tgt_emb = wt_pooled_embeddings(
|
| 669 |
-
wt_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 670 |
-
wt_tok, wt_esm,
|
| 671 |
-
batch_size=WT_BATCH,
|
| 672 |
-
max_length=WT_MAX_LEN,
|
| 673 |
-
).astype(np.float32)
|
| 674 |
-
|
| 675 |
-
wt_val_tgt_emb = wt_pooled_embeddings(
|
| 676 |
-
wt_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 677 |
-
wt_tok, wt_esm,
|
| 678 |
-
batch_size=WT_BATCH,
|
| 679 |
-
max_length=WT_MAX_LEN,
|
| 680 |
-
).astype(np.float32)
|
| 681 |
-
|
| 682 |
-
# ---- SMILES targets (independent; may include UAA-only targets) ----
|
| 683 |
-
smi_train_tgt_emb = wt_pooled_embeddings(
|
| 684 |
-
smi_train["target_sequence"].astype(str).str.strip().tolist(),
|
| 685 |
-
wt_tok, wt_esm,
|
| 686 |
-
batch_size=WT_BATCH,
|
| 687 |
-
max_length=WT_MAX_LEN,
|
| 688 |
-
).astype(np.float32)
|
| 689 |
-
|
| 690 |
-
smi_val_tgt_emb = wt_pooled_embeddings(
|
| 691 |
-
smi_val["target_sequence"].astype(str).str.strip().tolist(),
|
| 692 |
-
wt_tok, wt_esm,
|
| 693 |
-
batch_size=WT_BATCH,
|
| 694 |
-
max_length=WT_MAX_LEN,
|
| 695 |
-
).astype(np.float32)
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
# =========================
|
| 699 |
-
# WT pooled binder embeddings (binder = WT peptide)
|
| 700 |
-
# =========================
|
| 701 |
-
with section("WT pooled binder embeddings + save"):
|
| 702 |
-
wt_train_emb = wt_pooled_embeddings(
|
| 703 |
-
wt_train["sequence"].astype(str).str.strip().tolist(),
|
| 704 |
-
wt_tok, wt_esm,
|
| 705 |
-
batch_size=WT_BATCH,
|
| 706 |
-
max_length=WT_MAX_LEN,
|
| 707 |
-
).astype(np.float32)
|
| 708 |
-
|
| 709 |
-
wt_val_emb = wt_pooled_embeddings(
|
| 710 |
-
wt_val["sequence"].astype(str).str.strip().tolist(),
|
| 711 |
-
wt_tok, wt_esm,
|
| 712 |
-
batch_size=WT_BATCH,
|
| 713 |
-
max_length=WT_MAX_LEN,
|
| 714 |
-
).astype(np.float32)
|
| 715 |
-
|
| 716 |
-
wt_train_ds = Dataset.from_dict({
|
| 717 |
-
"target_sequence": wt_train["target_sequence"].tolist(),
|
| 718 |
-
"sequence": wt_train["sequence"].tolist(),
|
| 719 |
-
"label": wt_train["label"].astype(float).tolist(),
|
| 720 |
-
"target_embedding": wt_train_tgt_emb,
|
| 721 |
-
"embedding": wt_train_emb,
|
| 722 |
-
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
|
| 723 |
-
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
|
| 724 |
-
"affinity_class": wt_train["affinity_class"].tolist(),
|
| 725 |
-
})
|
| 726 |
-
|
| 727 |
-
wt_val_ds = Dataset.from_dict({
|
| 728 |
-
"target_sequence": wt_val["target_sequence"].tolist(),
|
| 729 |
-
"sequence": wt_val["sequence"].tolist(),
|
| 730 |
-
"label": wt_val["label"].astype(float).tolist(),
|
| 731 |
-
"target_embedding": wt_val_tgt_emb,
|
| 732 |
-
"embedding": wt_val_emb,
|
| 733 |
-
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
|
| 734 |
-
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
|
| 735 |
-
"affinity_class": wt_val["affinity_class"].tolist(),
|
| 736 |
-
})
|
| 737 |
-
|
| 738 |
-
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
|
| 739 |
-
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
|
| 740 |
-
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
|
| 741 |
-
log(f"Saved WT pooled -> {wt_pooled_out}")
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
# =========================
|
| 745 |
-
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
|
| 746 |
-
# =========================
|
| 747 |
-
with section("SMILES pooled binder embeddings + save"):
|
| 748 |
-
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 749 |
-
smi_roformer = (
|
| 750 |
-
AutoModelForMaskedLM
|
| 751 |
-
.from_pretrained(SMI_MODEL_NAME)
|
| 752 |
-
.roformer
|
| 753 |
-
.to(DEVICE)
|
| 754 |
-
.eval()
|
| 755 |
-
)
|
| 756 |
-
|
| 757 |
-
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 758 |
-
smi_train["sequence"].astype(str).str.strip().tolist(),
|
| 759 |
-
smi_tok, smi_roformer,
|
| 760 |
-
batch_size=SMI_BATCH,
|
| 761 |
-
max_length=SMI_MAX_LEN,
|
| 762 |
-
)
|
| 763 |
-
|
| 764 |
-
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
|
| 765 |
-
smi_val["sequence"].astype(str).str.strip().tolist(),
|
| 766 |
-
smi_tok, smi_roformer,
|
| 767 |
-
batch_size=SMI_BATCH,
|
| 768 |
-
max_length=SMI_MAX_LEN,
|
| 769 |
-
)
|
| 770 |
-
|
| 771 |
-
smi_train_ds = Dataset.from_dict({
|
| 772 |
-
"target_sequence": smi_train["target_sequence"].tolist(),
|
| 773 |
-
"sequence": smi_train["sequence"].tolist(),
|
| 774 |
-
"label": smi_train["label"].astype(float).tolist(),
|
| 775 |
-
"target_embedding": smi_train_tgt_emb,
|
| 776 |
-
"embedding": smi_train_pooled.astype(np.float32),
|
| 777 |
-
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
|
| 778 |
-
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
|
| 779 |
-
"affinity_class": smi_train["affinity_class"].tolist(),
|
| 780 |
-
})
|
| 781 |
-
|
| 782 |
-
smi_val_ds = Dataset.from_dict({
|
| 783 |
-
"target_sequence": smi_val["target_sequence"].tolist(),
|
| 784 |
-
"sequence": smi_val["sequence"].tolist(),
|
| 785 |
-
"label": smi_val["label"].astype(float).tolist(),
|
| 786 |
-
"target_embedding": smi_val_tgt_emb,
|
| 787 |
-
"embedding": smi_val_pooled.astype(np.float32),
|
| 788 |
-
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
|
| 789 |
-
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
|
| 790 |
-
"affinity_class": smi_val["affinity_class"].tolist(),
|
| 791 |
-
})
|
| 792 |
-
|
| 793 |
-
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
|
| 794 |
-
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
|
| 795 |
-
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
|
| 796 |
-
log(f"Saved SMILES pooled -> {smi_pooled_out}")
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
# =========================
|
| 800 |
-
# WT unpooled paired (ESM target + ESM binder) + save
|
| 801 |
-
# =========================
|
| 802 |
-
with section("WT unpooled paired embeddings + save"):
|
| 803 |
-
wt_tok_unpooled = wt_tok # reuse tokenizer
|
| 804 |
-
wt_esm_unpooled = wt_esm # reuse model
|
| 805 |
-
|
| 806 |
-
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
|
| 807 |
-
wt_unpooled_dd = DatasetDict({
|
| 808 |
-
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
|
| 809 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 810 |
-
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
|
| 811 |
-
wt_tok_unpooled, wt_esm_unpooled),
|
| 812 |
-
})
|
| 813 |
-
# (Optional) also save as DatasetDict root if you want a single load_from_disk path:
|
| 814 |
-
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
|
| 815 |
-
log(f"Saved WT unpooled -> {wt_unpooled_out}")
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
# =========================
|
| 819 |
-
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
|
| 820 |
-
# =========================
|
| 821 |
-
with section("SMILES unpooled paired embeddings + save"):
|
| 822 |
-
# reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
|
| 823 |
-
# otherwise re-init here:
|
| 824 |
-
# smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
|
| 825 |
-
# smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
|
| 826 |
-
|
| 827 |
-
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
|
| 828 |
-
smi_unpooled_dd = DatasetDict({
|
| 829 |
-
"train": build_smiles_unpooled_paired_dataset(
|
| 830 |
-
smi_train, smi_unpooled_out / "train",
|
| 831 |
-
wt_tok, wt_esm,
|
| 832 |
-
smi_tok, smi_roformer
|
| 833 |
-
),
|
| 834 |
-
"val": build_smiles_unpooled_paired_dataset(
|
| 835 |
-
smi_val, smi_unpooled_out / "val",
|
| 836 |
-
wt_tok, wt_esm,
|
| 837 |
-
smi_tok, smi_roformer
|
| 838 |
-
),
|
| 839 |
-
})
|
| 840 |
-
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
|
| 841 |
-
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
|
| 842 |
-
|
| 843 |
-
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
if __name__ == "__main__":
|
| 847 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_classifiers/binding_wt.bash
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
#SBATCH --job-name=b-data
|
| 3 |
-
#SBATCH --partition=dgx-b200
|
| 4 |
-
#SBATCH --gpus=1
|
| 5 |
-
#SBATCH --cpus-per-task=10
|
| 6 |
-
#SBATCH --mem=100G
|
| 7 |
-
#SBATCH --time=48:00:00
|
| 8 |
-
#SBATCH --output=%x_%j.out
|
| 9 |
-
|
| 10 |
-
HOME_LOC=/vast/projects/pranam/lab/yz927
|
| 11 |
-
SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
|
| 12 |
-
DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
|
| 13 |
-
OBJECTIVE='binding_affinity'
|
| 14 |
-
WT='smiles' #wt/smiles
|
| 15 |
-
STATUS='pooled' #pooled/unpooled
|
| 16 |
-
DATA_FILE="pair_wt_${WT}_${STATUS}"
|
| 17 |
-
LOG_LOC=$SCRIPT_LOC
|
| 18 |
-
DATE=$(date +%m_%d)
|
| 19 |
-
SPECIAL_PREFIX="binding_affinity_data_generation"
|
| 20 |
-
|
| 21 |
-
# Create log directory if it doesn't exist
|
| 22 |
-
mkdir -p $LOG_LOC
|
| 23 |
-
|
| 24 |
-
cd $SCRIPT_LOC
|
| 25 |
-
source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
|
| 26 |
-
conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
|
| 27 |
-
|
| 28 |
-
python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 29 |
-
|
| 30 |
-
echo "Script completed at $(date)"
|
| 31 |
-
conda deactivate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|