File size: 3,426 Bytes
15c5ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn
import numpy as np
from utils.utils import set_seed
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from model.model import BAPULM
from utils.preprocessing import preprocess_function, EmbeddingExtractor
import yaml



class BindingAffinityDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        protein_seq = item['seq']
        ligand_smiles = item['smiles_can']
        log_affinity = item['neg_log10_affinity_M']
        return protein_seq, ligand_smiles, log_affinity


def main():
    # Load the configuration file
    with open('config.yaml', 'r') as config_file:
        config = yaml.safe_load(config_file)
    # set seed
    set_seed(2102)

    # Initialize device and model
    device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
    torch.cuda.empty_cache()
    model = BAPULM().to(device)
    model.load_state_dict(torch.load(config['model_inference_path']))
    model.eval()

    # Initialize the embedding extractor
    embedding_extractor = EmbeddingExtractor(device)
    
    # Process each CSV file
    for csv_file in config['benchmark_files']:
        df = pd.read_csv(csv_file)
        df = preprocess_function(df)
        dataset = BindingAffinityDataset(df)
        validation_loader = DataLoader(dataset, batch_size=config['inference_batch_size'], shuffle=False)

        true_affinities = []
        predicted_affinities = []

        with torch.no_grad():
            for sequences, smiles, affinities in tqdm(validation_loader, desc=f"Processing {csv_file}"):
                prot_embeddings = []
                mol_embeddings = []
                for seq, smile in zip(sequences, smiles):
                    prot_embedding, mol_embedding = embedding_extractor.get_combined_embedding(seq, smile)
                    prot_embeddings.append(prot_embedding)
                    mol_embeddings.append(mol_embedding)

                prot_embeddings = torch.cat(prot_embeddings, dim=0).to(device)
                mol_embeddings = torch.cat(mol_embeddings, dim=0).to(device)
                affinities = affinities.to(device)

                predictions = model(prot_embeddings, mol_embeddings).squeeze().cpu().numpy()
                true_affinities.extend(affinities.cpu().numpy())
                predicted_affinities.extend(predictions)

        true_affinities = np.array(true_affinities)
        predicted_affinities = np.array(predicted_affinities)

        # Apply scaling and mean adjustments
        mean = 6.51286529169358
        scale = 1.5614094578916633
        predicted_affinities = predicted_affinities * scale + mean

        # Calculate metrics
        mse = mean_squared_error(true_affinities, predicted_affinities)
        mae = mean_absolute_error(true_affinities, predicted_affinities)
        pearson_corr, _ = pearsonr(true_affinities, predicted_affinities)
        rmse = np.sqrt(mse)

        print(f"Results for {csv_file}:")
        print(f"RMSE: {rmse}")
        print(f"MSE: {mse}")
        print(f"MAE: {mae}")
        print(f"Pearson Correlation Coefficient: {pearson_corr}")

if __name__ == '__main__':
    main()