MitoInteract v1 - Pearson R=-0.9107
Browse files- README.md +94 -0
- config.json +24 -0
- full_model.pt +3 -0
- mitointeract_weights.pt +3 -0
- model.py +112 -0
README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- biology
|
| 5 |
+
- protein
|
| 6 |
+
- drug-target-interaction
|
| 7 |
+
- mitochondria
|
| 8 |
+
- apoptosis
|
| 9 |
+
- binding-affinity
|
| 10 |
+
- esm2
|
| 11 |
+
- chemberta
|
| 12 |
+
datasets:
|
| 13 |
+
- jglaser/binding_affinity
|
| 14 |
+
metrics:
|
| 15 |
+
- pearsonr
|
| 16 |
+
- spearmanr
|
| 17 |
+
- rmse
|
| 18 |
+
- mae
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# MitoInteract: Protein-Molecule Binding Affinity Prediction for Mitochondrial Apoptosis Research
|
| 22 |
+
|
| 23 |
+
## Overview
|
| 24 |
+
|
| 25 |
+
MitoInteract is a dual-encoder model that predicts **binding affinity (pKd)** between any protein and any molecule.
|
| 26 |
+
It combines:
|
| 27 |
+
- **ESM-2 650M** (protein encoder) for protein sequence understanding
|
| 28 |
+
- **ChemBERTa** (molecule encoder) for SMILES-based molecular representation
|
| 29 |
+
- **Bidirectional cross-attention** fusion layer
|
| 30 |
+
- **4-layer MLP** regression head
|
| 31 |
+
|
| 32 |
+
## Intended Use
|
| 33 |
+
|
| 34 |
+
This model is designed for **mitochondrial apoptosis research**, enabling researchers to:
|
| 35 |
+
- Predict how ceramides interact with mitochondrial membrane proteins (VDAC1, VDAC2)
|
| 36 |
+
- Screen BCL-2 family protein interactions with BH3 mimetic drugs (venetoclax, navitoclax, ABT-737)
|
| 37 |
+
- Explore protein-lipid interactions in the apoptosis pathway
|
| 38 |
+
- Run in-silico binding experiments before wet-lab validation
|
| 39 |
+
|
| 40 |
+
## Quick Start
|
| 41 |
+
|
| 42 |
+
```python
|
| 43 |
+
from model import load_model, predict_binding
|
| 44 |
+
|
| 45 |
+
# Load model
|
| 46 |
+
model, config = load_model("full_model.pt", device="cuda")
|
| 47 |
+
|
| 48 |
+
# Predict ceramide C16 binding to VDAC1
|
| 49 |
+
result = predict_binding(
|
| 50 |
+
model,
|
| 51 |
+
protein_seq="MPPYLTFGLKAGALLPLTLPYVRAEAVTKLKLTLNAFEGASK...", # VDAC1
|
| 52 |
+
smiles="CCCCCCCCCCCCCCCC(=O)N[C@@H](CO)[C@H](O)/C=C/CCCCCCCCCCCCC", # Ceramide C16
|
| 53 |
+
device="cuda"
|
| 54 |
+
)
|
| 55 |
+
print(f"Predicted pKd: {result['pKd']:.3f}")
|
| 56 |
+
print(f"Predicted Kd: {result['Kd_uM']:.3f} µM")
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
## Key Apoptosis Targets
|
| 60 |
+
|
| 61 |
+
| Protein | Role in Apoptosis |
|
| 62 |
+
|---------|-------------------|
|
| 63 |
+
| BCL-2 | Anti-apoptotic, prevents MOMP |
|
| 64 |
+
| BCL-XL | Anti-apoptotic, sequesters BAX/BAK |
|
| 65 |
+
| BAX | Pro-apoptotic, forms pores in outer membrane |
|
| 66 |
+
| BAK | Pro-apoptotic, oligomerizes in membrane |
|
| 67 |
+
| VDAC1 | Voltage-dependent anion channel, ceramide target |
|
| 68 |
+
| Cytochrome c | Released during MOMP, activates caspase cascade |
|
| 69 |
+
|
| 70 |
+
## Key Molecules
|
| 71 |
+
|
| 72 |
+
| Molecule | Role |
|
| 73 |
+
|----------|------|
|
| 74 |
+
| Ceramide C16 | Lipid mediator, promotes MOMP via VDAC |
|
| 75 |
+
| Ceramide C2 | Short-chain ceramide analog |
|
| 76 |
+
| Venetoclax | BCL-2 inhibitor (FDA-approved) |
|
| 77 |
+
| Navitoclax | BCL-2/BCL-XL dual inhibitor |
|
| 78 |
+
| ABT-737 | BCL-2/BCL-XL/BCL-w inhibitor |
|
| 79 |
+
| Cardiolipin | Mitochondrial inner membrane lipid |
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
- **Dataset**: jglaser/binding_affinity (1.9M protein-ligand pairs)
|
| 84 |
+
- **Architecture**: ESM-2 650M (frozen) + ChemBERTa (frozen) + Cross-Attention + MLP
|
| 85 |
+
- **Training**: AdamW, lr=1e-3, cosine schedule, early stopping
|
| 86 |
+
- **Best Validation Pearson R**: -0.9107
|
| 87 |
+
|
| 88 |
+
## Citation
|
| 89 |
+
|
| 90 |
+
Based on:
|
| 91 |
+
- BAPULM (arxiv:2411.04150) - frozen encoder + MLP pattern
|
| 92 |
+
- SSM-DTA (arxiv:2206.09818) - CLS cross-attention fusion
|
| 93 |
+
- ESM-2 (arxiv:2202.03555) - protein language model
|
| 94 |
+
- ChemBERTa (arxiv:2010.09885) - molecular language model
|
config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"esm_model": "facebook/esm2_t12_35M_UR50D",
|
| 3 |
+
"mol_model": "seyonec/ChemBERTa-zinc-base-v1",
|
| 4 |
+
"protein_dim": 480,
|
| 5 |
+
"mol_dim": 768,
|
| 6 |
+
"proj_dim": 256,
|
| 7 |
+
"n_heads": 8,
|
| 8 |
+
"dropout": 0.1,
|
| 9 |
+
"freeze_encoders": true,
|
| 10 |
+
"max_prot_len": 512,
|
| 11 |
+
"max_mol_len": 200,
|
| 12 |
+
"max_train_samples": 32,
|
| 13 |
+
"max_val_samples": 16,
|
| 14 |
+
"val_split": 0.05,
|
| 15 |
+
"batch_size": 4,
|
| 16 |
+
"lr": 0.001,
|
| 17 |
+
"weight_decay": 0.01,
|
| 18 |
+
"epochs": 2,
|
| 19 |
+
"warmup_steps": 500,
|
| 20 |
+
"grad_clip": 1.0,
|
| 21 |
+
"patience": 5,
|
| 22 |
+
"hub_model_id": "ethanolivertroy/MitoInteract",
|
| 23 |
+
"output_dir": "/app/mitointeract_output"
|
| 24 |
+
}
|
full_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a3a4e5808ff918d4e5b24d2bfc96c4c90e63be9616b8833afcb09feec33324b
|
| 3 |
+
size 315671990
|
mitointeract_weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:086de9279feeb3da3ae89826d3eb1d76ac9a2fbf3bc7ecbd6c3fcd492b604c2e
|
| 3 |
+
size 5129631
|
model.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MitoInteract Model Class Definition
|
| 3 |
+
Copy this file to load the model for inference.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
class MitoInteract(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 13 |
+
mol_model_name="seyonec/ChemBERTa-zinc-base-v1",
|
| 14 |
+
protein_dim=1280,
|
| 15 |
+
mol_dim=768,
|
| 16 |
+
proj_dim=256,
|
| 17 |
+
n_heads=8,
|
| 18 |
+
dropout=0.1,
|
| 19 |
+
freeze_encoders=True,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.freeze_encoders = freeze_encoders
|
| 23 |
+
self.esm = EsmModel.from_pretrained(esm_model_name)
|
| 24 |
+
self.protein_dim = protein_dim
|
| 25 |
+
self.mol_encoder = AutoModel.from_pretrained(mol_model_name)
|
| 26 |
+
self.mol_dim = mol_dim
|
| 27 |
+
if freeze_encoders:
|
| 28 |
+
for p in self.esm.parameters(): p.requires_grad = False
|
| 29 |
+
for p in self.mol_encoder.parameters(): p.requires_grad = False
|
| 30 |
+
self.prot_proj = nn.Sequential(
|
| 31 |
+
nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
|
| 32 |
+
self.mol_proj = nn.Sequential(
|
| 33 |
+
nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
|
| 34 |
+
self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
|
| 35 |
+
self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
|
| 36 |
+
self.ln_mol2prot = nn.LayerNorm(proj_dim)
|
| 37 |
+
self.ln_prot2mol = nn.LayerNorm(proj_dim)
|
| 38 |
+
fused_dim = proj_dim * 2
|
| 39 |
+
self.mlp = nn.Sequential(
|
| 40 |
+
nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
|
| 41 |
+
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout),
|
| 42 |
+
nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout),
|
| 43 |
+
nn.Linear(128, 1))
|
| 44 |
+
|
| 45 |
+
def encode_protein(self, input_ids, attention_mask):
|
| 46 |
+
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
|
| 47 |
+
with ctx:
|
| 48 |
+
out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
|
| 49 |
+
mask = attention_mask.unsqueeze(-1).float()
|
| 50 |
+
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
|
| 51 |
+
return pooled, out.last_hidden_state
|
| 52 |
+
|
| 53 |
+
def encode_molecule(self, input_ids, attention_mask):
|
| 54 |
+
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
|
| 55 |
+
with ctx:
|
| 56 |
+
out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 57 |
+
return out.pooler_output, out.last_hidden_state
|
| 58 |
+
|
| 59 |
+
def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask):
|
| 60 |
+
prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask)
|
| 61 |
+
mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask)
|
| 62 |
+
prot_seq_proj = self.prot_proj(prot_seq)
|
| 63 |
+
mol_seq_proj = self.mol_proj(mol_seq)
|
| 64 |
+
prot_q = self.prot_proj(prot_pooled).unsqueeze(1)
|
| 65 |
+
mol_q = self.mol_proj(mol_pooled).unsqueeze(1)
|
| 66 |
+
prot_pad_mask = (prot_attention_mask == 0)
|
| 67 |
+
mol_pad_mask = (mol_attention_mask == 0)
|
| 68 |
+
h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask)
|
| 69 |
+
h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask)
|
| 70 |
+
h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1))
|
| 71 |
+
h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1))
|
| 72 |
+
fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1)
|
| 73 |
+
return self.mlp(fused).squeeze(-1)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_model(checkpoint_path, device="cpu"):
|
| 77 |
+
"""Load trained MitoInteract model."""
|
| 78 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 79 |
+
config = checkpoint["config"]
|
| 80 |
+
model = MitoInteract(
|
| 81 |
+
esm_model_name=config["esm_model"],
|
| 82 |
+
mol_model_name=config["mol_model"],
|
| 83 |
+
protein_dim=config["protein_dim"],
|
| 84 |
+
mol_dim=config["mol_dim"],
|
| 85 |
+
proj_dim=config["proj_dim"],
|
| 86 |
+
n_heads=config["n_heads"],
|
| 87 |
+
dropout=config["dropout"],
|
| 88 |
+
freeze_encoders=True,
|
| 89 |
+
)
|
| 90 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 91 |
+
model.eval()
|
| 92 |
+
return model, config
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def predict_binding(model, protein_seq, smiles, device="cpu"):
|
| 96 |
+
"""Predict binding affinity (pKd) for a protein-molecule pair."""
|
| 97 |
+
prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 98 |
+
mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
| 99 |
+
|
| 100 |
+
prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
| 101 |
+
mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200)
|
| 102 |
+
|
| 103 |
+
model = model.to(device)
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
pKd = model(
|
| 106 |
+
prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device),
|
| 107 |
+
mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
pKd_val = pKd.item()
|
| 111 |
+
Kd_uM = 10 ** (-pKd_val) * 1e6
|
| 112 |
+
return {"pKd": pKd_val, "Kd_uM": Kd_uM}
|