File size: 5,424 Bytes
7e792a6
 
 
2fdd454
7e792a6
e33b6c9
7e792a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e33b6c9
7e792a6
 
 
2fdd454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e792a6
2fdd454
7e792a6
2fdd454
 
 
e33b6c9
 
 
 
 
 
2fdd454
7e792a6
 
 
 
 
 
 
 
 
 
2fdd454
 
 
 
7e792a6
 
e33b6c9
7e792a6
2fdd454
7e792a6
 
 
1390640
e33b6c9
 
 
7e792a6
 
 
 
 
 
e33b6c9
7e792a6
 
 
 
 
 
 
 
 
 
 
 
 
e33b6c9
7e792a6
e33b6c9
 
 
7e792a6
 
e33b6c9
 
 
 
 
 
7e792a6
e33b6c9
 
 
 
 
 
7e792a6
 
e33b6c9
7e792a6
2fdd454
 
7e792a6
e33b6c9
7e792a6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import math
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)

        # Create a vector of shape (seq_len, 1)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
            1
        )  # (Seq_len, 1)
        # Compute the positional encodings once in log space.
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        # Apply the sin to even positions
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply the cos to odd positions
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # (1, Seq_len, d_model) batch dimension
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)


# class LigandGNN(nn.Module): # GCN CONV
#     def __init__(self, input_dim, hidden_channels):
#         super().__init__()
#         self.hidden_channels = hidden_channels
#
#         self.conv1 = GCNConv(input_dim, hidden_channels)
#         self.conv2 = GCNConv(hidden_channels, hidden_channels)
#         self.conv3 = GCNConv(hidden_channels, hidden_channels)
#         self.dropout = nn.Dropout(0.2)
#
#     def forward(self, x, edge_index, batch):
#         x = self.conv1(x, edge_index)
#         x = x.relu()
#         x = self.dropout(x)
#
#         x = self.conv2(x, edge_index)
#         x = x.relu()
#         x = self.conv3(x, edge_index)
#         x = self.dropout(x)
#
#         # Averaging nodes and got the molecula vector
#         x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
#         return x


class LigandGNN(nn.Module):
    def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
        super().__init__()
        # Heads=4 means we use 4 attention heads
        # Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
        self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
        self.conv2 = GATConv(
            hidden_channels, hidden_channels, heads=heads, concat=False
        )
        self.conv3 = GATConv(
            hidden_channels, hidden_channels, heads=heads, concat=False
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.dropout(x)

        x = self.conv3(x, edge_index)

        # Global Mean Pooling
        x = global_mean_pool(x, batch)
        return x


class ProteinTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=h, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)

        self.fc = nn.Linear(d_model, output_dim)

    def forward(self, x):
        # x: [batch_size, seq_len]
        padding_mask = x == 0  # mask for PAD tokens
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)
        x = self.transformer(x, src_key_padding_mask=padding_mask)

        mask = (~padding_mask).float().unsqueeze(-1)
        x = x * mask

        sum_x = x.sum(dim=1)  # Global average pooling
        token_counts = mask.sum(dim=1).clamp(min=1e-9)
        x = sum_x / token_counts
        x = self.fc(x)
        return x


class BindingAffinityModel(nn.Module):
    def __init__(
        self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
    ):
        super().__init__()
        # Tower 1 - Ligand GNN
        self.ligand_gnn = LigandGNN(
            input_dim=num_node_features,
            hidden_channels=hidden_channels,
            heads=gat_heads,
            dropout=dropout,
        )
        # Tower 2 - Protein Transformer
        self.protein_transformer = ProteinTransformer(
            vocab_size=26,
            d_model=hidden_channels,
            output_dim=hidden_channels,
            dropout=dropout,
        )

        self.head = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1),
        )

    def forward(self, x, edge_index, batch, protein_seq):
        ligand_vec = self.ligand_gnn(x, edge_index, batch)
        batch_size = batch.max().item() + 1
        protein_seq = protein_seq.view(batch_size, -1)

        protein_vec = self.protein_transformer(protein_seq)
        combined = torch.cat([ligand_vec, protein_vec], dim=1)
        return self.head(combined)