File size: 5,232 Bytes
98ed1b7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | """
MitoInteract Model Class Definition
Copy this file to load the model for inference.
"""
import torch
import torch.nn as nn
from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer
class MitoInteract(nn.Module):
def __init__(
self,
esm_model_name="facebook/esm2_t33_650M_UR50D",
mol_model_name="seyonec/ChemBERTa-zinc-base-v1",
protein_dim=1280,
mol_dim=768,
proj_dim=256,
n_heads=8,
dropout=0.1,
freeze_encoders=True,
):
super().__init__()
self.freeze_encoders = freeze_encoders
self.esm = EsmModel.from_pretrained(esm_model_name)
self.protein_dim = protein_dim
self.mol_encoder = AutoModel.from_pretrained(mol_model_name)
self.mol_dim = mol_dim
if freeze_encoders:
for p in self.esm.parameters(): p.requires_grad = False
for p in self.mol_encoder.parameters(): p.requires_grad = False
self.prot_proj = nn.Sequential(
nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
self.mol_proj = nn.Sequential(
nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
self.ln_mol2prot = nn.LayerNorm(proj_dim)
self.ln_prot2mol = nn.LayerNorm(proj_dim)
fused_dim = proj_dim * 2
self.mlp = nn.Sequential(
nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(128, 1))
def encode_protein(self, input_ids, attention_mask):
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
with ctx:
out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
mask = attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return pooled, out.last_hidden_state
def encode_molecule(self, input_ids, attention_mask):
ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
with ctx:
out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask)
return out.pooler_output, out.last_hidden_state
def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask):
prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask)
mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask)
prot_seq_proj = self.prot_proj(prot_seq)
mol_seq_proj = self.mol_proj(mol_seq)
prot_q = self.prot_proj(prot_pooled).unsqueeze(1)
mol_q = self.mol_proj(mol_pooled).unsqueeze(1)
prot_pad_mask = (prot_attention_mask == 0)
mol_pad_mask = (mol_attention_mask == 0)
h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask)
h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask)
h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1))
h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1))
fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1)
return self.mlp(fused).squeeze(-1)
def load_model(checkpoint_path, device="cpu"):
"""Load trained MitoInteract model."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = checkpoint["config"]
model = MitoInteract(
esm_model_name=config["esm_model"],
mol_model_name=config["mol_model"],
protein_dim=config["protein_dim"],
mol_dim=config["mol_dim"],
proj_dim=config["proj_dim"],
n_heads=config["n_heads"],
dropout=config["dropout"],
freeze_encoders=True,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, config
def predict_binding(model, protein_seq, smiles, device="cpu"):
"""Predict binding affinity (pKd) for a protein-molecule pair."""
prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512)
mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200)
model = model.to(device)
with torch.no_grad():
pKd = model(
prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device),
mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device),
)
pKd_val = pKd.item()
Kd_uM = 10 ** (-pKd_val) * 1e6
return {"pKd": pKd_val, "Kd_uM": Kd_uM}
|