AlexSychovUN's picture
Prepared for deploy
de183ef
import random
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from model import BindingAffinityModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime
import os
# GAT V2
# BATCH_SIZE = 16
# LR = 0.0005
# WEIGHT_DECAY = 1e-5
# EPOCS = 100
# DROPOUT = 0.4
# GAT_HEADS = 4
# HIDDEN_CHANNELS = 256
# GAT
# BATCH_SIZE = 16
# LR = 0.00064
# WEIGHT_DECAY = 7.06e-6
# EPOCS = 100
# DROPOUT = 0.325
# GAT_HEADS = 2
# HIDDEN_CHANNELS = 256
# GCN
BATCH_SIZE = 16
LR = 0.001
WEIGHT_DECAY = 7.06e-6
EPOCS = 100
DROPOUT = 0.3
HIDDEN_CHANNELS = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG_DIR = f"runs/GATV2{datetime.now().strftime('%Y%m%d_%H%M%S')}"
TOP_K = 3
SAVES_DIR = LOG_DIR + "/models"
def set_seed(seed=42):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
return torch.Generator().manual_seed(seed)
def train_epoch(epoch, model, loader, optimizer, criterion, writer):
model.train()
total_loss = 0
loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
for i, batch in enumerate(loop):
batch = batch.to(DEVICE)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
loss = criterion(out.squeeze(), batch.y.squeeze())
loss.backward()
optimizer.step()
current_loss = loss.item()
total_loss += current_loss
global_step = (epoch - 1) * len(loader) + i
writer.add_scalar("Loss/Train_Step", current_loss, global_step)
loop.set_postfix(loss=loss.item())
avg_loss = total_loss / len(loader)
return avg_loss
def evaluate(epoch, model, loader, criterion, writer):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
batch = batch.to(DEVICE)
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
loss = criterion(out.squeeze(), batch.y.squeeze())
total_loss += loss.item()
avg_loss = total_loss / len(loader)
writer.add_scalar("Loss/Test", avg_loss, epoch)
return avg_loss
def main():
gen = set_seed(42)
writer = SummaryWriter(LOG_DIR)
if not os.path.exists(SAVES_DIR):
os.makedirs(SAVES_DIR)
print(f"Logging to {LOG_DIR}...")
print(f"Model saves to {SAVES_DIR}...")
# Load dataset
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
dataframe.dropna(inplace=True)
print("Dataset loaded with {} samples".format(len(dataframe)))
dataset = BindingDataset(dataframe, max_seq_length=1200)
print("Dataset transformed with {} samples".format(len(dataset)))
if len(dataset) == 0:
print("Dataset is empty")
return
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(
dataset, [train_size, test_size], generator=gen
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
num_features = train_dataset[0].x.shape[1]
print("Number of node features:", num_features)
model = BindingAffinityModel(
num_node_features=num_features,
hidden_channels=HIDDEN_CHANNELS,
gat_heads=GAT_HEADS,
dropout=DROPOUT,
).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# factor of 0.5 means reducing lr to half when triggered
# patience of 5 means wait for 5 epochs before reducing lr
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=5
)
criterion = nn.MSELoss()
top_models = []
print(f"Starting training on {DEVICE}")
for epoch in range(1, EPOCS + 1):
train_loss = train_epoch(
epoch, model, train_loader, optimizer, criterion, writer
)
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
old_lr = optimizer.param_groups[0]["lr"]
scheduler.step(test_loss)
new_lr = optimizer.param_groups[0]["lr"]
if new_lr != old_lr:
print(
f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
)
print(
f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
end="",
)
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
torch.save(model.state_dict(), filename)
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
top_models.sort(key=lambda x: x["loss"])
if len(top_models) > TOP_K:
worst_model = top_models.pop()
os.remove(worst_model["path"])
if any(m["epoch"] == epoch for m in top_models):
rank = [m["epoch"] for m in top_models].index(epoch) + 1
print(f"-- Model saved (Rank: {rank})")
else:
print("")
writer.close()
print("Training finished.")
print("Top models saved:")
for i, m in enumerate(top_models):
print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
if __name__ == "__main__":
main()