Feature Extraction
PEFT
Safetensors
protein
protein-language-model
embeddings
lora
llm2vec
progen2
bidirectional
Instructions to use ratishsp/progen2-base-bidirectional-llm2vec with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use ratishsp/progen2-base-bidirectional-llm2vec with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
File size: 9,716 Bytes
e6bc942 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | #!/usr/bin/env python
"""Token-level (per-residue) evaluation for the bidirectional ProGen2 encoder.
The per-residue counterpart to eval_protein.py: the only evals that exercise the
encoder's CONTEXTUAL TOKEN representations (the proposal's token-level promise)
rather than a pooled vector. ProGen2 tokenizes one amino acid -> one token
(verified), so token i aligns to residue i with no special-token offset.
Tasks (--task):
ss3 3-state secondary structure AI4Protein/ssp_q3 (HF) TAPE
ptm phosphosite PTM (binary) PhosphositePTM.*.csv ProteinBERT
disorder intrinsic disorder (binary) disorder_secondary_structure.* ProteinBERT
PTM and disorder are distributed only as CSVs in the ProteinBERT data repo
(github.com/Brandes-Lab/proteinbert_data_files), staged locally under --data-dir.
Their per-residue label is a contiguous digit string (one char per residue); SS3's
label is a numpy-printed array of space-separated ints.
Protocol: freeze encoder, take last-layer per-residue hidden states, fit a closed-form
one-hot ridge classifier over residues. Report per-residue accuracy; for the imbalanced
binary tasks (ptm, disorder) also report majority-class baseline and ROC-AUC (the honest
metric under class imbalance). Normal equations (A=X'X, B=X'Y) are accumulated
incrementally so residue features never all sit in memory at once.
Run on a GPU node in the pinned transformers-4.44.2 venv.
"""
from __future__ import annotations
import argparse
import os
import re
import sys
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.bidir_progen import make_bidirectional # noqa: E402
TOKEN_TASKS = {
"ss3": {"source": "hf", "dataset": "AI4Protein/ssp_q3", "seq": "aa_seq",
"label": "label", "fmt": "tokens", "num_classes": 3, "binary": False},
"ptm": {"source": "csv", "base": "PhosphositePTM", "seq": "seq",
"label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
"disorder": {"source": "csv", "base": "disorder_secondary_structure", "seq": "seq",
"label": "label", "fmt": "chars", "num_classes": 2, "binary": True},
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model-name", default="hugohrban/progen2-base")
p.add_argument("--adapter", default=None, help="trained LoRA adapter dir (optional)")
p.add_argument("--task", default="ss3", choices=list(TOKEN_TASKS))
p.add_argument("--data-dir", default="./pbert_data",
help="dir holding the ProteinBERT CSVs (for ptm/disorder)")
p.add_argument("--max-length", type=int, default=512)
p.add_argument("--batch-size", type=int, default=16)
p.add_argument("--max-train", type=int, default=6000, help="cap train seqs (0 = full)")
p.add_argument("--max-test", type=int, default=0, help="cap test seqs (0 = full)")
p.add_argument("--ridge-alpha", type=float, default=10.0)
return p.parse_args()
def parse_labels(s, fmt):
if fmt == "chars":
return [int(c) for c in str(s).strip() if c.isdigit()]
return [int(x) for x in re.findall(r"-?\d+", str(s))] # "tokens"
def load_split(cfg, split, cap, data_dir):
"""Return (list[str] seqs, list[list[int]] per-residue labels)."""
if cfg["source"] == "hf":
from datasets import load_dataset
ds = load_dataset(cfg["dataset"], split=split)
if cap and len(ds) > cap:
ds = ds.shuffle(seed=0).select(range(cap))
seqs = list(ds[cfg["seq"]])
labs = [parse_labels(x, cfg["fmt"]) for x in ds[cfg["label"]]]
else:
import csv as _csv
path = os.path.join(data_dir, f"{cfg['base']}.{split}.csv")
seqs, labs = [], []
with open(path) as fh:
r = _csv.DictReader(fh)
for i, row in enumerate(r):
if cap and i >= cap:
break
seqs.append(row[cfg["seq"]])
labs.append(parse_labels(row[cfg["label"]], cfg["fmt"]))
return seqs, labs
def build_base(model_name, device):
from transformers import AutoModelForCausalLM
import transformers.modeling_utils as _mu
if "all_tied_weights_keys" not in vars(_mu.PreTrainedModel):
_mu.PreTrainedModel.all_tied_weights_keys = {}
base = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True, torch_dtype=torch.bfloat16,
attn_implementation="eager",
)
n = make_bidirectional(base)
print(f"[build] bidirectional patch applied to {n} attention modules", flush=True)
return base.to(device)
@torch.no_grad()
def residue_features(model, tok, seqs, labels, device, max_length, batch_size):
"""Yield (feats (R,H) float64, labs (R,) long) per batch — labelled residues only,
aligned token i <-> residue i (no special tokens added)."""
for i in range(0, len(seqs), batch_size):
chunk_s = seqs[i:i + batch_size]
chunk_l = labels[i:i + batch_size]
enc = tok(chunk_s, padding=True, truncation=True, max_length=max_length,
add_special_tokens=False, return_tensors="pt").to(device)
h = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
output_hidden_states=True).hidden_states[-1] # (B, T, H)
feats, labs = [], []
for b in range(h.size(0)):
n = int(enc["attention_mask"][b].sum())
n = min(n, len(chunk_l[b]))
if n <= 0:
continue
feats.append(h[b, :n].float().cpu())
labs.append(torch.tensor(chunk_l[b][:n], dtype=torch.long))
if feats:
yield torch.cat(feats, 0).double(), torch.cat(labs, 0)
def roc_auc(scores, labels):
# Rank-based (Mann-Whitney U) AUC for binary labels; no sklearn dependency.
s = torch.as_tensor(scores, dtype=torch.float64)
y = torch.as_tensor(labels, dtype=torch.long)
n_pos = int((y == 1).sum()); n_neg = int((y == 0).sum())
if n_pos == 0 or n_neg == 0:
return float("nan")
order = s.argsort()
ranks = torch.empty_like(s)
ranks[order] = torch.arange(1, s.numel() + 1, dtype=torch.float64) # ties ignored (ok for probe scores)
sum_pos = float(ranks[y == 1].sum())
return (sum_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg)
def evaluate(tag, model, tok, cfg, tr, te, device, args):
model.eval()
C = cfg["num_classes"]
A = B = None
n_res = 0
for feats, labs in residue_features(model, tok, tr[0], tr[1], device,
args.max_length, args.batch_size):
if A is None:
D = feats.size(1) + 1
A = torch.zeros(D, D, dtype=torch.float64)
B = torch.zeros(D, C, dtype=torch.float64)
X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
Y = torch.zeros(feats.size(0), C, dtype=torch.float64)
Y[torch.arange(feats.size(0)), labs] = 1.0
A += X.t() @ X
B += X.t() @ Y
n_res += feats.size(0)
W = torch.linalg.solve(A + args.ridge_alpha * torch.eye(A.size(0), dtype=torch.float64), B)
correct = total = 0
cls_count = torch.zeros(C, dtype=torch.long)
scores, ys = [], []
for feats, labs in residue_features(model, tok, te[0], te[1], device,
args.max_length, args.batch_size):
X = torch.cat([feats, torch.ones(feats.size(0), 1, dtype=torch.float64)], 1)
logits = X @ W
pred = logits.argmax(1)
correct += int((pred == labs).sum())
total += labs.numel()
cls_count += torch.bincount(labs, minlength=C)
if cfg["binary"]:
scores.append((logits[:, 1] - logits[:, 0]).cpu())
ys.append(labs.cpu())
acc = correct / max(total, 1)
auc = roc_auc(torch.cat(scores), torch.cat(ys)) if cfg["binary"] else None
msg = f"[{tag}] {args.task} acc={acc:.4f} (train_res={n_res}, eval_res={total})"
if cfg["binary"]:
majority = float(cls_count.max()) / max(total, 1)
pos_frac = float(cls_count[1]) / max(total, 1)
msg += f" | majority={majority:.4f} pos_frac={pos_frac:.4f} AUC={auc:.4f}"
print(msg, flush=True)
return acc, auc
def main():
args = parse_args()
from transformers import AutoTokenizer
cfg = TOKEN_TASKS[args.task]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tok = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or tok.convert_ids_to_tokens(0)
tr = load_split(cfg, "train", args.max_train, args.data_dir)
te = load_split(cfg, "test", args.max_test, args.data_dir)
print(f"[data] {args.task} train_seqs={len(tr[0])} test_seqs={len(te[0])} "
f"classes={cfg['num_classes']}", flush=True)
base = build_base(args.model_name, device)
base_acc, base_auc = evaluate("baseline (frozen, bidir)", base, tok, cfg, tr, te, device, args)
trained_acc = trained_auc = None
if args.adapter:
from peft import PeftModel
trained = PeftModel.from_pretrained(base, args.adapter).to(device)
trained_acc, trained_auc = evaluate("trained (adapter)", trained, tok, cfg, tr, te,
device, args)
print(f"\n=== SUMMARY ({args.task}) ===", flush=True)
metric, b, t = ("AUC", base_auc, trained_auc) if cfg["binary"] else ("acc", base_acc, trained_acc)
line = f" {args.task:20s} [{metric}] baseline={b:.4f}"
if t is not None:
line += f" trained={t:.4f} delta={t - b:+.4f}"
print(line, flush=True)
if __name__ == "__main__":
main()
|