""" MitoInteract Model Class Definition Copy this file to load the model for inference. """ import torch import torch.nn as nn from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer class MitoInteract(nn.Module): def __init__( self, esm_model_name="facebook/esm2_t33_650M_UR50D", mol_model_name="seyonec/ChemBERTa-zinc-base-v1", protein_dim=1280, mol_dim=768, proj_dim=256, n_heads=8, dropout=0.1, freeze_encoders=True, ): super().__init__() self.freeze_encoders = freeze_encoders self.esm = EsmModel.from_pretrained(esm_model_name) self.protein_dim = protein_dim self.mol_encoder = AutoModel.from_pretrained(mol_model_name) self.mol_dim = mol_dim if freeze_encoders: for p in self.esm.parameters(): p.requires_grad = False for p in self.mol_encoder.parameters(): p.requires_grad = False self.prot_proj = nn.Sequential( nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout)) self.mol_proj = nn.Sequential( nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout)) self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True) self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True) self.ln_mol2prot = nn.LayerNorm(proj_dim) self.ln_prot2mol = nn.LayerNorm(proj_dim) fused_dim = proj_dim * 2 self.mlp = nn.Sequential( nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 1)) def encode_protein(self, input_ids, attention_mask): ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad() with ctx: out = self.esm(input_ids=input_ids, attention_mask=attention_mask) mask = attention_mask.unsqueeze(-1).float() pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9) return pooled, out.last_hidden_state def encode_molecule(self, input_ids, attention_mask): ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad() with ctx: out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask) return out.pooler_output, out.last_hidden_state def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask): prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask) mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask) prot_seq_proj = self.prot_proj(prot_seq) mol_seq_proj = self.mol_proj(mol_seq) prot_q = self.prot_proj(prot_pooled).unsqueeze(1) mol_q = self.mol_proj(mol_pooled).unsqueeze(1) prot_pad_mask = (prot_attention_mask == 0) mol_pad_mask = (mol_attention_mask == 0) h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask) h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask) h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1)) h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1)) fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1) return self.mlp(fused).squeeze(-1) def load_model(checkpoint_path, device="cpu"): """Load trained MitoInteract model.""" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) config = checkpoint["config"] model = MitoInteract( esm_model_name=config["esm_model"], mol_model_name=config["mol_model"], protein_dim=config["protein_dim"], mol_dim=config["mol_dim"], proj_dim=config["proj_dim"], n_heads=config["n_heads"], dropout=config["dropout"], freeze_encoders=True, ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, config def predict_binding(model, protein_seq, smiles, device="cpu"): """Predict binding affinity (pKd) for a protein-molecule pair.""" prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512) mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200) model = model.to(device) with torch.no_grad(): pKd = model( prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device), mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device), ) pKd_val = pKd.item() Kd_uM = 10 ** (-pKd_val) * 1e6 return {"pKd": pKd_val, "Kd_uM": Kd_uM}