BAPULM / utils /train.py
Moreza009's picture
Upload folder using huggingface_hub
15c5ffb verified
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from utils.utils import set_seed
def train(model, train_loader, validation_loader, criterion, optimizer, scheduler, device, num_epochs=60):
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as train_progress:
for sequences, smiles, affinities in train_progress:
optimizer.zero_grad()
sequences, smiles = sequences.to(device), smiles.to(device)
predictions = model(sequences, smiles)
loss = criterion(predictions.squeeze(), affinities.to(device).float())
loss.backward()
optimizer.step()
train_loss += loss.item()
train_progress.set_postfix(loss=train_loss / len(train_loader))
model.eval()
validation_loss = 0.0
with torch.no_grad():
for sequences, smiles, affinities in validation_loader:
sequences, smiles = sequences.to(device), smiles.to(device)
predictions = model(sequences, smiles)
loss = criterion(predictions.squeeze(), affinities.to(device).float())
validation_loss += loss.item()
print(f"Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader)}, Validation Loss: {validation_loss / len(validation_loader)}")
scheduler.step(validation_loss)
return model