AlexSychovUN commited on
Commit
2fdd454
·
1 Parent(s): 6afa7ea

Updated code

Browse files
inference.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from torch_geometric.loader import DataLoader
8
+ from dataset import BindingDataset
9
+ from model import BindingAffinityModel
10
+ from tqdm import tqdm
11
+ from scipy.stats import pearsonr
12
+ from torch.utils.data import random_split
13
+
14
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ MODEL_PATH = "best_model_gat.pth"
16
+
17
+ def set_seed(seed=42):
18
+ random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed(seed)
21
+ np.random.seed(seed)
22
+ return torch.Generator().manual_seed(seed)
23
+
24
+ def predict_and_plot():
25
+ gen = set_seed(42)
26
+ print("Loading data...")
27
+
28
+ dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
29
+ dataframe.dropna(inplace=True)
30
+ dataset = BindingDataset(dataframe)
31
+ if len(dataset) == 0:
32
+ print("Dataset is empty")
33
+ return
34
+
35
+ train_size = int(0.8 * len(dataset))
36
+ test_size = len(dataset) - train_size
37
+ _, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
38
+
39
+ loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
40
+ num_features = test_dataset[0].x.shape[1]
41
+
42
+ print("Loading model...")
43
+ model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
44
+ model.load_state_dict(torch.load(MODEL_PATH))
45
+ model.eval()
46
+
47
+ y_true = []
48
+ y_pred = []
49
+ print("Predicting...")
50
+ with torch.no_grad():
51
+ for batch in tqdm(loader):
52
+ batch = batch.to(DEVICE)
53
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
54
+
55
+ y_true.extend(batch.y.cpu().numpy())
56
+ y_pred.extend(out.squeeze().cpu().numpy())
57
+ y_true = np.array(y_true)
58
+ y_pred = np.array(y_pred)
59
+
60
+ rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
61
+ mae = np.mean(np.abs(y_true - y_pred))
62
+ pearson_corr, _ = pearsonr(y_true, y_pred) # Pearson correlation
63
+
64
+ print("Results:")
65
+ print(f"RMSE: {rmse:.4f}")
66
+ print(f"MAE: {mae:.4f}")
67
+ print(f"Pearson Correlation: {pearson_corr:.4f}")
68
+
69
+ plt.figure(figsize=(9, 9))
70
+ plt.scatter(y_true, y_pred, alpha=0.4, s=15, c='blue', label='Predictions')
71
+ plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='red', linestyle='--', linewidth=2,
72
+ label='Ideal')
73
+
74
+ plt.xlabel('Experimental Affinity (pK)')
75
+ plt.ylabel('Predicted Affinity (pK)')
76
+ plt.title(f'Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}')
77
+ plt.legend()
78
+ plt.grid(True, alpha=0.3)
79
+ plot_file = 'final_results_gat.png'
80
+ plt.savefig(plot_file)
81
+ print(f"График сохранен в {plot_file}")
82
+ plt.show()
83
+
84
+ if __name__ == "__main__":
85
+ predict_and_plot()
model.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import torch.nn as nn
5
 
6
 
7
- from torch_geometric.nn import GCNConv, global_mean_pool
8
 
9
  class PositionalEncoding(nn.Module):
10
  def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
@@ -39,15 +39,40 @@ class PositionalEncoding(nn.Module):
39
 
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class LigandGNN(nn.Module):
43
- def __init__(self, input_dim, hidden_channels):
44
  super().__init__()
45
- self.hidden_channels = hidden_channels
46
-
47
- self.conv1 = GCNConv(input_dim, hidden_channels)
48
- self.conv2 = GCNConv(hidden_channels, hidden_channels)
49
- self.conv3 = GCNConv(hidden_channels, hidden_channels)
50
- self.dropout = nn.Dropout(0.2)
51
 
52
  def forward(self, x, edge_index, batch):
53
  x = self.conv1(x, edge_index)
@@ -56,19 +81,20 @@ class LigandGNN(nn.Module):
56
 
57
  x = self.conv2(x, edge_index)
58
  x = x.relu()
59
- x = self.conv3(x, edge_index)
60
  x = self.dropout(x)
61
 
62
- # Averaging nodes and got the molecula vector
63
- x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
 
 
64
  return x
65
 
66
  class ProteinTransformer(nn.Module):
67
- def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128):
68
  super().__init__()
