#!/usr/bin/env python3 # export_val_preds_csv.py import argparse from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader from datasets import load_from_disk, DatasetDict # ----------------------------- # Repro / device # ----------------------------- def seed_all(seed=1986): import random random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) seed_all(1986) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ----------------------------- # Load paired DatasetDict # ----------------------------- def load_split_paired(path: str): dd = load_from_disk(path) if not isinstance(dd, DatasetDict): raise ValueError(f"Expected DatasetDict at {path}") if "train" not in dd or "val" not in dd: raise ValueError(f"DatasetDict missing train/val at {path}") return dd["train"], dd["val"] # ----------------------------- # Collate fns (same as yours) # ----------------------------- def collate_pair_pooled(batch): Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) return Pt, Pb, y def collate_pair_unpooled(batch): B = len(batch) Ht = len(batch[0]["target_embedding"][0]) Hb = len(batch[0]["binder_embedding"][0]) Lt_max = max(int(x["target_length"]) for x in batch) Lb_max = max(int(x["binder_length"]) for x in batch) Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32) Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32) Mt = torch.zeros(B, Lt_max, dtype=torch.bool) Mb = torch.zeros(B, Lb_max, dtype=torch.bool) y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32) for i, x in enumerate(batch): t = torch.tensor(x["target_embedding"], dtype=torch.float32) b = torch.tensor(x["binder_embedding"], dtype=torch.float32) lt, lb = t.shape[0], b.shape[0] Pt[i, :lt] = t Pb[i, :lb] = b Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool) Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool) return Pt, Mt, Pb, Mb, y # ----------------------------- # Models (same as yours) # ----------------------------- class CrossAttnPooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def forward(self, t_vec, b_vec): t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H) b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H) for L in self.layers: t_attn, _ = L["attn_tb"](t, b, b) t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) b_attn, _ = L["attn_bt"](b, t, t) b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) z = torch.cat([t[0], b[0]], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) class CrossAttnUnpooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def masked_mean(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom def forward(self, T, Mt, B, Mb): T = self.t_proj(T) Bx = self.b_proj(B) kp_t = ~Mt kp_b = ~Mb for L in self.layers: T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) T = L["n1t"](T + T_attn) T = L["n2t"](T + L["fft"](T)) B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) Bx = L["n1b"](Bx + B_attn) Bx = L["n2b"](Bx + L["ffb"](Bx)) t_pool = self.masked_mean(T, Mt) b_pool = self.masked_mean(Bx, Mb) z = torch.cat([t_pool, b_pool], dim=-1) h = self.shared(z) return self.reg(h).squeeze(-1), self.cls(h) # ----------------------------- # Helpers # ----------------------------- def softmax_np(logits: np.ndarray) -> np.ndarray: x = logits - logits.max(axis=1, keepdims=True) ex = np.exp(x) return ex / ex.sum(axis=1, keepdims=True) def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray: centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3) return (probs * centers).sum(axis=1) def load_checkpoint(ckpt_path: str, mode: str, train_ds): ckpt = torch.load(ckpt_path, map_location="cpu") params = ckpt.get("best_params", {}) hidden = int(params.get("hidden_dim", 512)) n_heads = int(params.get("n_heads", 8)) n_layers = int(params.get("n_layers", 3)) dropout = float(params.get("dropout", 0.1)) if mode == "pooled": Ht = len(train_ds[0]["target_embedding"]) Hb = len(train_ds[0]["binder_embedding"]) model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout) else: Ht = len(train_ds[0]["target_embedding"][0]) Hb = len(train_ds[0]["binder_embedding"][0]) model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout) model.load_state_dict(ckpt["state_dict"], strict=True) model.to(DEVICE).eval() return model @torch.no_grad() def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str, out_csv: str, batch_size: int, num_workers: int, class_centers=(9.5, 8.0, 6.0)): train_ds, val_ds = load_split_paired(dataset_path) model = load_checkpoint(ckpt_path, mode, train_ds) if mode == "pooled": loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_pair_pooled) y_all, pred_reg_all, logits_all = [], [], [] for t, b, y in loader: t = t.to(DEVICE, non_blocking=True) b = b.to(DEVICE, non_blocking=True) pred_reg, logits = model(t, b) y_all.append(y.numpy()) pred_reg_all.append(pred_reg.detach().cpu().numpy()) logits_all.append(logits.detach().cpu().numpy()) else: loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=collate_pair_unpooled) y_all, pred_reg_all, logits_all = [], [], [] for T, Mt, B, Mb, y in loader: T = T.to(DEVICE, non_blocking=True) Mt = Mt.to(DEVICE, non_blocking=True) B = B.to(DEVICE, non_blocking=True) Mb = Mb.to(DEVICE, non_blocking=True) pred_reg, logits = model(T, Mt, B, Mb) y_all.append(y.numpy()) pred_reg_all.append(pred_reg.detach().cpu().numpy()) logits_all.append(logits.detach().cpu().numpy()) y_true = np.concatenate(y_all) y_pred_reg = np.concatenate(pred_reg_all) logits = np.concatenate(logits_all) probs = softmax_np(logits) # (N,3) y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers) # Build CSV rows out = Path(out_csv) out.parent.mkdir(parents=True, exist_ok=True) header = [ "split", "mode", "y_true", "y_pred_reg", "p_high", "p_moderate", "p_low", "y_pred_cls_score", "center_high", "center_moderate", "center_low", ] centers = list(class_centers) rows = np.column_stack([ y_true, y_pred_reg, probs[:, 0], probs[:, 1], probs[:, 2], y_pred_cls_score, np.full_like(y_true, centers[0], dtype=np.float32), np.full_like(y_true, centers[1], dtype=np.float32), np.full_like(y_true, centers[2], dtype=np.float32), ]) with out.open("w") as f: f.write(",".join(header) + "\n") for i in range(rows.shape[0]): f.write( "val," + mode + "," + ",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) + "\n" ) print(f"[Data] Val N={len(y_true)} | mode={mode}") print(f"[Saved] {out}") def main(): ap = argparse.ArgumentParser() ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)") ap.add_argument("--ckpt", required=True, help="Path to best_model.pt") ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True) ap.add_argument("--out_csv", required=True) ap.add_argument("--batch_size", type=int, default=128) ap.add_argument("--num_workers", type=int, default=4) # Optional: choose class-centers for expected-score conversion ap.add_argument("--center_high", type=float, default=9.5) ap.add_argument("--center_moderate", type=float, default=8.0) ap.add_argument("--center_low", type=float, default=6.0) args = ap.parse_args() export_val_preds_csv( dataset_path=args.dataset_path, ckpt_path=args.ckpt, mode=args.mode, out_csv=args.out_csv, batch_size=args.batch_size, num_workers=args.num_workers, class_centers=(args.center_high, args.center_moderate, args.center_low), ) if __name__ == "__main__": main()