Joblib
PeptiVerse / training_classifiers /.ipynb_checkpoints /generate_binding_val-checkpoint.py
ynuozhang
update code
baf3373
#!/usr/bin/env python3
# export_val_preds_csv.py
import argparse
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_from_disk, DatasetDict
# -----------------------------
# Repro / device
# -----------------------------
def seed_all(seed=1986):
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
seed_all(1986)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------
# Load paired DatasetDict
# -----------------------------
def load_split_paired(path: str):
dd = load_from_disk(path)
if not isinstance(dd, DatasetDict):
raise ValueError(f"Expected DatasetDict at {path}")
if "train" not in dd or "val" not in dd:
raise ValueError(f"DatasetDict missing train/val at {path}")
return dd["train"], dd["val"]
# -----------------------------
# Collate fns (same as yours)
# -----------------------------
def collate_pair_pooled(batch):
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
return Pt, Pb, y
def collate_pair_unpooled(batch):
B = len(batch)
Ht = len(batch[0]["target_embedding"][0])
Hb = len(batch[0]["binder_embedding"][0])
Lt_max = max(int(x["target_length"]) for x in batch)
Lb_max = max(int(x["binder_length"]) for x in batch)
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
for i, x in enumerate(batch):
t = torch.tensor(x["target_embedding"], dtype=torch.float32)
b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
lt, lb = t.shape[0], b.shape[0]
Pt[i, :lt] = t
Pb[i, :lb] = b
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
return Pt, Mt, Pb, Mb, y
# -----------------------------
# Models (same as yours)
# -----------------------------
class CrossAttnPooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def forward(self, t_vec, b_vec):
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
for L in self.layers:
t_attn, _ = L["attn_tb"](t, b, b)
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
b_attn, _ = L["attn_bt"](b, t, t)
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
z = torch.cat([t[0], b[0]], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
class CrossAttnUnpooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def masked_mean(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
def forward(self, T, Mt, B, Mb):
T = self.t_proj(T)
Bx = self.b_proj(B)
kp_t = ~Mt
kp_b = ~Mb
for L in self.layers:
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
T = L["n1t"](T + T_attn)
T = L["n2t"](T + L["fft"](T))
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
Bx = L["n1b"](Bx + B_attn)
Bx = L["n2b"](Bx + L["ffb"](Bx))
t_pool = self.masked_mean(T, Mt)
b_pool = self.masked_mean(Bx, Mb)
z = torch.cat([t_pool, b_pool], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
# -----------------------------
# Helpers
# -----------------------------
def softmax_np(logits: np.ndarray) -> np.ndarray:
x = logits - logits.max(axis=1, keepdims=True)
ex = np.exp(x)
return ex / ex.sum(axis=1, keepdims=True)
def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray:
centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3)
return (probs * centers).sum(axis=1)
def load_checkpoint(ckpt_path: str, mode: str, train_ds):
ckpt = torch.load(ckpt_path, map_location="cpu")
params = ckpt.get("best_params", {})
hidden = int(params.get("hidden_dim", 512))
n_heads = int(params.get("n_heads", 8))
n_layers = int(params.get("n_layers", 3))
dropout = float(params.get("dropout", 0.1))
if mode == "pooled":
Ht = len(train_ds[0]["target_embedding"])
Hb = len(train_ds[0]["binder_embedding"])
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
else:
Ht = len(train_ds[0]["target_embedding"][0])
Hb = len(train_ds[0]["binder_embedding"][0])
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
model.load_state_dict(ckpt["state_dict"], strict=True)
model.to(DEVICE).eval()
return model
@torch.no_grad()
def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str,
out_csv: str, batch_size: int, num_workers: int,
class_centers=(9.5, 8.0, 6.0)):
train_ds, val_ds = load_split_paired(dataset_path)
model = load_checkpoint(ckpt_path, mode, train_ds)
if mode == "pooled":
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
collate_fn=collate_pair_pooled)
y_all, pred_reg_all, logits_all = [], [], []
for t, b, y in loader:
t = t.to(DEVICE, non_blocking=True)
b = b.to(DEVICE, non_blocking=True)
pred_reg, logits = model(t, b)
y_all.append(y.numpy())
pred_reg_all.append(pred_reg.detach().cpu().numpy())
logits_all.append(logits.detach().cpu().numpy())
else:
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
collate_fn=collate_pair_unpooled)
y_all, pred_reg_all, logits_all = [], [], []
for T, Mt, B, Mb, y in loader:
T = T.to(DEVICE, non_blocking=True)
Mt = Mt.to(DEVICE, non_blocking=True)
B = B.to(DEVICE, non_blocking=True)
Mb = Mb.to(DEVICE, non_blocking=True)
pred_reg, logits = model(T, Mt, B, Mb)
y_all.append(y.numpy())
pred_reg_all.append(pred_reg.detach().cpu().numpy())
logits_all.append(logits.detach().cpu().numpy())
y_true = np.concatenate(y_all)
y_pred_reg = np.concatenate(pred_reg_all)
logits = np.concatenate(logits_all)
probs = softmax_np(logits) # (N,3)
y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers)
# Build CSV rows
out = Path(out_csv)
out.parent.mkdir(parents=True, exist_ok=True)
header = [
"split", "mode",
"y_true",
"y_pred_reg",
"p_high", "p_moderate", "p_low",
"y_pred_cls_score",
"center_high", "center_moderate", "center_low",
]
centers = list(class_centers)
rows = np.column_stack([
y_true,
y_pred_reg,
probs[:, 0], probs[:, 1], probs[:, 2],
y_pred_cls_score,
np.full_like(y_true, centers[0], dtype=np.float32),
np.full_like(y_true, centers[1], dtype=np.float32),
np.full_like(y_true, centers[2], dtype=np.float32),
])
with out.open("w") as f:
f.write(",".join(header) + "\n")
for i in range(rows.shape[0]):
f.write(
"val," + mode + "," +
",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) +
"\n"
)
print(f"[Data] Val N={len(y_true)} | mode={mode}")
print(f"[Saved] {out}")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)")
ap.add_argument("--ckpt", required=True, help="Path to best_model.pt")
ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True)
ap.add_argument("--out_csv", required=True)
ap.add_argument("--batch_size", type=int, default=128)
ap.add_argument("--num_workers", type=int, default=4)
# Optional: choose class-centers for expected-score conversion
ap.add_argument("--center_high", type=float, default=9.5)
ap.add_argument("--center_moderate", type=float, default=8.0)
ap.add_argument("--center_low", type=float, default=6.0)
args = ap.parse_args()
export_val_preds_csv(
dataset_path=args.dataset_path,
ckpt_path=args.ckpt,
mode=args.mode,
out_csv=args.out_csv,
batch_size=args.batch_size,
num_workers=args.num_workers,
class_centers=(args.center_high, args.center_moderate, args.center_low),
)
if __name__ == "__main__":
main()