Commit ·
2fdd454
1
Parent(s): 6afa7ea
Updated code
Browse files- inference.py +85 -0
- model.py +45 -19
- optuna_train.py +79 -0
- requirements.txt +2 -0
- train.py +56 -16
- transformer_from_scratch/attention_visual.ipynb +0 -0
- transformer_from_scratch/model.py +1 -1
inference.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
from dataset import BindingDataset
|
| 9 |
+
from model import BindingAffinityModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from scipy.stats import pearsonr
|
| 12 |
+
from torch.utils.data import random_split
|
| 13 |
+
|
| 14 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
+
MODEL_PATH = "best_model_gat.pth"
|
| 16 |
+
|
| 17 |
+
def set_seed(seed=42):
|
| 18 |
+
random.seed(seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
torch.cuda.manual_seed(seed)
|
| 21 |
+
np.random.seed(seed)
|
| 22 |
+
return torch.Generator().manual_seed(seed)
|
| 23 |
+
|
| 24 |
+
def predict_and_plot():
|
| 25 |
+
gen = set_seed(42)
|
| 26 |
+
print("Loading data...")
|
| 27 |
+
|
| 28 |
+
dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
|
| 29 |
+
dataframe.dropna(inplace=True)
|
| 30 |
+
dataset = BindingDataset(dataframe)
|
| 31 |
+
if len(dataset) == 0:
|
| 32 |
+
print("Dataset is empty")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
train_size = int(0.8 * len(dataset))
|
| 36 |
+
test_size = len(dataset) - train_size
|
| 37 |
+
_, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
|
| 38 |
+
|
| 39 |
+
loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 40 |
+
num_features = test_dataset[0].x.shape[1]
|
| 41 |
+
|
| 42 |
+
print("Loading model...")
|
| 43 |
+
model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
|
| 44 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
| 45 |
+
model.eval()
|
| 46 |
+
|
| 47 |
+
y_true = []
|
| 48 |
+
y_pred = []
|
| 49 |
+
print("Predicting...")
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
for batch in tqdm(loader):
|
| 52 |
+
batch = batch.to(DEVICE)
|
| 53 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 54 |
+
|
| 55 |
+
y_true.extend(batch.y.cpu().numpy())
|
| 56 |
+
y_pred.extend(out.squeeze().cpu().numpy())
|
| 57 |
+
y_true = np.array(y_true)
|
| 58 |
+
y_pred = np.array(y_pred)
|
| 59 |
+
|
| 60 |
+
rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
|
| 61 |
+
mae = np.mean(np.abs(y_true - y_pred))
|
| 62 |
+
pearson_corr, _ = pearsonr(y_true, y_pred) # Pearson correlation
|
| 63 |
+
|
| 64 |
+
print("Results:")
|
| 65 |
+
print(f"RMSE: {rmse:.4f}")
|
| 66 |
+
print(f"MAE: {mae:.4f}")
|
| 67 |
+
print(f"Pearson Correlation: {pearson_corr:.4f}")
|
| 68 |
+
|
| 69 |
+
plt.figure(figsize=(9, 9))
|
| 70 |
+
plt.scatter(y_true, y_pred, alpha=0.4, s=15, c='blue', label='Predictions')
|
| 71 |
+
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='red', linestyle='--', linewidth=2,
|
| 72 |
+
label='Ideal')
|
| 73 |
+
|
| 74 |
+
plt.xlabel('Experimental Affinity (pK)')
|
| 75 |
+
plt.ylabel('Predicted Affinity (pK)')
|
| 76 |
+
plt.title(f'Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}')
|
| 77 |
+
plt.legend()
|
| 78 |
+
plt.grid(True, alpha=0.3)
|
| 79 |
+
plot_file = 'final_results_gat.png'
|
| 80 |
+
plt.savefig(plot_file)
|
| 81 |
+
print(f"График сохранен в {plot_file}")
|
| 82 |
+
plt.show()
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
predict_and_plot()
|
model.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch
|
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
|
| 7 |
-
from torch_geometric.nn import GCNConv, global_mean_pool
|
| 8 |
|
| 9 |
class PositionalEncoding(nn.Module):
|
| 10 |
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
|
|
@@ -39,15 +39,40 @@ class PositionalEncoding(nn.Module):
|
|
| 39 |
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
class LigandGNN(nn.Module):
|
| 43 |
-
def __init__(self, input_dim, hidden_channels):
|
| 44 |
super().__init__()
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
self.conv1 =
|
| 48 |
-
self.conv2 =
|
| 49 |
-
self.conv3 =
|
| 50 |
-
self.dropout = nn.Dropout(
|
| 51 |
|
| 52 |
def forward(self, x, edge_index, batch):
|
| 53 |
x = self.conv1(x, edge_index)
|
|
@@ -56,19 +81,20 @@ class LigandGNN(nn.Module):
|
|
| 56 |
|
| 57 |
x = self.conv2(x, edge_index)
|
| 58 |
x = x.relu()
|
| 59 |
-
x = self.conv3(x, edge_index)
|
| 60 |
x = self.dropout(x)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
return x
|
| 65 |
|
| 66 |
class ProteinTransformer(nn.Module):
|
| 67 |
-
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128):
|
| 68 |
super().__init__()
|
| 69 |
self.d_model = d_model
|
| 70 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 71 |
-
self.pos_encoder = PositionalEncoding(d_model)
|
| 72 |
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
|
| 73 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
|
| 74 |
|
|
@@ -91,18 +117,18 @@ class ProteinTransformer(nn.Module):
|
|
| 91 |
return x
|
| 92 |
|
| 93 |
class BindingAffinityModel(nn.Module):
|
| 94 |
-
def __init__(self, num_node_features,
|
| 95 |
super().__init__()
|
| 96 |
# Tower 1 - Ligand GNN
|
| 97 |
-
self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=
|
| 98 |
# Tower 2 - Protein Transformer
|
| 99 |
-
self.protein_transformer = ProteinTransformer(vocab_size=26)
|
| 100 |
|
| 101 |
self.head = nn.Sequential(
|
| 102 |
-
nn.Linear(
|
| 103 |
nn.ReLU(),
|
| 104 |
-
nn.Dropout(
|
| 105 |
-
nn.Linear(
|
| 106 |
)
|
| 107 |
def forward(self, x, edge_index, batch, protein_seq):
|
| 108 |
ligand_vec = self.ligand_gnn(x, edge_index, batch)
|
|
|
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
|
| 7 |
+
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
|
| 8 |
|
| 9 |
class PositionalEncoding(nn.Module):
|
| 10 |
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
|
| 42 |
+
# class LigandGNN(nn.Module): # GCN CONV
|
| 43 |
+
# def __init__(self, input_dim, hidden_channels):
|
| 44 |
+
# super().__init__()
|
| 45 |
+
# self.hidden_channels = hidden_channels
|
| 46 |
+
#
|
| 47 |
+
# self.conv1 = GCNConv(input_dim, hidden_channels)
|
| 48 |
+
# self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
| 49 |
+
# self.conv3 = GCNConv(hidden_channels, hidden_channels)
|
| 50 |
+
# self.dropout = nn.Dropout(0.2)
|
| 51 |
+
#
|
| 52 |
+
# def forward(self, x, edge_index, batch):
|
| 53 |
+
# x = self.conv1(x, edge_index)
|
| 54 |
+
# x = x.relu()
|
| 55 |
+
# x = self.dropout(x)
|
| 56 |
+
#
|
| 57 |
+
# x = self.conv2(x, edge_index)
|
| 58 |
+
# x = x.relu()
|
| 59 |
+
# x = self.conv3(x, edge_index)
|
| 60 |
+
# x = self.dropout(x)
|
| 61 |
+
#
|
| 62 |
+
# # Averaging nodes and got the molecula vector
|
| 63 |
+
# x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
|
| 64 |
+
# return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
class LigandGNN(nn.Module):
|
| 68 |
+
def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
|
| 69 |
super().__init__()
|
| 70 |
+
# Heads=4 means we use 4 attention heads
|
| 71 |
+
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
|
| 72 |
+
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
|
| 73 |
+
self.conv2 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
|
| 74 |
+
self.conv3 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
|
| 75 |
+
self.dropout = nn.Dropout(dropout)
|
| 76 |
|
| 77 |
def forward(self, x, edge_index, batch):
|
| 78 |
x = self.conv1(x, edge_index)
|
|
|
|
| 81 |
|
| 82 |
x = self.conv2(x, edge_index)
|
| 83 |
x = x.relu()
|
|
|
|
| 84 |
x = self.dropout(x)
|
| 85 |
|
| 86 |
+
x = self.conv3(x, edge_index)
|
| 87 |
+
|
| 88 |
+
# Global Mean Pooling
|
| 89 |
+
x = global_mean_pool(x, batch)
|
| 90 |
return x
|
| 91 |
|
| 92 |
class ProteinTransformer(nn.Module):
|
| 93 |
+
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
|
| 94 |
super().__init__()
|
| 95 |
self.d_model = d_model
|
| 96 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 97 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
| 98 |
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
|
| 99 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
|
| 100 |
|
|
|
|
| 117 |
return x
|
| 118 |
|
| 119 |
class BindingAffinityModel(nn.Module):
|
| 120 |
+
def __init__(self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2):
|
| 121 |
super().__init__()
|
| 122 |
# Tower 1 - Ligand GNN
|
| 123 |
+
self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels, heads=gat_heads, dropout=dropout)
|
| 124 |
# Tower 2 - Protein Transformer
|
| 125 |
+
self.protein_transformer = ProteinTransformer(vocab_size=26, d_model=hidden_channels, output_dim=hidden_channels, dropout=dropout)
|
| 126 |
|
| 127 |
self.head = nn.Sequential(
|
| 128 |
+
nn.Linear(hidden_channels*2, hidden_channels),
|
| 129 |
nn.ReLU(),
|
| 130 |
+
nn.Dropout(dropout),
|
| 131 |
+
nn.Linear(hidden_channels, 1),
|
| 132 |
)
|
| 133 |
def forward(self, x, edge_index, batch, protein_seq):
|
| 134 |
ligand_vec = self.ligand_gnn(x, edge_index, batch)
|
optuna_train.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import optuna
|
| 5 |
+
from torch.nn.functional import dropout
|
| 6 |
+
from torch.utils.data import random_split
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
from dataset import BindingDataset
|
| 9 |
+
from model import BindingAffinityModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
EPOCHS_PER_TRIAL = 10
|
| 15 |
+
|
| 16 |
+
dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
|
| 17 |
+
dataframe.dropna(inplace=True)
|
| 18 |
+
dataset = BindingDataset(dataframe)
|
| 19 |
+
train_size = int(0.8 * len(dataset))
|
| 20 |
+
test_size = len(dataset) - train_size
|
| 21 |
+
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
| 22 |
+
num_features = train_dataset[0].x.shape[1]
|
| 23 |
+
|
| 24 |
+
def train(model, loader, optimizer, criterion):
|
| 25 |
+
model.train()
|
| 26 |
+
for batch in loader:
|
| 27 |
+
batch = batch.to(DEVICE)
|
| 28 |
+
optimizer.zero_grad()
|
| 29 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 30 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 31 |
+
loss.backward()
|
| 32 |
+
optimizer.step()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test(model, loader, criterion):
|
| 36 |
+
model.eval()
|
| 37 |
+
total_loss = 0
|
| 38 |
+
with torch.no_grad():
|
| 39 |
+
for batch in loader:
|
| 40 |
+
batch = batch.to(DEVICE)
|
| 41 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 42 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 43 |
+
total_loss += loss.item()
|
| 44 |
+
return total_loss / len(loader)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def objective(trial):
|
| 48 |
+
lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # Learning rate from 0.00001 to 0.01
|
| 49 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) # Weight decay from 0.000001 to 0.001
|
| 50 |
+
|
| 51 |
+
model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
|
| 52 |
+
|
| 53 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 54 |
+
criterion = nn.MSELoss()
|
| 55 |
+
|
| 56 |
+
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
| 57 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 58 |
+
|
| 59 |
+
for epoch in range(EPOCHS_PER_TRIAL):
|
| 60 |
+
train(model, train_loader, optimizer, criterion)
|
| 61 |
+
val_loss = test(model, test_loader, criterion)
|
| 62 |
+
|
| 63 |
+
trial.report(val_loss, epoch)
|
| 64 |
+
if trial.should_prune():
|
| 65 |
+
raise optuna.exceptions.TrialPruned()
|
| 66 |
+
return val_loss
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
study = optuna.create_study(direction="minimize")
|
| 71 |
+
print("Start hyperparameter optimization...")
|
| 72 |
+
|
| 73 |
+
study.optimize(objective, n_trials=10)
|
| 74 |
+
print("\n--- Optimization Finished ---")
|
| 75 |
+
print("Best parameters found: ", study.best_params)
|
| 76 |
+
print("Best Test MSE: ", study.best_value)
|
| 77 |
+
|
| 78 |
+
df_results = study.trials_dataframe()
|
| 79 |
+
df_results.to_csv("optuna_results.csv")
|
requirements.txt
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
torch
|
|
|
|
|
|
|
| 2 |
|
| 3 |
numpy
|
| 4 |
pandas
|
|
|
|
| 1 |
torch
|
| 2 |
+
pytorch-lightning
|
| 3 |
+
optuna
|
| 4 |
|
| 5 |
numpy
|
| 6 |
pandas
|
train.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import pandas as pd
|
|
@@ -6,15 +8,30 @@ from torch_geometric.loader import DataLoader
|
|
| 6 |
from dataset import BindingDataset
|
| 7 |
from model import BindingAffinityModel
|
| 8 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
-
def train_epoch(epoch, model, loader, optimizer, criterion):
|
| 15 |
model.train()
|
| 16 |
total_loss = 0
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
batch = batch.to(DEVICE)
|
| 19 |
optimizer.zero_grad()
|
| 20 |
|
|
@@ -23,21 +40,35 @@ def train_epoch(epoch, model, loader, optimizer, criterion):
|
|
| 23 |
|
| 24 |
loss.backward()
|
| 25 |
optimizer.step()
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
def evaluate(epoch, model, loader, criterion):
|
| 30 |
model.eval()
|
| 31 |
total_loss = 0
|
| 32 |
with torch.no_grad():
|
| 33 |
-
for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}"):
|
| 34 |
batch = batch.to(DEVICE)
|
| 35 |
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 36 |
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 37 |
total_loss += loss.item()
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def main():
|
|
|
|
|
|
|
|
|
|
| 41 |
# Load dataset
|
| 42 |
dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
|
| 43 |
dataframe.dropna(inplace=True)
|
|
@@ -52,10 +83,10 @@ def main():
|
|
| 52 |
|
| 53 |
train_size = int(0.8 * len(dataset))
|
| 54 |
test_size = len(dataset) - train_size
|
| 55 |
-
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
| 56 |
|
| 57 |
-
train_loader = DataLoader(train_dataset, batch_size=
|
| 58 |
-
test_loader = DataLoader(test_dataset, batch_size=
|
| 59 |
num_features = train_dataset[0].x.shape[1]
|
| 60 |
print("Number of node features:", num_features)
|
| 61 |
|
|
@@ -63,13 +94,22 @@ def main():
|
|
| 63 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
|
| 64 |
criterion = nn.MSELoss()
|
| 65 |
|
| 66 |
-
|
|
|
|
| 67 |
print(f"Starting training on {DEVICE}")
|
| 68 |
-
for epoch in range(
|
| 69 |
-
train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion)
|
| 70 |
-
test_loss = evaluate(epoch, model, test_loader, criterion)
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import pandas as pd
|
|
|
|
| 8 |
from dataset import BindingDataset
|
| 9 |
from model import BindingAffinityModel
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
|
| 15 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 16 |
+
BATCH_SIZE = 32
|
| 17 |
+
LR = 0.0005
|
| 18 |
+
EPOCS = 30
|
| 19 |
+
LOG_DIR = f"runs/experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 20 |
|
| 21 |
+
def set_seed(seed=42):
|
| 22 |
+
random.seed(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed(seed)
|
| 25 |
+
np.random.seed(seed)
|
| 26 |
+
return torch.Generator().manual_seed(seed)
|
| 27 |
|
| 28 |
|
| 29 |
+
def train_epoch(epoch, model, loader, optimizer, criterion, writer):
|
| 30 |
model.train()
|
| 31 |
total_loss = 0
|
| 32 |
+
|
| 33 |
+
loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
|
| 34 |
+
for i, batch in enumerate(loop):
|
| 35 |
batch = batch.to(DEVICE)
|
| 36 |
optimizer.zero_grad()
|
| 37 |
|
|
|
|
| 40 |
|
| 41 |
loss.backward()
|
| 42 |
optimizer.step()
|
| 43 |
+
current_loss = loss.item()
|
| 44 |
+
total_loss += current_loss
|
| 45 |
+
|
| 46 |
+
global_step = (epoch - 1) * len(loader) + i
|
| 47 |
+
writer.add_scalar('Loss/Train_Step', current_loss, global_step)
|
| 48 |
+
|
| 49 |
+
loop.set_postfix(loss = loss.item())
|
| 50 |
+
|
| 51 |
+
avg_loss = total_loss / len(loader)
|
| 52 |
+
return avg_loss
|
| 53 |
|
| 54 |
+
def evaluate(epoch, model, loader, criterion, writer):
|
| 55 |
model.eval()
|
| 56 |
total_loss = 0
|
| 57 |
with torch.no_grad():
|
| 58 |
+
for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
|
| 59 |
batch = batch.to(DEVICE)
|
| 60 |
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 61 |
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 62 |
total_loss += loss.item()
|
| 63 |
+
|
| 64 |
+
avg_loss = total_loss / len(loader)
|
| 65 |
+
writer.add_scalar('Loss/Test', avg_loss, epoch)
|
| 66 |
+
return avg_loss
|
| 67 |
|
| 68 |
def main():
|
| 69 |
+
gen = set_seed(42)
|
| 70 |
+
writer = SummaryWriter(LOG_DIR)
|
| 71 |
+
print(f"Logging to {LOG_DIR}...")
|
| 72 |
# Load dataset
|
| 73 |
dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
|
| 74 |
dataframe.dropna(inplace=True)
|
|
|
|
| 83 |
|
| 84 |
train_size = int(0.8 * len(dataset))
|
| 85 |
test_size = len(dataset) - train_size
|
| 86 |
+
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
|
| 87 |
|
| 88 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 89 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 90 |
num_features = train_dataset[0].x.shape[1]
|
| 91 |
print("Number of node features:", num_features)
|
| 92 |
|
|
|
|
| 94 |
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
|
| 95 |
criterion = nn.MSELoss()
|
| 96 |
|
| 97 |
+
best_test_loss = float('inf')
|
| 98 |
+
|
| 99 |
print(f"Starting training on {DEVICE}")
|
| 100 |
+
for epoch in range(1, EPOCS):
|
| 101 |
+
train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion, writer)
|
| 102 |
+
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
print(f'Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
|
| 106 |
+
if test_loss < best_test_loss:
|
| 107 |
+
best_test_loss = test_loss
|
| 108 |
+
torch.save(model.state_dict(), f'best_model_gat.pth')
|
| 109 |
+
print(f'Best model saved with Test Loss MSE: {best_test_loss:.4f}')
|
| 110 |
+
|
| 111 |
+
writer.close()
|
| 112 |
+
print("Training finished.")
|
| 113 |
|
| 114 |
|
| 115 |
if __name__ == "__main__":
|
transformer_from_scratch/attention_visual.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
transformer_from_scratch/model.py
CHANGED
|
@@ -119,7 +119,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 119 |
1, 2
|
| 120 |
)
|
| 121 |
|
| 122 |
-
x, attention_scores = MultiHeadAttention.attention(
|
| 123 |
query, key, value, mask, self.dropout
|
| 124 |
)
|
| 125 |
|
|
|
|
| 119 |
1, 2
|
| 120 |
)
|
| 121 |
|
| 122 |
+
x, self.attention_scores = MultiHeadAttention.attention(
|
| 123 |
query, key, value, mask, self.dropout
|
| 124 |
)
|
| 125 |
|