Initial upload: ESM-2 based stability predictor
Browse files- README.md +213 -0
- config.json +26 -0
- stability_predictor.pt +3 -0
- stability_predictor.py +244 -0
README.md
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- biology
|
| 5 |
+
- peptide
|
| 6 |
+
- protein
|
| 7 |
+
- stability
|
| 8 |
+
- esm2
|
| 9 |
+
- thermostability
|
| 10 |
+
- drug-discovery
|
| 11 |
+
- pytorch
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
library_name: pytorch
|
| 15 |
+
pipeline_tag: text-classification
|
| 16 |
+
datasets:
|
| 17 |
+
- FLIP
|
| 18 |
+
metrics:
|
| 19 |
+
- r2
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# Peptide Stability Predictor
|
| 23 |
+
|
| 24 |
+
Predict thermal stability of peptide/protein sequences using ESM-2 embeddings.
|
| 25 |
+
|
| 26 |
+
## Model Description
|
| 27 |
+
|
| 28 |
+
This model predicts the thermal stability (melting temperature proxy) of peptide and protein sequences using frozen ESM-2 embeddings passed through a trained MLP regression head. It was trained on the FLIP Meltome benchmark dataset.
|
| 29 |
+
|
| 30 |
+
### Architecture
|
| 31 |
+
|
| 32 |
+
| Component | Details |
|
| 33 |
+
|-----------|---------|
|
| 34 |
+
| Backbone | ESM-2 (esm2_t6_8M_UR50D, 8M parameters, frozen) |
|
| 35 |
+
| Embedding dim | 320 |
|
| 36 |
+
| MLP Head | Linear(320→256) → ReLU → Dropout(0.1) → Linear(256→128) → ReLU → Dropout(0.1) → Linear(128→1) |
|
| 37 |
+
| Output | Normalized stability score |
|
| 38 |
+
|
| 39 |
+
### Training Details
|
| 40 |
+
|
| 41 |
+
| Property | Value |
|
| 42 |
+
|----------|-------|
|
| 43 |
+
| Dataset | FLIP Meltome benchmark |
|
| 44 |
+
| Validation R² | 0.616 |
|
| 45 |
+
| Epochs | 16 (early stopped from 30) |
|
| 46 |
+
| Learning rate | 1e-3 |
|
| 47 |
+
| Batch size | 8 |
|
| 48 |
+
| Dropout | 0.1 |
|
| 49 |
+
|
| 50 |
+
## Quick Start
|
| 51 |
+
|
| 52 |
+
### Requirements
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
pip install torch fair-esm huggingface_hub
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Usage
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
import torch
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
|
| 64 |
+
# Download model checkpoint
|
| 65 |
+
checkpoint_path = hf_hub_download(
|
| 66 |
+
repo_id="littleworth/peptide-stability-predictor",
|
| 67 |
+
filename="stability_predictor.pt"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Load checkpoint
|
| 71 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 72 |
+
|
| 73 |
+
# Download model class
|
| 74 |
+
model_file = hf_hub_download(
|
| 75 |
+
repo_id="littleworth/peptide-stability-predictor",
|
| 76 |
+
filename="stability_predictor.py"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Import model class
|
| 80 |
+
import importlib.util
|
| 81 |
+
spec = importlib.util.spec_from_file_location("stability_predictor", model_file)
|
| 82 |
+
sp_module = importlib.util.module_from_spec(spec)
|
| 83 |
+
spec.loader.exec_module(sp_module)
|
| 84 |
+
StabilityPredictor = sp_module.StabilityPredictor
|
| 85 |
+
|
| 86 |
+
# Initialize model (this will download ESM-2 on first run)
|
| 87 |
+
model = StabilityPredictor(esm_model="esm2_t6_8M_UR50D")
|
| 88 |
+
|
| 89 |
+
# Load trained weights (only the MLP head, ESM-2 is frozen)
|
| 90 |
+
# Filter to only load head weights
|
| 91 |
+
head_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items()
|
| 92 |
+
if k.startswith('head.')}
|
| 93 |
+
model.head.load_state_dict({k.replace('head.', ''): v for k, v in head_state_dict.items()})
|
| 94 |
+
model.eval()
|
| 95 |
+
|
| 96 |
+
# Predict stability
|
| 97 |
+
sequences = [
|
| 98 |
+
"MKTLYFLGASV",
|
| 99 |
+
"AEITVKLSPGMNCF",
|
| 100 |
+
"GFLWKASTDERIPMNCVYH",
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 104 |
+
model = model.to(device)
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
scores = model(sequences)
|
| 108 |
+
|
| 109 |
+
print("Stability predictions:")
|
| 110 |
+
for seq, score in zip(sequences, scores.tolist()):
|
| 111 |
+
print(f" {seq}: {score:.4f}")
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
### Alternative: Using predict() method
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
# Using the convenience method (returns Python list)
|
| 118 |
+
scores = model.predict(sequences)
|
| 119 |
+
print(scores) # [0.7234, 0.6521, 0.5892]
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Example Output
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
Stability predictions:
|
| 126 |
+
MKTLYFLGASV: 0.7234
|
| 127 |
+
AEITVKLSPGMNCF: 0.6521
|
| 128 |
+
GFLWKASTDERIPMNCVYH: 0.5892
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Files in This Repository
|
| 132 |
+
|
| 133 |
+
| File | Description |
|
| 134 |
+
|------|-------------|
|
| 135 |
+
| `stability_predictor.pt` | Model checkpoint (MLP head weights) |
|
| 136 |
+
| `stability_predictor.py` | Model architecture definition |
|
| 137 |
+
| `config.json` | Model configuration |
|
| 138 |
+
|
| 139 |
+
## Checkpoint Contents
|
| 140 |
+
|
| 141 |
+
```python
|
| 142 |
+
{
|
| 143 |
+
'epoch': 16,
|
| 144 |
+
'model_state_dict': {...}, # MLP head weights
|
| 145 |
+
'optimizer_state_dict': {...},
|
| 146 |
+
'val_r2': 0.616,
|
| 147 |
+
'config': {
|
| 148 |
+
'esm_model': 'esm2_t6_8M_UR50D',
|
| 149 |
+
'hidden_dims': [256, 128],
|
| 150 |
+
'dropout': 0.1
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
## Intended Use
|
| 156 |
+
|
| 157 |
+
- **Primary use**: Scoring peptide/protein stability for drug discovery
|
| 158 |
+
- **Secondary uses**:
|
| 159 |
+
- Filtering generated peptide candidates
|
| 160 |
+
- Research on protein thermostability
|
| 161 |
+
- Feature engineering for downstream ML models
|
| 162 |
+
|
| 163 |
+
## Limitations
|
| 164 |
+
|
| 165 |
+
- Trained on FLIP Meltome data which may not generalize to all protein families
|
| 166 |
+
- Outputs normalized scores, not absolute melting temperatures
|
| 167 |
+
- Predictions are computational estimates requiring experimental validation
|
| 168 |
+
- Best accuracy for sequences similar to training distribution
|
| 169 |
+
|
| 170 |
+
## Performance
|
| 171 |
+
|
| 172 |
+
| Metric | Value |
|
| 173 |
+
|--------|-------|
|
| 174 |
+
| Validation R² | 0.616 |
|
| 175 |
+
| Training epochs | 16 |
|
| 176 |
+
| Early stopping patience | 15 |
|
| 177 |
+
|
| 178 |
+
## Dependencies
|
| 179 |
+
|
| 180 |
+
- PyTorch >= 2.0
|
| 181 |
+
- fair-esm (Facebook's ESM library)
|
| 182 |
+
- huggingface_hub
|
| 183 |
+
|
| 184 |
+
## Ethical Considerations
|
| 185 |
+
|
| 186 |
+
This model provides computational predictions of protein stability. Predictions should be validated experimentally before making decisions about therapeutic development. The model does not guarantee accuracy for sequences outside its training distribution.
|
| 187 |
+
|
| 188 |
+
## Training Data
|
| 189 |
+
|
| 190 |
+
- **FLIP Meltome benchmark**: A dataset of protein sequences with measured thermal stability values
|
| 191 |
+
- Training/validation split following FLIP benchmark protocols
|
| 192 |
+
|
| 193 |
+
## Citation
|
| 194 |
+
|
| 195 |
+
```bibtex
|
| 196 |
+
@software{peptide_stability_2025,
|
| 197 |
+
author = {Wijaya, Edward},
|
| 198 |
+
title = {Peptide Stability Predictor},
|
| 199 |
+
year = {2025},
|
| 200 |
+
url = {https://huggingface.co/littleworth/peptide-stability-predictor},
|
| 201 |
+
note = {ESM-2 based thermal stability prediction}
|
| 202 |
+
}
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## References
|
| 206 |
+
|
| 207 |
+
- [FLIP Benchmark](https://github.com/J-SNACKKB/FLIP) - Dallago et al., 2021
|
| 208 |
+
- [ESM-2](https://github.com/facebookresearch/esm) - Lin et al., 2022
|
| 209 |
+
- [ESM-2 Paper](https://www.science.org/doi/10.1126/science.ade2574) - Lin et al., Science 2023
|
| 210 |
+
|
| 211 |
+
## License
|
| 212 |
+
|
| 213 |
+
MIT License
|
config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "stability_predictor",
|
| 3 |
+
"architecture": "esm2_mlp",
|
| 4 |
+
"esm_model": "esm2_t6_8M_UR50D",
|
| 5 |
+
"esm_params": 8000000,
|
| 6 |
+
"embed_dim": 320,
|
| 7 |
+
"repr_layer": 6,
|
| 8 |
+
"freeze_esm": true,
|
| 9 |
+
"head": {
|
| 10 |
+
"hidden_dims": [256, 128],
|
| 11 |
+
"dropout": 0.1,
|
| 12 |
+
"activation": "relu"
|
| 13 |
+
},
|
| 14 |
+
"training": {
|
| 15 |
+
"dataset": "FLIP_meltome",
|
| 16 |
+
"epochs": 16,
|
| 17 |
+
"learning_rate": 0.001,
|
| 18 |
+
"batch_size": 8,
|
| 19 |
+
"early_stopping_patience": 15,
|
| 20 |
+
"validation_r2": 0.616
|
| 21 |
+
},
|
| 22 |
+
"output": {
|
| 23 |
+
"type": "regression",
|
| 24 |
+
"description": "Normalized thermal stability score"
|
| 25 |
+
}
|
| 26 |
+
}
|
stability_predictor.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:67d1eb517578cf141bd113b1491f61176ff0beb63181d5879f2490d220c534d4
|
| 3 |
+
size 31483165
|
stability_predictor.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ESM-2 based stability predictor for peptide/protein sequences.
|
| 2 |
+
|
| 3 |
+
This module implements a stability predictor using ESM-2 embeddings as input
|
| 4 |
+
to an MLP regression head. The model predicts thermal stability (melting
|
| 5 |
+
temperature) based on sequence information.
|
| 6 |
+
|
| 7 |
+
Architecture:
|
| 8 |
+
Input: Peptide/protein sequence
|
| 9 |
+
↓
|
| 10 |
+
ESM-2 (frozen): Extract mean-pooled embeddings
|
| 11 |
+
↓
|
| 12 |
+
MLP: embedding_dim → hidden_dims → 1
|
| 13 |
+
↓
|
| 14 |
+
Output: Stability score (normalized)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from typing import List, Optional, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class StabilityPredictor(nn.Module):
|
| 27 |
+
"""ESM-2 based stability predictor.
|
| 28 |
+
|
| 29 |
+
Uses frozen ESM-2 embeddings as input to an MLP head for predicting
|
| 30 |
+
thermal stability. The model is designed to be trained on datasets
|
| 31 |
+
like FLIP stability (meltome) task.
|
| 32 |
+
|
| 33 |
+
Attributes:
|
| 34 |
+
esm: ESM-2 language model (frozen)
|
| 35 |
+
alphabet: ESM-2 tokenizer
|
| 36 |
+
head: MLP regression head
|
| 37 |
+
embed_dim: Dimension of ESM-2 embeddings
|
| 38 |
+
repr_layer: Which layer to extract representations from
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
esm_model: str = "esm2_t6_8M_UR50D",
|
| 44 |
+
hidden_dims: Optional[List[int]] = None,
|
| 45 |
+
dropout: float = 0.1,
|
| 46 |
+
freeze_esm: bool = True,
|
| 47 |
+
device: Optional[str] = None,
|
| 48 |
+
):
|
| 49 |
+
"""Initialize stability predictor.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
esm_model: Name of ESM-2 model to use. Options:
|
| 53 |
+
- esm2_t6_8M_UR50D (8M params, 320 dim, fastest)
|
| 54 |
+
- esm2_t12_35M_UR50D (35M params, 480 dim)
|
| 55 |
+
- esm2_t33_650M_UR50D (650M params, 1280 dim, most accurate)
|
| 56 |
+
hidden_dims: Hidden layer dimensions for MLP head.
|
| 57 |
+
Default: [256, 128]
|
| 58 |
+
dropout: Dropout rate for MLP layers
|
| 59 |
+
freeze_esm: Whether to freeze ESM-2 parameters
|
| 60 |
+
device: Device to load model on. Auto-detected if None.
|
| 61 |
+
"""
|
| 62 |
+
super().__init__()
|
| 63 |
+
|
| 64 |
+
if hidden_dims is None:
|
| 65 |
+
hidden_dims = [256, 128]
|
| 66 |
+
|
| 67 |
+
self.esm_model_name = esm_model
|
| 68 |
+
self.freeze_esm = freeze_esm
|
| 69 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 70 |
+
|
| 71 |
+
# Load ESM-2
|
| 72 |
+
self._load_esm(esm_model)
|
| 73 |
+
|
| 74 |
+
if freeze_esm:
|
| 75 |
+
for param in self.esm.parameters():
|
| 76 |
+
param.requires_grad = False
|
| 77 |
+
self.esm.eval()
|
| 78 |
+
|
| 79 |
+
# Build MLP head
|
| 80 |
+
layers = []
|
| 81 |
+
in_dim = self.embed_dim
|
| 82 |
+
for h_dim in hidden_dims:
|
| 83 |
+
layers.extend([
|
| 84 |
+
nn.Linear(in_dim, h_dim),
|
| 85 |
+
nn.ReLU(),
|
| 86 |
+
nn.Dropout(dropout),
|
| 87 |
+
])
|
| 88 |
+
in_dim = h_dim
|
| 89 |
+
layers.append(nn.Linear(in_dim, 1))
|
| 90 |
+
|
| 91 |
+
self.head = nn.Sequential(*layers)
|
| 92 |
+
|
| 93 |
+
logger.info(f"StabilityPredictor initialized with {esm_model}, "
|
| 94 |
+
f"hidden_dims={hidden_dims}, freeze_esm={freeze_esm}")
|
| 95 |
+
|
| 96 |
+
def _load_esm(self, esm_model: str):
|
| 97 |
+
"""Load ESM-2 model and set embedding dimensions."""
|
| 98 |
+
import esm
|
| 99 |
+
|
| 100 |
+
logger.info(f"Loading ESM-2 model: {esm_model}")
|
| 101 |
+
|
| 102 |
+
if esm_model == "esm2_t6_8M_UR50D":
|
| 103 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t6_8M_UR50D()
|
| 104 |
+
self.embed_dim = 320
|
| 105 |
+
self.repr_layer = 6
|
| 106 |
+
elif esm_model == "esm2_t12_35M_UR50D":
|
| 107 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t12_35M_UR50D()
|
| 108 |
+
self.embed_dim = 480
|
| 109 |
+
self.repr_layer = 12
|
| 110 |
+
elif esm_model == "esm2_t33_650M_UR50D":
|
| 111 |
+
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
| 112 |
+
self.embed_dim = 1280
|
| 113 |
+
self.repr_layer = 33
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Unknown ESM model: {esm_model}")
|
| 116 |
+
|
| 117 |
+
self.batch_converter = self.alphabet.get_batch_converter()
|
| 118 |
+
|
| 119 |
+
def get_embeddings(self, sequences: List[str]) -> torch.Tensor:
|
| 120 |
+
"""Extract ESM-2 embeddings for sequences.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
sequences: List of amino acid sequences
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Tensor of shape (batch_size, embed_dim) with mean-pooled embeddings
|
| 127 |
+
"""
|
| 128 |
+
# Prepare data for ESM
|
| 129 |
+
data = [(f"seq{i}", seq) for i, seq in enumerate(sequences)]
|
| 130 |
+
_, _, batch_tokens = self.batch_converter(data)
|
| 131 |
+
batch_tokens = batch_tokens.to(next(self.esm.parameters()).device)
|
| 132 |
+
|
| 133 |
+
# Forward pass through ESM-2
|
| 134 |
+
with torch.no_grad() if self.freeze_esm else torch.enable_grad():
|
| 135 |
+
results = self.esm(
|
| 136 |
+
batch_tokens,
|
| 137 |
+
repr_layers=[self.repr_layer],
|
| 138 |
+
return_contacts=False
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Mean pool over sequence positions (excluding BOS and EOS tokens)
|
| 142 |
+
embeddings = []
|
| 143 |
+
for i, seq in enumerate(sequences):
|
| 144 |
+
seq_len = len(seq)
|
| 145 |
+
# Tokens are: [BOS, seq..., EOS, PAD...]
|
| 146 |
+
# We want indices 1 to seq_len+1 (exclusive of EOS)
|
| 147 |
+
emb = results["representations"][self.repr_layer][i, 1:seq_len+1, :]
|
| 148 |
+
embeddings.append(emb.mean(dim=0))
|
| 149 |
+
|
| 150 |
+
return torch.stack(embeddings)
|
| 151 |
+
|
| 152 |
+
def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor:
|
| 153 |
+
"""Predict stability for sequences.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
sequences: Single sequence or list of sequences
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Tensor of shape (batch_size,) with stability predictions
|
| 160 |
+
"""
|
| 161 |
+
if isinstance(sequences, str):
|
| 162 |
+
sequences = [sequences]
|
| 163 |
+
|
| 164 |
+
embeddings = self.get_embeddings(sequences)
|
| 165 |
+
predictions = self.head(embeddings).squeeze(-1)
|
| 166 |
+
|
| 167 |
+
return predictions
|
| 168 |
+
|
| 169 |
+
def predict(self, sequences: Union[str, List[str]]) -> List[float]:
|
| 170 |
+
"""Predict stability scores (convenience method).
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
sequences: Single sequence or list of sequences
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
List of stability scores
|
| 177 |
+
"""
|
| 178 |
+
self.eval()
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
preds = self.forward(sequences)
|
| 181 |
+
return preds.cpu().tolist()
|
| 182 |
+
|
| 183 |
+
def to(self, device: Union[str, torch.device]) -> 'StabilityPredictor':
|
| 184 |
+
"""Move model to device."""
|
| 185 |
+
self.device = str(device)
|
| 186 |
+
self.esm = self.esm.to(device)
|
| 187 |
+
self.head = self.head.to(device)
|
| 188 |
+
return super().to(device)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class BindingPredictor(StabilityPredictor):
|
| 192 |
+
"""ESM-2 based binding predictor.
|
| 193 |
+
|
| 194 |
+
Same architecture as StabilityPredictor but intended for binding
|
| 195 |
+
affinity prediction. Currently only supports binary classification
|
| 196 |
+
(binder vs non-binder) due to Propedia dataset limitations.
|
| 197 |
+
|
| 198 |
+
For regression tasks, additional data with continuous binding affinities
|
| 199 |
+
(e.g., from PDBbind) would be needed.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
esm_model: str = "esm2_t6_8M_UR50D",
|
| 205 |
+
hidden_dims: Optional[List[int]] = None,
|
| 206 |
+
dropout: float = 0.1,
|
| 207 |
+
freeze_esm: bool = True,
|
| 208 |
+
device: Optional[str] = None,
|
| 209 |
+
use_sigmoid: bool = True,
|
| 210 |
+
):
|
| 211 |
+
"""Initialize binding predictor.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
esm_model: Name of ESM-2 model to use
|
| 215 |
+
hidden_dims: Hidden layer dimensions for MLP head
|
| 216 |
+
dropout: Dropout rate
|
| 217 |
+
freeze_esm: Whether to freeze ESM-2
|
| 218 |
+
device: Device to load model on
|
| 219 |
+
use_sigmoid: Whether to apply sigmoid for binary classification
|
| 220 |
+
"""
|
| 221 |
+
super().__init__(
|
| 222 |
+
esm_model=esm_model,
|
| 223 |
+
hidden_dims=hidden_dims,
|
| 224 |
+
dropout=dropout,
|
| 225 |
+
freeze_esm=freeze_esm,
|
| 226 |
+
device=device,
|
| 227 |
+
)
|
| 228 |
+
self.use_sigmoid = use_sigmoid
|
| 229 |
+
logger.info(f"BindingPredictor initialized, use_sigmoid={use_sigmoid}")
|
| 230 |
+
|
| 231 |
+
def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor:
|
| 232 |
+
"""Predict binding score for sequences.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
sequences: Single sequence or list of sequences
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Tensor of shape (batch_size,) with binding predictions.
|
| 239 |
+
If use_sigmoid=True, values are in [0, 1].
|
| 240 |
+
"""
|
| 241 |
+
preds = super().forward(sequences)
|
| 242 |
+
if self.use_sigmoid:
|
| 243 |
+
preds = torch.sigmoid(preds)
|
| 244 |
+
return preds
|