Joblib
PeptiVerse / training_classifiers /.ipynb_checkpoints /binding_affinity_split-checkpoint.py
ynuozhang
update code
baf3373
#!/usr/bin/env python3
import os
import math
from pathlib import Path
import sys
from contextlib import contextmanager
import numpy as np
import pandas as pd
import torch
# tqdm is optional; we’ll disable it by default in notebooks
from tqdm import tqdm
sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
# -------------------------
# Config
# -------------------------
CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
OUT_ROOT = Path(
"/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
)
# WT (seq) embedding model
WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
WT_MAX_LEN = 1022
WT_BATCH = 32
# SMILES embedding model + tokenizer
SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
SMI_MAX_LEN = 768
SMI_BATCH = 128
# Split config
TRAIN_FRAC = 0.80
RANDOM_SEED = 1986
AFFINITY_Q_BINS = 30
# Columns expected in CSV
COL_SEQ1 = "seq1"
COL_SEQ2 = "seq2"
COL_AFF = "affinity"
COL_F2S = "Fasta2SMILES"
COL_REACT = "REACT_SMILES"
COL_WT_IPTM = "wt_iptm_score"
COL_SMI_IPTM = "smiles_iptm_score"
# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -------------------------
# Quiet / notebook-safe output controls
# -------------------------
QUIET = True # suppress most prints
USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
LOG_FILE = None # optionally: OUT_ROOT / "build.log"
def log(msg: str):
if LOG_FILE is not None:
Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
with open(LOG_FILE, "a") as f:
f.write(msg.rstrip() + "\n")
if not QUIET:
print(msg)
def pbar(it, **kwargs):
return tqdm(it, **kwargs) if USE_TQDM else it
@contextmanager
def section(title: str):
log(f"\n=== {title} ===")
yield
log(f"=== done: {title} ===")
# -------------------------
# Helpers
# -------------------------
def has_uaa(seq: str) -> bool:
return "X" in str(seq).upper()
def affinity_to_class(a: float) -> str:
# High: >= 9 ; Moderate: [7, 9) ; Low: < 7
if a >= 9.0:
return "High"
elif a >= 7.0:
return "Moderate"
else:
return "Low"
def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
df = df.copy()
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
try:
df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
strat_col = "aff_bin"
except Exception:
df["aff_bin"] = df["affinity_class"]
strat_col = "aff_bin"
rng = np.random.RandomState(RANDOM_SEED)
df["split"] = None
for _, g in df.groupby(strat_col, observed=True):
idx = g.index.to_numpy()
rng.shuffle(idx)
n_train = int(math.floor(len(idx) * TRAIN_FRAC))
df.loc[idx[:n_train], "split"] = "train"
df.loc[idx[n_train:], "split"] = "val"
df["split"] = df["split"].fillna("train")
return df
def _summ(x):
x = np.asarray(x, dtype=float)
x = x[~np.isnan(x)]
if len(x) == 0:
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
return {
"n": int(len(x)),
"mean": float(np.mean(x)),
"std": float(np.std(x)),
"p50": float(np.quantile(x, 0.50)),
"p95": float(np.quantile(x, 0.95)),
}
def _len_stats(seqs):
lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
if len(lens) == 0:
return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
return {
"n": int(len(lens)),
"mean": float(lens.mean()),
"std": float(lens.std()),
"p50": float(np.quantile(lens, 0.50)),
"p95": float(np.quantile(lens, 0.95)),
}
def verify_split_before_embedding(
df2: pd.DataFrame,
affinity_col: str,
split_col: str,
seq_col: str,
iptm_col: str,
aff_class_col: str = "affinity_class",
aff_bins: int = 30,
save_report_prefix: str | None = None,
verbose: bool = False,
):
"""
Notebook-safe: by default prints only ONE line via `log()`.
Optionally writes CSV reports (stats + class proportions).
"""
df2 = df2.copy()
df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
assert split_col in df2.columns, f"Missing split col: {split_col}"
assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
try:
df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
except Exception:
df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
tr = df2[df2[split_col] == "train"].reset_index(drop=True)
va = df2[df2[split_col] == "val"].reset_index(drop=True)
tr_aff = _summ(tr[affinity_col].to_numpy())
va_aff = _summ(va[affinity_col].to_numpy())
tr_len = _len_stats(tr[seq_col].tolist())
va_len = _len_stats(va[seq_col].tolist())
# bin drift
bin_ct = (
df2.groupby([split_col, "_aff_bin_dbg"])
.size()
.groupby(level=0)
.apply(lambda s: s / s.sum())
)
tr_bins = bin_ct.loc["train"]
va_bins = bin_ct.loc["val"]
all_bins = tr_bins.index.union(va_bins.index)
tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
va_bins = va_bins.reindex(all_bins, fill_value=0.0)
max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
msg = (
f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
f"max_bin_diff={max_bin_diff:.4f}"
)
log(msg)
if verbose and (not QUIET):
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
print("\n[verbose] affinity_class counts:\n", class_ct)
print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
if save_report_prefix is not None:
out = Path(save_report_prefix)
out.parent.mkdir(parents=True, exist_ok=True)
stats_df = pd.DataFrame([
{"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
{"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
])
class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
# -------------------------
# WT pooled (ESM2)
# -------------------------
@torch.no_grad()
def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
embs = []
for i in pbar(range(0, len(seqs), batch_size)):
batch = seqs[i:i + batch_size]
inputs = tokenizer(
batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
out = model(**inputs)
h = out.last_hidden_state # (B, L, H)
attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
summed = (h * attn).sum(dim=1) # (B, H)
denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
pooled = (summed / denom).detach().cpu().numpy()
embs.append(pooled)
return np.vstack(embs)
# -------------------------
# WT unpooled (ESM2)
# -------------------------
@torch.no_grad()
def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
tok = {k: v.to(DEVICE) for k, v in tok.items()}
out = model(**tok)
h = out.last_hidden_state[0] # (L, H)
attn = tok["attention_mask"][0].bool() # (L,)
ids = tok["input_ids"][0]
keep = attn.clone()
if cls_id is not None:
keep &= (ids != cls_id)
if eos_id is not None:
keep &= (ids != eos_id)
return h[keep].detach().cpu().to(torch.float16).numpy()
def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
"""
Expects df_split to have:
- target_sequence (seq1)
- sequence (binder seq2; WT binder)
- label, affinity_class, COL_AFF, COL_WT_IPTM
Saves a dataset where each row contains BOTH:
- target_embedding (Lt,H), target_attention_mask, target_length
- binder_embedding (Lb,H), binder_attention_mask, binder_length
"""
cls_id = tokenizer.cls_token_id
eos_id = tokenizer.eos_token_id
H = model.config.hidden_size
features = Features({
"target_sequence": Value("string"),
"sequence": Value("string"),
"label": Value("float32"),
"affinity": Value("float32"),
"affinity_class": Value("string"),
"target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
"target_attention_mask": HFSequence(Value("int8")),
"target_length": Value("int64"),
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
"binder_attention_mask": HFSequence(Value("int8")),
"binder_length": Value("int64"),
COL_WT_IPTM: Value("float32"),
COL_AFF: Value("float32"),
})
def gen_rows(df: pd.DataFrame):
for r in pbar(df.itertuples(index=False), total=len(df)):
tgt = str(getattr(r, "target_sequence")).strip()
bnd = str(getattr(r, "sequence")).strip()
y = float(getattr(r, "label"))
aff = float(getattr(r, COL_AFF))
acls = str(getattr(r, "affinity_class"))
iptm = getattr(r, COL_WT_IPTM)
iptm = float(iptm) if pd.notna(iptm) else np.nan
# token embeddings for target + binder (both ESM)
t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
t_list = t_emb.tolist()
b_list = b_emb.tolist()
Lt = len(t_list)
Lb = len(b_list)
yield {
"target_sequence": tgt,
"sequence": bnd,
"label": np.float32(y),
"affinity": np.float32(aff),
"affinity_class": acls,
"target_embedding": t_list,
"target_attention_mask": [1] * Lt,
"target_length": int(Lt),
"binder_embedding": b_list,
"binder_attention_mask": [1] * Lb,
"binder_length": int(Lb),
COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
COL_AFF: np.float32(aff),
}
out_dir.mkdir(parents=True, exist_ok=True)
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
return ds
def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
smi_tok, smi_roformer):
"""
df_split must have:
- target_sequence (seq1)
- sequence (binder smiles string)
- label, affinity_class, COL_AFF, COL_SMI_IPTM
Saves rows with:
target_embedding (Lt,Ht) from ESM
binder_embedding (Lb,Hb) from PeptideCLM
"""
cls_id = wt_tokenizer.cls_token_id
eos_id = wt_tokenizer.eos_token_id
Ht = wt_model_unpooled.config.hidden_size
# Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
# Here: we’ll infer from model config if available.
Hb = getattr(smi_roformer.config, "hidden_size", None)
if Hb is None:
Hb = getattr(smi_roformer.config, "dim", None)
if Hb is None:
raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
features = Features({
"target_sequence": Value("string"),
"sequence": Value("string"),
"label": Value("float32"),
"affinity": Value("float32"),
"affinity_class": Value("string"),
"target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
"target_attention_mask": HFSequence(Value("int8")),
"target_length": Value("int64"),
"binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
"binder_attention_mask": HFSequence(Value("int8")),
"binder_length": Value("int64"),
COL_SMI_IPTM: Value("float32"),
COL_AFF: Value("float32"),
})
def gen_rows(df: pd.DataFrame):
for r in pbar(df.itertuples(index=False), total=len(df)):
tgt = str(getattr(r, "target_sequence")).strip()
bnd = str(getattr(r, "sequence")).strip()
y = float(getattr(r, "label"))
aff = float(getattr(r, COL_AFF))
acls = str(getattr(r, "affinity_class"))
iptm = getattr(r, COL_SMI_IPTM)
iptm = float(iptm) if pd.notna(iptm) else np.nan
# target token embeddings (ESM)
t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
t_list = t_emb.tolist()
Lt = len(t_list)
# binder token embeddings (PeptideCLM) — single-item batch
_, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
[bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
)
b_emb = tok_list[0] # np.float16 (Lb, Hb)
b_list = b_emb.tolist()
Lb = int(lengths[0])
b_mask = mask_list[0].astype(np.int8).tolist()
yield {
"target_sequence": tgt,
"sequence": bnd,
"label": np.float32(y),
"affinity": np.float32(aff),
"affinity_class": acls,
"target_embedding": t_list,
"target_attention_mask": [1] * Lt,
"target_length": int(Lt),
"binder_embedding": b_list,
"binder_attention_mask": [int(x) for x in b_mask],
"binder_length": int(Lb),
COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
COL_AFF: np.float32(aff),
}
out_dir.mkdir(parents=True, exist_ok=True)
ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
ds.save_to_disk(str(out_dir), max_shard_size="1GB")
return ds
# -------------------------
# SMILES pooled + unpooled (PeptideCLM)
# -------------------------
def get_special_ids(tokenizer_obj):
cand = [
getattr(tokenizer_obj, "pad_token_id", None),
getattr(tokenizer_obj, "cls_token_id", None),
getattr(tokenizer_obj, "sep_token_id", None),
getattr(tokenizer_obj, "bos_token_id", None),
getattr(tokenizer_obj, "eos_token_id", None),
getattr(tokenizer_obj, "mask_token_id", None),
]
return sorted({x for x in cand if x is not None})
@torch.no_grad()
def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
tok = tokenizer_obj(
batch_sequences,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
input_ids = tok["input_ids"].to(DEVICE)
attention_mask = tok["attention_mask"].to(DEVICE)
outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state # (B, L, H)
special_ids = get_special_ids(tokenizer_obj)
valid = attention_mask.bool()
if len(special_ids) > 0:
sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
if hasattr(torch, "isin"):
valid = valid & (~torch.isin(input_ids, sid))
else:
m = torch.zeros_like(valid)
for s in special_ids:
m |= (input_ids == s)
valid = valid & (~m)
valid_f = valid.unsqueeze(-1).float()
summed = torch.sum(last_hidden * valid_f, dim=1)
denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
pooled = (summed / denom).detach().cpu().numpy()
token_emb_list, mask_list, lengths = [], [], []
for b in range(last_hidden.shape[0]):
emb = last_hidden[b, valid[b]] # (Li, H)
token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
li = emb.shape[0]
lengths.append(int(li))
mask_list.append(np.ones((li,), dtype=np.int8))
return pooled, token_emb_list, mask_list, lengths
def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
pooled_all = []
token_emb_all = []
mask_all = []
lengths_all = []
for i in pbar(range(0, len(seqs), batch_size)):
batch = seqs[i:i + batch_size]
pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
batch, tokenizer_obj, model_roformer, max_length
)
pooled_all.append(pooled)
token_emb_all.extend(tok_list)
mask_all.extend(m_list)
lengths_all.extend(lens)
return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
# -------------------------
# Target embedding cache (NO extra ESM runs)
# We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
# -------------------------
def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
# compute target pooled embeddings once
tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
wt_train_tgt_emb = wt_pooled_embeddings(
tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
)
wt_val_tgt_emb = wt_pooled_embeddings(
tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
)
# build dict: target_sequence -> embedding (float32 array)
# if duplicates exist, last wins; you can add checks if needed
train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
# -------------------------
# Main
# -------------------------
def main():
log(f"[INFO] DEVICE: {DEVICE}")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
# 1) Load
with section("load csv + dedup"):
df = pd.read_csv(CSV_PATH)
for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
if c in df.columns:
df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
# Dedup on the full identity tuple you want
DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
print("Rows after dedup on", DEDUP_COLS, ":", len(df))
need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
missing = [c for c in need if c not in df.columns]
if missing:
raise ValueError(f"Missing required columns: {missing}")
# numeric affinity for both branches
df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
# 2) Build WT subset + SMILES subset separately (NO global dropping)
with section("prepare wt/smiles subsets"):
# WT: requires a canonical peptide sequence (no X) + affinity
df_wt = df.copy()
df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
# SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
df_smi = df.copy()
df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
df_smi = df_smi[
pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
# 3) Split separately (different sizes and memberships are expected)
with section("split wt and smiles separately"):
df_wt2 = make_distribution_matched_split(df_wt)
df_smi2 = make_distribution_matched_split(df_smi)
# save split tables
wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
df_wt2.to_csv(wt_split_csv, index=False)
df_smi2.to_csv(smi_split_csv, index=False)
log(f"Saved WT split meta: {wt_split_csv}")
log(f"Saved SMILES split meta: {smi_split_csv}")
# lightweight double-check (one-line)
verify_split_before_embedding(
df2=df_wt2,
affinity_col=COL_AFF,
split_col="split",
seq_col="wt_sequence",
iptm_col=COL_WT_IPTM,
aff_class_col="affinity_class",
aff_bins=AFFINITY_Q_BINS,
save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
verbose=False,
)
verify_split_before_embedding(
df2=df_smi2,
affinity_col=COL_AFF,
split_col="split",
seq_col="smiles_sequence",
iptm_col=COL_SMI_IPTM,
aff_class_col="affinity_class",
aff_bins=AFFINITY_Q_BINS,
save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
verbose=False,
)
# Prepare split views
def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
out = df_in.copy()
out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
# -------------------------
# Split views
# -------------------------
wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
# =========================
# TARGET pooled embeddings (ESM) — SEPARATE per branch
# =========================
with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
# ---- WT targets ----
wt_train_tgt_emb = wt_pooled_embeddings(
wt_train["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_val_tgt_emb = wt_pooled_embeddings(
wt_val["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
# ---- SMILES targets (independent; may include UAA-only targets) ----
smi_train_tgt_emb = wt_pooled_embeddings(
smi_train["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
smi_val_tgt_emb = wt_pooled_embeddings(
smi_val["target_sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
# =========================
# WT pooled binder embeddings (binder = WT peptide)
# =========================
with section("WT pooled binder embeddings + save"):
wt_train_emb = wt_pooled_embeddings(
wt_train["sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_val_emb = wt_pooled_embeddings(
wt_val["sequence"].astype(str).str.strip().tolist(),
wt_tok, wt_esm,
batch_size=WT_BATCH,
max_length=WT_MAX_LEN,
).astype(np.float32)
wt_train_ds = Dataset.from_dict({
"target_sequence": wt_train["target_sequence"].tolist(),
"sequence": wt_train["sequence"].tolist(),
"label": wt_train["label"].astype(float).tolist(),
"target_embedding": wt_train_tgt_emb,
"embedding": wt_train_emb,
COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
"affinity_class": wt_train["affinity_class"].tolist(),
})
wt_val_ds = Dataset.from_dict({
"target_sequence": wt_val["target_sequence"].tolist(),
"sequence": wt_val["sequence"].tolist(),
"label": wt_val["label"].astype(float).tolist(),
"target_embedding": wt_val_tgt_emb,
"embedding": wt_val_emb,
COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
"affinity_class": wt_val["affinity_class"].tolist(),
})
wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
wt_pooled_dd.save_to_disk(str(wt_pooled_out))
log(f"Saved WT pooled -> {wt_pooled_out}")
# =========================
# SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
# =========================
with section("SMILES pooled binder embeddings + save"):
smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
smi_roformer = (
AutoModelForMaskedLM
.from_pretrained(SMI_MODEL_NAME)
.roformer
.to(DEVICE)
.eval()
)
smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
smi_train["sequence"].astype(str).str.strip().tolist(),
smi_tok, smi_roformer,
batch_size=SMI_BATCH,
max_length=SMI_MAX_LEN,
)
smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
smi_val["sequence"].astype(str).str.strip().tolist(),
smi_tok, smi_roformer,
batch_size=SMI_BATCH,
max_length=SMI_MAX_LEN,
)
smi_train_ds = Dataset.from_dict({
"target_sequence": smi_train["target_sequence"].tolist(),
"sequence": smi_train["sequence"].tolist(),
"label": smi_train["label"].astype(float).tolist(),
"target_embedding": smi_train_tgt_emb,
"embedding": smi_train_pooled.astype(np.float32),
COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
"affinity_class": smi_train["affinity_class"].tolist(),
})
smi_val_ds = Dataset.from_dict({
"target_sequence": smi_val["target_sequence"].tolist(),
"sequence": smi_val["sequence"].tolist(),
"label": smi_val["label"].astype(float).tolist(),
"target_embedding": smi_val_tgt_emb,
"embedding": smi_val_pooled.astype(np.float32),
COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
"affinity_class": smi_val["affinity_class"].tolist(),
})
smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
smi_pooled_dd.save_to_disk(str(smi_pooled_out))
log(f"Saved SMILES pooled -> {smi_pooled_out}")
# =========================
# WT unpooled paired (ESM target + ESM binder) + save
# =========================
with section("WT unpooled paired embeddings + save"):
wt_tok_unpooled = wt_tok # reuse tokenizer
wt_esm_unpooled = wt_esm # reuse model
wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
wt_unpooled_dd = DatasetDict({
"train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
wt_tok_unpooled, wt_esm_unpooled),
"val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
wt_tok_unpooled, wt_esm_unpooled),
})
# (Optional) also save as DatasetDict root if you want a single load_from_disk path:
wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
log(f"Saved WT unpooled -> {wt_unpooled_out}")
# =========================
# SMILES unpooled paired (ESM target + PeptideCLM binder) + save
# =========================
with section("SMILES unpooled paired embeddings + save"):
# reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
# otherwise re-init here:
# smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
# smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
smi_unpooled_dd = DatasetDict({
"train": build_smiles_unpooled_paired_dataset(
smi_train, smi_unpooled_out / "train",
wt_tok, wt_esm,
smi_tok, smi_roformer
),
"val": build_smiles_unpooled_paired_dataset(
smi_val, smi_unpooled_out / "val",
wt_tok, wt_esm,
smi_tok, smi_roformer
),
})
smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
if __name__ == "__main__":
main()