ratishsp's picture
Bidirectional ProGen2 LoRA adapter + 9-task benchmark + code
e6bc942
|
Raw
History Blame Contribute Delete
6.51 kB
---
license: bsd-3-clause
base_model: hugohrban/progen2-base
library_name: peft
pipeline_tag: feature-extraction
tags:
- protein
- protein-language-model
- embeddings
- lora
- peft
- llm2vec
- progen2
- bidirectional
---
# ProGen2-base → Bidirectional Protein Encoder (LLM2Vec recipe)
A LoRA adapter that converts the generative (decoder-only) protein language model
[`hugohrban/progen2-base`](https://huggingface.co/hugohrban/progen2-base) into a
bidirectional sequence encoder for producing fixed-length protein embeddings,
following the [LLM2Vec](https://arxiv.org/abs/2404.05961) recipe adapted to proteins.
The adaptation has three ingredients:
1. **Bidirectional attention**: the causal triangular attention mask is removed so every
residue attends to the whole sequence.
2. **Two-stage training**: masked next-token prediction (MNTP) first, then a separate
SimCSE contrastive stage (same sequence encoded twice under dropout = a positive
pair), resuming the MNTP adapter.
3. **LoRA**: only low-rank adapters are trained; the base weights are frozen.
## Results: full benchmark (frozen encoder + linear probe)
Evaluated on all 9 tasks from the [TAPE](https://github.com/songlab-cal/tape) (Rao et al.
2019) and [ProteinBERT](https://github.com/nadavbra/protein_bert) (Brandes et al. 2022)
benchmark suites. Protocol: freeze the encoder, take a fixed representation (mean-pooled for
sequence-level tasks, per-residue for token-level), fit a closed-form linear probe on the
train split, and score the test split. "Baseline" = the same bidirectional ProGen2 without
this adapter.
Validation vs. held-out. One of the nine tasks, Stability (a protein-stability
regression benchmark from TAPE), was used to pick the recipe: across a 10-config sweep, the
hyperparameters scoring highest on Stability were kept. The other 8 tasks were then evaluated
once on that fixed configuration, so their numbers are held-out generalization, not values
the recipe was tuned to maximize.
**Sequence-level**
| Task | metric | baseline | this adapter | Δ |
|------|--------|----------|------------------|------|
| Stability | Spearman ρ | 0.403 | 0.661 | +0.258 |
| Fluorescence | Spearman ρ | 0.157 | 0.367 | +0.210 |
| Remote homology | accuracy | 0.063 | 0.102 | +0.039 |
| Fold class. | accuracy | 0.099 | 0.210 | +0.112 |
| Signal peptide | accuracy | 0.838 | 0.931 | +0.093 |
| Neuropeptide | accuracy | 0.691 | 0.932 | +0.240 |
**Token-level (per-residue)**
| Task | metric | baseline | this adapter | Δ |
|------|--------|----------|------------------|------|
| Secondary structure (SS3) | Q3 accuracy | 0.535 | 0.590 | +0.056 |
| PTM (phosphosite) | ROC-AUC | 0.908 | 0.920 | +0.012 |
| Disorder | ROC-AUC | 0.745 | 0.822 | +0.077 |
The adapter improves the frozen representation on every task, at both the sequence and
per-residue level. PTM and disorder are
reported as ROC-AUC because they are heavily imbalanced (majority class ≈ 0.98), making raw
accuracy uninformative. Absolute values are those of a *frozen linear probe* (an
embedding-quality measurement), not task-specific fine-tuning.
## Usage
```python
import torch, torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE = "hugohrban/progen2-base"
ADAPTER = "ratishsp/progen2-base-bidirectional-llm2vec" # this repo
tok = AutoTokenizer.from_pretrained(BASE, trust_remote_code=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE, trust_remote_code=True, torch_dtype=torch.bfloat16,
attn_implementation="eager", # required: the bidirectional mask edit needs eager attention
)
# 1) make attention bidirectional: flip each attention module's causal `bias` buffer to all-True
for m in model.modules():
b = getattr(m, "bias", None)
if isinstance(b, torch.Tensor) and b.dtype == torch.bool:
b.fill_(True)
# 2) load the LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER)
model.eval().cuda()
@torch.no_grad()
def embed(seqs):
enc = tok(seqs, padding=True, truncation=True, max_length=512, return_tensors="pt").to("cuda")
h = model(input_ids=enc.input_ids, attention_mask=enc.attention_mask,
output_hidden_states=True).hidden_states[-1]
mask = enc.attention_mask.unsqueeze(-1).float()
pooled = (h * mask).sum(1) / mask.sum(1).clamp(min=1) # mean pool
return F.normalize(pooled.float(), dim=-1)
emb = embed(["MKTAYIAKQR", "MVLSPADKTNVKAAW"])
```
The full, tested implementation is bundled in this repo under [`code/`](./code): training
(`pretrain.py`), the bidirectional conversion and objectives (`src/bidir_progen.py`:
`make_bidirectional`, `mean_pool`), and all benchmark evals (`eval_protein.py`, `eval_token.py`).
## Training details
| | |
|---|---|
| Base model | `hugohrban/progen2-base` (764M, BSD-3-Clause) |
| Pretraining data | UniRef50 (`agemagician/uniref50`), streamed |
| Stage 1 (MNTP) | 2,000 steps, max-len 512 |
| Stage 2 (SimCSE) | 2,000 steps, temperature 0.10, dropout 0.1, max-len 256 |
| LoRA | r=16, α=32, targets: `qkv_proj`, `out_proj`, `fc_in`, `fc_out` |
| Optimizer | AdamW, lr 1e-4, warmup 50 |
| Precision | bf16 |
This was the best of a 10-configuration sweep. Key findings: the MNTP stage is necessary
(SimCSE-only transfers far worse), SimCSE temperature 0.10 beat 0.05/0.02, and the short
recipe outperformed longer / larger-data runs.
## Limitations
- Numbers come from a frozen linear probe, not task fine-tuning; they measure embedding
quality, so absolute values trail task-specific SOTA (e.g. SS3 Q3 ≈ 0.59 vs ~0.8+ for
dedicated predictors).
- Sequence-level uses a single mean-pooled vector; no specialised pooling head.
- No per-task hyperparameter tuning: the recipe was tuned once, on Stability (see Results).
## License & provenance
- Adapter weights: BSD-3-Clause, matching the `progen2-base` base model.
- Bundled code (`code/`): MIT; see [`code/LICENSE`](./code/LICENSE).
Research artifact produced on the [DCAI Gefion](https://dcai.dk/) cluster. The base
model and all benchmark datasets (TAPE via AI4Protein/GleghornLab mirrors; ProteinBERT
signal-peptide/neuropeptide via GrimSqueaker mirrors; PTM/disorder from the ProteinBERT data
repo) are publicly available, so the results are reproducible from public sources.
### Framework versions
- PEFT 0.12.0
- transformers 4.44.2