69
  self.d_model = d_model
70
  self.embedding = nn.Embedding(vocab_size, d_model)
71
- self.pos_encoder = PositionalEncoding(d_model)
72
  encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
73
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
74
 
@@ -91,18 +117,18 @@ class ProteinTransformer(nn.Module):
91
  return x
92
 
93
  class BindingAffinityModel(nn.Module):
94
- def __init__(self, num_node_features, hidden_channels_gnn):
95
  super().__init__()
96
  # Tower 1 - Ligand GNN
97
- self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels_gnn)
98
  # Tower 2 - Protein Transformer
99
- self.protein_transformer = ProteinTransformer(vocab_size=26)
100
 
101
  self.head = nn.Sequential(
102
- nn.Linear(128 + 128, 256),
103
  nn.ReLU(),
104
- nn.Dropout(0.2),
105
- nn.Linear(256, 1),
106
  )
107
  def forward(self, x, edge_index, batch, protein_seq):
108
  ligand_vec = self.ligand_gnn(x, edge_index, batch)
 
4
  import torch.nn as nn
5
 
6
 
7
+ from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
8
 
9
  class PositionalEncoding(nn.Module):
10
  def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
 
39
 
40
 
41
 
42
+ # class LigandGNN(nn.Module): # GCN CONV
43
+ # def __init__(self, input_dim, hidden_channels):
44
+ # super().__init__()
45
+ # self.hidden_channels = hidden_channels
46
+ #
47
+ # self.conv1 = GCNConv(input_dim, hidden_channels)
48
+ # self.conv2 = GCNConv(hidden_channels, hidden_channels)
49
+ # self.conv3 = GCNConv(hidden_channels, hidden_channels)
50
+ # self.dropout = nn.Dropout(0.2)
51
+ #
52
+ # def forward(self, x, edge_index, batch):
53
+ # x = self.conv1(x, edge_index)
54
+ # x = x.relu()
55
+ # x = self.dropout(x)
56
+ #
57
+ # x = self.conv2(x, edge_index)
58
+ # x = x.relu()
59
+ # x = self.conv3(x, edge_index)
60
+ # x = self.dropout(x)
61
+ #
62
+ # # Averaging nodes and got the molecula vector
63
+ # x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
64
+ # return x
65
+
66
+
67
  class LigandGNN(nn.Module):
68
+ def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
69
  super().__init__()
70
+ # Heads=4 means we use 4 attention heads
71
+ # Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
72
+ self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
73
+ self.conv2 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
74
+ self.conv3 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
75
+ self.dropout = nn.Dropout(dropout)
76
 
77
  def forward(self, x, edge_index, batch):
78
  x = self.conv1(x, edge_index)
 
81
 
82
  x = self.conv2(x, edge_index)
83
  x = x.relu()
 
84
  x = self.dropout(x)
85
 
86
+ x = self.conv3(x, edge_index)
87
+
88
+ # Global Mean Pooling
89
+ x = global_mean_pool(x, batch)
90
  return x
91
 
92
  class ProteinTransformer(nn.Module):
93
+ def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
94
  super().__init__()
95
  self.d_model = d_model
96
  self.embedding = nn.Embedding(vocab_size, d_model)
97
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
98
  encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
99
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
100
 
 
117
  return x
118
 
119
  class BindingAffinityModel(nn.Module):
120
+ def __init__(self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2):
121
  super().__init__()
122
  # Tower 1 - Ligand GNN
123
+ self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels, heads=gat_heads, dropout=dropout)
124
  # Tower 2 - Protein Transformer
125
+ self.protein_transformer = ProteinTransformer(vocab_size=26, d_model=hidden_channels, output_dim=hidden_channels, dropout=dropout)
126
 
127
  self.head = nn.Sequential(
128
+ nn.Linear(hidden_channels*2, hidden_channels),
129
  nn.ReLU(),
130
+ nn.Dropout(dropout),
131
+ nn.Linear(hidden_channels, 1),
132
  )
133
  def forward(self, x, edge_index, batch, protein_seq):
134
  ligand_vec = self.ligand_gnn(x, edge_index, batch)
