AlexSychovUN commited on
Commit
1390640
·
1 Parent(s): 2fdd454

Updated all code

Browse files
Files changed (3) hide show
  1. model.py +1 -1
  2. optuna_train.py +39 -12
  3. train.py +45 -12
model.py CHANGED
@@ -94,7 +94,7 @@ class ProteinTransformer(nn.Module):
94
  super().__init__()
95
  self.d_model = d_model
96
  self.embedding = nn.Embedding(vocab_size, d_model)
97
- self.pos_encoder = PositionalEncoding(d_model, dropout)
98
  encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
99
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
100
 
 
94
  super().__init__()
95
  self.d_model = d_model
96
  self.embedding = nn.Embedding(vocab_size, d_model)
97
+ self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
98
  encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
99
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
100
 
optuna_train.py CHANGED
@@ -1,24 +1,34 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import pandas as pd
4
- import optuna
5
- from torch.nn.functional import dropout
6
- from torch.utils.data import random_split
7
  from torch_geometric.loader import DataLoader
 
8
  from dataset import BindingDataset
9
  from model import BindingAffinityModel
10
- from tqdm import tqdm
11
- import sys
12
 
13
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
- EPOCHS_PER_TRIAL = 10
 
 
 
 
 
 
 
 
15
 
16
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
17
  dataframe.dropna(inplace=True)
18
  dataset = BindingDataset(dataframe)
 
 
 
19
  train_size = int(0.8 * len(dataset))
20
  test_size = len(dataset) - train_size
21
- train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
22
  num_features = train_dataset[0].x.shape[1]
23
 
24
  def train(model, loader, optimizer, criterion):
@@ -45,21 +55,31 @@ def test(model, loader, criterion):
45
 
46
 
47
  def objective(trial):
 
 
 
 
 
 
 
48
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # Learning rate from 0.00001 to 0.01
49
  weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) # Weight decay from 0.000001 to 0.001
 
50
 
51
- model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
52
 
53
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
54
  criterion = nn.MSELoss()
55
 
56
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
57
- test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
58
 
59
  for epoch in range(EPOCHS_PER_TRIAL):
60
  train(model, train_loader, optimizer, criterion)
61
  val_loss = test(model, test_loader, criterion)
62
 
 
 
63
  trial.report(val_loss, epoch)
64
  if trial.should_prune():
65
  raise optuna.exceptions.TrialPruned()
@@ -67,10 +87,17 @@ def objective(trial):
67
 
68
 
69
  if __name__ == "__main__":
70
- study = optuna.create_study(direction="minimize")
 
 
 
 
 
 
 
71
  print("Start hyperparameter optimization...")
72
 
73
- study.optimize(objective, n_trials=10)
74
  print("\n--- Optimization Finished ---")
75
  print("Best parameters found: ", study.best_params)
76
  print("Best Test MSE: ", study.best_value)
 
1
+ import optuna
2
  import torch
3
  import torch.nn as nn
4
  import pandas as pd
5
+ import random
6
+ import numpy as np
 
7
  from torch_geometric.loader import DataLoader
8
+ from torch.utils.data import random_split
9
  from dataset import BindingDataset
10
  from model import BindingAffinityModel
 
 
11
 
12
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ N_TRIALS = 20
14
+ EPOCHS_PER_TRIAL = 15
15
+
16
+ def set_seed(seed=42):
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed(seed)
21
+ return torch.Generator().manual_seed(seed)
22
 
23
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
24
  dataframe.dropna(inplace=True)
25
  dataset = BindingDataset(dataframe)
26
+
27
+ gen = set_seed(42)
28
+
29
  train_size = int(0.8 * len(dataset))
30
  test_size = len(dataset) - train_size
31
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
32
  num_features = train_dataset[0].x.shape[1]
33
 
34
  def train(model, loader, optimizer, criterion):
 
55
 
56
 
57
  def objective(trial):
58
+ # Architecture
59
+ hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256])
60
+ gat_heads = trial.suggest_categorical("gat_heads", [2, 4, 8])
61
+ dropout = trial.suggest_float("dropout", 0.1, 0.5)
62
+
63
+ # Learning
64
+
65
  lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # Learning rate from 0.00001 to 0.01
66
  weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) # Weight decay from 0.000001 to 0.001
67
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
68
 
69
+ model = BindingAffinityModel(num_node_features=num_features, hidden_channels=hidden_dim, gat_heads=gat_heads, dropout=dropout).to(DEVICE)
70
 
71
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
72
  criterion = nn.MSELoss()
73
 
74
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
75
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
76
 
77
  for epoch in range(EPOCHS_PER_TRIAL):
78
  train(model, train_loader, optimizer, criterion)
79
  val_loss = test(model, test_loader, criterion)
80
 
81
+ print(f"Trial {trial.number} | Epoch {epoch + 1}/{EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}")
82
+
83
  trial.report(val_loss, epoch)
84
  if trial.should_prune():
85
  raise optuna.exceptions.TrialPruned()
 
87
 
88
 
89
  if __name__ == "__main__":
