|
|
import torch |
|
|
from data.dataset import BindingAffinityDataset |
|
|
from model.model import BAPULM |
|
|
from utils.train import train |
|
|
import yaml |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
from utils.utils import set_seed |
|
|
|
|
|
def main(): |
|
|
|
|
|
with open('config.yaml', 'r') as config_file: |
|
|
config = yaml.safe_load(config_file) |
|
|
|
|
|
|
|
|
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
dataset = BindingAffinityDataset(config['dataset_path']) |
|
|
|
|
|
|
|
|
train_size = int(config['train_split'] * len(dataset)) |
|
|
valid_size = len(dataset) - train_size |
|
|
train_data, valid_data = random_split(dataset, [train_size, valid_size]) |
|
|
print(f"Total dataset size: {len(dataset)}") |
|
|
print(f"Training data size: {len(train_data)}") |
|
|
print(f"Validation data size: {len(valid_data)}") |
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_data, batch_size=config['train_batch_size'], shuffle=True) |
|
|
validation_loader = DataLoader(valid_data, batch_size=config['train_batch_size'], shuffle=False) |
|
|
|
|
|
|
|
|
model = BAPULM().to(device) |
|
|
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) |
|
|
criterion = nn.MSELoss() |
|
|
scheduler = ReduceLROnPlateau(optimizer, factor=config['scheduler_factor'], patience=config['scheduler_patience']) |
|
|
|
|
|
|
|
|
model = train(model, train_loader, validation_loader, criterion, optimizer, scheduler, device, num_epochs=config['num_epochs']) |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), config['model_train_save_path']) |
|
|
print("Model training complete and saved.") |
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |