ratishsp's picture
Bidirectional ProGen2 LoRA adapter + 9-task benchmark + code
e6bc942
Raw
History Blame Contribute Delete
9.72 kB
#!/usr/bin/env python
"""Token-level (per-residue) evaluation for the bidirectional ProGen2 encoder.
The per-residue counterpart to eval_protein.py: the only evals that exercise the
encoder's CONTEXTUAL TOKEN representations (the proposal's token-level promise)
rather than a pooled vector. ProGen2 tokenizes one amino acid -> one token
(verified), so token i aligns to residue i with no special-token offset.
Tasks (--task):
ss3 3-state secondary structure AI4Protein/ssp_q3 (HF) TAPE
ptm phosphosite PTM (binary) PhosphositePTM.*.csv ProteinBERT
disorder intrinsic disorder (binary) disorder_secondary_structure.* ProteinBERT
PTM and disorder are distributed only as CSVs in the ProteinBERT data repo
(github.com/Brandes-Lab/proteinbert_data_files), staged locally under --data-dir.
Their per-residue label is a contiguous digit string (one char per residue); SS3's
label is a numpy-printed array of space-separated ints.
Protocol: freeze encoder, take last-layer per-residue hidden states, fit a closed-form
one-hot ridge classifier over residues. Report per-residue accuracy; for the imbalanced
binary tasks (ptm, disorder) also report majority-class baseline and ROC-AUC (the honest
metric under class imbalance). Normal equations (A=X'X, B=X'Y) are accumulated
incrementally so residue features never all sit in memory at once.
Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations
import argparse
import os
import re
import sys
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional # noqa: E402
TOKEN_TASKS = {
"ss3": {"source": "hf", "dataset": "AI4Protein/ssp_q3", "seq": "aa_seq",
"label": "label", "fmt": "tokens", "num_classes": 3, "binary": False},
"ptm": {"source": "csv", "base": "PhosphositePTM", "seq": "seq",
"label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
"disorder": {"source": "csv", "base": "disorder_secondary_structure", "seq": "seq",
"label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model-name", default="hugohrban/progen2-base")
p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
p.add_argument("--task", default="ss3", choices=list(TOKEN_TASKS))
p.add_argument("--data-dir", default="./pbert_data",
help="dir holding the ProteinBERT CSVs (for ptm/disorder)")
p.add_argument("--max-length", type=int, default=512)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--max-train", type=int, default=6000, help="cap train seqs (0 = full)")
p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full)")
p.add_argument("--ridge-alpha", type=float, default=10.0)
return p.parse_args()
def parse_labels(s, fmt):
if fmt == "chars":
return [int(c) for c in str(s).strip() if c.isdigit()]
return [int(x) for x in re.findall(r"-?\d+", str(s))] # "tokens"
def load_split(cfg, split, cap, data_dir):
"""Return (list[str] seqs, list[list[int]] per-residue labels)."""
if cfg["source"] == "hf":
from datasets import load_dataset
ds = load_dataset(cfg["dataset"], split=split)
if cap and len(ds) > cap:
ds = ds.shuffle(seed=0).select(range(cap))
seqs = list(ds[cfg["seq"]])
labs = [parse_labels(x, cfg["fmt"]) for x in ds[cfg["label"]]]
else:
import csv as _csv
path = os.path.join(data_dir, f"{cfg['base']}.{split}.csv")
seqs, labs = [], []
with open(path) as fh:
r = _csv.DictReader(fh)
for i, row in enumerate(r):
if cap and i >= cap:
break
seqs.append(row[cfg["seq"]])
labs.append(parse_labels(row[cfg["label"]], cfg["fmt"]))
return seqs, labs
def build_base(model_name, device):
from transformers import AutoModelForCausalLM
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
n = make_bidirectional(base)
print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
return base.to(device)
@torch.no_grad()
def residue_features(model, tok, seqs, labels, device, max_length, batch_size):
"""Yield (feats (R,H) float64, labs (R,) long) per batch — labelled residues only,
aligned token i <-> residue i (no special tokens added)."""
for i in range(0, len(seqs), batch_size):
chunk_s = seqs[i:i + batch_size]
chunk_l = labels[i:i + batch_size]
enc = tok(chunk_s, padding=True, truncation=True, max_length=max_length,
add_special_tokens=False, return_tensors="pt").to(device)
h = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
output_hidden_states=True).hidden_states[-1] # (B, T, H)
feats, labs = [], []
for b in range(h.size(0)):
n = int(enc["attention_mask"][b].sum())
n = min(n, len(chunk_l[b]))
if n <= 0:
continue
feats.append(h[b, :n].float().cpu())
labs.append(torch.tensor(chunk_l[b][:n], dtype=torch.long))
if feats:
yield torch.cat(feats, 0).double(), torch.cat(labs, 0)
def roc_auc(scores, labels):
# Rank-based (Mann-Whitney U) AUC for binary labels; no sklearn dependency.
s = torch.as_tensor(scores, dtype=torch.float64)
y = torch.as_tensor(labels, dtype=torch.long)
n_pos = int((y == 1).sum()); n_neg = int((y == 0).sum())
if n_pos == 0 or n_neg == 0:
return float("nan")
order = s.argsort()
ranks = torch.empty_like(s)
ranks[order] = torch.arange(1, s.numel() + 1, dtype=torch.float64) # ties ignored (ok for probe scores)
sum_pos = float(ranks[y == 1].sum())
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
def evaluate(tag, model, tok, cfg, tr, te, device, args):
model.eval()
C = cfg["num_classes"]
A = B = None
n_res = 0
for feats, labs in residue_features(model, tok, tr[0], tr[1], device,
args.max_length, args.batch_size):
if A is None:
D = feats.size(1) + 1
A = torch.zeros(D, D, dtype=torch.float64)
B = torch.zeros(D, C, dtype=torch.float64)
X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
Y = torch.zeros(feats.size(0), C, dtype=torch.float64)
Y[torch.arange(feats.size(0)), labs] = 1.0
A += X.t() @ X
B += X.t() @ Y
n_res += feats.size(0)
W = torch.linalg.solve(A + args.ridge_alpha * torch.eye(A.size(0), dtype=torch.float64), B)
correct = total = 0
cls_count = torch.zeros(C, dtype=torch.long)
scores, ys = [], []
for feats, labs in residue_features(model, tok, te[0], te[1], device,
args.max_length, args.batch_size):
X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
logits = X @ W
pred = logits.argmax(1)
correct += int((pred == labs).sum())
total += labs.numel()
cls_count += torch.bincount(labs, minlength=C)
if cfg["binary"]:
scores.append((logits[:, 1] - logits[:, 0]).cpu())
ys.append(labs.cpu())
acc = correct / max(total, 1)
auc = roc_auc(torch.cat(scores), torch.cat(ys)) if cfg["binary"] else None
msg = f"[{tag}] {args.task} acc={acc:.4f} (train_res={n_res}, eval_res={total})"
if cfg["binary"]:
majority = float(cls_count.max()) / max(total, 1)
pos_frac = float(cls_count[1]) / max(total, 1)
msg += f" | majority={majority:.4f} pos_frac={pos_frac:.4f} AUC={auc:.4f}"
print(msg, flush=True)
return acc, auc
def main():
args = parse_args()
from transformers import AutoTokenizer
cfg = TOKEN_TASKS[args.task]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)
tr = load_split(cfg, "train", args.max_train, args.data_dir)
te = load_split(cfg, "test", args.max_test, args.data_dir)
print(f"[data] {args.task} train_seqs={len(tr[0])} test_seqs={len(te[0])} "
f"classes={cfg['num_classes']}", flush=True)
base = build_base(args.model_name, device)
base_acc, base_auc = evaluate("baseline (frozen, bidir)", base, tok, cfg, tr, te, device, args)
trained_acc = trained_auc = None
if args.adapter:
from peft import PeftModel
trained = PeftModel.from_pretrained(base, args.adapter).to(device)
trained_acc, trained_auc = evaluate("trained (adapter)", trained, tok, cfg, tr, te,
device, args)
print(f"\n=== SUMMARY ({args.task}) ===", flush=True)
metric, b, t = ("AUC", base_auc, trained_auc) if cfg["binary"] else ("acc", base_acc, trained_acc)
line = f" {args.task:20s} [{metric}] baseline={b:.4f}"
if t is not None:
line += f" trained={t:.4f} delta={t - b:+.4f}"
print(line, flush=True)
if __name__ == "__main__":
main()