| | import pytorch_lightning as pl |
| | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping |
| | from torch_geometric.loader import DataLoader |
| | from dataset import BindingDataset |
| | from torch.utils.data import random_split |
| | from model_pl import BindingAffinityModelPL |
| | import pandas as pd |
| |
|
| |
|
| | def main(): |
| | lr = 0.0005 |
| | |
| | dataframe = pd.read_csv("pdbbind_refined_dataset.csv") |
| | dataframe.dropna(inplace=True) |
| | print("Dataset loaded with {} samples".format(len(dataframe))) |
| | dataset = BindingDataset(dataframe) |
| | print("Dataset transformed with {} samples".format(len(dataset))) |
| |
|
| | if len(dataset) == 0: |
| | print("Dataset is empty") |
| | return |
| |
|
| | train_size = int(0.8 * len(dataset)) |
| | test_size = len(dataset) - train_size |
| | train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
| | val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) |
| | num_features = train_dataset[0].x.shape[1] |
| | print("Number of node features:", num_features) |
| |
|
| | model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr) |
| | checkpoint_callback = ModelCheckpoint( |
| | monitor="val_loss", |
| | dirpath="checkpoints/", |
| | filename="best-checkpoint", |
| | save_top_k=3, |
| | mode="min", |
| | ) |
| | early_stop_callback = EarlyStopping(monitor="val_loss", patience=5) |
| |
|
| | trainer = pl.Trainer( |
| | max_epochs=20, |
| | accelerator="auto", |
| | devices=1, |
| | callbacks=[checkpoint_callback, early_stop_callback], |
| | ) |
| | trainer.fit(model, train_loader, val_loader) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|