littleworth commited on
Commit
f17375e
·
verified ·
1 Parent(s): 247b6ba

Initial upload: ESM-2 based stability predictor

Browse files
Files changed (4) hide show
  1. README.md +213 -0
  2. config.json +26 -0
  3. stability_predictor.pt +3 -0
  4. stability_predictor.py +244 -0
README.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - biology
5
+ - peptide
6
+ - protein
7
+ - stability
8
+ - esm2
9
+ - thermostability
10
+ - drug-discovery
11
+ - pytorch
12
+ language:
13
+ - en
14
+ library_name: pytorch
15
+ pipeline_tag: text-classification
16
+ datasets:
17
+ - FLIP
18
+ metrics:
19
+ - r2
20
+ ---
21
+
22
+ # Peptide Stability Predictor
23
+
24
+ Predict thermal stability of peptide/protein sequences using ESM-2 embeddings.
25
+
26
+ ## Model Description
27
+
28
+ This model predicts the thermal stability (melting temperature proxy) of peptide and protein sequences using frozen ESM-2 embeddings passed through a trained MLP regression head. It was trained on the FLIP Meltome benchmark dataset.
29
+
30
+ ### Architecture
31
+
32
+ | Component | Details |
33
+ |-----------|---------|
34
+ | Backbone | ESM-2 (esm2_t6_8M_UR50D, 8M parameters, frozen) |
35
+ | Embedding dim | 320 |
36
+ | MLP Head | Linear(320→256) → ReLU → Dropout(0.1) → Linear(256→128) → ReLU → Dropout(0.1) → Linear(128→1) |
37
+ | Output | Normalized stability score |
38
+
39
+ ### Training Details
40
+
41
+ | Property | Value |
42
+ |----------|-------|
43
+ | Dataset | FLIP Meltome benchmark |
44
+ | Validation R² | 0.616 |
45
+ | Epochs | 16 (early stopped from 30) |
46
+ | Learning rate | 1e-3 |
47
+ | Batch size | 8 |
48
+ | Dropout | 0.1 |
49
+
50
+ ## Quick Start
51
+
52
+ ### Requirements
53
+
54
+ ```bash
55
+ pip install torch fair-esm huggingface_hub
56
+ ```
57
+
58
+ ### Usage
59
+
60
+ ```python
61
+ import torch
62
+ from huggingface_hub import hf_hub_download
63
+
64
+ # Download model checkpoint
65
+ checkpoint_path = hf_hub_download(
66
+ repo_id="littleworth/peptide-stability-predictor",
67
+ filename="stability_predictor.pt"
68
+ )
69
+
70
+ # Load checkpoint
71
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
72
+
73
+ # Download model class
74
+ model_file = hf_hub_download(
75
+ repo_id="littleworth/peptide-stability-predictor",
76
+ filename="stability_predictor.py"
77
+ )
78
+
79
+ # Import model class
80
+ import importlib.util
81
+ spec = importlib.util.spec_from_file_location("stability_predictor", model_file)
82
+ sp_module = importlib.util.module_from_spec(spec)
83
+ spec.loader.exec_module(sp_module)
84
+ StabilityPredictor = sp_module.StabilityPredictor
85
+
86
+ # Initialize model (this will download ESM-2 on first run)
87
+ model = StabilityPredictor(esm_model="esm2_t6_8M_UR50D")
88
+
89
+ # Load trained weights (only the MLP head, ESM-2 is frozen)
90
+ # Filter to only load head weights
91
+ head_state_dict = {k: v for k, v in checkpoint['model_state_dict'].items()
92
+ if k.startswith('head.')}
93
+ model.head.load_state_dict({k.replace('head.', ''): v for k, v in head_state_dict.items()})
94
+ model.eval()
95
+
96
+ # Predict stability
97
+ sequences = [
98
+ "MKTLYFLGASV",
99
+ "AEITVKLSPGMNCF",
100
+ "GFLWKASTDERIPMNCVYH",
101
+ ]
102
+
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ model = model.to(device)
105
+
106
+ with torch.no_grad():
107
+ scores = model(sequences)
108
+
109
+ print("Stability predictions:")
110
+ for seq, score in zip(sequences, scores.tolist()):
111
+ print(f" {seq}: {score:.4f}")
112
+ ```
113
+
114
+ ### Alternative: Using predict() method
115
+
116
+ ```python
117
+ # Using the convenience method (returns Python list)
118
+ scores = model.predict(sequences)
119
+ print(scores) # [0.7234, 0.6521, 0.5892]
120
+ ```
121
+
122
+ ## Example Output
123
+
124
+ ```
125
+ Stability predictions:
126
+ MKTLYFLGASV: 0.7234
127
+ AEITVKLSPGMNCF: 0.6521
128
+ GFLWKASTDERIPMNCVYH: 0.5892
129
+ ```
130
+
131
+ ## Files in This Repository
132
+
133
+ | File | Description |
134
+ |------|-------------|
135
+ | `stability_predictor.pt` | Model checkpoint (MLP head weights) |
136
+ | `stability_predictor.py` | Model architecture definition |
137
+ | `config.json` | Model configuration |
138
+
139
+ ## Checkpoint Contents
140
+
141
+ ```python
142
+ {
143
+ 'epoch': 16,
144
+ 'model_state_dict': {...}, # MLP head weights
145
+ 'optimizer_state_dict': {...},
146
+ 'val_r2': 0.616,
147
+ 'config': {
148
+ 'esm_model': 'esm2_t6_8M_UR50D',
149
+ 'hidden_dims': [256, 128],
150
+ 'dropout': 0.1
151
+ }
152
+ }
153
+ ```
154
+
155
+ ## Intended Use
156
+
157
+ - **Primary use**: Scoring peptide/protein stability for drug discovery
158
+ - **Secondary uses**:
159
+ - Filtering generated peptide candidates
160
+ - Research on protein thermostability
161
+ - Feature engineering for downstream ML models
162
+
163
+ ## Limitations
164
+
165
+ - Trained on FLIP Meltome data which may not generalize to all protein families
166
+ - Outputs normalized scores, not absolute melting temperatures
167
+ - Predictions are computational estimates requiring experimental validation
168
+ - Best accuracy for sequences similar to training distribution
169
+
170
+ ## Performance
171
+
172
+ | Metric | Value |
173
+ |--------|-------|
174
+ | Validation R² | 0.616 |
175
+ | Training epochs | 16 |
176
+ | Early stopping patience | 15 |
177
+
178
+ ## Dependencies
179
+
180
+ - PyTorch >= 2.0
181
+ - fair-esm (Facebook's ESM library)
182
+ - huggingface_hub
183
+
184
+ ## Ethical Considerations
185
+
186
+ This model provides computational predictions of protein stability. Predictions should be validated experimentally before making decisions about therapeutic development. The model does not guarantee accuracy for sequences outside its training distribution.
187
+
188
+ ## Training Data
189
+
190
+ - **FLIP Meltome benchmark**: A dataset of protein sequences with measured thermal stability values
191
+ - Training/validation split following FLIP benchmark protocols
192
+
193
+ ## Citation
194
+
195
+ ```bibtex
196
+ @software{peptide_stability_2025,
197
+ author = {Wijaya, Edward},
198
+ title = {Peptide Stability Predictor},
199
+ year = {2025},
200
+ url = {https://huggingface.co/littleworth/peptide-stability-predictor},
201
+ note = {ESM-2 based thermal stability prediction}
202
+ }
203
+ ```
204
+
205
+ ## References
206
+
207
+ - [FLIP Benchmark](https://github.com/J-SNACKKB/FLIP) - Dallago et al., 2021
208
+ - [ESM-2](https://github.com/facebookresearch/esm) - Lin et al., 2022
209
+ - [ESM-2 Paper](https://www.science.org/doi/10.1126/science.ade2574) - Lin et al., Science 2023
210
+
211
+ ## License
212
+
213
+ MIT License
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "stability_predictor",
3
+ "architecture": "esm2_mlp",
4
+ "esm_model": "esm2_t6_8M_UR50D",
5
+ "esm_params": 8000000,
6
+ "embed_dim": 320,
7
+ "repr_layer": 6,
8
+ "freeze_esm": true,
9
+ "head": {
10
+ "hidden_dims": [256, 128],
11
+ "dropout": 0.1,
12
+ "activation": "relu"
13
+ },
14
+ "training": {
15
+ "dataset": "FLIP_meltome",
16
+ "epochs": 16,
17
+ "learning_rate": 0.001,
18
+ "batch_size": 8,
19
+ "early_stopping_patience": 15,
20
+ "validation_r2": 0.616
21
+ },
22
+ "output": {
23
+ "type": "regression",
24
+ "description": "Normalized thermal stability score"
25
+ }
26
+ }
stability_predictor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67d1eb517578cf141bd113b1491f61176ff0beb63181d5879f2490d220c534d4
3
+ size 31483165
stability_predictor.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ESM-2 based stability predictor for peptide/protein sequences.
2
+
3
+ This module implements a stability predictor using ESM-2 embeddings as input
4
+ to an MLP regression head. The model predicts thermal stability (melting
5
+ temperature) based on sequence information.
6
+
7
+ Architecture:
8
+ Input: Peptide/protein sequence
9
+
10
+ ESM-2 (frozen): Extract mean-pooled embeddings
11
+
12
+ MLP: embedding_dim → hidden_dims → 1
13
+
14
+ Output: Stability score (normalized)
15
+ """
16
+
17
+ import logging
18
+ from typing import List, Optional, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class StabilityPredictor(nn.Module):
27
+ """ESM-2 based stability predictor.
28
+
29
+ Uses frozen ESM-2 embeddings as input to an MLP head for predicting
30
+ thermal stability. The model is designed to be trained on datasets
31
+ like FLIP stability (meltome) task.
32
+
33
+ Attributes:
34
+ esm: ESM-2 language model (frozen)
35
+ alphabet: ESM-2 tokenizer
36
+ head: MLP regression head
37
+ embed_dim: Dimension of ESM-2 embeddings
38
+ repr_layer: Which layer to extract representations from
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ esm_model: str = "esm2_t6_8M_UR50D",
44
+ hidden_dims: Optional[List[int]] = None,
45
+ dropout: float = 0.1,
46
+ freeze_esm: bool = True,
47
+ device: Optional[str] = None,
48
+ ):
49
+ """Initialize stability predictor.
50
+
51
+ Args:
52
+ esm_model: Name of ESM-2 model to use. Options:
53
+ - esm2_t6_8M_UR50D (8M params, 320 dim, fastest)
54
+ - esm2_t12_35M_UR50D (35M params, 480 dim)
55
+ - esm2_t33_650M_UR50D (650M params, 1280 dim, most accurate)
56
+ hidden_dims: Hidden layer dimensions for MLP head.
57
+ Default: [256, 128]
58
+ dropout: Dropout rate for MLP layers
59
+ freeze_esm: Whether to freeze ESM-2 parameters
60
+ device: Device to load model on. Auto-detected if None.
61
+ """
62
+ super().__init__()
63
+
64
+ if hidden_dims is None:
65
+ hidden_dims = [256, 128]
66
+
67
+ self.esm_model_name = esm_model
68
+ self.freeze_esm = freeze_esm
69
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
70
+
71
+ # Load ESM-2
72
+ self._load_esm(esm_model)
73
+
74
+ if freeze_esm:
75
+ for param in self.esm.parameters():
76
+ param.requires_grad = False
77
+ self.esm.eval()
78
+
79
+ # Build MLP head
80
+ layers = []
81
+ in_dim = self.embed_dim
82
+ for h_dim in hidden_dims:
83
+ layers.extend([
84
+ nn.Linear(in_dim, h_dim),
85
+ nn.ReLU(),
86
+ nn.Dropout(dropout),
87
+ ])
88
+ in_dim = h_dim
89
+ layers.append(nn.Linear(in_dim, 1))
90
+
91
+ self.head = nn.Sequential(*layers)
92
+
93
+ logger.info(f"StabilityPredictor initialized with {esm_model}, "
94
+ f"hidden_dims={hidden_dims}, freeze_esm={freeze_esm}")
95
+
96
+ def _load_esm(self, esm_model: str):
97
+ """Load ESM-2 model and set embedding dimensions."""
98
+ import esm
99
+
100
+ logger.info(f"Loading ESM-2 model: {esm_model}")
101
+
102
+ if esm_model == "esm2_t6_8M_UR50D":
103
+ self.esm, self.alphabet = esm.pretrained.esm2_t6_8M_UR50D()
104
+ self.embed_dim = 320
105
+ self.repr_layer = 6
106
+ elif esm_model == "esm2_t12_35M_UR50D":
107
+ self.esm, self.alphabet = esm.pretrained.esm2_t12_35M_UR50D()
108
+ self.embed_dim = 480
109
+ self.repr_layer = 12
110
+ elif esm_model == "esm2_t33_650M_UR50D":
111
+ self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
112
+ self.embed_dim = 1280
113
+ self.repr_layer = 33
114
+ else:
115
+ raise ValueError(f"Unknown ESM model: {esm_model}")
116
+
117
+ self.batch_converter = self.alphabet.get_batch_converter()
118
+
119
+ def get_embeddings(self, sequences: List[str]) -> torch.Tensor:
120
+ """Extract ESM-2 embeddings for sequences.
121
+
122
+ Args:
123
+ sequences: List of amino acid sequences
124
+
125
+ Returns:
126
+ Tensor of shape (batch_size, embed_dim) with mean-pooled embeddings
127
+ """
128
+ # Prepare data for ESM
129
+ data = [(f"seq{i}", seq) for i, seq in enumerate(sequences)]
130
+ _, _, batch_tokens = self.batch_converter(data)
131
+ batch_tokens = batch_tokens.to(next(self.esm.parameters()).device)
132
+
133
+ # Forward pass through ESM-2
134
+ with torch.no_grad() if self.freeze_esm else torch.enable_grad():
135
+ results = self.esm(
136
+ batch_tokens,
137
+ repr_layers=[self.repr_layer],
138
+ return_contacts=False
139
+ )
140
+
141
+ # Mean pool over sequence positions (excluding BOS and EOS tokens)
142
+ embeddings = []
143
+ for i, seq in enumerate(sequences):
144
+ seq_len = len(seq)
145
+ # Tokens are: [BOS, seq..., EOS, PAD...]
146
+ # We want indices 1 to seq_len+1 (exclusive of EOS)
147
+ emb = results["representations"][self.repr_layer][i, 1:seq_len+1, :]
148
+ embeddings.append(emb.mean(dim=0))
149
+
150
+ return torch.stack(embeddings)
151
+
152
+ def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor:
153
+ """Predict stability for sequences.
154
+
155
+ Args:
156
+ sequences: Single sequence or list of sequences
157
+
158
+ Returns:
159
+ Tensor of shape (batch_size,) with stability predictions
160
+ """
161
+ if isinstance(sequences, str):
162
+ sequences = [sequences]
163
+
164
+ embeddings = self.get_embeddings(sequences)
165
+ predictions = self.head(embeddings).squeeze(-1)
166
+
167
+ return predictions
168
+
169
+ def predict(self, sequences: Union[str, List[str]]) -> List[float]:
170
+ """Predict stability scores (convenience method).
171
+
172
+ Args:
173
+ sequences: Single sequence or list of sequences
174
+
175
+ Returns:
176
+ List of stability scores
177
+ """
178
+ self.eval()
179
+ with torch.no_grad():
180
+ preds = self.forward(sequences)
181
+ return preds.cpu().tolist()
182
+
183
+ def to(self, device: Union[str, torch.device]) -> 'StabilityPredictor':
184
+ """Move model to device."""
185
+ self.device = str(device)
186
+ self.esm = self.esm.to(device)
187
+ self.head = self.head.to(device)
188
+ return super().to(device)
189
+
190
+
191
+ class BindingPredictor(StabilityPredictor):
192
+ """ESM-2 based binding predictor.
193
+
194
+ Same architecture as StabilityPredictor but intended for binding
195
+ affinity prediction. Currently only supports binary classification
196
+ (binder vs non-binder) due to Propedia dataset limitations.
197
+
198
+ For regression tasks, additional data with continuous binding affinities
199
+ (e.g., from PDBbind) would be needed.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ esm_model: str = "esm2_t6_8M_UR50D",
205
+ hidden_dims: Optional[List[int]] = None,
206
+ dropout: float = 0.1,
207
+ freeze_esm: bool = True,
208
+ device: Optional[str] = None,
209
+ use_sigmoid: bool = True,
210
+ ):
211
+ """Initialize binding predictor.
212
+
213
+ Args:
214
+ esm_model: Name of ESM-2 model to use
215
+ hidden_dims: Hidden layer dimensions for MLP head
216
+ dropout: Dropout rate
217
+ freeze_esm: Whether to freeze ESM-2
218
+ device: Device to load model on
219
+ use_sigmoid: Whether to apply sigmoid for binary classification
220
+ """
221
+ super().__init__(
222
+ esm_model=esm_model,
223
+ hidden_dims=hidden_dims,
224
+ dropout=dropout,
225
+ freeze_esm=freeze_esm,
226
+ device=device,
227
+ )
228
+ self.use_sigmoid = use_sigmoid
229
+ logger.info(f"BindingPredictor initialized, use_sigmoid={use_sigmoid}")
230
+
231
+ def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor:
232
+ """Predict binding score for sequences.
233
+
234
+ Args:
235
+ sequences: Single sequence or list of sequences
236
+
237
+ Returns:
238
+ Tensor of shape (batch_size,) with binding predictions.
239
+ If use_sigmoid=True, values are in [0, 1].
240
+ """
241
+ preds = super().forward(sequences)
242
+ if self.use_sigmoid:
243
+ preds = torch.sigmoid(preds)
244
+ return preds