Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
File size: 10,004 Bytes
e6bc942 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | #!/usr/bin/env python
"""TAPE downstream evaluation for the bidirectional ProGen2 encoder.
Protein analog of eval_sts.py: freeze the encoder, mean-pool a fixed embedding
per sequence, fit a linear (ridge) probe on the train split, and report Spearman
on the test split — the standard frozen-features protocol for judging embedding
quality. Reports the FROZEN baseline (bidirectional ProGen2, no adaptation) vs the
TRAINED adapter, so we can see whether MNTP/SimCSE actually improved the encoder.
Multi-task: pass --tasks stability,fluorescence,homology to harden the verdict
across more than one downstream task. Pass --max-train 0 / --max-test 0 to use the
FULL splits (no cap).
Two probe kinds (both frozen mean-pooled embedding + closed-form linear probe):
reg -> ridge regression, reported as Spearman ρ
clf -> one-hot ridge classifier (argmax), reported as top-1 accuracy
Tasks (sequence-level; TAPE = Rao 2019, ProteinBERT = Brandes 2022):
stability -> AI4Protein/TAPE_Stability reg TAPE
fluorescence -> AI4Protein/TAPE_Fluorescence reg TAPE
homology -> GleghornLab/bom_remote_homology clf TAPE (remote homology)
fold -> GleghornLab/fold_prediction clf TAPE (fold classification)
signalpeptide -> GrimSqueaker/SignalP_Binary clf ProteinBERT
neuropeptide -> GrimSqueaker/ProFET_NP_SP_Cleaved clf ProteinBERT (cleaved precursor)
(Token-level PTM and disorder live in eval_token.py — they are per-residue and only
distributed as CSVs in the ProteinBERT data repo, not on HF.)
Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations
import argparse
import os
import sys
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional, mean_pool # noqa: E402
TASKS = {
"stability": {"dataset": "AI4Protein/TAPE_Stability", "kind": "reg"},
"fluorescence": {"dataset": "AI4Protein/TAPE_Fluorescence", "kind": "reg"},
"homology": {"dataset": "GleghornLab/bom_remote_homology", "kind": "clf",
"test_split": "test"},
"fold": {"dataset": "GleghornLab/fold_prediction", "kind": "clf",
"test_split": "test"},
# ProteinBERT sequence-level benchmarks (Brandes 2022), via HF mirrors:
"signalpeptide": {"dataset": "GrimSqueaker/SignalP_Binary", "kind": "clf",
"test_split": "test"},
"neuropeptide": {"dataset": "GrimSqueaker/ProFET_NP_SP_Cleaved", "kind": "clf",
"test_split": "test"},
}
SEQ_COLS = ("aa_seq", "seqs", "seq", "sequence", "primary")
LABEL_COLS = ("label", "labels", "target", "log_fluorescence", "stability_score")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model-name", default="hugohrban/progen2-base")
p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
p.add_argument("--tasks", default="stability,fluorescence,homology",
help="comma-separated subset of: " + ",".join(TASKS))
p.add_argument("--max-length", type=int, default=512)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--max-train", type=int, default=0, help="cap train seqs (0 = full split)")
p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full split)")
p.add_argument("--ridge-alpha", type=float, default=10.0)
return p.parse_args()
def spearman(a, b):
try:
from scipy.stats import spearmanr
return float(spearmanr(a, b).correlation)
except Exception:
ta = torch.tensor(a, dtype=torch.float64); tb = torch.tensor(b, dtype=torch.float64)
ra = ta.argsort().argsort().double(); rb = tb.argsort().argsort().double()
ra = ra - ra.mean(); rb = rb - rb.mean()
return float((ra @ rb) / (ra.norm() * rb.norm() + 1e-12))
@torch.no_grad()
def encode(model, tok, seqs, device, max_length, batch_size):
embs = []
for i in range(0, len(seqs), batch_size):
enc = tok(seqs[i:i + batch_size], padding=True, truncation=True,
max_length=max_length, return_tensors="pt").to(device)
out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
output_hidden_states=True)
pooled = mean_pool(out.hidden_states[-1], enc["attention_mask"])
embs.append(F.normalize(pooled.float(), dim=-1).cpu())
return torch.cat(embs, 0)
def ridge_probe(xtr, ytr, xte, alpha):
# Closed-form ridge (no sklearn dependency): w = (X'X + aI)^-1 X'y, with bias.
Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
ytr = torch.tensor(ytr, dtype=torch.float64).unsqueeze(1)
d = Xtr.size(1)
A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
w = torch.linalg.solve(A, Xtr.t() @ ytr)
return (Xte @ w).squeeze(1).tolist()
def clf_probe(xtr, ytr, xte, alpha, num_classes):
# Closed-form one-hot ridge classifier: W = (X'X + aI)^-1 X'Y_onehot, argmax.
# A legitimate linear probe, dependency-free and consistent with ridge_probe.
Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double()
Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double()
yt = torch.tensor(ytr, dtype=torch.long)
Y = torch.zeros(Xtr.size(0), num_classes, dtype=torch.float64)
Y[torch.arange(Xtr.size(0)), yt] = 1.0
d = Xtr.size(1)
A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64)
W = torch.linalg.solve(A, Xtr.t() @ Y) # (d, C)
return (Xte @ W).argmax(1).tolist()
def pairwise_cos(xte, cap=2000):
# O(n^2) — subsample so a big test split doesn't blow up memory.
x = xte[:cap]
return float((x @ x.t())[~torch.eye(x.size(0), dtype=torch.bool)].mean())
def build_base(model_name, device):
from transformers import AutoModelForCausalLM
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
n = make_bidirectional(base)
print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
return base.to(device)
def _pick(cols, candidates, what):
for c in candidates:
if c in cols:
return c
raise KeyError(f"no {what} column in {cols} (tried {candidates})")
def load_split(dataset, split, cap, cast=float):
from datasets import load_dataset
try:
ds = load_dataset(dataset, split=split)
except Exception:
alt = {"test": "valid", "valid": "test"}.get(split, split)
ds = load_dataset(dataset, split=alt)
seq_c = _pick(ds.column_names, SEQ_COLS, "sequence")
lab_c = _pick(ds.column_names, LABEL_COLS, "label")
if cap and len(ds) > cap:
ds = ds.shuffle(seed=0).select(range(cap))
return list(ds[seq_c]), [cast(x) for x in ds[lab_c]]
def main():
args = parse_args()
from transformers import AutoTokenizer
tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
for t in tasks:
if t not in TASKS:
sys.exit(f"unknown task '{t}'; choose from {list(TASKS)}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)
# Load every task's splits up front so we encode each model once per task.
data = {}
for t in tasks:
meta = TASKS[t]
d, kind = meta["dataset"], meta["kind"]
cast = int if kind == "clf" else float
tr_seq, tr_y = load_split(d, "train", args.max_train, cast)
te_seq, te_y = load_split(d, meta.get("test_split", "test"), args.max_test, cast)
nc = (max(tr_y + te_y) + 1) if kind == "clf" else None
data[t] = (tr_seq, tr_y, te_seq, te_y, kind, nc)
extra = f" classes={nc}" if nc else ""
print(f"[data] {t} ({kind}): train={len(tr_seq)} test={len(te_seq)}{extra}", flush=True)
base = build_base(args.model_name, device)
def run(model, label):
out = {}
for t in tasks:
tr_seq, tr_y, te_seq, te_y, kind, nc = data[t]
xtr = encode(model, tok, tr_seq, device, args.max_length, args.batch_size)
xte = encode(model, tok, te_seq, device, args.max_length, args.batch_size)
if kind == "clf":
pred = clf_probe(xtr, tr_y, xte, args.ridge_alpha, nc)
score = float(sum(int(p == y) for p, y in zip(pred, te_y)) / len(te_y))
metric = "Acc"
else:
pred = ridge_probe(xtr, tr_y, xte, args.ridge_alpha)
score = spearman(pred, te_y)
metric = "Spearman"
print(f"[{label}] {t} {metric}={score:.4f} (probe on {len(tr_seq)} train, "
f"eval {len(te_seq)} test) | mean_pairwise_cos={pairwise_cos(xte):.4f}",
flush=True)
out[t] = (metric, score)
return out
base.eval()
base_s = run(base, "baseline (frozen, bidir)")
trained_s = {}
if args.adapter:
from peft import PeftModel
trained = PeftModel.from_pretrained(base, args.adapter).to(device)
trained.eval()
trained_s = run(trained, "trained (adapter)")
print("\n=== SUMMARY ===", flush=True)
for t in tasks:
metric, b = base_s[t]
line = f" {t:14s} [{metric}] baseline={b:.4f}"
if t in trained_s:
tr = trained_s[t][1]
line += f" trained={tr:.4f} delta={tr - b:+.4f}"
print(line, flush=True)
if __name__ == "__main__":
main()
|