AlexSychovUN commited on
Commit
6afa7ea
·
1 Parent(s): 62dcdc0

Added pytorch lightning version

Browse files
Files changed (2) hide show
  1. model_pl.py +52 -0
  2. train_pl.py +50 -0
model_pl.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
7
+ from torch.optim import Adam
8
+
9
+ from model import LigandGNN, ProteinTransformer
10
+
11
+ class BindingAffinityModelPL(pl.LightningModule):
12
+ def __init__(self, num_node_features, hidden_channels_gnn, lr):
13
+ super().__init__()
14
+ self.save_hyperparameters() # Save hyperparameters for easy access
15
+ self.lr = lr
16
+
17
+ self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels_gnn)
18
+ self.protein_transformer = ProteinTransformer(vocab_size=26)
19
+ self.head = nn.Sequential(
20
+ nn.Linear(128 + 128, 256),
21
+ nn.ReLU(),
22
+ nn.Dropout(0.2),
23
+ nn.Linear(256, 1)
24
+ )
25
+ self.criterion = nn.MSELoss()
26
+
27
+ def forward(self, x, edge_index, batch, protein_seq):
28
+ ligand_vec = self.ligand_gnn(x, edge_index, batch)
29
+ batch_size = batch.max().item() + 1
30
+ protein_seq = protein_seq.view(batch_size, -1)
31
+
32
+ protein_vec = self.protein_transformer(protein_seq)
33
+ combined = torch.cat([ligand_vec, protein_vec], dim=1)
34
+ return self.head(combined)
35
+
36
+ def training_step(self, batch, batch_idx):
37
+ # We don't need .to(device), zero_grad, backward, PL handles that
38
+ out = self(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
39
+ loss = self.criterion(out.squeeze(), batch.y.squeeze())
40
+
41
+ self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
42
+ return loss
43
+
44
+ def validation_step(self, batch, batch_idx):
45
+ out = self(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
46
+ loss = self.criterion(out.squeeze(), batch.y.squeeze())
47
+
48
+ self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
49
+ return loss
50
+
51
+ def configure_optimizers(self):
52
+ return Adam(self.parameters(), lr=self.lr, weight_decay=1e-4)
train_pl.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
3
+ from torch_geometric.loader import DataLoader
4
+ from dataset import BindingDataset
5
+ from torch.utils.data import random_split
6
+ from model_pl import BindingAffinityModelPL
7
+ import pandas as pd
8
+
9
+ def main():
10
+ lr = 0.0005
11
+ # Load dataset
12
+ dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
13
+ dataframe.dropna(inplace=True)
14
+ print("Dataset loaded with {} samples".format(len(dataframe)))
15
+ dataset = BindingDataset(dataframe)
16
+ print("Dataset transformed with {} samples".format(len(dataset)))
17
+
18
+ if len(dataset) == 0:
19
+ print("Dataset is empty")
20
+ return
21
+
22
+ train_size = int(0.8 * len(dataset))
23
+ test_size = len(dataset) - train_size
24
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
25
+
26
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
27
+ val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
28
+ num_features = train_dataset[0].x.shape[1]
29
+ print("Number of node features:", num_features)
30
+
31
+ model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
32
+ checkpoint_callback = ModelCheckpoint(
33
+ monitor='val_loss',
34
+ dirpath='checkpoints/',
35
+ filename='best-checkpoint',
36
+ save_top_k=3,
37
+ mode='min'
38
+ )
39
+ early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
40
+
41
+ trainer = pl.Trainer(
42
+ max_epochs=20,
43
+ accelerator="auto", # Use GPU if available
44
+ devices=1,
45
+ callbacks=[checkpoint_callback, early_stop_callback]
46
+ )
47
+ trainer.fit(model, train_loader, val_loader)
48
+
49
+ if __name__ == "__main__":
50
+ main()