AlexSychovUN's picture
Updated all code
e33b6c9
import math
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
# Create a matrix of shape (seq_len, d_model)
pe = torch.zeros(seq_len, d_model)
# Create a vector of shape (seq_len, 1)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
1
) # (Seq_len, 1)
# Compute the positional encodings once in log space.
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply the sin to even positions
pe[:, 0::2] = torch.sin(position * div_term)
# Apply the cos to odd positions
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
self.register_buffer("pe", pe)
def forward(self, x):
# x: [batch_size, seq_len, d_model]
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
return self.dropout(x)
# class LigandGNN(nn.Module): # GCN CONV
# def __init__(self, input_dim, hidden_channels):
# super().__init__()
# self.hidden_channels = hidden_channels
#
# self.conv1 = GCNConv(input_dim, hidden_channels)
# self.conv2 = GCNConv(hidden_channels, hidden_channels)
# self.conv3 = GCNConv(hidden_channels, hidden_channels)
# self.dropout = nn.Dropout(0.2)
#
# def forward(self, x, edge_index, batch):
# x = self.conv1(x, edge_index)
# x = x.relu()
# x = self.dropout(x)
#
# x = self.conv2(x, edge_index)
# x = x.relu()
# x = self.conv3(x, edge_index)
# x = self.dropout(x)
#
# # Averaging nodes and got the molecula vector
# x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# return x
class LigandGNN(nn.Module):
def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
super().__init__()
# Heads=4 means we use 4 attention heads
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
self.conv2 = GATConv(
hidden_channels, hidden_channels, heads=heads, concat=False
)
self.conv3 = GATConv(
hidden_channels, hidden_channels, heads=heads, concat=False
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.dropout(x)
x = self.conv2(x, edge_index)
x = x.relu()
x = self.dropout(x)
x = self.conv3(x, edge_index)
# Global Mean Pooling
x = global_mean_pool(x, batch)
return x
class ProteinTransformer(nn.Module):
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=h, batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
self.fc = nn.Linear(d_model, output_dim)
def forward(self, x):
# x: [batch_size, seq_len]
padding_mask = x == 0 # mask for PAD tokens
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoder(x)
x = self.transformer(x, src_key_padding_mask=padding_mask)
mask = (~padding_mask).float().unsqueeze(-1)
x = x * mask
sum_x = x.sum(dim=1) # Global average pooling
token_counts = mask.sum(dim=1).clamp(min=1e-9)
x = sum_x / token_counts
x = self.fc(x)
return x
class BindingAffinityModel(nn.Module):
def __init__(
self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
):
super().__init__()
# Tower 1 - Ligand GNN
self.ligand_gnn = LigandGNN(
input_dim=num_node_features,
hidden_channels=hidden_channels,
heads=gat_heads,
dropout=dropout,
)
# Tower 2 - Protein Transformer
self.protein_transformer = ProteinTransformer(
vocab_size=26,
d_model=hidden_channels,
output_dim=hidden_channels,
dropout=dropout,
)
self.head = nn.Sequential(
nn.Linear(hidden_channels * 2, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, 1),
)
def forward(self, x, edge_index, batch, protein_seq):
ligand_vec = self.ligand_gnn(x, edge_index, batch)
batch_size = batch.max().item() + 1
protein_seq = protein_seq.view(batch_size, -1)
protein_vec = self.protein_transformer(protein_seq)
combined = torch.cat([ligand_vec, protein_vec], dim=1)
return self.head(combined)