|
|
"""ESM-2 based stability predictor for peptide/protein sequences. |
|
|
|
|
|
This module implements a stability predictor using ESM-2 embeddings as input |
|
|
to an MLP regression head. The model predicts thermal stability (melting |
|
|
temperature) based on sequence information. |
|
|
|
|
|
Architecture: |
|
|
Input: Peptide/protein sequence |
|
|
↓ |
|
|
ESM-2 (frozen): Extract mean-pooled embeddings |
|
|
↓ |
|
|
MLP: embedding_dim → hidden_dims → 1 |
|
|
↓ |
|
|
Output: Stability score (normalized) |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class StabilityPredictor(nn.Module): |
|
|
"""ESM-2 based stability predictor. |
|
|
|
|
|
Uses frozen ESM-2 embeddings as input to an MLP head for predicting |
|
|
thermal stability. The model is designed to be trained on datasets |
|
|
like FLIP stability (meltome) task. |
|
|
|
|
|
Attributes: |
|
|
esm: ESM-2 language model (frozen) |
|
|
alphabet: ESM-2 tokenizer |
|
|
head: MLP regression head |
|
|
embed_dim: Dimension of ESM-2 embeddings |
|
|
repr_layer: Which layer to extract representations from |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
esm_model: str = "esm2_t6_8M_UR50D", |
|
|
hidden_dims: Optional[List[int]] = None, |
|
|
dropout: float = 0.1, |
|
|
freeze_esm: bool = True, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
"""Initialize stability predictor. |
|
|
|
|
|
Args: |
|
|
esm_model: Name of ESM-2 model to use. Options: |
|
|
- esm2_t6_8M_UR50D (8M params, 320 dim, fastest) |
|
|
- esm2_t12_35M_UR50D (35M params, 480 dim) |
|
|
- esm2_t33_650M_UR50D (650M params, 1280 dim, most accurate) |
|
|
hidden_dims: Hidden layer dimensions for MLP head. |
|
|
Default: [256, 128] |
|
|
dropout: Dropout rate for MLP layers |
|
|
freeze_esm: Whether to freeze ESM-2 parameters |
|
|
device: Device to load model on. Auto-detected if None. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
if hidden_dims is None: |
|
|
hidden_dims = [256, 128] |
|
|
|
|
|
self.esm_model_name = esm_model |
|
|
self.freeze_esm = freeze_esm |
|
|
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
self._load_esm(esm_model) |
|
|
|
|
|
if freeze_esm: |
|
|
for param in self.esm.parameters(): |
|
|
param.requires_grad = False |
|
|
self.esm.eval() |
|
|
|
|
|
|
|
|
layers = [] |
|
|
in_dim = self.embed_dim |
|
|
for h_dim in hidden_dims: |
|
|
layers.extend([ |
|
|
nn.Linear(in_dim, h_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
]) |
|
|
in_dim = h_dim |
|
|
layers.append(nn.Linear(in_dim, 1)) |
|
|
|
|
|
self.head = nn.Sequential(*layers) |
|
|
|
|
|
logger.info(f"StabilityPredictor initialized with {esm_model}, " |
|
|
f"hidden_dims={hidden_dims}, freeze_esm={freeze_esm}") |
|
|
|
|
|
def _load_esm(self, esm_model: str): |
|
|
"""Load ESM-2 model and set embedding dimensions.""" |
|
|
import esm |
|
|
|
|
|
logger.info(f"Loading ESM-2 model: {esm_model}") |
|
|
|
|
|
if esm_model == "esm2_t6_8M_UR50D": |
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t6_8M_UR50D() |
|
|
self.embed_dim = 320 |
|
|
self.repr_layer = 6 |
|
|
elif esm_model == "esm2_t12_35M_UR50D": |
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t12_35M_UR50D() |
|
|
self.embed_dim = 480 |
|
|
self.repr_layer = 12 |
|
|
elif esm_model == "esm2_t33_650M_UR50D": |
|
|
self.esm, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
|
self.embed_dim = 1280 |
|
|
self.repr_layer = 33 |
|
|
else: |
|
|
raise ValueError(f"Unknown ESM model: {esm_model}") |
|
|
|
|
|
self.batch_converter = self.alphabet.get_batch_converter() |
|
|
|
|
|
def get_embeddings(self, sequences: List[str]) -> torch.Tensor: |
|
|
"""Extract ESM-2 embeddings for sequences. |
|
|
|
|
|
Args: |
|
|
sequences: List of amino acid sequences |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (batch_size, embed_dim) with mean-pooled embeddings |
|
|
""" |
|
|
|
|
|
data = [(f"seq{i}", seq) for i, seq in enumerate(sequences)] |
|
|
_, _, batch_tokens = self.batch_converter(data) |
|
|
batch_tokens = batch_tokens.to(next(self.esm.parameters()).device) |
|
|
|
|
|
|
|
|
with torch.no_grad() if self.freeze_esm else torch.enable_grad(): |
|
|
results = self.esm( |
|
|
batch_tokens, |
|
|
repr_layers=[self.repr_layer], |
|
|
return_contacts=False |
|
|
) |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
for i, seq in enumerate(sequences): |
|
|
seq_len = len(seq) |
|
|
|
|
|
|
|
|
emb = results["representations"][self.repr_layer][i, 1:seq_len+1, :] |
|
|
embeddings.append(emb.mean(dim=0)) |
|
|
|
|
|
return torch.stack(embeddings) |
|
|
|
|
|
def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor: |
|
|
"""Predict stability for sequences. |
|
|
|
|
|
Args: |
|
|
sequences: Single sequence or list of sequences |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (batch_size,) with stability predictions |
|
|
""" |
|
|
if isinstance(sequences, str): |
|
|
sequences = [sequences] |
|
|
|
|
|
embeddings = self.get_embeddings(sequences) |
|
|
predictions = self.head(embeddings).squeeze(-1) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def predict(self, sequences: Union[str, List[str]]) -> List[float]: |
|
|
"""Predict stability scores (convenience method). |
|
|
|
|
|
Args: |
|
|
sequences: Single sequence or list of sequences |
|
|
|
|
|
Returns: |
|
|
List of stability scores |
|
|
""" |
|
|
self.eval() |
|
|
with torch.no_grad(): |
|
|
preds = self.forward(sequences) |
|
|
return preds.cpu().tolist() |
|
|
|
|
|
def to(self, device: Union[str, torch.device]) -> 'StabilityPredictor': |
|
|
"""Move model to device.""" |
|
|
self.device = str(device) |
|
|
self.esm = self.esm.to(device) |
|
|
self.head = self.head.to(device) |
|
|
return super().to(device) |
|
|
|
|
|
|
|
|
class BindingPredictor(StabilityPredictor): |
|
|
"""ESM-2 based binding predictor. |
|
|
|
|
|
Same architecture as StabilityPredictor but intended for binding |
|
|
affinity prediction. Currently only supports binary classification |
|
|
(binder vs non-binder) due to Propedia dataset limitations. |
|
|
|
|
|
For regression tasks, additional data with continuous binding affinities |
|
|
(e.g., from PDBbind) would be needed. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
esm_model: str = "esm2_t6_8M_UR50D", |
|
|
hidden_dims: Optional[List[int]] = None, |
|
|
dropout: float = 0.1, |
|
|
freeze_esm: bool = True, |
|
|
device: Optional[str] = None, |
|
|
use_sigmoid: bool = True, |
|
|
): |
|
|
"""Initialize binding predictor. |
|
|
|
|
|
Args: |
|
|
esm_model: Name of ESM-2 model to use |
|
|
hidden_dims: Hidden layer dimensions for MLP head |
|
|
dropout: Dropout rate |
|
|
freeze_esm: Whether to freeze ESM-2 |
|
|
device: Device to load model on |
|
|
use_sigmoid: Whether to apply sigmoid for binary classification |
|
|
""" |
|
|
super().__init__( |
|
|
esm_model=esm_model, |
|
|
hidden_dims=hidden_dims, |
|
|
dropout=dropout, |
|
|
freeze_esm=freeze_esm, |
|
|
device=device, |
|
|
) |
|
|
self.use_sigmoid = use_sigmoid |
|
|
logger.info(f"BindingPredictor initialized, use_sigmoid={use_sigmoid}") |
|
|
|
|
|
def forward(self, sequences: Union[str, List[str]]) -> torch.Tensor: |
|
|
"""Predict binding score for sequences. |
|
|
|
|
|
Args: |
|
|
sequences: Single sequence or list of sequences |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (batch_size,) with binding predictions. |
|
|
If use_sigmoid=True, values are in [0, 1]. |
|
|
""" |
|
|
preds = super().forward(sequences) |
|
|
if self.use_sigmoid: |
|
|
preds = torch.sigmoid(preds) |
|
|
return preds |
|
|
|