| | import torch |
| | import torch.nn as nn |
| | from torch_geometric.nn import GATConv |
| | from torch_geometric.utils import to_dense_batch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class CrossAttentionLayer(nn.Module): |
| | def __init__(self, feature_dim, num_heads=4, dropout=0.1): |
| | super().__init__() |
| | |
| | |
| | self.attention = nn.MultiheadAttention( |
| | feature_dim, num_heads, dropout=dropout, batch_first=True |
| | ) |
| |
|
| | |
| | self.norm = nn.LayerNorm(feature_dim) |
| |
|
| | |
| | self.ff = nn.Sequential( |
| | nn.Linear(feature_dim, feature_dim * 4), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(feature_dim * 4, feature_dim), |
| | ) |
| | self.norm_ff = nn.LayerNorm(feature_dim) |
| | self.last_attention_weights = None |
| |
|
| | def forward(self, ligand_features, protein_features, key_padding_mask=None): |
| | |
| | |
| | |
| | |
| | |
| | |
| | attention_output, attn_weights = self.attention( |
| | query=ligand_features, |
| | key=protein_features, |
| | value=protein_features, |
| | key_padding_mask=key_padding_mask, |
| | need_weights=True, |
| | average_attn_weights=True, |
| | ) |
| | self.last_attention_weights = attn_weights.detach().cpu() |
| |
|
| | |
| | ligand_features = self.norm(ligand_features + attention_output) |
| |
|
| | |
| | ff_output = self.ff(ligand_features) |
| | ligand_features = self.norm_ff(ligand_features + ff_output) |
| |
|
| | return ligand_features |
| |
|
| |
|
| | class BindingAffinityModel(nn.Module): |
| | def __init__( |
| | self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3 |
| | ): |
| | super().__init__() |
| | self.dropout = dropout |
| | self.hidden_channels = hidden_channels |
| |
|
| | |
| | |
| | self.gat1 = GATConv( |
| | num_node_features, hidden_channels, heads=gat_heads, concat=False |
| | ) |
| | self.gat2 = GATConv( |
| | hidden_channels, hidden_channels, heads=gat_heads, concat=False |
| | ) |
| | self.gat3 = GATConv( |
| | hidden_channels, hidden_channels, heads=gat_heads, concat=False |
| | ) |
| |
|
| | |
| | self.protein_embedding = nn.Embedding(22, hidden_channels) |
| | |
| | self.prot_conv = nn.Conv1d( |
| | hidden_channels, hidden_channels, kernel_size=3, padding=1 |
| | ) |
| |
|
| | |
| | self.cross_attention = CrossAttentionLayer( |
| | feature_dim=hidden_channels, num_heads=4, dropout=dropout |
| | ) |
| |
|
| | self.fc1 = nn.Linear(hidden_channels, hidden_channels) |
| | self.fc2 = nn.Linear(hidden_channels, 1) |
| |
|
| | def forward(self, x, edge_index, batch, protein_seq): |
| | |
| | x = F.elu(self.gat1(x, edge_index)) |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| |
|
| | x = F.elu(self.gat2(x, edge_index)) |
| | x = F.dropout(x, p=self.dropout, training=self.training) |
| |
|
| | x = F.elu(self.gat3(x, edge_index)) |
| |
|
| | |
| | |
| | ligand_dense, ligand_mask = to_dense_batch(x, batch) |
| | |
| | |
| |
|
| | batch_size = ligand_dense.size(0) |
| | protein_seq = protein_seq.view(batch_size, -1) |
| |
|
| | |
| | p = self.protein_embedding(protein_seq) |
| |
|
| | |
| | p = p.permute(0, 2, 1) |
| | p = F.relu(self.prot_conv(p)) |
| | p = p.permute(0, 2, 1) |
| |
|
| | |
| | |
| | protein_pad_mask = protein_seq == 0 |
| |
|
| | |
| | x_cross = self.cross_attention( |
| | ligand_dense, p, key_padding_mask=protein_pad_mask |
| | ) |
| |
|
| | |
| | |
| | mask_expanded = ligand_mask.unsqueeze(-1) |
| |
|
| | |
| | x_cross = x_cross * mask_expanded |
| |
|
| | |
| | sum_features = torch.sum(x_cross, dim=1) |
| | num_atoms = torch.sum(mask_expanded, dim=1) |
| | pooled_x = sum_features / (num_atoms + 1e-6) |
| |
|
| | |
| | out = F.relu(self.fc1(pooled_x)) |
| | out = F.dropout(out, p=self.dropout, training=self.training) |
| | out = self.fc2(out) |
| | return out |
| |
|