| | import math |
| | import torch |
| | import torch.nn as nn |
| | from torch_geometric.nn import GCNConv, GATConv, global_mean_pool |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.seq_len = seq_len |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | pe = torch.zeros(seq_len, d_model) |
| |
|
| | |
| | position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze( |
| | 1 |
| | ) |
| | |
| | div_term = torch.exp( |
| | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| | ) |
| | |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| |
|
| | pe = pe.unsqueeze(0) |
| | self.register_buffer("pe", pe) |
| |
|
| | def forward(self, x): |
| | |
| | x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False) |
| | return self.dropout(x) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class LigandGNN(nn.Module): |
| | def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2): |
| | super().__init__() |
| | |
| | |
| | self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False) |
| | self.conv2 = GATConv( |
| | hidden_channels, hidden_channels, heads=heads, concat=False |
| | ) |
| | self.conv3 = GATConv( |
| | hidden_channels, hidden_channels, heads=heads, concat=False |
| | ) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x, edge_index, batch): |
| | x = self.conv1(x, edge_index) |
| | x = x.relu() |
| | x = self.dropout(x) |
| |
|
| | x = self.conv2(x, edge_index) |
| | x = x.relu() |
| | x = self.dropout(x) |
| |
|
| | x = self.conv3(x, edge_index) |
| |
|
| | |
| | x = global_mean_pool(x, batch) |
| | return x |
| |
|
| |
|
| | class ProteinTransformer(nn.Module): |
| | def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2): |
| | super().__init__() |
| | self.d_model = d_model |
| | self.embedding = nn.Embedding(vocab_size, d_model) |
| | self.pos_encoder = PositionalEncoding(d_model, dropout=dropout) |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=d_model, nhead=h, batch_first=True |
| | ) |
| | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N) |
| |
|
| | self.fc = nn.Linear(d_model, output_dim) |
| |
|
| | def forward(self, x): |
| | |
| | padding_mask = x == 0 |
| | x = self.embedding(x) * math.sqrt(self.d_model) |
| | x = self.pos_encoder(x) |
| | x = self.transformer(x, src_key_padding_mask=padding_mask) |
| |
|
| | mask = (~padding_mask).float().unsqueeze(-1) |
| | x = x * mask |
| |
|
| | sum_x = x.sum(dim=1) |
| | token_counts = mask.sum(dim=1).clamp(min=1e-9) |
| | x = sum_x / token_counts |
| | x = self.fc(x) |
| | return x |
| |
|
| |
|
| | class BindingAffinityModel(nn.Module): |
| | def __init__( |
| | self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2 |
| | ): |
| | super().__init__() |
| | |
| | self.ligand_gnn = LigandGNN( |
| | input_dim=num_node_features, |
| | hidden_channels=hidden_channels, |
| | heads=gat_heads, |
| | dropout=dropout, |
| | ) |
| | |
| | self.protein_transformer = ProteinTransformer( |
| | vocab_size=26, |
| | d_model=hidden_channels, |
| | output_dim=hidden_channels, |
| | dropout=dropout, |
| | ) |
| |
|
| | self.head = nn.Sequential( |
| | nn.Linear(hidden_channels * 2, hidden_channels), |
| | nn.ReLU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_channels, 1), |
| | ) |
| |
|
| | def forward(self, x, edge_index, batch, protein_seq): |
| | ligand_vec = self.ligand_gnn(x, edge_index, batch) |
| | batch_size = batch.max().item() + 1 |
| | protein_seq = protein_seq.view(batch_size, -1) |
| |
|
| | protein_vec = self.protein_transformer(protein_seq) |
| | combined = torch.cat([ligand_vec, protein_vec], dim=1) |
| | return self.head(combined) |
| |
|