File size: 5,567 Bytes
2fdd454
 
7e792a6
 
 
 
 
 
 
 
2fdd454
 
 
1390640
 
de183ef
 
 
 
 
 
 
 
1390640
de183ef
 
 
 
 
 
 
 
 
 
1390640
de183ef
1390640
 
de183ef
e33b6c9
7e792a6
e33b6c9
de183ef
1390640
 
7e792a6
e33b6c9
2fdd454
 
 
 
 
 
7e792a6
2b97908
2fdd454
7e792a6
 
2fdd454
 
 
7e792a6
 
 
 
 
 
 
 
2fdd454
 
 
 
e33b6c9
2fdd454
e33b6c9
2fdd454
 
 
7e792a6
e33b6c9
2fdd454
7e792a6
 
 
2fdd454
7e792a6
 
 
 
2fdd454
 
e33b6c9
2fdd454
7e792a6
e33b6c9
7e792a6
2fdd454
 
1390640
 
 
2fdd454
1390640
7e792a6
e33b6c9
7e792a6
 
e33b6c9
7e792a6
 
2b97908
 
 
 
7e792a6
 
e33b6c9
 
 
7e792a6
2fdd454
 
7e792a6
 
 
1390640
 
e33b6c9
1390640
e33b6c9
1390640
 
e33b6c9
 
 
 
 
7e792a6
 
1390640
2fdd454
2b97908
1390640
e33b6c9
 
 
2fdd454
 
e33b6c9
 
 
 
 
 
 
 
 
 
 
 
1390640
 
 
 
e33b6c9
1390640
e33b6c9
1390640
 
 
e33b6c9
1390640
e33b6c9
 
 
1390640
 
 
2fdd454
 
1390640
 
 
7e792a6
 
 
e33b6c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import random

import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from dataset import BindingDataset
from model import BindingAffinityModel
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from datetime import datetime
import os

# GAT V2
# BATCH_SIZE = 16
# LR = 0.0005
# WEIGHT_DECAY = 1e-5
# EPOCS = 100
# DROPOUT = 0.4
# GAT_HEADS = 4
# HIDDEN_CHANNELS = 256

# GAT
# BATCH_SIZE = 16
# LR = 0.00064
# WEIGHT_DECAY = 7.06e-6
# EPOCS = 100
# DROPOUT = 0.325
# GAT_HEADS = 2
# HIDDEN_CHANNELS = 256

# GCN
BATCH_SIZE = 16
LR = 0.001
WEIGHT_DECAY = 7.06e-6
EPOCS = 100
DROPOUT = 0.3
HIDDEN_CHANNELS = 256

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG_DIR = f"runs/GATV2{datetime.now().strftime('%Y%m%d_%H%M%S')}"
TOP_K = 3
SAVES_DIR = LOG_DIR + "/models"


def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    return torch.Generator().manual_seed(seed)


def train_epoch(epoch, model, loader, optimizer, criterion, writer):
    model.train()
    total_loss = 0

    loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
    for i, batch in enumerate(loop):
        batch = batch.to(DEVICE)
        optimizer.zero_grad()

        out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
        loss = criterion(out.squeeze(), batch.y.squeeze())

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss

        global_step = (epoch - 1) * len(loader) + i
        writer.add_scalar("Loss/Train_Step", current_loss, global_step)

        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(loader)
    return avg_loss


def evaluate(epoch, model, loader, criterion, writer):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
            batch = batch.to(DEVICE)
            out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
            loss = criterion(out.squeeze(), batch.y.squeeze())
            total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    writer.add_scalar("Loss/Test", avg_loss, epoch)
    return avg_loss


def main():
    gen = set_seed(42)
    writer = SummaryWriter(LOG_DIR)

    if not os.path.exists(SAVES_DIR):
        os.makedirs(SAVES_DIR)
    print(f"Logging to {LOG_DIR}...")
    print(f"Model saves to {SAVES_DIR}...")
    # Load dataset
    dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
    dataframe.dropna(inplace=True)
    print("Dataset loaded with {} samples".format(len(dataframe)))
    dataset = BindingDataset(dataframe, max_seq_length=1200)
    print("Dataset transformed with {} samples".format(len(dataset)))

    if len(dataset) == 0:
        print("Dataset is empty")
        return

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(
        dataset, [train_size, test_size], generator=gen
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    num_features = train_dataset[0].x.shape[1]
    print("Number of node features:", num_features)

    model = BindingAffinityModel(
        num_node_features=num_features,
        hidden_channels=HIDDEN_CHANNELS,
        gat_heads=GAT_HEADS,
        dropout=DROPOUT,
    ).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    # factor of 0.5 means reducing lr to half when triggered
    # patience of 5 means wait for 5 epochs before reducing lr
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=5
    )
    criterion = nn.MSELoss()

    top_models = []

    print(f"Starting training on {DEVICE}")
    for epoch in range(1, EPOCS + 1):
        train_loss = train_epoch(
            epoch, model, train_loader, optimizer, criterion, writer
        )
        test_loss = evaluate(epoch, model, test_loader, criterion, writer)

        old_lr = optimizer.param_groups[0]["lr"]
        scheduler.step(test_loss)
        new_lr = optimizer.param_groups[0]["lr"]

        if new_lr != old_lr:
            print(
                f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
            )
        print(
            f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
            end="",
        )

        filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"

        torch.save(model.state_dict(), filename)
        top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})

        top_models.sort(key=lambda x: x["loss"])

        if len(top_models) > TOP_K:
            worst_model = top_models.pop()
            os.remove(worst_model["path"])

        if any(m["epoch"] == epoch for m in top_models):
            rank = [m["epoch"] for m in top_models].index(epoch) + 1
            print(f"-- Model saved (Rank: {rank})")
        else:
            print("")

    writer.close()
    print("Training finished.")
    print("Top models saved:")
    for i, m in enumerate(top_models):
        print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")


if __name__ == "__main__":
    main()