| | import pytorch_lightning as pl |
| | import torch |
| | import torch.nn as nn |
| | from torch.optim import Adam |
| |
|
| | from model import LigandGNN, ProteinTransformer |
| |
|
| |
|
| | class BindingAffinityModelPL(pl.LightningModule): |
| | def __init__(self, num_node_features, hidden_channels_gnn, lr): |
| | super().__init__() |
| | self.save_hyperparameters() |
| | self.lr = lr |
| |
|
| | self.ligand_gnn = LigandGNN( |
| | input_dim=num_node_features, hidden_channels=hidden_channels_gnn |
| | ) |
| | self.protein_transformer = ProteinTransformer(vocab_size=26) |
| | self.head = nn.Sequential( |
| | nn.Linear(128 + 128, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 1) |
| | ) |
| | self.criterion = nn.MSELoss() |
| |
|
| | def forward(self, x, edge_index, batch, protein_seq): |
| | ligand_vec = self.ligand_gnn(x, edge_index, batch) |
| | batch_size = batch.max().item() + 1 |
| | protein_seq = protein_seq.view(batch_size, -1) |
| |
|
| | protein_vec = self.protein_transformer(protein_seq) |
| | combined = torch.cat([ligand_vec, protein_vec], dim=1) |
| | return self.head(combined) |
| |
|
| | def training_step(self, batch, batch_idx): |
| | |
| | out = self(batch.x, batch.edge_index, batch.batch, batch.protein_seq) |
| | loss = self.criterion(out.squeeze(), batch.y.squeeze()) |
| |
|
| | self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | out = self(batch.x, batch.edge_index, batch.batch, batch.protein_seq) |
| | loss = self.criterion(out.squeeze(), batch.y.squeeze()) |
| |
|
| | self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
| | return loss |
| |
|
| | def configure_optimizers(self): |
| | return Adam(self.parameters(), lr=self.lr, weight_decay=1e-4) |
| |
|