#!/usr/bin/env python """TAPE downstream evaluation for the bidirectional ProGen2 encoder. Protein analog of eval_sts.py: freeze the encoder, mean-pool a fixed embedding per sequence, fit a linear (ridge) probe on the train split, and report Spearman on the test split — the standard frozen-features protocol for judging embedding quality. Reports the FROZEN baseline (bidirectional ProGen2, no adaptation) vs the TRAINED adapter, so we can see whether MNTP/SimCSE actually improved the encoder. Multi-task: pass --tasks stability,fluorescence,homology to harden the verdict across more than one downstream task. Pass --max-train 0 / --max-test 0 to use the FULL splits (no cap). Two probe kinds (both frozen mean-pooled embedding + closed-form linear probe): reg -> ridge regression, reported as Spearman ρ clf -> one-hot ridge classifier (argmax), reported as top-1 accuracy Tasks (sequence-level; TAPE = Rao 2019, ProteinBERT = Brandes 2022): stability -> AI4Protein/TAPE_Stability reg TAPE fluorescence -> AI4Protein/TAPE_Fluorescence reg TAPE homology -> GleghornLab/bom_remote_homology clf TAPE (remote homology) fold -> GleghornLab/fold_prediction clf TAPE (fold classification) signalpeptide -> GrimSqueaker/SignalP_Binary clf ProteinBERT neuropeptide -> GrimSqueaker/ProFET_NP_SP_Cleaved clf ProteinBERT (cleaved precursor) (Token-level PTM and disorder live in eval_token.py — they are per-residue and only distributed as CSVs in the ProteinBERT data repo, not on HF.) Run on a GPU node in the pinned transformers-4.44.2 venv. """ from __future__ import annotations import argparse import os import sys import torch import torch.nn.functional as F sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from src.bidir_progen import make_bidirectional, mean_pool # noqa: E402 TASKS = { "stability": {"dataset": "AI4Protein/TAPE_Stability", "kind": "reg"}, "fluorescence": {"dataset": "AI4Protein/TAPE_Fluorescence", "kind": "reg"}, "homology": {"dataset": "GleghornLab/bom_remote_homology", "kind": "clf", "test_split": "test"}, "fold": {"dataset": "GleghornLab/fold_prediction", "kind": "clf", "test_split": "test"}, # ProteinBERT sequence-level benchmarks (Brandes 2022), via HF mirrors: "signalpeptide": {"dataset": "GrimSqueaker/SignalP_Binary", "kind": "clf", "test_split": "test"}, "neuropeptide": {"dataset": "GrimSqueaker/ProFET_NP_SP_Cleaved", "kind": "clf", "test_split": "test"}, } SEQ_COLS = ("aa_seq", "seqs", "seq", "sequence", "primary") LABEL_COLS = ("label", "labels", "target", "log_fluorescence", "stability_score") def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model-name", default="hugohrban/progen2-base") p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)") p.add_argument("--tasks", default="stability,fluorescence,homology", help="comma-separated subset of: " + ",".join(TASKS)) p.add_argument("--max-length", type=int, default=512) p.add_argument("--batch-size", type=int, default=32) p.add_argument("--max-train", type=int, default=0, help="cap train seqs (0 = full split)") p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full split)") p.add_argument("--ridge-alpha", type=float, default=10.0) return p.parse_args() def spearman(a, b): try: from scipy.stats import spearmanr return float(spearmanr(a, b).correlation) except Exception: ta = torch.tensor(a, dtype=torch.float64); tb = torch.tensor(b, dtype=torch.float64) ra = ta.argsort().argsort().double(); rb = tb.argsort().argsort().double() ra = ra - ra.mean(); rb = rb - rb.mean() return float((ra @ rb) / (ra.norm() * rb.norm() + 1e-12)) @torch.no_grad() def encode(model, tok, seqs, device, max_length, batch_size): embs = [] for i in range(0, len(seqs), batch_size): enc = tok(seqs[i:i + batch_size], padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(device) out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], output_hidden_states=True) pooled = mean_pool(out.hidden_states[-1], enc["attention_mask"]) embs.append(F.normalize(pooled.float(), dim=-1).cpu()) return torch.cat(embs, 0) def ridge_probe(xtr, ytr, xte, alpha): # Closed-form ridge (no sklearn dependency): w = (X'X + aI)^-1 X'y, with bias. Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double() Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double() ytr = torch.tensor(ytr, dtype=torch.float64).unsqueeze(1) d = Xtr.size(1) A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64) w = torch.linalg.solve(A, Xtr.t() @ ytr) return (Xte @ w).squeeze(1).tolist() def clf_probe(xtr, ytr, xte, alpha, num_classes): # Closed-form one-hot ridge classifier: W = (X'X + aI)^-1 X'Y_onehot, argmax. # A legitimate linear probe, dependency-free and consistent with ridge_probe. Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double() Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double() yt = torch.tensor(ytr, dtype=torch.long) Y = torch.zeros(Xtr.size(0), num_classes, dtype=torch.float64) Y[torch.arange(Xtr.size(0)), yt] = 1.0 d = Xtr.size(1) A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64) W = torch.linalg.solve(A, Xtr.t() @ Y) # (d, C) return (Xte @ W).argmax(1).tolist() def pairwise_cos(xte, cap=2000): # O(n^2) — subsample so a big test split doesn't blow up memory. x = xte[:cap] return float((x @ x.t())[~torch.eye(x.size(0), dtype=torch.bool)].mean()) def build_base(model_name, device): from transformers import AutoModelForCausalLM import transformers.modeling_utils as _mu if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel): _mu.PreTrainedModel.all_tied_weights_keys = {} base = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager", ) n = make_bidirectional(base) print(f"[build] bidirectional patch applied to {n} attention modules", flush=True) return base.to(device) def _pick(cols, candidates, what): for c in candidates: if c in cols: return c raise KeyError(f"no {what} column in {cols} (tried {candidates})") def load_split(dataset, split, cap, cast=float): from datasets import load_dataset try: ds = load_dataset(dataset, split=split) except Exception: alt = {"test": "valid", "valid": "test"}.get(split, split) ds = load_dataset(dataset, split=alt) seq_c = _pick(ds.column_names, SEQ_COLS, "sequence") lab_c = _pick(ds.column_names, LABEL_COLS, "label") if cap and len(ds) > cap: ds = ds.shuffle(seed=0).select(range(cap)) return list(ds[seq_c]), [cast(x) for x in ds[lab_c]] def main(): args = parse_args() from transformers import AutoTokenizer tasks = [t.strip() for t in args.tasks.split(",") if t.strip()] for t in tasks: if t not in TASKS: sys.exit(f"unknown task '{t}'; choose from {list(TASKS)}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) if tok.pad_token_id is None: tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0) # Load every task's splits up front so we encode each model once per task. data = {} for t in tasks: meta = TASKS[t] d, kind = meta["dataset"], meta["kind"] cast = int if kind == "clf" else float tr_seq, tr_y = load_split(d, "train", args.max_train, cast) te_seq, te_y = load_split(d, meta.get("test_split", "test"), args.max_test, cast) nc = (max(tr_y + te_y) + 1) if kind == "clf" else None data[t] = (tr_seq, tr_y, te_seq, te_y, kind, nc) extra = f" classes={nc}" if nc else "" print(f"[data] {t} ({kind}): train={len(tr_seq)} test={len(te_seq)}{extra}", flush=True) base = build_base(args.model_name, device) def run(model, label): out = {} for t in tasks: tr_seq, tr_y, te_seq, te_y, kind, nc = data[t] xtr = encode(model, tok, tr_seq, device, args.max_length, args.batch_size) xte = encode(model, tok, te_seq, device, args.max_length, args.batch_size) if kind == "clf": pred = clf_probe(xtr, tr_y, xte, args.ridge_alpha, nc) score = float(sum(int(p == y) for p, y in zip(pred, te_y)) / len(te_y)) metric = "Acc" else: pred = ridge_probe(xtr, tr_y, xte, args.ridge_alpha) score = spearman(pred, te_y) metric = "Spearman" print(f"[{label}] {t} {metric}={score:.4f} (probe on {len(tr_seq)} train, " f"eval {len(te_seq)} test) | mean_pairwise_cos={pairwise_cos(xte):.4f}", flush=True) out[t] = (metric, score) return out base.eval() base_s = run(base, "baseline (frozen, bidir)") trained_s = {} if args.adapter: from peft import PeftModel trained = PeftModel.from_pretrained(base, args.adapter).to(device) trained.eval() trained_s = run(trained, "trained (adapter)") print("\n=== SUMMARY ===", flush=True) for t in tasks: metric, b = base_s[t] line = f" {t:14s} [{metric}] baseline={b:.4f}" if t in trained_s: tr = trained_s[t][1] line += f" trained={tr:.4f} delta={tr - b:+.4f}" print(line, flush=True) if __name__ == "__main__": main()