File size: 1,567 Bytes
15c5ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from utils.utils import set_seed
def train(model, train_loader, validation_loader, criterion, optimizer, scheduler, device, num_epochs=60):
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as train_progress:
for sequences, smiles, affinities in train_progress:
optimizer.zero_grad()
sequences, smiles = sequences.to(device), smiles.to(device)
predictions = model(sequences, smiles)
loss = criterion(predictions.squeeze(), affinities.to(device).float())
loss.backward()
optimizer.step()
train_loss += loss.item()
train_progress.set_postfix(loss=train_loss / len(train_loader))
model.eval()
validation_loss = 0.0
with torch.no_grad():
for sequences, smiles, affinities in validation_loader:
sequences, smiles = sequences.to(device), smiles.to(device)
predictions = model(sequences, smiles)
loss = criterion(predictions.squeeze(), affinities.to(device).float())
validation_loss += loss.item()
print(f"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader)}, Validation Loss: {validation_loss / len(validation_loader)}")
scheduler.step(validation_loss)
return model |