Joblib
ynuozhang commited on
Commit
df85e24
·
1 Parent(s): 21ea966
training_classifiers/binding_affinity_iptm.py DELETED
@@ -1,132 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- extract_iptm_affinity_csv_all.py
4
-
5
- Writes:
6
- - out_dir/wt_iptm_affinity_all.csv
7
- - out_dir/smiles_iptm_affinity_all.csv
8
-
9
- Also prints:
10
- - N
11
- - Spearman rho (affinity vs iptm)
12
- - Pearson r (affinity vs iptm)
13
- """
14
-
15
- from pathlib import Path
16
- import numpy as np
17
- import pandas as pd
18
-
19
-
20
- def corr_stats(df: pd.DataFrame, x: str, y: str):
21
- # pandas handles NaNs if we already dropped them; still be safe
22
- xx = pd.to_numeric(df[x], errors="coerce")
23
- yy = pd.to_numeric(df[y], errors="coerce")
24
- m = xx.notna() & yy.notna()
25
- xx = xx[m]
26
- yy = yy[m]
27
- n = int(m.sum())
28
-
29
- # Pearson r
30
- pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
31
- # Spearman rho
32
- spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
33
-
34
- return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
35
-
36
-
37
- def clean_one(
38
- in_csv: Path,
39
- out_csv: Path,
40
- iptm_col: str,
41
- affinity_col: str = "affinity",
42
- keep_cols=(),
43
- ):
44
- df = pd.read_csv(in_csv)
45
-
46
- # affinity + iptm must exist
47
- need = [affinity_col, iptm_col]
48
- missing = [c for c in need if c not in df.columns]
49
- if missing:
50
- raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
51
-
52
- # coerce numeric
53
- df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
54
- df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
55
-
56
- # drop NaNs in either
57
- df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
58
-
59
- # output cols (standardize names)
60
- out = pd.DataFrame({
61
- "affinity": df[affinity_col].astype(float),
62
- "iptm": df[iptm_col].astype(float),
63
- })
64
-
65
- # keep split if present (handy for coloring later, but not used for corr)
66
- if "split" in df.columns:
67
- out.insert(0, "split", df["split"].astype(str))
68
-
69
- # optional extras for labeling/debug
70
- for c in keep_cols:
71
- if c in df.columns:
72
- out[c] = df[c]
73
-
74
- out_csv.parent.mkdir(parents=True, exist_ok=True)
75
- out.to_csv(out_csv, index=False)
76
-
77
- stats = corr_stats(out, "iptm", "affinity")
78
- print(f"[write] {out_csv}")
79
- print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
80
-
81
- # also save stats json next to csv
82
- stats_path = out_csv.with_suffix(".stats.json")
83
- with open(stats_path, "w") as f:
84
- import json
85
- json.dump(
86
- {
87
- "input_csv": str(in_csv),
88
- "output_csv": str(out_csv),
89
- "iptm_col": iptm_col,
90
- "affinity_col": affinity_col,
91
- **stats,
92
- },
93
- f,
94
- indent=2,
95
- )
96
-
97
-
98
- def main():
99
- import argparse
100
- ap = argparse.ArgumentParser()
101
- ap.add_argument("--wt_meta_csv", type=str, required=True)
102
- ap.add_argument("--smiles_meta_csv", type=str, required=True)
103
- ap.add_argument("--out_dir", type=str, required=True)
104
-
105
- ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
106
- ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
107
- ap.add_argument("--affinity_col", type=str, default="affinity")
108
- args = ap.parse_args()
109
-
110
- out_dir = Path(args.out_dir)
111
-
112
- clean_one(
113
- Path(args.wt_meta_csv),
114
- out_dir / "wt_iptm_affinity_all.csv",
115
- iptm_col=args.wt_iptm_col,
116
- affinity_col=args.affinity_col,
117
- keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
118
- )
119
-
120
- clean_one(
121
- Path(args.smiles_meta_csv),
122
- out_dir / "smiles_iptm_affinity_all.csv",
123
- iptm_col=args.smiles_iptm_col,
124
- affinity_col=args.affinity_col,
125
- keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
126
- )
127
-
128
- print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
129
-
130
-
131
- if __name__ == "__main__":
132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/binding_affinity_split.py DELETED
@@ -1,847 +0,0 @@
1
- #!/usr/bin/env python3
2
- import os
3
- import math
4
- from pathlib import Path
5
- import sys
6
- from contextlib import contextmanager
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import torch
11
-
12
- # tqdm is optional; we’ll disable it by default in notebooks
13
- from tqdm import tqdm
14
-
15
- sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
16
- from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
-
18
- from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
19
- from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
20
-
21
- # -------------------------
22
- # Config
23
- # -------------------------
24
- CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
25
-
26
- OUT_ROOT = Path(
27
- "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
28
- )
29
-
30
- # WT (seq) embedding model
31
- WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
32
- WT_MAX_LEN = 1022
33
- WT_BATCH = 32
34
-
35
- # SMILES embedding model + tokenizer
36
- SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
37
- TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
38
- TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
39
- SMI_MAX_LEN = 768
40
- SMI_BATCH = 128
41
-
42
- # Split config
43
- TRAIN_FRAC = 0.80
44
- RANDOM_SEED = 1986
45
- AFFINITY_Q_BINS = 30
46
-
47
- # Columns expected in CSV
48
- COL_SEQ1 = "seq1"
49
- COL_SEQ2 = "seq2"
50
- COL_AFF = "affinity"
51
- COL_F2S = "Fasta2SMILES"
52
- COL_REACT = "REACT_SMILES"
53
- COL_WT_IPTM = "wt_iptm_score"
54
- COL_SMI_IPTM = "smiles_iptm_score"
55
-
56
- # Device
57
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
-
59
- # -------------------------
60
- # Quiet / notebook-safe output controls
61
- # -------------------------
62
- QUIET = True # suppress most prints
63
- USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
64
- LOG_FILE = None # optionally: OUT_ROOT / "build.log"
65
-
66
- def log(msg: str):
67
- if LOG_FILE is not None:
68
- Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
69
- with open(LOG_FILE, "a") as f:
70
- f.write(msg.rstrip() + "\n")
71
- if not QUIET:
72
- print(msg)
73
-
74
- def pbar(it, **kwargs):
75
- return tqdm(it, **kwargs) if USE_TQDM else it
76
-
77
- @contextmanager
78
- def section(title: str):
79
- log(f"\n=== {title} ===")
80
- yield
81
- log(f"=== done: {title} ===")
82
-
83
-
84
- # -------------------------
85
- # Helpers
86
- # -------------------------
87
- def has_uaa(seq: str) -> bool:
88
- return "X" in str(seq).upper()
89
-
90
- def affinity_to_class(a: float) -> str:
91
- # High: >= 9 ; Moderate: [7, 9) ; Low: < 7
92
- if a >= 9.0:
93
- return "High"
94
- elif a >= 7.0:
95
- return "Moderate"
96
- else:
97
- return "Low"
98
-
99
- def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
100
- df = df.copy()
101
-
102
- df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
103
- df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
104
-
105
- df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
106
-
107
- try:
108
- df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
109
- strat_col = "aff_bin"
110
- except Exception:
111
- df["aff_bin"] = df["affinity_class"]
112
- strat_col = "aff_bin"
113
-
114
- rng = np.random.RandomState(RANDOM_SEED)
115
-
116
- df["split"] = None
117
- for _, g in df.groupby(strat_col, observed=True):
118
- idx = g.index.to_numpy()
119
- rng.shuffle(idx)
120
- n_train = int(math.floor(len(idx) * TRAIN_FRAC))
121
- df.loc[idx[:n_train], "split"] = "train"
122
- df.loc[idx[n_train:], "split"] = "val"
123
-
124
- df["split"] = df["split"].fillna("train")
125
- return df
126
-
127
- def _summ(x):
128
- x = np.asarray(x, dtype=float)
129
- x = x[~np.isnan(x)]
130
- if len(x) == 0:
131
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
132
- return {
133
- "n": int(len(x)),
134
- "mean": float(np.mean(x)),
135
- "std": float(np.std(x)),
136
- "p50": float(np.quantile(x, 0.50)),
137
- "p95": float(np.quantile(x, 0.95)),
138
- }
139
-
140
- def _len_stats(seqs):
141
- lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
142
- if len(lens) == 0:
143
- return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
144
- return {
145
- "n": int(len(lens)),
146
- "mean": float(lens.mean()),
147
- "std": float(lens.std()),
148
- "p50": float(np.quantile(lens, 0.50)),
149
- "p95": float(np.quantile(lens, 0.95)),
150
- }
151
-
152
- def verify_split_before_embedding(
153
- df2: pd.DataFrame,
154
- affinity_col: str,
155
- split_col: str,
156
- seq_col: str,
157
- iptm_col: str,
158
- aff_class_col: str = "affinity_class",
159
- aff_bins: int = 30,
160
- save_report_prefix: str | None = None,
161
- verbose: bool = False,
162
- ):
163
- """
164
- Notebook-safe: by default prints only ONE line via `log()`.
165
- Optionally writes CSV reports (stats + class proportions).
166
- """
167
- df2 = df2.copy()
168
- df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
169
- df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
170
-
171
- assert split_col in df2.columns, f"Missing split col: {split_col}"
172
- assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
173
- assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
174
-
175
- try:
176
- df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
177
- except Exception:
178
- df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
179
-
180
- tr = df2[df2[split_col] == "train"].reset_index(drop=True)
181
- va = df2[df2[split_col] == "val"].reset_index(drop=True)
182
-
183
- tr_aff = _summ(tr[affinity_col].to_numpy())
184
- va_aff = _summ(va[affinity_col].to_numpy())
185
- tr_len = _len_stats(tr[seq_col].tolist())
186
- va_len = _len_stats(va[seq_col].tolist())
187
-
188
- # bin drift
189
- bin_ct = (
190
- df2.groupby([split_col, "_aff_bin_dbg"])
191
- .size()
192
- .groupby(level=0)
193
- .apply(lambda s: s / s.sum())
194
- )
195
- tr_bins = bin_ct.loc["train"]
196
- va_bins = bin_ct.loc["val"]
197
- all_bins = tr_bins.index.union(va_bins.index)
198
- tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
199
- va_bins = va_bins.reindex(all_bins, fill_value=0.0)
200
- max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
201
-
202
- msg = (
203
- f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
204
- f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
205
- f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
206
- f"max_bin_diff={max_bin_diff:.4f}"
207
- )
208
- log(msg)
209
-
210
- if verbose and (not QUIET):
211
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
212
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
213
- print("\n[verbose] affinity_class counts:\n", class_ct)
214
- print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
215
-
216
- if save_report_prefix is not None:
217
- out = Path(save_report_prefix)
218
- out.parent.mkdir(parents=True, exist_ok=True)
219
-
220
- stats_df = pd.DataFrame([
221
- {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
222
- {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
223
- ])
224
- class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
225
- class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
226
-
227
- stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
228
- class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
229
-
230
-
231
- # -------------------------
232
- # WT pooled (ESM2)
233
- # -------------------------
234
- @torch.no_grad()
235
- def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
236
- embs = []
237
- for i in pbar(range(0, len(seqs), batch_size)):
238
- batch = seqs[i:i + batch_size]
239
- inputs = tokenizer(
240
- batch,
241
- padding=True,
242
- truncation=True,
243
- max_length=max_length,
244
- return_tensors="pt",
245
- )
246
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
247
- out = model(**inputs)
248
- h = out.last_hidden_state # (B, L, H)
249
-
250
- attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
251
- summed = (h * attn).sum(dim=1) # (B, H)
252
- denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
253
- pooled = (summed / denom).detach().cpu().numpy()
254
- embs.append(pooled)
255
-
256
- return np.vstack(embs)
257
-
258
-
259
- # -------------------------
260
- # WT unpooled (ESM2)
261
- # -------------------------
262
- @torch.no_grad()
263
- def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
264
- tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
265
- tok = {k: v.to(DEVICE) for k, v in tok.items()}
266
- out = model(**tok)
267
- h = out.last_hidden_state[0] # (L, H)
268
- attn = tok["attention_mask"][0].bool() # (L,)
269
- ids = tok["input_ids"][0]
270
-
271
- keep = attn.clone()
272
- if cls_id is not None:
273
- keep &= (ids != cls_id)
274
- if eos_id is not None:
275
- keep &= (ids != eos_id)
276
-
277
- return h[keep].detach().cpu().to(torch.float16).numpy()
278
-
279
- def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
280
- """
281
- Expects df_split to have:
282
- - target_sequence (seq1)
283
- - sequence (binder seq2; WT binder)
284
- - label, affinity_class, COL_AFF, COL_WT_IPTM
285
- Saves a dataset where each row contains BOTH:
286
- - target_embedding (Lt,H), target_attention_mask, target_length
287
- - binder_embedding (Lb,H), binder_attention_mask, binder_length
288
- """
289
- cls_id = tokenizer.cls_token_id
290
- eos_id = tokenizer.eos_token_id
291
- H = model.config.hidden_size
292
-
293
- features = Features({
294
- "target_sequence": Value("string"),
295
- "sequence": Value("string"),
296
- "label": Value("float32"),
297
- "affinity": Value("float32"),
298
- "affinity_class": Value("string"),
299
-
300
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
301
- "target_attention_mask": HFSequence(Value("int8")),
302
- "target_length": Value("int64"),
303
-
304
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
305
- "binder_attention_mask": HFSequence(Value("int8")),
306
- "binder_length": Value("int64"),
307
-
308
- COL_WT_IPTM: Value("float32"),
309
- COL_AFF: Value("float32"),
310
- })
311
-
312
- def gen_rows(df: pd.DataFrame):
313
- for r in pbar(df.itertuples(index=False), total=len(df)):
314
- tgt = str(getattr(r, "target_sequence")).strip()
315
- bnd = str(getattr(r, "sequence")).strip()
316
-
317
- y = float(getattr(r, "label"))
318
- aff = float(getattr(r, COL_AFF))
319
- acls = str(getattr(r, "affinity_class"))
320
-
321
- iptm = getattr(r, COL_WT_IPTM)
322
- iptm = float(iptm) if pd.notna(iptm) else np.nan
323
-
324
- # token embeddings for target + binder (both ESM)
325
- t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
326
- b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
327
-
328
- t_list = t_emb.tolist()
329
- b_list = b_emb.tolist()
330
- Lt = len(t_list)
331
- Lb = len(b_list)
332
-
333
- yield {
334
- "target_sequence": tgt,
335
- "sequence": bnd,
336
- "label": np.float32(y),
337
- "affinity": np.float32(aff),
338
- "affinity_class": acls,
339
-
340
- "target_embedding": t_list,
341
- "target_attention_mask": [1] * Lt,
342
- "target_length": int(Lt),
343
-
344
- "binder_embedding": b_list,
345
- "binder_attention_mask": [1] * Lb,
346
- "binder_length": int(Lb),
347
-
348
- COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
349
- COL_AFF: np.float32(aff),
350
- }
351
-
352
- out_dir.mkdir(parents=True, exist_ok=True)
353
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
354
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
355
- return ds
356
-
357
- def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
358
- smi_tok, smi_roformer):
359
- """
360
- df_split must have:
361
- - target_sequence (seq1)
362
- - sequence (binder smiles string)
363
- - label, affinity_class, COL_AFF, COL_SMI_IPTM
364
- Saves rows with:
365
- target_embedding (Lt,Ht) from ESM
366
- binder_embedding (Lb,Hb) from PeptideCLM
367
- """
368
- cls_id = wt_tokenizer.cls_token_id
369
- eos_id = wt_tokenizer.eos_token_id
370
- Ht = wt_model_unpooled.config.hidden_size
371
-
372
- # Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
373
- # Here: we’ll infer from model config if available.
374
- Hb = getattr(smi_roformer.config, "hidden_size", None)
375
- if Hb is None:
376
- Hb = getattr(smi_roformer.config, "dim", None)
377
- if Hb is None:
378
- raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
379
-
380
- features = Features({
381
- "target_sequence": Value("string"),
382
- "sequence": Value("string"),
383
- "label": Value("float32"),
384
- "affinity": Value("float32"),
385
- "affinity_class": Value("string"),
386
-
387
- "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
388
- "target_attention_mask": HFSequence(Value("int8")),
389
- "target_length": Value("int64"),
390
-
391
- "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
392
- "binder_attention_mask": HFSequence(Value("int8")),
393
- "binder_length": Value("int64"),
394
-
395
- COL_SMI_IPTM: Value("float32"),
396
- COL_AFF: Value("float32"),
397
- })
398
-
399
- def gen_rows(df: pd.DataFrame):
400
- for r in pbar(df.itertuples(index=False), total=len(df)):
401
- tgt = str(getattr(r, "target_sequence")).strip()
402
- bnd = str(getattr(r, "sequence")).strip()
403
-
404
- y = float(getattr(r, "label"))
405
- aff = float(getattr(r, COL_AFF))
406
- acls = str(getattr(r, "affinity_class"))
407
-
408
- iptm = getattr(r, COL_SMI_IPTM)
409
- iptm = float(iptm) if pd.notna(iptm) else np.nan
410
-
411
- # target token embeddings (ESM)
412
- t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
413
- t_list = t_emb.tolist()
414
- Lt = len(t_list)
415
-
416
- # binder token embeddings (PeptideCLM) — single-item batch
417
- _, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
418
- [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
419
- )
420
- b_emb = tok_list[0] # np.float16 (Lb, Hb)
421
- b_list = b_emb.tolist()
422
- Lb = int(lengths[0])
423
- b_mask = mask_list[0].astype(np.int8).tolist()
424
-
425
- yield {
426
- "target_sequence": tgt,
427
- "sequence": bnd,
428
- "label": np.float32(y),
429
- "affinity": np.float32(aff),
430
- "affinity_class": acls,
431
-
432
- "target_embedding": t_list,
433
- "target_attention_mask": [1] * Lt,
434
- "target_length": int(Lt),
435
-
436
- "binder_embedding": b_list,
437
- "binder_attention_mask": [int(x) for x in b_mask],
438
- "binder_length": int(Lb),
439
-
440
- COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
441
- COL_AFF: np.float32(aff),
442
- }
443
-
444
- out_dir.mkdir(parents=True, exist_ok=True)
445
- ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
446
- ds.save_to_disk(str(out_dir), max_shard_size="1GB")
447
- return ds
448
-
449
-
450
- # -------------------------
451
- # SMILES pooled + unpooled (PeptideCLM)
452
- # -------------------------
453
- def get_special_ids(tokenizer_obj):
454
- cand = [
455
- getattr(tokenizer_obj, "pad_token_id", None),
456
- getattr(tokenizer_obj, "cls_token_id", None),
457
- getattr(tokenizer_obj, "sep_token_id", None),
458
- getattr(tokenizer_obj, "bos_token_id", None),
459
- getattr(tokenizer_obj, "eos_token_id", None),
460
- getattr(tokenizer_obj, "mask_token_id", None),
461
- ]
462
- return sorted({x for x in cand if x is not None})
463
-
464
- @torch.no_grad()
465
- def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
466
- tok = tokenizer_obj(
467
- batch_sequences,
468
- return_tensors="pt",
469
- padding=True,
470
- truncation=True,
471
- max_length=max_length,
472
- )
473
- input_ids = tok["input_ids"].to(DEVICE)
474
- attention_mask = tok["attention_mask"].to(DEVICE)
475
-
476
- outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
477
- last_hidden = outputs.last_hidden_state # (B, L, H)
478
-
479
- special_ids = get_special_ids(tokenizer_obj)
480
- valid = attention_mask.bool()
481
- if len(special_ids) > 0:
482
- sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
483
- if hasattr(torch, "isin"):
484
- valid = valid & (~torch.isin(input_ids, sid))
485
- else:
486
- m = torch.zeros_like(valid)
487
- for s in special_ids:
488
- m |= (input_ids == s)
489
- valid = valid & (~m)
490
-
491
- valid_f = valid.unsqueeze(-1).float()
492
- summed = torch.sum(last_hidden * valid_f, dim=1)
493
- denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
494
- pooled = (summed / denom).detach().cpu().numpy()
495
-
496
- token_emb_list, mask_list, lengths = [], [], []
497
- for b in range(last_hidden.shape[0]):
498
- emb = last_hidden[b, valid[b]] # (Li, H)
499
- token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
500
- li = emb.shape[0]
501
- lengths.append(int(li))
502
- mask_list.append(np.ones((li,), dtype=np.int8))
503
-
504
- return pooled, token_emb_list, mask_list, lengths
505
-
506
- def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
507
- pooled_all = []
508
- token_emb_all = []
509
- mask_all = []
510
- lengths_all = []
511
-
512
- for i in pbar(range(0, len(seqs), batch_size)):
513
- batch = seqs[i:i + batch_size]
514
- pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
515
- batch, tokenizer_obj, model_roformer, max_length
516
- )
517
- pooled_all.append(pooled)
518
- token_emb_all.extend(tok_list)
519
- mask_all.extend(m_list)
520
- lengths_all.extend(lens)
521
-
522
- return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
523
-
524
- # -------------------------
525
- # Target embedding cache (NO extra ESM runs)
526
- # We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
527
- # -------------------------
528
- def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
529
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
530
- wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
531
-
532
- # compute target pooled embeddings once
533
- tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
534
- tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
535
-
536
- wt_train_tgt_emb = wt_pooled_embeddings(
537
- tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
538
- )
539
- wt_val_tgt_emb = wt_pooled_embeddings(
540
- tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
541
- )
542
-
543
- # build dict: target_sequence -> embedding (float32 array)
544
- # if duplicates exist, last wins; you can add checks if needed
545
- train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
546
- val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
547
- return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
548
- # -------------------------
549
- # Main
550
- # -------------------------
551
- def main():
552
- log(f"[INFO] DEVICE: {DEVICE}")
553
- OUT_ROOT.mkdir(parents=True, exist_ok=True)
554
-
555
- # 1) Load
556
- with section("load csv + dedup"):
557
- df = pd.read_csv(CSV_PATH)
558
- for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
559
- if c in df.columns:
560
- df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
561
-
562
- # Dedup on the full identity tuple you want
563
- DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
564
- df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
565
-
566
- print("Rows after dedup on", DEDUP_COLS, ":", len(df))
567
-
568
- need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
569
- missing = [c for c in need if c not in df.columns]
570
- if missing:
571
- raise ValueError(f"Missing required columns: {missing}")
572
-
573
- # numeric affinity for both branches
574
- df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
575
-
576
- # 2) Build WT subset + SMILES subset separately (NO global dropping)
577
- with section("prepare wt/smiles subsets"):
578
- # WT: requires a canonical peptide sequence (no X) + affinity
579
- df_wt = df.copy()
580
- df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
581
- df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
582
- df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
583
- df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
584
-
585
- # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
586
- df_smi = df.copy()
587
- df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
588
- df_smi = df_smi[
589
- pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
590
- ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
591
-
592
- is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
593
- df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
594
- df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
595
- df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
596
- df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
597
-
598
- log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
599
-
600
- # 3) Split separately (different sizes and memberships are expected)
601
- with section("split wt and smiles separately"):
602
- df_wt2 = make_distribution_matched_split(df_wt)
603
- df_smi2 = make_distribution_matched_split(df_smi)
604
-
605
- # save split tables
606
- wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
607
- smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
608
- df_wt2.to_csv(wt_split_csv, index=False)
609
- df_smi2.to_csv(smi_split_csv, index=False)
610
- log(f"Saved WT split meta: {wt_split_csv}")
611
- log(f"Saved SMILES split meta: {smi_split_csv}")
612
-
613
- # lightweight double-check (one-line)
614
- verify_split_before_embedding(
615
- df2=df_wt2,
616
- affinity_col=COL_AFF,
617
- split_col="split",
618
- seq_col="wt_sequence",
619
- iptm_col=COL_WT_IPTM,
620
- aff_class_col="affinity_class",
621
- aff_bins=AFFINITY_Q_BINS,
622
- save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
623
- verbose=False,
624
- )
625
- verify_split_before_embedding(
626
- df2=df_smi2,
627
- affinity_col=COL_AFF,
628
- split_col="split",
629
- seq_col="smiles_sequence",
630
- iptm_col=COL_SMI_IPTM,
631
- aff_class_col="affinity_class",
632
- aff_bins=AFFINITY_Q_BINS,
633
- save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
634
- verbose=False,
635
- )
636
-
637
- # Prepare split views
638
- def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
639
- out = df_in.copy()
640
- out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
641
- out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
642
- out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
643
- out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
644
- out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
645
- out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
646
- return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
647
-
648
- wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
649
- smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
650
-
651
- # -------------------------
652
- # Split views
653
- # -------------------------
654
- wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
655
- wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
656
- smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
657
- smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
658
-
659
-
660
- # =========================
661
- # TARGET pooled embeddings (ESM) — SEPARATE per branch
662
- # =========================
663
- with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
664
- wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
665
- wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
666
-
667
- # ---- WT targets ----
668
- wt_train_tgt_emb = wt_pooled_embeddings(
669
- wt_train["target_sequence"].astype(str).str.strip().tolist(),
670
- wt_tok, wt_esm,
671
- batch_size=WT_BATCH,
672
- max_length=WT_MAX_LEN,
673
- ).astype(np.float32)
674
-
675
- wt_val_tgt_emb = wt_pooled_embeddings(
676
- wt_val["target_sequence"].astype(str).str.strip().tolist(),
677
- wt_tok, wt_esm,
678
- batch_size=WT_BATCH,
679
- max_length=WT_MAX_LEN,
680
- ).astype(np.float32)
681
-
682
- # ---- SMILES targets (independent; may include UAA-only targets) ----
683
- smi_train_tgt_emb = wt_pooled_embeddings(
684
- smi_train["target_sequence"].astype(str).str.strip().tolist(),
685
- wt_tok, wt_esm,
686
- batch_size=WT_BATCH,
687
- max_length=WT_MAX_LEN,
688
- ).astype(np.float32)
689
-
690
- smi_val_tgt_emb = wt_pooled_embeddings(
691
- smi_val["target_sequence"].astype(str).str.strip().tolist(),
692
- wt_tok, wt_esm,
693
- batch_size=WT_BATCH,
694
- max_length=WT_MAX_LEN,
695
- ).astype(np.float32)
696
-
697
-
698
- # =========================
699
- # WT pooled binder embeddings (binder = WT peptide)
700
- # =========================
701
- with section("WT pooled binder embeddings + save"):
702
- wt_train_emb = wt_pooled_embeddings(
703
- wt_train["sequence"].astype(str).str.strip().tolist(),
704
- wt_tok, wt_esm,
705
- batch_size=WT_BATCH,
706
- max_length=WT_MAX_LEN,
707
- ).astype(np.float32)
708
-
709
- wt_val_emb = wt_pooled_embeddings(
710
- wt_val["sequence"].astype(str).str.strip().tolist(),
711
- wt_tok, wt_esm,
712
- batch_size=WT_BATCH,
713
- max_length=WT_MAX_LEN,
714
- ).astype(np.float32)
715
-
716
- wt_train_ds = Dataset.from_dict({
717
- "target_sequence": wt_train["target_sequence"].tolist(),
718
- "sequence": wt_train["sequence"].tolist(),
719
- "label": wt_train["label"].astype(float).tolist(),
720
- "target_embedding": wt_train_tgt_emb,
721
- "embedding": wt_train_emb,
722
- COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
723
- COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
724
- "affinity_class": wt_train["affinity_class"].tolist(),
725
- })
726
-
727
- wt_val_ds = Dataset.from_dict({
728
- "target_sequence": wt_val["target_sequence"].tolist(),
729
- "sequence": wt_val["sequence"].tolist(),
730
- "label": wt_val["label"].astype(float).tolist(),
731
- "target_embedding": wt_val_tgt_emb,
732
- "embedding": wt_val_emb,
733
- COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
734
- COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
735
- "affinity_class": wt_val["affinity_class"].tolist(),
736
- })
737
-
738
- wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
739
- wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
740
- wt_pooled_dd.save_to_disk(str(wt_pooled_out))
741
- log(f"Saved WT pooled -> {wt_pooled_out}")
742
-
743
-
744
- # =========================
745
- # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
746
- # =========================
747
- with section("SMILES pooled binder embeddings + save"):
748
- smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
749
- smi_roformer = (
750
- AutoModelForMaskedLM
751
- .from_pretrained(SMI_MODEL_NAME)
752
- .roformer
753
- .to(DEVICE)
754
- .eval()
755
- )
756
-
757
- smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
758
- smi_train["sequence"].astype(str).str.strip().tolist(),
759
- smi_tok, smi_roformer,
760
- batch_size=SMI_BATCH,
761
- max_length=SMI_MAX_LEN,
762
- )
763
-
764
- smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
765
- smi_val["sequence"].astype(str).str.strip().tolist(),
766
- smi_tok, smi_roformer,
767
- batch_size=SMI_BATCH,
768
- max_length=SMI_MAX_LEN,
769
- )
770
-
771
- smi_train_ds = Dataset.from_dict({
772
- "target_sequence": smi_train["target_sequence"].tolist(),
773
- "sequence": smi_train["sequence"].tolist(),
774
- "label": smi_train["label"].astype(float).tolist(),
775
- "target_embedding": smi_train_tgt_emb,
776
- "embedding": smi_train_pooled.astype(np.float32),
777
- COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
778
- COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
779
- "affinity_class": smi_train["affinity_class"].tolist(),
780
- })
781
-
782
- smi_val_ds = Dataset.from_dict({
783
- "target_sequence": smi_val["target_sequence"].tolist(),
784
- "sequence": smi_val["sequence"].tolist(),
785
- "label": smi_val["label"].astype(float).tolist(),
786
- "target_embedding": smi_val_tgt_emb,
787
- "embedding": smi_val_pooled.astype(np.float32),
788
- COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
789
- COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
790
- "affinity_class": smi_val["affinity_class"].tolist(),
791
- })
792
-
793
- smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
794
- smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
795
- smi_pooled_dd.save_to_disk(str(smi_pooled_out))
796
- log(f"Saved SMILES pooled -> {smi_pooled_out}")
797
-
798
-
799
- # =========================
800
- # WT unpooled paired (ESM target + ESM binder) + save
801
- # =========================
802
- with section("WT unpooled paired embeddings + save"):
803
- wt_tok_unpooled = wt_tok # reuse tokenizer
804
- wt_esm_unpooled = wt_esm # reuse model
805
-
806
- wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
807
- wt_unpooled_dd = DatasetDict({
808
- "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
809
- wt_tok_unpooled, wt_esm_unpooled),
810
- "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
811
- wt_tok_unpooled, wt_esm_unpooled),
812
- })
813
- # (Optional) also save as DatasetDict root if you want a single load_from_disk path:
814
- wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
815
- log(f"Saved WT unpooled -> {wt_unpooled_out}")
816
-
817
-
818
- # =========================
819
- # SMILES unpooled paired (ESM target + PeptideCLM binder) + save
820
- # =========================
821
- with section("SMILES unpooled paired embeddings + save"):
822
- # reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
823
- # otherwise re-init here:
824
- # smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
825
- # smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
826
-
827
- smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
828
- smi_unpooled_dd = DatasetDict({
829
- "train": build_smiles_unpooled_paired_dataset(
830
- smi_train, smi_unpooled_out / "train",
831
- wt_tok, wt_esm,
832
- smi_tok, smi_roformer
833
- ),
834
- "val": build_smiles_unpooled_paired_dataset(
835
- smi_val, smi_unpooled_out / "val",
836
- wt_tok, wt_esm,
837
- smi_tok, smi_roformer
838
- ),
839
- })
840
- smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
841
- log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
842
-
843
- log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
844
-
845
-
846
- if __name__ == "__main__":
847
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_classifiers/binding_wt.bash DELETED
@@ -1,31 +0,0 @@
1
- #!/bin/bash
2
- #SBATCH --job-name=b-data
3
- #SBATCH --partition=dgx-b200
4
- #SBATCH --gpus=1
5
- #SBATCH --cpus-per-task=10
6
- #SBATCH --mem=100G
7
- #SBATCH --time=48:00:00
8
- #SBATCH --output=%x_%j.out
9
-
10
- HOME_LOC=/vast/projects/pranam/lab/yz927
11
- SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
12
- DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
13
- OBJECTIVE='binding_affinity'
14
- WT='smiles' #wt/smiles
15
- STATUS='pooled' #pooled/unpooled
16
- DATA_FILE="pair_wt_${WT}_${STATUS}"
17
- LOG_LOC=$SCRIPT_LOC
18
- DATE=$(date +%m_%d)
19
- SPECIAL_PREFIX="binding_affinity_data_generation"
20
-
21
- # Create log directory if it doesn't exist
22
- mkdir -p $LOG_LOC
23
-
24
- cd $SCRIPT_LOC
25
- source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
26
- conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
27
-
28
- python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
29
-
30
- echo "Script completed at $(date)"
31
- conda deactivate