| import torch | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from tqdm import tqdm | |
| from utils.utils import set_seed | |
| def train(model, train_loader, validation_loader, criterion, optimizer, scheduler, device, num_epochs=60): | |
| for epoch in range(num_epochs): | |
| model.train() | |
| train_loss = 0.0 | |
| with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as train_progress: | |
| for sequences, smiles, affinities in train_progress: | |
| optimizer.zero_grad() | |
| sequences, smiles = sequences.to(device), smiles.to(device) | |
| predictions = model(sequences, smiles) | |
| loss = criterion(predictions.squeeze(), affinities.to(device).float()) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_progress.set_postfix(loss=train_loss / len(train_loader)) | |
| model.eval() | |
| validation_loss = 0.0 | |
| with torch.no_grad(): | |
| for sequences, smiles, affinities in validation_loader: | |
| sequences, smiles = sequences.to(device), smiles.to(device) | |
| predictions = model(sequences, smiles) | |
| loss = criterion(predictions.squeeze(), affinities.to(device).float()) | |
| validation_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader)}, Validation Loss: {validation_loss / len(validation_loader)}") | |
| scheduler.step(validation_loss) | |
| return model |