| """ |
| MitoInteract Model Class Definition |
| Copy this file to load the model for inference. |
| """ |
| import torch |
| import torch.nn as nn |
| from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer |
|
|
| class MitoInteract(nn.Module): |
| def __init__( |
| self, |
| esm_model_name="facebook/esm2_t33_650M_UR50D", |
| mol_model_name="seyonec/ChemBERTa-zinc-base-v1", |
| protein_dim=1280, |
| mol_dim=768, |
| proj_dim=256, |
| n_heads=8, |
| dropout=0.1, |
| freeze_encoders=True, |
| ): |
| super().__init__() |
| self.freeze_encoders = freeze_encoders |
| self.esm = EsmModel.from_pretrained(esm_model_name) |
| self.protein_dim = protein_dim |
| self.mol_encoder = AutoModel.from_pretrained(mol_model_name) |
| self.mol_dim = mol_dim |
| if freeze_encoders: |
| for p in self.esm.parameters(): p.requires_grad = False |
| for p in self.mol_encoder.parameters(): p.requires_grad = False |
| self.prot_proj = nn.Sequential( |
| nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout)) |
| self.mol_proj = nn.Sequential( |
| nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout)) |
| self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True) |
| self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True) |
| self.ln_mol2prot = nn.LayerNorm(proj_dim) |
| self.ln_prot2mol = nn.LayerNorm(proj_dim) |
| fused_dim = proj_dim * 2 |
| self.mlp = nn.Sequential( |
| nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(128, 1)) |
| |
| def encode_protein(self, input_ids, attention_mask): |
| ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad() |
| with ctx: |
| out = self.esm(input_ids=input_ids, attention_mask=attention_mask) |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| return pooled, out.last_hidden_state |
| |
| def encode_molecule(self, input_ids, attention_mask): |
| ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad() |
| with ctx: |
| out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask) |
| return out.pooler_output, out.last_hidden_state |
| |
| def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask): |
| prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask) |
| mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask) |
| prot_seq_proj = self.prot_proj(prot_seq) |
| mol_seq_proj = self.mol_proj(mol_seq) |
| prot_q = self.prot_proj(prot_pooled).unsqueeze(1) |
| mol_q = self.mol_proj(mol_pooled).unsqueeze(1) |
| prot_pad_mask = (prot_attention_mask == 0) |
| mol_pad_mask = (mol_attention_mask == 0) |
| h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask) |
| h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask) |
| h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1)) |
| h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1)) |
| fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1) |
| return self.mlp(fused).squeeze(-1) |
|
|
|
|
| def load_model(checkpoint_path, device="cpu"): |
| """Load trained MitoInteract model.""" |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| config = checkpoint["config"] |
| model = MitoInteract( |
| esm_model_name=config["esm_model"], |
| mol_model_name=config["mol_model"], |
| protein_dim=config["protein_dim"], |
| mol_dim=config["mol_dim"], |
| proj_dim=config["proj_dim"], |
| n_heads=config["n_heads"], |
| dropout=config["dropout"], |
| freeze_encoders=True, |
| ) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
| return model, config |
|
|
|
|
| def predict_binding(model, protein_seq, smiles, device="cpu"): |
| """Predict binding affinity (pKd) for a protein-molecule pair.""" |
| prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
| |
| prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512) |
| mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200) |
| |
| model = model.to(device) |
| with torch.no_grad(): |
| pKd = model( |
| prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device), |
| mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device), |
| ) |
| |
| pKd_val = pKd.item() |
| Kd_uM = 10 ** (-pKd_val) * 1e6 |
| return {"pKd": pKd_val, "Kd_uM": Kd_uM} |
|
|