File size: 1,798 Bytes
15c5ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from data.dataset import BindingAffinityDataset
from model.model import BAPULM
from utils.train import train
import yaml
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.utils import set_seed

def main():
    # Load config
    with open('config.yaml', 'r') as config_file:
        config = yaml.safe_load(config_file)

    # Set device
    device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')

    # Load dataset
    dataset = BindingAffinityDataset(config['dataset_path'])

    # Split dataset
    train_size = int(config['train_split'] * len(dataset))
    valid_size = len(dataset) - train_size
    train_data, valid_data = random_split(dataset, [train_size, valid_size])
    print(f"Total dataset size: {len(dataset)}")
    print(f"Training data size: {len(train_data)}")
    print(f"Validation data size: {len(valid_data)}")

    # DataLoader
    train_loader = DataLoader(train_data, batch_size=config['train_batch_size'], shuffle=True)
    validation_loader = DataLoader(valid_data, batch_size=config['train_batch_size'], shuffle=False)

    # Model setup
    model = BAPULM().to(device)
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    criterion = nn.MSELoss()
    scheduler = ReduceLROnPlateau(optimizer, factor=config['scheduler_factor'], patience=config['scheduler_patience'])

    # Training
    model = train(model, train_loader, validation_loader, criterion, optimizer, scheduler, device, num_epochs=config['num_epochs'])

    # Save model
    torch.save(model.state_dict(), config['model_train_save_path'])
    print("Model training complete and saved.")

if __name__ == '__main__':
    main()