File size: 5,567 Bytes
2fdd454 7e792a6 2fdd454 1390640 de183ef 1390640 de183ef 1390640 de183ef 1390640 de183ef e33b6c9 7e792a6 e33b6c9 de183ef 1390640 7e792a6 e33b6c9 2fdd454 7e792a6 2b97908 2fdd454 7e792a6 2fdd454 7e792a6 2fdd454 e33b6c9 2fdd454 e33b6c9 2fdd454 7e792a6 e33b6c9 2fdd454 7e792a6 2fdd454 7e792a6 2fdd454 e33b6c9 2fdd454 7e792a6 e33b6c9 7e792a6 2fdd454 1390640 2fdd454 1390640 7e792a6 e33b6c9 7e792a6 e33b6c9 7e792a6 2b97908 7e792a6 e33b6c9 7e792a6 2fdd454 7e792a6 1390640 e33b6c9 1390640 e33b6c9 1390640 e33b6c9 7e792a6 1390640 2fdd454 2b97908 1390640 e33b6c9 2fdd454 e33b6c9 1390640 e33b6c9 1390640 e33b6c9 1390640 e33b6c9 1390640 e33b6c9 1390640 2fdd454 1390640 7e792a6 e33b6c9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import random
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from model import BindingAffinityModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime
import os
# GAT V2
# BATCH_SIZE = 16
# LR = 0.0005
# WEIGHT_DECAY = 1e-5
# EPOCS = 100
# DROPOUT = 0.4
# GAT_HEADS = 4
# HIDDEN_CHANNELS = 256
# GAT
# BATCH_SIZE = 16
# LR = 0.00064
# WEIGHT_DECAY = 7.06e-6
# EPOCS = 100
# DROPOUT = 0.325
# GAT_HEADS = 2
# HIDDEN_CHANNELS = 256
# GCN
BATCH_SIZE = 16
LR = 0.001
WEIGHT_DECAY = 7.06e-6
EPOCS = 100
DROPOUT = 0.3
HIDDEN_CHANNELS = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG_DIR = f"runs/GATV2{datetime.now().strftime('%Y%m%d_%H%M%S')}"
TOP_K = 3
SAVES_DIR = LOG_DIR + "/models"
def set_seed(seed=42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
return torch.Generator().manual_seed(seed)
def train_epoch(epoch, model, loader, optimizer, criterion, writer):
model.train()
total_loss = 0
loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
for i, batch in enumerate(loop):
batch = batch.to(DEVICE)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
loss = criterion(out.squeeze(), batch.y.squeeze())
loss.backward()
optimizer.step()
current_loss = loss.item()
total_loss += current_loss
global_step = (epoch - 1) * len(loader) + i
writer.add_scalar("Loss/Train_Step", current_loss, global_step)
loop.set_postfix(loss=loss.item())
avg_loss = total_loss / len(loader)
return avg_loss
def evaluate(epoch, model, loader, criterion, writer):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
batch = batch.to(DEVICE)
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
loss = criterion(out.squeeze(), batch.y.squeeze())
total_loss += loss.item()
avg_loss = total_loss / len(loader)
writer.add_scalar("Loss/Test", avg_loss, epoch)
return avg_loss
def main():
gen = set_seed(42)
writer = SummaryWriter(LOG_DIR)
if not os.path.exists(SAVES_DIR):
os.makedirs(SAVES_DIR)
print(f"Logging to {LOG_DIR}...")
print(f"Model saves to {SAVES_DIR}...")
# Load dataset
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
dataframe.dropna(inplace=True)
print("Dataset loaded with {} samples".format(len(dataframe)))
dataset = BindingDataset(dataframe, max_seq_length=1200)
print("Dataset transformed with {} samples".format(len(dataset)))
if len(dataset) == 0:
print("Dataset is empty")
return
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=gen
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
num_features = train_dataset[0].x.shape[1]
print("Number of node features:", num_features)
model = BindingAffinityModel(
num_node_features=num_features,
hidden_channels=HIDDEN_CHANNELS,
gat_heads=GAT_HEADS,
dropout=DROPOUT,
).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# factor of 0.5 means reducing lr to half when triggered
# patience of 5 means wait for 5 epochs before reducing lr
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
criterion = nn.MSELoss()
top_models = []
print(f"Starting training on {DEVICE}")
for epoch in range(1, EPOCS + 1):
train_loss = train_epoch(
epoch, model, train_loader, optimizer, criterion, writer
)
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
old_lr = optimizer.param_groups[0]["lr"]
scheduler.step(test_loss)
new_lr = optimizer.param_groups[0]["lr"]
if new_lr != old_lr:
print(
f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
)
print(
f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
end="",
)
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
torch.save(model.state_dict(), filename)
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
top_models.sort(key=lambda x: x["loss"])
if len(top_models) > TOP_K:
worst_model = top_models.pop()
os.remove(worst_model["path"])
if any(m["epoch"] == epoch for m in top_models):
rank = [m["epoch"] for m in top_models].index(epoch) + 1
print(f"-- Model saved (Rank: {rank})")
else:
print("")
writer.close()
print("Training finished.")
print("Top models saved:")
for i, m in enumerate(top_models):
print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
if __name__ == "__main__":
main()
|