ethanolivertroy commited on
Commit
98ed1b7
·
verified ·
1 Parent(s): 5a62889

MitoInteract v1 - Pearson R=-0.9107

Browse files
Files changed (5) hide show
  1. README.md +94 -0
  2. config.json +24 -0
  3. full_model.pt +3 -0
  4. mitointeract_weights.pt +3 -0
  5. model.py +112 -0
README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - biology
5
+ - protein
6
+ - drug-target-interaction
7
+ - mitochondria
8
+ - apoptosis
9
+ - binding-affinity
10
+ - esm2
11
+ - chemberta
12
+ datasets:
13
+ - jglaser/binding_affinity
14
+ metrics:
15
+ - pearsonr
16
+ - spearmanr
17
+ - rmse
18
+ - mae
19
+ ---
20
+
21
+ # MitoInteract: Protein-Molecule Binding Affinity Prediction for Mitochondrial Apoptosis Research
22
+
23
+ ## Overview
24
+
25
+ MitoInteract is a dual-encoder model that predicts **binding affinity (pKd)** between any protein and any molecule.
26
+ It combines:
27
+ - **ESM-2 650M** (protein encoder) for protein sequence understanding
28
+ - **ChemBERTa** (molecule encoder) for SMILES-based molecular representation
29
+ - **Bidirectional cross-attention** fusion layer
30
+ - **4-layer MLP** regression head
31
+
32
+ ## Intended Use
33
+
34
+ This model is designed for **mitochondrial apoptosis research**, enabling researchers to:
35
+ - Predict how ceramides interact with mitochondrial membrane proteins (VDAC1, VDAC2)
36
+ - Screen BCL-2 family protein interactions with BH3 mimetic drugs (venetoclax, navitoclax, ABT-737)
37
+ - Explore protein-lipid interactions in the apoptosis pathway
38
+ - Run in-silico binding experiments before wet-lab validation
39
+
40
+ ## Quick Start
41
+
42
+ ```python
43
+ from model import load_model, predict_binding
44
+
45
+ # Load model
46
+ model, config = load_model("full_model.pt", device="cuda")
47
+
48
+ # Predict ceramide C16 binding to VDAC1
49
+ result = predict_binding(
50
+ model,
51
+ protein_seq="MPPYLTFGLKAGALLPLTLPYVRAEAVTKLKLTLNAFEGASK...", # VDAC1
52
+ smiles="CCCCCCCCCCCCCCCC(=O)N[C@@H](CO)[C@H](O)/C=C/CCCCCCCCCCCCC", # Ceramide C16
53
+ device="cuda"
54
+ )
55
+ print(f"Predicted pKd: {result['pKd']:.3f}")
56
+ print(f"Predicted Kd: {result['Kd_uM']:.3f} µM")
57
+ ```
58
+
59
+ ## Key Apoptosis Targets
60
+
61
+ | Protein | Role in Apoptosis |
62
+ |---------|-------------------|
63
+ | BCL-2 | Anti-apoptotic, prevents MOMP |
64
+ | BCL-XL | Anti-apoptotic, sequesters BAX/BAK |
65
+ | BAX | Pro-apoptotic, forms pores in outer membrane |
66
+ | BAK | Pro-apoptotic, oligomerizes in membrane |
67
+ | VDAC1 | Voltage-dependent anion channel, ceramide target |
68
+ | Cytochrome c | Released during MOMP, activates caspase cascade |
69
+
70
+ ## Key Molecules
71
+
72
+ | Molecule | Role |
73
+ |----------|------|
74
+ | Ceramide C16 | Lipid mediator, promotes MOMP via VDAC |
75
+ | Ceramide C2 | Short-chain ceramide analog |
76
+ | Venetoclax | BCL-2 inhibitor (FDA-approved) |
77
+ | Navitoclax | BCL-2/BCL-XL dual inhibitor |
78
+ | ABT-737 | BCL-2/BCL-XL/BCL-w inhibitor |
79
+ | Cardiolipin | Mitochondrial inner membrane lipid |
80
+
81
+ ## Training Details
82
+
83
+ - **Dataset**: jglaser/binding_affinity (1.9M protein-ligand pairs)
84
+ - **Architecture**: ESM-2 650M (frozen) + ChemBERTa (frozen) + Cross-Attention + MLP
85
+ - **Training**: AdamW, lr=1e-3, cosine schedule, early stopping
86
+ - **Best Validation Pearson R**: -0.9107
87
+
88
+ ## Citation
89
+
90
+ Based on:
91
+ - BAPULM (arxiv:2411.04150) - frozen encoder + MLP pattern
92
+ - SSM-DTA (arxiv:2206.09818) - CLS cross-attention fusion
93
+ - ESM-2 (arxiv:2202.03555) - protein language model
94
+ - ChemBERTa (arxiv:2010.09885) - molecular language model
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "esm_model": "facebook/esm2_t12_35M_UR50D",
3
+ "mol_model": "seyonec/ChemBERTa-zinc-base-v1",
4
+ "protein_dim": 480,
5
+ "mol_dim": 768,
6
+ "proj_dim": 256,
7
+ "n_heads": 8,
8
+ "dropout": 0.1,
9
+ "freeze_encoders": true,
10
+ "max_prot_len": 512,
11
+ "max_mol_len": 200,
12
+ "max_train_samples": 32,
13
+ "max_val_samples": 16,
14
+ "val_split": 0.05,
15
+ "batch_size": 4,
16
+ "lr": 0.001,
17
+ "weight_decay": 0.01,
18
+ "epochs": 2,
19
+ "warmup_steps": 500,
20
+ "grad_clip": 1.0,
21
+ "patience": 5,
22
+ "hub_model_id": "ethanolivertroy/MitoInteract",
23
+ "output_dir": "/app/mitointeract_output"
24
+ }
full_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a3a4e5808ff918d4e5b24d2bfc96c4c90e63be9616b8833afcb09feec33324b
3
+ size 315671990
mitointeract_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:086de9279feeb3da3ae89826d3eb1d76ac9a2fbf3bc7ecbd6c3fcd492b604c2e
3
+ size 5129631
model.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MitoInteract Model Class Definition
3
+ Copy this file to load the model for inference.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer
8
+
9
+ class MitoInteract(nn.Module):
10
+ def __init__(
11
+ self,
12
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
13
+ mol_model_name="seyonec/ChemBERTa-zinc-base-v1",
14
+ protein_dim=1280,
15
+ mol_dim=768,
16
+ proj_dim=256,
17
+ n_heads=8,
18
+ dropout=0.1,
19
+ freeze_encoders=True,
20
+ ):
21
+ super().__init__()
22
+ self.freeze_encoders = freeze_encoders
23
+ self.esm = EsmModel.from_pretrained(esm_model_name)
24
+ self.protein_dim = protein_dim
25
+ self.mol_encoder = AutoModel.from_pretrained(mol_model_name)
26
+ self.mol_dim = mol_dim
27
+ if freeze_encoders:
28
+ for p in self.esm.parameters(): p.requires_grad = False
29
+ for p in self.mol_encoder.parameters(): p.requires_grad = False
30
+ self.prot_proj = nn.Sequential(
31
+ nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
32
+ self.mol_proj = nn.Sequential(
33
+ nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
34
+ self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
35
+ self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
36
+ self.ln_mol2prot = nn.LayerNorm(proj_dim)
37
+ self.ln_prot2mol = nn.LayerNorm(proj_dim)
38
+ fused_dim = proj_dim * 2
39
+ self.mlp = nn.Sequential(
40
+ nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
41
+ nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout),
42
+ nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout),
43
+ nn.Linear(128, 1))
44
+
45
+ def encode_protein(self, input_ids, attention_mask):
46
+ ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
47
+ with ctx:
48
+ out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
49
+ mask = attention_mask.unsqueeze(-1).float()
50
+ pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
51
+ return pooled, out.last_hidden_state
52
+
53
+ def encode_molecule(self, input_ids, attention_mask):
54
+ ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
55
+ with ctx:
56
+ out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask)
57
+ return out.pooler_output, out.last_hidden_state
58
+
59
+ def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask):
60
+ prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask)
61
+ mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask)
62
+ prot_seq_proj = self.prot_proj(prot_seq)
63
+ mol_seq_proj = self.mol_proj(mol_seq)
64
+ prot_q = self.prot_proj(prot_pooled).unsqueeze(1)
65
+ mol_q = self.mol_proj(mol_pooled).unsqueeze(1)
66
+ prot_pad_mask = (prot_attention_mask == 0)
67
+ mol_pad_mask = (mol_attention_mask == 0)
68
+ h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask)
69
+ h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask)
70
+ h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1))
71
+ h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1))
72
+ fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1)
73
+ return self.mlp(fused).squeeze(-1)
74
+
75
+
76
+ def load_model(checkpoint_path, device="cpu"):
77
+ """Load trained MitoInteract model."""
78
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
79
+ config = checkpoint["config"]
80
+ model = MitoInteract(
81
+ esm_model_name=config["esm_model"],
82
+ mol_model_name=config["mol_model"],
83
+ protein_dim=config["protein_dim"],
84
+ mol_dim=config["mol_dim"],
85
+ proj_dim=config["proj_dim"],
86
+ n_heads=config["n_heads"],
87
+ dropout=config["dropout"],
88
+ freeze_encoders=True,
89
+ )
90
+ model.load_state_dict(checkpoint["model_state_dict"])
91
+ model.eval()
92
+ return model, config
93
+
94
+
95
+ def predict_binding(model, protein_seq, smiles, device="cpu"):
96
+ """Predict binding affinity (pKd) for a protein-molecule pair."""
97
+ prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
98
+ mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
99
+
100
+ prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512)
101
+ mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200)
102
+
103
+ model = model.to(device)
104
+ with torch.no_grad():
105
+ pKd = model(
106
+ prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device),
107
+ mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device),
108
+ )
109
+
110
+ pKd_val = pKd.item()
111
+ Kd_uM = 10 ** (-pKd_val) * 1e6
112
+ return {"pKd": pKd_val, "Kd_uM": Kd_uM}