optuna_train.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+ import optuna
5
+ from torch.nn.functional import dropout
6
+ from torch.utils.data import random_split
7
+ from torch_geometric.loader import DataLoader
8
+ from dataset import BindingDataset
9
+ from model import BindingAffinityModel
10
+ from tqdm import tqdm
11
+ import sys
12
+
13
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ EPOCHS_PER_TRIAL = 10
15
+
16
+ dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
17
+ dataframe.dropna(inplace=True)
18
+ dataset = BindingDataset(dataframe)
19
+ train_size = int(0.8 * len(dataset))
20
+ test_size = len(dataset) - train_size
21
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
22
+ num_features = train_dataset[0].x.shape[1]
23
+
24
+ def train(model, loader, optimizer, criterion):
25
+ model.train()
26
+ for batch in loader:
27
+ batch = batch.to(DEVICE)
28
+ optimizer.zero_grad()
29
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
30
+ loss = criterion(out.squeeze(), batch.y.squeeze())
31
+ loss.backward()
32
+ optimizer.step()
33
+
34
+
35
+ def test(model, loader, criterion):
36
+ model.eval()
37
+ total_loss = 0
38
+ with torch.no_grad():
39
+ for batch in loader:
40
+ batch = batch.to(DEVICE)
41
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
42
+ loss = criterion(out.squeeze(), batch.y.squeeze())
43
+ total_loss += loss.item()
44
+ return total_loss / len(loader)
45
+
46
+
47
+ def objective(trial):
48
+ lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # Learning rate from 0.00001 to 0.01
49
+ weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) # Weight decay from 0.000001 to 0.001
50
+
51
+ model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
52
+
53
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
54
+ criterion = nn.MSELoss()
55
+
56
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
57
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
58
+
59
+ for epoch in range(EPOCHS_PER_TRIAL):
60
+ train(model, train_loader, optimizer, criterion)
61
+ val_loss = test(model, test_loader, criterion)
62
+
63
+ trial.report(val_loss, epoch)
64
+ if trial.should_prune():
65
+ raise optuna.exceptions.TrialPruned()
66
+ return val_loss
67
+
68
+
69
+ if __name__ == "__main__":
70
+ study = optuna.create_study(direction="minimize")
71
+ print("Start hyperparameter optimization...")
72
+
73
+ study.optimize(objective, n_trials=10)
74
+ print("\n--- Optimization Finished ---")
75
+ print("Best parameters found: ", study.best_params)
76
+ print("Best Test MSE: ", study.best_value)
77
+
78
+ df_results = study.trials_dataframe()
79
+ df_results.to_csv("optuna_results.csv")
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  torch
 
 
2
 
3
  numpy
4
  pandas
 
1
  torch
2
+ pytorch-lightning
3
+ optuna
4
 
5
  numpy
6
  pandas
train.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
  import pandas as pd
@@ -6,15 +8,30 @@ from torch_geometric.loader import DataLoader
6
  from dataset import BindingDataset
7
  from model import BindingAffinityModel
8
  from tqdm import tqdm
 
 
 
9
 
10
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
11
 
 
 
 
 
 
 
12
 
13
 
14
- def train_epoch(epoch, model, loader, optimizer, criterion):
15
  model.train()
16
  total_loss = 0
17
- for batch in tqdm(loader, desc=f"Training epoch: {epoch}"):
 
 
18
  batch = batch.to(DEVICE)
19
  optimizer.zero_grad()
20
 
@@ -23,21 +40,35 @@ def train_epoch(epoch, model, loader, optimizer, criterion):
23
 
24
  loss.backward()
25
  optimizer.step()
26
- total_loss += loss.item()
27
- return total_loss / len(loader)
 
 
 
 
 
 
 
 
28
 
29
- def evaluate(epoch, model, loader, criterion):
30
  model.eval()
31
  total_loss = 0
32
  with torch.no_grad():
33
- for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}"):
34
  batch = batch.to(DEVICE)
35
  out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
36
  loss = criterion(out.squeeze(), batch.y.squeeze())
37
  total_loss += loss.item()
38
- return total_loss / len(loader)
 
 
 
39
 
40
  def main():
 
 
 
41
  # Load dataset
42
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
43
  dataframe.dropna(inplace=True)
@@ -52,10 +83,10 @@ def main():
52
 
53
  train_size = int(0.8 * len(dataset))
54
  test_size = len(dataset) - train_size
55
- train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
56
 
57
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
58
- test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
59
  num_features = train_dataset[0].x.shape[1]
60
  print("Number of node features:", num_features)
