File size: 5,232 Bytes
98ed1b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
MitoInteract Model Class Definition
Copy this file to load the model for inference.
"""
import torch
import torch.nn as nn
from transformers import EsmModel, EsmTokenizer, AutoModel, AutoTokenizer

class MitoInteract(nn.Module):
    def __init__(
        self,
        esm_model_name="facebook/esm2_t33_650M_UR50D",
        mol_model_name="seyonec/ChemBERTa-zinc-base-v1",
        protein_dim=1280,
        mol_dim=768,
        proj_dim=256,
        n_heads=8,
        dropout=0.1,
        freeze_encoders=True,
    ):
        super().__init__()
        self.freeze_encoders = freeze_encoders
        self.esm = EsmModel.from_pretrained(esm_model_name)
        self.protein_dim = protein_dim
        self.mol_encoder = AutoModel.from_pretrained(mol_model_name)
        self.mol_dim = mol_dim
        if freeze_encoders:
            for p in self.esm.parameters(): p.requires_grad = False
            for p in self.mol_encoder.parameters(): p.requires_grad = False
        self.prot_proj = nn.Sequential(
            nn.Linear(protein_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
        self.mol_proj = nn.Sequential(
            nn.Linear(mol_dim, proj_dim), nn.LayerNorm(proj_dim), nn.ReLU(), nn.Dropout(dropout))
        self.cross_attn_mol2prot = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
        self.cross_attn_prot2mol = nn.MultiheadAttention(proj_dim, n_heads, dropout=dropout, batch_first=True)
        self.ln_mol2prot = nn.LayerNorm(proj_dim)
        self.ln_prot2mol = nn.LayerNorm(proj_dim)
        fused_dim = proj_dim * 2
        self.mlp = nn.Sequential(
            nn.Linear(fused_dim, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128, 1))
    
    def encode_protein(self, input_ids, attention_mask):
        ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
        with ctx:
            out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
        return pooled, out.last_hidden_state
    
    def encode_molecule(self, input_ids, attention_mask):
        ctx = torch.no_grad() if self.freeze_encoders else torch.enable_grad()
        with ctx:
            out = self.mol_encoder(input_ids=input_ids, attention_mask=attention_mask)
        return out.pooler_output, out.last_hidden_state
    
    def forward(self, prot_input_ids, prot_attention_mask, mol_input_ids, mol_attention_mask):
        prot_pooled, prot_seq = self.encode_protein(prot_input_ids, prot_attention_mask)
        mol_pooled, mol_seq = self.encode_molecule(mol_input_ids, mol_attention_mask)
        prot_seq_proj = self.prot_proj(prot_seq)
        mol_seq_proj = self.mol_proj(mol_seq)
        prot_q = self.prot_proj(prot_pooled).unsqueeze(1)
        mol_q = self.mol_proj(mol_pooled).unsqueeze(1)
        prot_pad_mask = (prot_attention_mask == 0)
        mol_pad_mask = (mol_attention_mask == 0)
        h_prot2mol, _ = self.cross_attn_prot2mol(prot_q, mol_seq_proj, mol_seq_proj, key_padding_mask=mol_pad_mask)
        h_mol2prot, _ = self.cross_attn_mol2prot(mol_q, prot_seq_proj, prot_seq_proj, key_padding_mask=prot_pad_mask)
        h_prot2mol = self.ln_prot2mol(h_prot2mol.squeeze(1))
        h_mol2prot = self.ln_mol2prot(h_mol2prot.squeeze(1))
        fused = torch.cat([h_prot2mol, h_mol2prot], dim=-1)
        return self.mlp(fused).squeeze(-1)


def load_model(checkpoint_path, device="cpu"):
    """Load trained MitoInteract model."""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    config = checkpoint["config"]
    model = MitoInteract(
        esm_model_name=config["esm_model"],
        mol_model_name=config["mol_model"],
        protein_dim=config["protein_dim"],
        mol_dim=config["mol_dim"],
        proj_dim=config["proj_dim"],
        n_heads=config["n_heads"],
        dropout=config["dropout"],
        freeze_encoders=True,
    )
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return model, config


def predict_binding(model, protein_seq, smiles, device="cpu"):
    """Predict binding affinity (pKd) for a protein-molecule pair."""
    prot_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    mol_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
    
    prot_enc = prot_tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=512)
    mol_enc = mol_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=200)
    
    model = model.to(device)
    with torch.no_grad():
        pKd = model(
            prot_enc["input_ids"].to(device), prot_enc["attention_mask"].to(device),
            mol_enc["input_ids"].to(device), mol_enc["attention_mask"].to(device),
        )
    
    pKd_val = pKd.item()
    Kd_uM = 10 ** (-pKd_val) * 1e6
    return {"pKd": pKd_val, "Kd_uM": Kd_uM}