File size: 2,953 Bytes
2fdd454
 
 
 
 
 
 
 
 
 
 
 
de183ef
e33b6c9
 
de183ef
2fdd454
 
 
 
 
 
 
 
 
e33b6c9
2fdd454
 
 
 
e33b6c9
2fdd454
 
 
 
 
 
 
 
 
 
 
 
 
 
e33b6c9
 
 
de183ef
e33b6c9
 
2fdd454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e33b6c9
 
 
 
 
 
 
 
 
2fdd454
e33b6c9
 
 
 
 
2fdd454
 
de183ef
2fdd454
de183ef
2fdd454
 
e33b6c9
2fdd454
e33b6c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import random

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from model import BindingAffinityModel
from tqdm import tqdm
from scipy.stats import pearsonr
from torch.utils.data import random_split
from train import DROPOUT, HIDDEN_CHANNELS

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "runs/GCN20260128_223529/models/model_ep091_mse2.3011.pth"


def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    return torch.Generator().manual_seed(seed)


def predict_and_plot():
    gen = set_seed(42)
    print("Loading data...")

    dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
    dataframe.dropna(inplace=True)
    dataset = BindingDataset(dataframe)
    if len(dataset) == 0:
        print("Dataset is empty")
        return

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    _, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)

    loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    num_features = test_dataset[0].x.shape[1]

    print("Loading model...")
    model = BindingAffinityModel(
        num_node_features=num_features,
        hidden_channels=HIDDEN_CHANNELS,
        # gat_heads=GAT_HEADS,
        dropout=DROPOUT,
    ).to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()

    y_true = []
    y_pred = []
    print("Predicting...")
    with torch.no_grad():
        for batch in tqdm(loader):
            batch = batch.to(DEVICE)
            out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)

            y_true.extend(batch.y.cpu().numpy())
            y_pred.extend(out.squeeze().cpu().numpy())
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
    mae = np.mean(np.abs(y_true - y_pred))
    pearson_corr, _ = pearsonr(y_true, y_pred)  # Pearson correlation

    print("Results:")
    print(f"RMSE: {rmse:.4f}")
    print(f"MAE: {mae:.4f}")
    print(f"Pearson Correlation: {pearson_corr:.4f}")

    plt.figure(figsize=(9, 9))
    plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions")
    plt.plot(
        [min(y_true), max(y_true)],
        [min(y_true), max(y_true)],
        color="red",
        linestyle="--",
        linewidth=2,
        label="Ideal",
    )

    plt.xlabel("Experimental Affinity (pK)")
    plt.ylabel("Predicted Affinity (pK)")
    plt.title(
        f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}"
    )
    plt.legend()
    plt.grid(True, alpha=0.3)
    plot_file = "final_results.png"
    plt.savefig(plot_file)
    print(f"Plot is saved to {plot_file}")
    plt.show()


if __name__ == "__main__":
    predict_and_plot()