File size: 1,723 Bytes
6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 6afa7ea e33b6c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from torch.utils.data import random_split
from model_pl import BindingAffinityModelPL
import pandas as pd
def main():
lr = 0.0005
# Load dataset
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
dataframe.dropna(inplace=True)
print("Dataset loaded with {} samples".format(len(dataframe)))
dataset = BindingDataset(dataframe)
print("Dataset transformed with {} samples".format(len(dataset)))
if len(dataset) == 0:
print("Dataset is empty")
return
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
num_features = train_dataset[0].x.shape[1]
print("Number of node features:", num_features)
model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath="checkpoints/",
filename="best-checkpoint",
save_top_k=3,
mode="min",
)
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
trainer = pl.Trainer(
max_epochs=20,
accelerator="auto", # Use GPU if available
devices=1,
callbacks=[checkpoint_callback, early_stop_callback],
)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
main()
|