ratishsp's picture
Bidirectional ProGen2 LoRA adapter + 9-task benchmark + code
e6bc942
Raw
History Blame Contribute Delete
10 kB
#!/usr/bin/env python
"""TAPE downstream evaluation for the bidirectional ProGen2 encoder.
Protein analog of eval_sts.py: freeze the encoder, mean-pool a fixed embedding
per sequence, fit a linear (ridge) probe on the train split, and report Spearman
on the test split — the standard frozen-features protocol for judging embedding
quality. Reports the FROZEN baseline (bidirectional ProGen2, no adaptation) vs the
TRAINED adapter, so we can see whether MNTP/SimCSE actually improved the encoder.
Multi-task: pass --tasks stability,fluorescence,homology to harden the verdict
across more than one downstream task. Pass --max-train 0 / --max-test 0 to use the
FULL splits (no cap).
Two probe kinds (both frozen mean-pooled embedding + closed-form linear probe):
reg -> ridge regression, reported as Spearman ρ
clf -> one-hot ridge classifier (argmax), reported as top-1 accuracy
Tasks (sequence-level; TAPE = Rao 2019, ProteinBERT = Brandes 2022):
stability -> AI4Protein/TAPE_Stability reg TAPE
fluorescence -> AI4Protein/TAPE_Fluorescence reg TAPE
homology -> GleghornLab/bom_remote_homology clf TAPE (remote homology)
fold -> GleghornLab/fold_prediction clf TAPE (fold classification)
signalpeptide -> GrimSqueaker/SignalP_Binary clf ProteinBERT
neuropeptide -> GrimSqueaker/ProFET_NP_SP_Cleaved clf ProteinBERT (cleaved precursor)
(Token-level PTM and disorder live in eval_token.py — they are per-residue and only
distributed as CSVs in the ProteinBERT data repo, not on HF.)
Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations
import argparse
import os
import sys
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional, mean_pool # noqa: E402
TASKS = {
"stability": {"dataset": "AI4Protein/TAPE_Stability", "kind": "reg"},
"fluorescence": {"dataset": "AI4Protein/TAPE_Fluorescence", "kind": "reg"},
"homology": {"dataset": "GleghornLab/bom_remote_homology", "kind": "clf",
"test_split": "test"},
"fold": {"dataset": "GleghornLab/fold_prediction", "kind": "clf",
"test_split": "test"},
# ProteinBERT sequence-level benchmarks (Brandes 2022), via HF mirrors:
"signalpeptide": {"dataset": "GrimSqueaker/SignalP_Binary", "kind": "clf",
"test_split": "test"},
"neuropeptide": {"dataset": "GrimSqueaker/ProFET_NP_SP_Cleaved", "kind": "clf",
"test_split": "test"},
}
SEQ_COLS = ("aa_seq", "seqs", "seq", "sequence", "primary")
LABEL_COLS = ("label", "labels", "target", "log_fluorescence", "stability_score")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model-name", default="hugohrban/progen2-base")
p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
p.add_argument("--tasks", default="stability,fluorescence,homology",
help="comma-separated subset of: " + ",".join(TASKS))
p.add_argument("--max-length", type=int, default=512)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--max-train", type=int, default=0, help="cap train seqs (0 = full split)")
p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full split)")
p.add_argument("--ridge-alpha", type=float, default=10.0)
return p.parse_args()
def spearman(a, b):
try:
from scipy.stats import spearmanr
return float(spearmanr(a, b).correlation)
except Exception:
ta = torch.tensor(a, dtype=torch.float64); tb = torch.tensor(b, dtype=torch.float64)
ra = ta.argsort().argsort().double(); rb = tb.argsort().argsort().double()
ra = ra - ra.mean(); rb = rb - rb.mean()
return float((ra @ rb) / (ra.norm() * rb.norm() + 1e-12))
@torch.no_grad()
def encode(model, tok, seqs, device, max_length, batch_size):
embs = []
for i in range(0, len(seqs), batch_size):
enc = tok(seqs[i:i + batch_size], padding=True, truncation=True,
max_length=max_length, return_tensors="pt").to(device)
out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
output_hidden_states=True)
pooled = mean_pool(out.hidden_states[-1], enc["attention_mask"])
embs.append(F.normalize(pooled.float(), dim=-1).cpu())
return torch.cat(embs, 0)
def ridge_probe(xtr, ytr, xte, alpha):
# Closed-form ridge (no sklearn dependency): w = (X'X + aI)^-1 X'y, with bias.
Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
ytr = torch.tensor(ytr, dtype=torch.float64).unsqueeze(1)
d = Xtr.size(1)
A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
w = torch.linalg.solve(A, Xtr.t() @ ytr)
return (Xte @ w).squeeze(1).tolist()
def clf_probe(xtr, ytr, xte, alpha, num_classes):
# Closed-form one-hot ridge classifier: W = (X'X + aI)^-1 X'Y_onehot, argmax.
# A legitimate linear probe, dependency-free and consistent with ridge_probe.
Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
yt = torch.tensor(ytr, dtype=torch.long)
Y = torch.zeros(Xtr.size(0), num_classes, dtype=torch.float64)
Y[torch.arange(Xtr.size(0)), yt] = 1.0
d = Xtr.size(1)
A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
W = torch.linalg.solve(A, Xtr.t() @ Y) # (d, C)
return (Xte @ W).argmax(1).tolist()
def pairwise_cos(xte, cap=2000):
# O(n^2) — subsample so a big test split doesn't blow up memory.
x = xte[:cap]
return float((x @ x.t())[~torch.eye(x.size(0), dtype=torch.bool)].mean())
def build_base(model_name, device):
from transformers import AutoModelForCausalLM
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
n = make_bidirectional(base)
print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
return base.to(device)
def _pick(cols, candidates, what):
for c in candidates:
if c in cols:
return c
raise KeyError(f"no {what} column in {cols} (tried {candidates})")
def load_split(dataset, split, cap, cast=float):
from datasets import load_dataset
try:
ds = load_dataset(dataset, split=split)
except Exception:
alt = {"test": "valid", "valid": "test"}.get(split, split)
ds = load_dataset(dataset, split=alt)
seq_c = _pick(ds.column_names, SEQ_COLS, "sequence")
lab_c = _pick(ds.column_names, LABEL_COLS, "label")
if cap and len(ds) > cap:
ds = ds.shuffle(seed=0).select(range(cap))
return list(ds[seq_c]), [cast(x) for x in ds[lab_c]]
def main():
args = parse_args()
from transformers import AutoTokenizer
tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
for t in tasks:
if t not in TASKS:
sys.exit(f"unknown task '{t}'; choose from {list(TASKS)}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)
# Load every task's splits up front so we encode each model once per task.
data = {}
for t in tasks:
meta = TASKS[t]
d, kind = meta["dataset"], meta["kind"]
cast = int if kind == "clf" else float
tr_seq, tr_y = load_split(d, "train", args.max_train, cast)
te_seq, te_y = load_split(d, meta.get("test_split", "test"), args.max_test, cast)
nc = (max(tr_y + te_y) + 1) if kind == "clf" else None
data[t] = (tr_seq, tr_y, te_seq, te_y, kind, nc)
extra = f" classes={nc}" if nc else ""
print(f"[data] {t} ({kind}): train={len(tr_seq)} test={len(te_seq)}{extra}", flush=True)
base = build_base(args.model_name, device)
def run(model, label):
out = {}
for t in tasks:
tr_seq, tr_y, te_seq, te_y, kind, nc = data[t]
xtr = encode(model, tok, tr_seq, device, args.max_length, args.batch_size)
xte = encode(model, tok, te_seq, device, args.max_length, args.batch_size)
if kind == "clf":
pred = clf_probe(xtr, tr_y, xte, args.ridge_alpha, nc)
score = float(sum(int(p == y) for p, y in zip(pred, te_y)) / len(te_y))
metric = "Acc"
else:
pred = ridge_probe(xtr, tr_y, xte, args.ridge_alpha)
score = spearman(pred, te_y)
metric = "Spearman"
print(f"[{label}] {t} {metric}={score:.4f} (probe on {len(tr_seq)} train, "
f"eval {len(te_seq)} test) | mean_pairwise_cos={pairwise_cos(xte):.4f}",
flush=True)
out[t] = (metric, score)
return out
base.eval()
base_s = run(base, "baseline (frozen, bidir)")
trained_s = {}
if args.adapter:
from peft import PeftModel
trained = PeftModel.from_pretrained(base, args.adapter).to(device)
trained.eval()
trained_s = run(trained, "trained (adapter)")
print("\n=== SUMMARY ===", flush=True)
for t in tasks:
metric, b = base_s[t]
line = f" {t:14s} [{metric}] baseline={b:.4f}"
if t in trained_s:
tr = trained_s[t][1]
line += f" trained={tr:.4f} delta={tr - b:+.4f}"
print(line, flush=True)
if __name__ == "__main__":
main()