import random import torch import pandas as pd import matplotlib.pyplot as plt import numpy as np from torch_geometric.loader import DataLoader from dataset import BindingDataset from model import BindingAffinityModel from tqdm import tqdm from scipy.stats import pearsonr from torch.utils.data import random_split from train import DROPOUT, HIDDEN_CHANNELS DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "runs/GCN20260128_223529/models/model_ep091_mse2.3011.pth" def set_seed(seed=42): random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) return torch.Generator().manual_seed(seed) def predict_and_plot(): gen = set_seed(42) print("Loading data...") dataframe = pd.read_csv("pdbbind_refined_dataset.csv") dataframe.dropna(inplace=True) dataset = BindingDataset(dataframe) if len(dataset) == 0: print("Dataset is empty") return train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size _, test_dataset = random_split(dataset, [train_size, test_size], generator=gen) loader = DataLoader(test_dataset, batch_size=32, shuffle=False) num_features = test_dataset[0].x.shape[1] print("Loading model...") model = BindingAffinityModel( num_node_features=num_features, hidden_channels=HIDDEN_CHANNELS, # gat_heads=GAT_HEADS, dropout=DROPOUT, ).to(DEVICE) model.load_state_dict(torch.load(MODEL_PATH)) model.eval() y_true = [] y_pred = [] print("Predicting...") with torch.no_grad(): for batch in tqdm(loader): batch = batch.to(DEVICE) out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq) y_true.extend(batch.y.cpu().numpy()) y_pred.extend(out.squeeze().cpu().numpy()) y_true = np.array(y_true) y_pred = np.array(y_pred) rmse = np.sqrt(np.mean((y_true - y_pred) ** 2)) mae = np.mean(np.abs(y_true - y_pred)) pearson_corr, _ = pearsonr(y_true, y_pred) # Pearson correlation print("Results:") print(f"RMSE: {rmse:.4f}") print(f"MAE: {mae:.4f}") print(f"Pearson Correlation: {pearson_corr:.4f}") plt.figure(figsize=(9, 9)) plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions") plt.plot( [min(y_true), max(y_true)], [min(y_true), max(y_true)], color="red", linestyle="--", linewidth=2, label="Ideal", ) plt.xlabel("Experimental Affinity (pK)") plt.ylabel("Predicted Affinity (pK)") plt.title( f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}" ) plt.legend() plt.grid(True, alpha=0.3) plot_file = "final_results.png" plt.savefig(plot_file) print(f"Plot is saved to {plot_file}") plt.show() if __name__ == "__main__": predict_and_plot()