File size: 6,179 Bytes
e33b6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13188b8
e33b6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543ad41
 
e33b6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543ad41
e33b6c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_dense_batch
import torch.nn.functional as F


class CrossAttentionLayer(nn.Module):
    def __init__(self, feature_dim, num_heads=4, dropout=0.1):
        super().__init__()
        # Main attention layer
        # Feature dim is the dimension of the hidden features
        self.attention = nn.MultiheadAttention(
            feature_dim, num_heads, dropout=dropout, batch_first=True
        )

        # Normalization layer for stabilizing training
        self.norm = nn.LayerNorm(feature_dim)

        # Feedforward network for further processing, classical transformer style
        self.ff = nn.Sequential(
            nn.Linear(feature_dim, feature_dim * 4),
            nn.GELU(),  # GELU works better with transformers
            nn.Dropout(dropout),
            nn.Linear(feature_dim * 4, feature_dim),
        )
        self.norm_ff = nn.LayerNorm(feature_dim)
        self.last_attention_weights = None

    def forward(self, ligand_features, protein_features, key_padding_mask=None):
        # ligand_features: [Batch, Atoms, Dim] - atoms
        # protein_features: [Batch, Residues, Dim] - amino acids
        # Cross attention:
        # Query = Ligand (What we want to find out)
        # Key, Value = Protein (Where we look for information)
        # Result: "Ligand enriched with knowledge about proteins"
        attention_output, attn_weights = self.attention(
            query=ligand_features,
            key=protein_features,
            value=protein_features,
            key_padding_mask=key_padding_mask,
            need_weights=True,
            average_attn_weights=True,
        )
        self.last_attention_weights = attn_weights.detach().cpu()

        # Residual connection (x + attention(x)) and normalization
        ligand_features = self.norm(ligand_features + attention_output)

        # Feedforward network with residual connection and normalization
        ff_output = self.ff(ligand_features)
        ligand_features = self.norm_ff(ligand_features + ff_output)

        return ligand_features


class BindingAffinityModel(nn.Module):
    def __init__(
        self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3
    ):
        super().__init__()
        self.dropout = dropout
        self.hidden_channels = hidden_channels

        # Tower 1 - Ligand GNN with GAT layers, using 3 GAT layers, so that every atom can "see" up to 3 bonds away,
        # Attention allows to measure the importance of the neighbours
        self.gat1 = GATConv(
            num_node_features, hidden_channels, heads=gat_heads, concat=False
        )
        self.gat2 = GATConv(
            hidden_channels, hidden_channels, heads=gat_heads, concat=False
        )
        self.gat3 = GATConv(
            hidden_channels, hidden_channels, heads=gat_heads, concat=False
        )

        # Tower 2 - Protein Transformer, 22 = 21 amino acids + 1 padding token PAD
        self.protein_embedding = nn.Embedding(22, hidden_channels)
        # Additional positional encoding (simple linear) to give the model information about the order
        self.prot_conv = nn.Conv1d(
            hidden_channels, hidden_channels, kernel_size=3, padding=1
        )

        # Cross-Attention Layer, atoms attending to amino acids
        self.cross_attention = CrossAttentionLayer(
            feature_dim=hidden_channels, num_heads=4, dropout=dropout
        )

        self.fc1 = nn.Linear(hidden_channels, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, 1)  # Final output for regression, pKd

    def forward(self, x, edge_index, batch, protein_seq):
        # Ligand GNN forward pass (Graph -> Node Embeddings)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.elu(self.gat2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.elu(self.gat3(x, edge_index))  # [Total_Atoms, Hidden_Channels]

        # Convert graph into tensor [Batch, Max_Atoms, Hidden_Channels]
        # to_dense_batch adds zeros paddings where necessary to the size of the largest graph in the batch
        ligand_dense, ligand_mask = to_dense_batch(x, batch)
        # ligand_dense: [Batch, Max_Atoms, Hidden_Channels]
        # ligand_mask: [Batch, Max_Atoms] True where there is real atom, False where there is padding

        batch_size = ligand_dense.size(0)
        protein_seq = protein_seq.view(batch_size, -1)  # [Batch, Seq_Len]

        # Protein forward pass protein_seq: [Batch, Seq_Len]
        p = self.protein_embedding(protein_seq)  # [Batch, Seq_Len, Hidden_Channels]

        # A simple convolution to understand local context in amino acids
        p = p.permute(0, 2, 1)  # Change to [Batch, Hidden_Channels, Seq_Len] for Conv1d
        p = F.relu(self.prot_conv(p))
        p = p.permute(0, 2, 1)  # [Batch, Seq, Hidden_Channels]

        # Mask for protein (where PAD=0, True, but MHA needs True where IGNOREME)
        # In Pytorch MHA, the key_padding_mask should be True where we want to ignore
        protein_pad_mask = protein_seq == 0

        # Cross-Attention
        x_cross = self.cross_attention(
            ligand_dense, p, key_padding_mask=protein_pad_mask
        )

        # Pooling over atoms to get a single vector per molecule, considering only real atoms, ignoring paddings
        # ligand mask True where real atom, False where padding
        mask_expanded = ligand_mask.unsqueeze(-1)  # [Batch, Max_Atoms, 1]

        # Zero out the padded atom features
        x_cross = x_cross * mask_expanded

        # Sum the features of real atoms / number of real atoms to get the mean
        sum_features = torch.sum(x_cross, dim=1)  # [Batch, Hidden_Channels]
        num_atoms = torch.sum(mask_expanded, dim=1)  # [Batch, 1]
        pooled_x = sum_features / (num_atoms + 1e-6)  # Avoid division by zero

        # MLP Head
        out = F.relu(self.fc1(pooled_x))
        out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.fc2(out)
        return out