| import torch |
| import torch.nn as nn |
| from torch_geometric.nn import GATConv |
| from torch_geometric.utils import to_dense_batch |
| import torch.nn.functional as F |
|
|
|
|
| class CrossAttentionLayer(nn.Module): |
| def __init__(self, feature_dim, num_heads=4, dropout=0.1): |
| super().__init__() |
| |
| |
| self.attention = nn.MultiheadAttention( |
| feature_dim, num_heads, dropout=dropout, batch_first=True |
| ) |
|
|
| |
| self.norm = nn.LayerNorm(feature_dim) |
|
|
| |
| self.ff = nn.Sequential( |
| nn.Linear(feature_dim, feature_dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(feature_dim * 4, feature_dim), |
| ) |
| self.norm_ff = nn.LayerNorm(feature_dim) |
| self.last_attention_weights = None |
|
|
| def forward(self, ligand_features, protein_features, key_padding_mask=None): |
| |
| |
| |
| |
| |
| |
| attention_output, attn_weights = self.attention( |
| query=ligand_features, |
| key=protein_features, |
| value=protein_features, |
| key_padding_mask=key_padding_mask, |
| need_weights=True, |
| average_attn_weights=True, |
| ) |
| self.last_attention_weights = attn_weights.detach().cpu() |
|
|
| |
| ligand_features = self.norm(ligand_features + attention_output) |
|
|
| |
| ff_output = self.ff(ligand_features) |
| ligand_features = self.norm_ff(ligand_features + ff_output) |
|
|
| return ligand_features |
|
|
|
|
| class BindingAffinityModel(nn.Module): |
| def __init__( |
| self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3 |
| ): |
| super().__init__() |
| self.dropout = dropout |
| self.hidden_channels = hidden_channels |
|
|
| |
| |
| self.gat1 = GATConv( |
| num_node_features, hidden_channels, heads=gat_heads, concat=False |
| ) |
| self.gat2 = GATConv( |
| hidden_channels, hidden_channels, heads=gat_heads, concat=False |
| ) |
| self.gat3 = GATConv( |
| hidden_channels, hidden_channels, heads=gat_heads, concat=False |
| ) |
|
|
| |
| self.protein_embedding = nn.Embedding(22, hidden_channels) |
| |
| self.prot_conv = nn.Conv1d( |
| hidden_channels, hidden_channels, kernel_size=3, padding=1 |
| ) |
|
|
| |
| self.cross_attention = CrossAttentionLayer( |
| feature_dim=hidden_channels, num_heads=4, dropout=dropout |
| ) |
|
|
| self.fc1 = nn.Linear(hidden_channels, hidden_channels) |
| self.fc2 = nn.Linear(hidden_channels, 1) |
|
|
| def forward(self, x, edge_index, batch, protein_seq): |
| |
| x = F.elu(self.gat1(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| x = F.elu(self.gat2(x, edge_index)) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| x = F.elu(self.gat3(x, edge_index)) |
|
|
| |
| |
| ligand_dense, ligand_mask = to_dense_batch(x, batch) |
| |
| |
|
|
| batch_size = ligand_dense.size(0) |
| protein_seq = protein_seq.view(batch_size, -1) |
|
|
| |
| p = self.protein_embedding(protein_seq) |
|
|
| |
| p = p.permute(0, 2, 1) |
| p = F.relu(self.prot_conv(p)) |
| p = p.permute(0, 2, 1) |
|
|
| |
| |
| protein_pad_mask = protein_seq == 0 |
|
|
| |
| x_cross = self.cross_attention( |
| ligand_dense, p, key_padding_mask=protein_pad_mask |
| ) |
|
|
| |
| |
| mask_expanded = ligand_mask.unsqueeze(-1) |
|
|
| |
| x_cross = x_cross * mask_expanded |
|
|
| |
| sum_features = torch.sum(x_cross, dim=1) |
| num_atoms = torch.sum(mask_expanded, dim=1) |
| pooled_x = sum_features / (num_atoms + 1e-6) |
|
|
| |
| out = F.relu(self.fc1(pooled_x)) |
| out = F.dropout(out, p=self.dropout, training=self.training) |
| out = self.fc2(out) |
| return out |
|
|