File size: 1,723 Bytes
6afa7ea
 
 
 
 
 
 
 
e33b6c9
6afa7ea
 
 
e33b6c9
6afa7ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e33b6c9
 
 
6afa7ea
e33b6c9
6afa7ea
 
 
 
 
e33b6c9
6afa7ea
e33b6c9
6afa7ea
 
 
e33b6c9
6afa7ea
e33b6c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from torch.utils.data import random_split
from model_pl import BindingAffinityModelPL
import pandas as pd


def main():
    lr = 0.0005
    # Load dataset
    dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
    dataframe.dropna(inplace=True)
    print("Dataset loaded with {} samples".format(len(dataframe)))
    dataset = BindingDataset(dataframe)
    print("Dataset transformed with {} samples".format(len(dataset)))

    if len(dataset) == 0:
        print("Dataset is empty")
        return

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    num_features = train_dataset[0].x.shape[1]
    print("Number of node features:", num_features)

    model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath="checkpoints/",
        filename="best-checkpoint",
        save_top_k=3,
        mode="min",
    )
    early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",  # Use GPU if available
        devices=1,
        callbacks=[checkpoint_callback, early_stop_callback],
    )
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    main()