|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from datasets import load_from_disk, DatasetDict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_all(seed=1986): |
|
|
import random |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
seed_all(1986) |
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_split_paired(path: str): |
|
|
dd = load_from_disk(path) |
|
|
if not isinstance(dd, DatasetDict): |
|
|
raise ValueError(f"Expected DatasetDict at {path}") |
|
|
if "train" not in dd or "val" not in dd: |
|
|
raise ValueError(f"DatasetDict missing train/val at {path}") |
|
|
return dd["train"], dd["val"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_pair_pooled(batch): |
|
|
Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) |
|
|
Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) |
|
|
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
return Pt, Pb, y |
|
|
|
|
|
def collate_pair_unpooled(batch): |
|
|
B = len(batch) |
|
|
Ht = len(batch[0]["target_embedding"][0]) |
|
|
Hb = len(batch[0]["binder_embedding"][0]) |
|
|
Lt_max = max(int(x["target_length"]) for x in batch) |
|
|
Lb_max = max(int(x["binder_length"]) for x in batch) |
|
|
|
|
|
Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) |
|
|
Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) |
|
|
Mt = torch.zeros(B, Lt_max, dtype=torch.bool) |
|
|
Mb = torch.zeros(B, Lb_max, dtype=torch.bool) |
|
|
y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) |
|
|
|
|
|
for i, x in enumerate(batch): |
|
|
t = torch.tensor(x["target_embedding"], dtype=torch.float32) |
|
|
b = torch.tensor(x["binder_embedding"], dtype=torch.float32) |
|
|
lt, lb = t.shape[0], b.shape[0] |
|
|
Pt[i, :lt] = t |
|
|
Pb[i, :lb] = b |
|
|
Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) |
|
|
Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) |
|
|
|
|
|
return Pt, Mt, Pb, Mb, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossAttnPooled(nn.Module): |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def forward(self, t_vec, b_vec): |
|
|
t = self.t_proj(t_vec).unsqueeze(0) |
|
|
b = self.b_proj(b_vec).unsqueeze(0) |
|
|
|
|
|
for L in self.layers: |
|
|
t_attn, _ = L["attn_tb"](t, b, b) |
|
|
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
|
|
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
b_attn, _ = L["attn_bt"](b, t, t) |
|
|
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
|
|
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
|
|
|
|
|
z = torch.cat([t[0], b[0]], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
|
|
|
class CrossAttnUnpooled(nn.Module): |
|
|
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
|
|
super().__init__() |
|
|
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
|
|
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(n_layers): |
|
|
self.layers.append(nn.ModuleDict({ |
|
|
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
|
|
"n1t": nn.LayerNorm(hidden), |
|
|
"n2t": nn.LayerNorm(hidden), |
|
|
"n1b": nn.LayerNorm(hidden), |
|
|
"n2b": nn.LayerNorm(hidden), |
|
|
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
|
|
})) |
|
|
|
|
|
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
|
|
self.reg = nn.Linear(hidden, 1) |
|
|
self.cls = nn.Linear(hidden, 3) |
|
|
|
|
|
def masked_mean(self, X, M): |
|
|
Mf = M.unsqueeze(-1).float() |
|
|
denom = Mf.sum(dim=1).clamp(min=1.0) |
|
|
return (X * Mf).sum(dim=1) / denom |
|
|
|
|
|
def forward(self, T, Mt, B, Mb): |
|
|
T = self.t_proj(T) |
|
|
Bx = self.b_proj(B) |
|
|
|
|
|
kp_t = ~Mt |
|
|
kp_b = ~Mb |
|
|
|
|
|
for L in self.layers: |
|
|
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
|
|
T = L["n1t"](T + T_attn) |
|
|
T = L["n2t"](T + L["fft"](T)) |
|
|
|
|
|
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
|
|
Bx = L["n1b"](Bx + B_attn) |
|
|
Bx = L["n2b"](Bx + L["ffb"](Bx)) |
|
|
|
|
|
t_pool = self.masked_mean(T, Mt) |
|
|
b_pool = self.masked_mean(Bx, Mb) |
|
|
z = torch.cat([t_pool, b_pool], dim=-1) |
|
|
h = self.shared(z) |
|
|
return self.reg(h).squeeze(-1), self.cls(h) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def softmax_np(logits: np.ndarray) -> np.ndarray: |
|
|
x = logits - logits.max(axis=1, keepdims=True) |
|
|
ex = np.exp(x) |
|
|
return ex / ex.sum(axis=1, keepdims=True) |
|
|
|
|
|
def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray: |
|
|
centers = np.asarray(class_centers, dtype=np.float32)[None, :] |
|
|
return (probs * centers).sum(axis=1) |
|
|
|
|
|
def load_checkpoint(ckpt_path: str, mode: str, train_ds): |
|
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
params = ckpt.get("best_params", {}) |
|
|
|
|
|
hidden = int(params.get("hidden_dim", 512)) |
|
|
n_heads = int(params.get("n_heads", 8)) |
|
|
n_layers = int(params.get("n_layers", 3)) |
|
|
dropout = float(params.get("dropout", 0.1)) |
|
|
|
|
|
if mode == "pooled": |
|
|
Ht = len(train_ds[0]["target_embedding"]) |
|
|
Hb = len(train_ds[0]["binder_embedding"]) |
|
|
model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout) |
|
|
else: |
|
|
Ht = len(train_ds[0]["target_embedding"][0]) |
|
|
Hb = len(train_ds[0]["binder_embedding"][0]) |
|
|
model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout) |
|
|
|
|
|
model.load_state_dict(ckpt["state_dict"], strict=True) |
|
|
model.to(DEVICE).eval() |
|
|
return model |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str, |
|
|
out_csv: str, batch_size: int, num_workers: int, |
|
|
class_centers=(9.5, 8.0, 6.0)): |
|
|
train_ds, val_ds = load_split_paired(dataset_path) |
|
|
model = load_checkpoint(ckpt_path, mode, train_ds) |
|
|
|
|
|
if mode == "pooled": |
|
|
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, |
|
|
num_workers=num_workers, pin_memory=True, |
|
|
collate_fn=collate_pair_pooled) |
|
|
y_all, pred_reg_all, logits_all = [], [], [] |
|
|
for t, b, y in loader: |
|
|
t = t.to(DEVICE, non_blocking=True) |
|
|
b = b.to(DEVICE, non_blocking=True) |
|
|
pred_reg, logits = model(t, b) |
|
|
y_all.append(y.numpy()) |
|
|
pred_reg_all.append(pred_reg.detach().cpu().numpy()) |
|
|
logits_all.append(logits.detach().cpu().numpy()) |
|
|
|
|
|
else: |
|
|
loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, |
|
|
num_workers=num_workers, pin_memory=True, |
|
|
collate_fn=collate_pair_unpooled) |
|
|
y_all, pred_reg_all, logits_all = [], [], [] |
|
|
for T, Mt, B, Mb, y in loader: |
|
|
T = T.to(DEVICE, non_blocking=True) |
|
|
Mt = Mt.to(DEVICE, non_blocking=True) |
|
|
B = B.to(DEVICE, non_blocking=True) |
|
|
Mb = Mb.to(DEVICE, non_blocking=True) |
|
|
pred_reg, logits = model(T, Mt, B, Mb) |
|
|
y_all.append(y.numpy()) |
|
|
pred_reg_all.append(pred_reg.detach().cpu().numpy()) |
|
|
logits_all.append(logits.detach().cpu().numpy()) |
|
|
|
|
|
y_true = np.concatenate(y_all) |
|
|
y_pred_reg = np.concatenate(pred_reg_all) |
|
|
logits = np.concatenate(logits_all) |
|
|
|
|
|
probs = softmax_np(logits) |
|
|
y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers) |
|
|
|
|
|
|
|
|
out = Path(out_csv) |
|
|
out.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
header = [ |
|
|
"split", "mode", |
|
|
"y_true", |
|
|
"y_pred_reg", |
|
|
"p_high", "p_moderate", "p_low", |
|
|
"y_pred_cls_score", |
|
|
"center_high", "center_moderate", "center_low", |
|
|
] |
|
|
|
|
|
centers = list(class_centers) |
|
|
rows = np.column_stack([ |
|
|
y_true, |
|
|
y_pred_reg, |
|
|
probs[:, 0], probs[:, 1], probs[:, 2], |
|
|
y_pred_cls_score, |
|
|
np.full_like(y_true, centers[0], dtype=np.float32), |
|
|
np.full_like(y_true, centers[1], dtype=np.float32), |
|
|
np.full_like(y_true, centers[2], dtype=np.float32), |
|
|
]) |
|
|
|
|
|
with out.open("w") as f: |
|
|
f.write(",".join(header) + "\n") |
|
|
for i in range(rows.shape[0]): |
|
|
f.write( |
|
|
"val," + mode + "," + |
|
|
",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) + |
|
|
"\n" |
|
|
) |
|
|
|
|
|
print(f"[Data] Val N={len(y_true)} | mode={mode}") |
|
|
print(f"[Saved] {out}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)") |
|
|
ap.add_argument("--ckpt", required=True, help="Path to best_model.pt") |
|
|
ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True) |
|
|
ap.add_argument("--out_csv", required=True) |
|
|
ap.add_argument("--batch_size", type=int, default=128) |
|
|
ap.add_argument("--num_workers", type=int, default=4) |
|
|
|
|
|
|
|
|
ap.add_argument("--center_high", type=float, default=9.5) |
|
|
ap.add_argument("--center_moderate", type=float, default=8.0) |
|
|
ap.add_argument("--center_low", type=float, default=6.0) |
|
|
|
|
|
args = ap.parse_args() |
|
|
|
|
|
export_val_preds_csv( |
|
|
dataset_path=args.dataset_path, |
|
|
ckpt_path=args.ckpt, |
|
|
mode=args.mode, |
|
|
out_csv=args.out_csv, |
|
|
batch_size=args.batch_size, |
|
|
num_workers=args.num_workers, |
|
|
class_centers=(args.center_high, args.center_moderate, args.center_low), |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|