Skin-Cancer / utils /trainer.py
umergohar's picture
1st Commit
f94b780 verified
def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"):
"""
Train a neural network model with specified training, validation, and testing datasets.
Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images.
This function performs a complete training loop, including:
- Creating DataLoaders for training, validation, and testing datasets
- Moving the model to the specified device (CPU/GPU)
- Training the model for a specified number of epochs
- Tracking and logging training, validation, and testing metrics
- Saving the best (based on validation performance) and last model weights
- Plotting and saving accuracy and loss graphs per epoch
Parameters:
-----------
model : torch.nn.Module
The neural network model to be trained
train_loader : torch.utils.data.DataLoader
Dataset used for training the model
val_loader : torch.utils.data.DataLoader
Dataset used for validating the model during training
test_loader : torch.utils.data.DataLoader
Dataset used for evaluating the model's performance after training
optimizer : torch.optim.Optimizer
Optimization algorithm for updating model weights
criterion : torch.nn.Module
Loss function used to compute the model's performance
epochs : int
Number of complete passes through the entire training dataset
output_folder : str
Folder path where the model weights and plots will be saved
device : str, optional
Computing device to use for training (default is "cpu")
Can be "cpu" or "cuda" for GPU training
Returns:
--------
None
Side Effects:
-------------
- Prints training, validation, and testing metrics for each epoch
- Saves the best performing model (based on validation accuracy) to "weights/best_model.pth"
- Saves the final model to "weights/last_model.pth"
- Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder
Example:
--------
>>> model = MyModel()
>>> optimizer = torch.optim.Adam(model.parameters())
>>> criterion = nn.CrossEntropyLoss()
>>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights")
"""
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
# Ensure weights folder exists
os.makedirs(output_folder, exist_ok=True)
print(f"Device Found: {device}, Starting Training 🚀")
# Move model to the specified device
model = model.to(device)
best_val_accuracy = 0.0 # Initialize best validation accuracy tracker
# Lists to store metrics per epoch for plotting
train_losses, val_losses, test_losses = [], [], []
train_accuracies, val_accuracies, test_accuracies = [], [], []
for epoch in range(epochs):
# ----------------------
# Training Phase
# ----------------------
model.train() # Set model to training mode
running_loss = 0.0
correct = 0
total = 0
train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False)
for images, labels in train_progress:
# Move tensors to the specified device
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad() # Reset gradients
outputs = model(images) # Forward pass
loss = criterion(outputs, labels) # Compute loss
loss.backward() # Backpropagation
optimizer.step() # Update weights
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct / total:.2f}%'
})
train_loss = running_loss / len(train_loader)
train_accuracy = 100 * correct / total
# ----------------------
# Validation Phase
# ----------------------
model.eval() # Set model to evaluation mode
val_loss = 0.0
correct_val = 0
total_val = 0
val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False)
with torch.no_grad():
for images, labels in val_progress:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).sum().item()
val_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct_val / total_val:.2f}%'
})
val_loss /= len(val_loader)
val_accuracy = 100 * correct_val / total_val
# ----------------------
# Testing Phase
# ----------------------
test_loss = 0.0
correct_test = 0
total_test = 0
test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False)
with torch.no_grad():
for images, labels in test_progress:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total_test += labels.size(0)
correct_test += (predicted == labels).sum().item()
test_progress.set_postfix({
'Loss': f'{loss.item():.4f}',
'Accuracy': f'{100 * correct_test / total_test:.2f}%'
})
test_loss /= len(test_loader)
test_accuracy = 100 * correct_test / total_test
# Store metrics for plotting
train_losses.append(train_loss)
val_losses.append(val_loss)
test_losses.append(test_loss)
train_accuracies.append(train_accuracy)
val_accuracies.append(val_accuracy)
test_accuracies.append(test_accuracy)
# Log the metrics for this epoch
print(
f"Epoch [{epoch+1}/{epochs}]: "
f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | "
f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | "
f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
)
# Save the best model based on validation accuracy
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth"))
# Save the last model
torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth"))
print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy))
# ----------------------
# Plotting Metrics with Matplotlib
# ----------------------
epochs_range = range(1, epochs + 1)
# Plot Losses
plt.figure()
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.plot(epochs_range, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss per Epoch')
plt.legend()
loss_plot_path = os.path.join(output_folder, 'loss_plot.png')
plt.savefig(loss_plot_path)
plt.close()
print(f"Loss plot saved to {loss_plot_path}")
# Plot Accuracies
plt.figure()
plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.plot(epochs_range, test_accuracies, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy per Epoch')
plt.legend()
acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png')
plt.savefig(acc_plot_path)
plt.close()
print(f"Accuracy plot saved to {acc_plot_path}")