#!/usr/bin/env python """Token-level (per-residue) evaluation for the bidirectional ProGen2 encoder. The per-residue counterpart to eval_protein.py: the only evals that exercise the encoder's CONTEXTUAL TOKEN representations (the proposal's token-level promise) rather than a pooled vector. ProGen2 tokenizes one amino acid -> one token (verified), so token i aligns to residue i with no special-token offset. Tasks (--task): ss3 3-state secondary structure AI4Protein/ssp_q3 (HF) TAPE ptm phosphosite PTM (binary) PhosphositePTM.*.csv ProteinBERT disorder intrinsic disorder (binary) disorder_secondary_structure.* ProteinBERT PTM and disorder are distributed only as CSVs in the ProteinBERT data repo (github.com/Brandes-Lab/proteinbert_data_files), staged locally under --data-dir. Their per-residue label is a contiguous digit string (one char per residue); SS3's label is a numpy-printed array of space-separated ints. Protocol: freeze encoder, take last-layer per-residue hidden states, fit a closed-form one-hot ridge classifier over residues. Report per-residue accuracy; for the imbalanced binary tasks (ptm, disorder) also report majority-class baseline and ROC-AUC (the honest metric under class imbalance). Normal equations (A=X'X, B=X'Y) are accumulated incrementally so residue features never all sit in memory at once. Run on a GPU node in the pinned transformers-4.44.2 venv. """ from __future__ import annotations import argparse import os import re import sys import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from src.bidir_progen import make_bidirectional # noqa: E402 TOKEN_TASKS = { "ss3": {"source": "hf", "dataset": "AI4Protein/ssp_q3", "seq": "aa_seq", "label": "label", "fmt": "tokens", "num_classes": 3, "binary": False}, "ptm": {"source": "csv", "base": "PhosphositePTM", "seq": "seq", "label": "label", "fmt": "chars", "num_classes": 2, "binary": True}, "disorder": {"source": "csv", "base": "disorder_secondary_structure", "seq": "seq", "label": "label", "fmt": "chars", "num_classes": 2, "binary": True}, } def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model-name", default="hugohrban/progen2-base") p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)") p.add_argument("--task", default="ss3", choices=list(TOKEN_TASKS)) p.add_argument("--data-dir", default="./pbert_data", help="dir holding the ProteinBERT CSVs (for ptm/disorder)") p.add_argument("--max-length", type=int, default=512) p.add_argument("--batch-size", type=int, default=16) p.add_argument("--max-train", type=int, default=6000, help="cap train seqs (0 = full)") p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full)") p.add_argument("--ridge-alpha", type=float, default=10.0) return p.parse_args() def parse_labels(s, fmt): if fmt == "chars": return [int(c) for c in str(s).strip() if c.isdigit()] return [int(x) for x in re.findall(r"-?\d+", str(s))] # "tokens" def load_split(cfg, split, cap, data_dir): """Return (list[str] seqs, list[list[int]] per-residue labels).""" if cfg["source"] == "hf": from datasets import load_dataset ds = load_dataset(cfg["dataset"], split=split) if cap and len(ds) > cap: ds = ds.shuffle(seed=0).select(range(cap)) seqs = list(ds[cfg["seq"]]) labs = [parse_labels(x, cfg["fmt"]) for x in ds[cfg["label"]]] else: import csv as _csv path = os.path.join(data_dir, f"{cfg['base']}.{split}.csv") seqs, labs = [], [] with open(path) as fh: r = _csv.DictReader(fh) for i, row in enumerate(r): if cap and i >= cap: break seqs.append(row[cfg["seq"]]) labs.append(parse_labels(row[cfg["label"]], cfg["fmt"])) return seqs, labs def build_base(model_name, device): from transformers import AutoModelForCausalLM import transformers.modeling_utils as _mu if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel): _mu.PreTrainedModel.all_tied_weights_keys = {} base = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager", ) n = make_bidirectional(base) print(f"[build] bidirectional patch applied to {n} attention modules", flush=True) return base.to(device) @torch.no_grad() def residue_features(model, tok, seqs, labels, device, max_length, batch_size): """Yield (feats (R,H) float64, labs (R,) long) per batch — labelled residues only, aligned token i <-> residue i (no special tokens added).""" for i in range(0, len(seqs), batch_size): chunk_s = seqs[i:i + batch_size] chunk_l = labels[i:i + batch_size] enc = tok(chunk_s, padding=True, truncation=True, max_length=max_length, add_special_tokens=False, return_tensors="pt").to(device) h = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], output_hidden_states=True).hidden_states[-1] # (B, T, H) feats, labs = [], [] for b in range(h.size(0)): n = int(enc["attention_mask"][b].sum()) n = min(n, len(chunk_l[b])) if n <= 0: continue feats.append(h[b, :n].float().cpu()) labs.append(torch.tensor(chunk_l[b][:n], dtype=torch.long)) if feats: yield torch.cat(feats, 0).double(), torch.cat(labs, 0) def roc_auc(scores, labels): # Rank-based (Mann-Whitney U) AUC for binary labels; no sklearn dependency. s = torch.as_tensor(scores, dtype=torch.float64) y = torch.as_tensor(labels, dtype=torch.long) n_pos = int((y == 1).sum()); n_neg = int((y == 0).sum()) if n_pos == 0 or n_neg == 0: return float("nan") order = s.argsort() ranks = torch.empty_like(s) ranks[order] = torch.arange(1, s.numel() + 1, dtype=torch.float64) # ties ignored (ok for probe scores) sum_pos = float(ranks[y == 1].sum()) return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) def evaluate(tag, model, tok, cfg, tr, te, device, args): model.eval() C = cfg["num_classes"] A = B = None n_res = 0 for feats, labs in residue_features(model, tok, tr[0], tr[1], device, args.max_length, args.batch_size): if A is None: D = feats.size(1) + 1 A = torch.zeros(D, D, dtype=torch.float64) B = torch.zeros(D, C, dtype=torch.float64) X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1) Y = torch.zeros(feats.size(0), C, dtype=torch.float64) Y[torch.arange(feats.size(0)), labs] = 1.0 A += X.t() @ X B += X.t() @ Y n_res += feats.size(0) W = torch.linalg.solve(A + args.ridge_alpha * torch.eye(A.size(0), dtype=torch.float64), B) correct = total = 0 cls_count = torch.zeros(C, dtype=torch.long) scores, ys = [], [] for feats, labs in residue_features(model, tok, te[0], te[1], device, args.max_length, args.batch_size): X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1) logits = X @ W pred = logits.argmax(1) correct += int((pred == labs).sum()) total += labs.numel() cls_count += torch.bincount(labs, minlength=C) if cfg["binary"]: scores.append((logits[:, 1] - logits[:, 0]).cpu()) ys.append(labs.cpu()) acc = correct / max(total, 1) auc = roc_auc(torch.cat(scores), torch.cat(ys)) if cfg["binary"] else None msg = f"[{tag}] {args.task} acc={acc:.4f} (train_res={n_res}, eval_res={total})" if cfg["binary"]: majority = float(cls_count.max()) / max(total, 1) pos_frac = float(cls_count[1]) / max(total, 1) msg += f" | majority={majority:.4f} pos_frac={pos_frac:.4f} AUC={auc:.4f}" print(msg, flush=True) return acc, auc def main(): args = parse_args() from transformers import AutoTokenizer cfg = TOKEN_TASKS[args.task] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) if tok.pad_token_id is None: tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0) tr = load_split(cfg, "train", args.max_train, args.data_dir) te = load_split(cfg, "test", args.max_test, args.data_dir) print(f"[data] {args.task} train_seqs={len(tr[0])} test_seqs={len(te[0])} " f"classes={cfg['num_classes']}", flush=True) base = build_base(args.model_name, device) base_acc, base_auc = evaluate("baseline (frozen, bidir)", base, tok, cfg, tr, te, device, args) trained_acc = trained_auc = None if args.adapter: from peft import PeftModel trained = PeftModel.from_pretrained(base, args.adapter).to(device) trained_acc, trained_auc = evaluate("trained (adapter)", trained, tok, cfg, tr, te, device, args) print(f"\n=== SUMMARY ({args.task}) ===", flush=True) metric, b, t = ("AUC", base_auc, trained_auc) if cfg["binary"] else ("acc", base_acc, trained_acc) line = f" {args.task:20s} [{metric}] baseline={b:.4f}" if t is not None: line += f" trained={t:.4f} delta={t - b:+.4f}" print(line, flush=True) if __name__ == "__main__": main()