90
+ storage_name = "sqlite:///db.sqlite3"
91
+ study = optuna.create_study(
92
+ direction="minimize",
93
+ pruner=optuna.pruners.MedianPruner(),
94
+ storage=storage_name,
95
+ study_name="binding_prediction_optimization",
96
+ load_if_exists=True
97
+ )
98
  print("Start hyperparameter optimization...")
99
 
100
+ study.optimize(objective, n_trials=N_TRIALS)
101
  print("\n--- Optimization Finished ---")
102
  print("Best parameters found: ", study.best_params)
103
  print("Best Test MSE: ", study.best_value)
train.py CHANGED
@@ -11,12 +11,20 @@ from tqdm import tqdm
11
  from torch.utils.tensorboard import SummaryWriter
12
  import numpy as np
13
  from datetime import datetime
 
 
 
 
 
 
 
 
 
14
 
15
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
- BATCH_SIZE = 32
17
- LR = 0.0005
18
- EPOCS = 30
19
  LOG_DIR = f"runs/experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
 
 
20
 
21
  def set_seed(seed=42):
22
  random.seed(seed)
@@ -68,7 +76,11 @@ def evaluate(epoch, model, loader, criterion, writer):
68
  def main():
69
  gen = set_seed(42)
70
  writer = SummaryWriter(LOG_DIR)
 
 
 
71
  print(f"Logging to {LOG_DIR}...")
 
72
  # Load dataset
73
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
74
  dataframe.dropna(inplace=True)
@@ -90,26 +102,47 @@ def main():
90
  num_features = train_dataset[0].x.shape[1]
91
  print("Number of node features:", num_features)
92
 
93
- model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
94
- optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
 
 
 
 
 
95
  criterion = nn.MSELoss()
96
 
97
- best_test_loss = float('inf')
98
 
99
  print(f"Starting training on {DEVICE}")
100
- for epoch in range(1, EPOCS):
101
  train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion, writer)
102
  test_loss = evaluate(epoch, model, test_loader, criterion, writer)
103
 
104
-
105
  print(f'Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
106
- if test_loss < best_test_loss:
107
- best_test_loss = test_loss
108
- torch.save(model.state_dict(), f'best_model_gat.pth')
109
- print(f'Best model saved with Test Loss MSE: {best_test_loss:.4f}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  writer.close()
112
  print("Training finished.")
 
 
 
113
 
114
 
115
  if __name__ == "__main__":
 
11
  from torch.utils.tensorboard import SummaryWriter
12
  import numpy as np
13
  from datetime import datetime
14
+ import os
15
+
16
+
17
+ BATCH_SIZE = 16
18
+ LR = 0.00064
19
+ WEIGHT_DECAY = 7.06e-6
20
+ EPOCS = 100
21
+ DROPOUT = 0.325
22
+ GAT_HEADS = 2
23
 
24
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
25
  LOG_DIR = f"runs/experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
26
+ TOP_K = 3
27
+ SAVES_DIR = LOG_DIR + "/models"
28
 
29
  def set_seed(seed=42):
30
  random.seed(seed)
 
76
  def main():
77
  gen = set_seed(42)
78
  writer = SummaryWriter(LOG_DIR)
79
+
80
+ if not os.path.exists(SAVES_DIR):
81
+ os.makedirs(SAVES_DIR)
82
  print(f"Logging to {LOG_DIR}...")
83
+ print(f"Model saves to {SAVES_DIR}...")
84
  # Load dataset
85
  dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
86
  dataframe.dropna(inplace=True)
 
102
  num_features = train_dataset[0].x.shape[1]
103
  print("Number of node features:", num_features)
104
 
105
+ model = BindingAffinityModel(
106
+ num_node_features=num_features,
107
+ hidden_channels=256,
108
+ gat_heads=GAT_HEADS,
109
+ dropout=DROPOUT
110
+ ).to(DEVICE)
111
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
112
  criterion = nn.MSELoss()
113
 
114
+ top_models = []
115
 
116
  print(f"Starting training on {DEVICE}")
117
+ for epoch in range(1, EPOCS + 1):
118
  train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion, writer)
119
  test_loss = evaluate(epoch, model, test_loader, criterion, writer)
120
 
 
121
  print(f'Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
122
+
123
+ filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
124
+
125
+ torch.save(model.state_dict(), filename)
126
+ top_models.append({'loss': test_loss, 'path': filename, 'epoch': epoch})
127
+
128
+ top_models.sort(key=lambda x: x['loss'])
129
+
130
+ if len(top_models) > TOP_K:
131
+ worst_model = top_models.pop()
132
+ os.remove(worst_model['path'])
133
+
134
+ if any(m['epoch'] == epoch for m in top_models):
135
+ rank = [m['epoch'] for m in top_models].index(epoch) + 1
136
+ print(f'-- Model saved (Rank: {rank})')
137
+ else:
138
+ print("")
139
+
140
 
141
  writer.close()
142
  print("Training finished.")
143
+ print("Top models saved:")
144
+ for i, m in enumerate(top_models):
145
+ print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
146
 
147
 
148
  if __name__ == "__main__":