Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python | |
| """TAPE downstream evaluation for the bidirectional ProGen2 encoder. | |
| Protein analog of eval_sts.py: freeze the encoder, mean-pool a fixed embedding | |
| per sequence, fit a linear (ridge) probe on the train split, and report Spearman | |
| on the test split — the standard frozen-features protocol for judging embedding | |
| quality. Reports the FROZEN baseline (bidirectional ProGen2, no adaptation) vs the | |
| TRAINED adapter, so we can see whether MNTP/SimCSE actually improved the encoder. | |
| Multi-task: pass --tasks stability,fluorescence,homology to harden the verdict | |
| across more than one downstream task. Pass --max-train 0 / --max-test 0 to use the | |
| FULL splits (no cap). | |
| Two probe kinds (both frozen mean-pooled embedding + closed-form linear probe): | |
| reg -> ridge regression, reported as Spearman ρ | |
| clf -> one-hot ridge classifier (argmax), reported as top-1 accuracy | |
| Tasks (sequence-level; TAPE = Rao 2019, ProteinBERT = Brandes 2022): | |
| stability -> AI4Protein/TAPE_Stability reg TAPE | |
| fluorescence -> AI4Protein/TAPE_Fluorescence reg TAPE | |
| homology -> GleghornLab/bom_remote_homology clf TAPE (remote homology) | |
| fold -> GleghornLab/fold_prediction clf TAPE (fold classification) | |
| signalpeptide -> GrimSqueaker/SignalP_Binary clf ProteinBERT | |
| neuropeptide -> GrimSqueaker/ProFET_NP_SP_Cleaved clf ProteinBERT (cleaved precursor) | |
| (Token-level PTM and disorder live in eval_token.py — they are per-residue and only | |
| distributed as CSVs in the ProteinBERT data repo, not on HF.) | |
| Run on a GPU node in the pinned transformers-4.44.2 venv. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn.functional as F | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from src.bidir_progen import make_bidirectional, mean_pool # noqa: E402 | |
| TASKS = { | |
| "stability": {"dataset": "AI4Protein/TAPE_Stability", "kind": "reg"}, | |
| "fluorescence": {"dataset": "AI4Protein/TAPE_Fluorescence", "kind": "reg"}, | |
| "homology": {"dataset": "GleghornLab/bom_remote_homology", "kind": "clf", | |
| "test_split": "test"}, | |
| "fold": {"dataset": "GleghornLab/fold_prediction", "kind": "clf", | |
| "test_split": "test"}, | |
| # ProteinBERT sequence-level benchmarks (Brandes 2022), via HF mirrors: | |
| "signalpeptide": {"dataset": "GrimSqueaker/SignalP_Binary", "kind": "clf", | |
| "test_split": "test"}, | |
| "neuropeptide": {"dataset": "GrimSqueaker/ProFET_NP_SP_Cleaved", "kind": "clf", | |
| "test_split": "test"}, | |
| } | |
| SEQ_COLS = ("aa_seq", "seqs", "seq", "sequence", "primary") | |
| LABEL_COLS = ("label", "labels", "target", "log_fluorescence", "stability_score") | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--model-name", default="hugohrban/progen2-base") | |
| p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)") | |
| p.add_argument("--tasks", default="stability,fluorescence,homology", | |
| help="comma-separated subset of: " + ",".join(TASKS)) | |
| p.add_argument("--max-length", type=int, default=512) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--max-train", type=int, default=0, help="cap train seqs (0 = full split)") | |
| p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full split)") | |
| p.add_argument("--ridge-alpha", type=float, default=10.0) | |
| return p.parse_args() | |
| def spearman(a, b): | |
| try: | |
| from scipy.stats import spearmanr | |
| return float(spearmanr(a, b).correlation) | |
| except Exception: | |
| ta = torch.tensor(a, dtype=torch.float64); tb = torch.tensor(b, dtype=torch.float64) | |
| ra = ta.argsort().argsort().double(); rb = tb.argsort().argsort().double() | |
| ra = ra - ra.mean(); rb = rb - rb.mean() | |
| return float((ra @ rb) / (ra.norm() * rb.norm() + 1e-12)) | |
| def encode(model, tok, seqs, device, max_length, batch_size): | |
| embs = [] | |
| for i in range(0, len(seqs), batch_size): | |
| enc = tok(seqs[i:i + batch_size], padding=True, truncation=True, | |
| max_length=max_length, return_tensors="pt").to(device) | |
| out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], | |
| output_hidden_states=True) | |
| pooled = mean_pool(out.hidden_states[-1], enc["attention_mask"]) | |
| embs.append(F.normalize(pooled.float(), dim=-1).cpu()) | |
| return torch.cat(embs, 0) | |
| def ridge_probe(xtr, ytr, xte, alpha): | |
| # Closed-form ridge (no sklearn dependency): w = (X'X + aI)^-1 X'y, with bias. | |
| Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double() | |
| Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double() | |
| ytr = torch.tensor(ytr, dtype=torch.float64).unsqueeze(1) | |
| d = Xtr.size(1) | |
| A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64) | |
| w = torch.linalg.solve(A, Xtr.t() @ ytr) | |
| return (Xte @ w).squeeze(1).tolist() | |
| def clf_probe(xtr, ytr, xte, alpha, num_classes): | |
| # Closed-form one-hot ridge classifier: W = (X'X + aI)^-1 X'Y_onehot, argmax. | |
| # A legitimate linear probe, dependency-free and consistent with ridge_probe. | |
| Xtr = torch.cat([xtr, torch.ones(xtr.size(0), 1)], 1).double() | |
| Xte = torch.cat([xte, torch.ones(xte.size(0), 1)], 1).double() | |
| yt = torch.tensor(ytr, dtype=torch.long) | |
| Y = torch.zeros(Xtr.size(0), num_classes, dtype=torch.float64) | |
| Y[torch.arange(Xtr.size(0)), yt] = 1.0 | |
| d = Xtr.size(1) | |
| A = Xtr.t() @ Xtr + alpha * torch.eye(d, dtype=torch.float64) | |
| W = torch.linalg.solve(A, Xtr.t() @ Y) # (d, C) | |
| return (Xte @ W).argmax(1).tolist() | |
| def pairwise_cos(xte, cap=2000): | |
| # O(n^2) — subsample so a big test split doesn't blow up memory. | |
| x = xte[:cap] | |
| return float((x @ x.t())[~torch.eye(x.size(0), dtype=torch.bool)].mean()) | |
| def build_base(model_name, device): | |
| from transformers import AutoModelForCausalLM | |
| import transformers.modeling_utils as _mu | |
| if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel): | |
| _mu.PreTrainedModel.all_tied_weights_keys = {} | |
| base = AutoModelForCausalLM.from_pretrained( | |
| model_name, trust_remote_code=True, torch_dtype=torch.bfloat16, | |
| attn_implementation="eager", | |
| ) | |
| n = make_bidirectional(base) | |
| print(f"[build] bidirectional patch applied to {n} attention modules", flush=True) | |
| return base.to(device) | |
| def _pick(cols, candidates, what): | |
| for c in candidates: | |
| if c in cols: | |
| return c | |
| raise KeyError(f"no {what} column in {cols} (tried {candidates})") | |
| def load_split(dataset, split, cap, cast=float): | |
| from datasets import load_dataset | |
| try: | |
| ds = load_dataset(dataset, split=split) | |
| except Exception: | |
| alt = {"test": "valid", "valid": "test"}.get(split, split) | |
| ds = load_dataset(dataset, split=alt) | |
| seq_c = _pick(ds.column_names, SEQ_COLS, "sequence") | |
| lab_c = _pick(ds.column_names, LABEL_COLS, "label") | |
| if cap and len(ds) > cap: | |
| ds = ds.shuffle(seed=0).select(range(cap)) | |
| return list(ds[seq_c]), [cast(x) for x in ds[lab_c]] | |
| def main(): | |
| args = parse_args() | |
| from transformers import AutoTokenizer | |
| tasks = [t.strip() for t in args.tasks.split(",") if t.strip()] | |
| for t in tasks: | |
| if t not in TASKS: | |
| sys.exit(f"unknown task '{t}'; choose from {list(TASKS)}") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) | |
| if tok.pad_token_id is None: | |
| tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0) | |
| # Load every task's splits up front so we encode each model once per task. | |
| data = {} | |
| for t in tasks: | |
| meta = TASKS[t] | |
| d, kind = meta["dataset"], meta["kind"] | |
| cast = int if kind == "clf" else float | |
| tr_seq, tr_y = load_split(d, "train", args.max_train, cast) | |
| te_seq, te_y = load_split(d, meta.get("test_split", "test"), args.max_test, cast) | |
| nc = (max(tr_y + te_y) + 1) if kind == "clf" else None | |
| data[t] = (tr_seq, tr_y, te_seq, te_y, kind, nc) | |
| extra = f" classes={nc}" if nc else "" | |
| print(f"[data] {t} ({kind}): train={len(tr_seq)} test={len(te_seq)}{extra}", flush=True) | |
| base = build_base(args.model_name, device) | |
| def run(model, label): | |
| out = {} | |
| for t in tasks: | |
| tr_seq, tr_y, te_seq, te_y, kind, nc = data[t] | |
| xtr = encode(model, tok, tr_seq, device, args.max_length, args.batch_size) | |
| xte = encode(model, tok, te_seq, device, args.max_length, args.batch_size) | |
| if kind == "clf": | |
| pred = clf_probe(xtr, tr_y, xte, args.ridge_alpha, nc) | |
| score = float(sum(int(p == y) for p, y in zip(pred, te_y)) / len(te_y)) | |
| metric = "Acc" | |
| else: | |
| pred = ridge_probe(xtr, tr_y, xte, args.ridge_alpha) | |
| score = spearman(pred, te_y) | |
| metric = "Spearman" | |
| print(f"[{label}] {t} {metric}={score:.4f} (probe on {len(tr_seq)} train, " | |
| f"eval {len(te_seq)} test) | mean_pairwise_cos={pairwise_cos(xte):.4f}", | |
| flush=True) | |
| out[t] = (metric, score) | |
| return out | |
| base.eval() | |
| base_s = run(base, "baseline (frozen, bidir)") | |
| trained_s = {} | |
| if args.adapter: | |
| from peft import PeftModel | |
| trained = PeftModel.from_pretrained(base, args.adapter).to(device) | |
| trained.eval() | |
| trained_s = run(trained, "trained (adapter)") | |
| print("\n=== SUMMARY ===", flush=True) | |
| for t in tasks: | |
| metric, b = base_s[t] | |
| line = f" {t:14s} [{metric}] baseline={b:.4f}" | |
| if t in trained_s: | |
| tr = trained_s[t][1] | |
| line += f" trained={tr:.4f} delta={tr - b:+.4f}" | |
| print(line, flush=True) | |
| if __name__ == "__main__": | |
| main() | |