61
 
@@ -63,13 +94,22 @@ def main():
63
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
64
  criterion = nn.MSELoss()
65
 
66
- num_epochs = 20
 
67
  print(f"Starting training on {DEVICE}")
68
- for epoch in range(num_epochs):
69
- train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion)
70
- test_loss = evaluate(epoch, model, test_loader, criterion)
71
- print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
72
- torch.save(model.state_dict(), './model.pth')
 
 
 
 
 
 
 
 
73
 
74
 
75
  if __name__ == "__main__":
 
1
+ import random
2
+
3
  import torch
4
  import torch.nn as nn
5
  import pandas as pd
 
8
  from dataset import BindingDataset
9
  from model import BindingAffinityModel
10
  from tqdm import tqdm
11
+ from torch.utils.tensorboard import SummaryWriter
12
+ import numpy as np
13
+ from datetime import datetime
14
 
15
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ BATCH_SIZE = 32
17
+ LR = 0.0005
18
+ EPOCS = 30
19
+ LOG_DIR = f"runs/experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
20
 
21
+ def set_seed(seed=42):
22
+ random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed(seed)
25
+ np.random.seed(seed)
26
+ return torch.Generator().manual_seed(seed)
27
 
28
 
29
+ def train_epoch(epoch, model, loader, optimizer, criterion, writer):
30
  model.train()
31
  total_loss = 0
32
+
33
+ loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
34
+ for i, batch in enumerate(loop):
35
  batch = batch.to(DEVICE)
36
  optimizer.zero_grad()
37
 
 
40
 
41
  loss.backward()
42
  optimizer.step()
43
+ current_loss = loss.item()
44
+ total_loss += current_loss
45
+
46
+ global_step = (epoch - 1) * len(loader) + i
47
+ writer.add_scalar('Loss/Train_Step', current_loss, global_step)
48
+
49
+ loop.set_postfix(loss = loss.item())
50
+
51
+ avg_loss = total_loss / len(loader)
52
+ return avg_loss
53
 
54
+ def evaluate(epoch, model, loader, criterion, writer):
55
  model.eval()
56
  total_loss = 0
57
  with torch.no_grad():
58
+ for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
59
  batch = batch.to(DEVICE)
60
  out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
61
  loss = criterion(out.squeeze(), batch.y.squeeze())
62
  total_loss += loss.item()
63
+
64
+ avg_loss = total_loss / len(loader)
65
+ writer.add_scalar('Loss/Test', avg_loss, epoch)
66
+ return avg_loss
67
 
68
  def main():
69
+ gen = set_seed(42)
70
+ writer = SummaryWriter(LOG_DIR)
71
+ print(f"Logging to {LOG_DIR}...")
72
  # Load dataset
73
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
74
  dataframe.dropna(inplace=True)
 
83
 
84
  train_size = int(0.8 * len(dataset))
85
  test_size = len(dataset) - train_size
86
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
87
 
88
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
89
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
90
  num_features = train_dataset[0].x.shape[1]
91
  print("Number of node features:", num_features)
92
 
 
94
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
95
  criterion = nn.MSELoss()
96
 
97
+ best_test_loss = float('inf')
98
+
99
  print(f"Starting training on {DEVICE}")
100
+ for epoch in range(1, EPOCS):
101
+ train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion, writer)
102
+ test_loss = evaluate(epoch, model, test_loader, criterion, writer)
103
+
104
+
105
+ print(f'Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
106
+ if test_loss < best_test_loss:
107
+ best_test_loss = test_loss
108
+ torch.save(model.state_dict(), f'best_model_gat.pth')
109
+ print(f'Best model saved with Test Loss MSE: {best_test_loss:.4f}')
110
+
111
+ writer.close()
112
+ print("Training finished.")
113
 
114
 
115
  if __name__ == "__main__":
transformer_from_scratch/attention_visual.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
transformer_from_scratch/model.py CHANGED
@@ -119,7 +119,7 @@ class MultiHeadAttention(nn.Module):
119
  1, 2
120
  )
121
 
122
- x, attention_scores = MultiHeadAttention.attention(
123
  query, key, value, mask, self.dropout
124
  )
125
 
 
119
  1, 2
120
  )
121
 
122
+ x, self.attention_scores = MultiHeadAttention.attention(
123
  query, key, value, mask, self.dropout
124
  )
125