def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"): """ Train a neural network model with specified training, validation, and testing datasets. Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images. This function performs a complete training loop, including: - Creating DataLoaders for training, validation, and testing datasets - Moving the model to the specified device (CPU/GPU) - Training the model for a specified number of epochs - Tracking and logging training, validation, and testing metrics - Saving the best (based on validation performance) and last model weights - Plotting and saving accuracy and loss graphs per epoch Parameters: ----------- model : torch.nn.Module The neural network model to be trained train_loader : torch.utils.data.DataLoader Dataset used for training the model val_loader : torch.utils.data.DataLoader Dataset used for validating the model during training test_loader : torch.utils.data.DataLoader Dataset used for evaluating the model's performance after training optimizer : torch.optim.Optimizer Optimization algorithm for updating model weights criterion : torch.nn.Module Loss function used to compute the model's performance epochs : int Number of complete passes through the entire training dataset output_folder : str Folder path where the model weights and plots will be saved device : str, optional Computing device to use for training (default is "cpu") Can be "cpu" or "cuda" for GPU training Returns: -------- None Side Effects: ------------- - Prints training, validation, and testing metrics for each epoch - Saves the best performing model (based on validation accuracy) to "weights/best_model.pth" - Saves the final model to "weights/last_model.pth" - Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder Example: -------- >>> model = MyModel() >>> optimizer = torch.optim.Adam(model.parameters()) >>> criterion = nn.CrossEntropyLoss() >>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights") """ import os import torch from torch.utils.data import DataLoader from tqdm import tqdm import matplotlib.pyplot as plt # Ensure weights folder exists os.makedirs(output_folder, exist_ok=True) print(f"Device Found: {device}, Starting Training 🚀") # Move model to the specified device model = model.to(device) best_val_accuracy = 0.0 # Initialize best validation accuracy tracker # Lists to store metrics per epoch for plotting train_losses, val_losses, test_losses = [], [], [] train_accuracies, val_accuracies, test_accuracies = [], [], [] for epoch in range(epochs): # ---------------------- # Training Phase # ---------------------- model.train() # Set model to training mode running_loss = 0.0 correct = 0 total = 0 train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False) for images, labels in train_progress: # Move tensors to the specified device images, labels = images.to(device), labels.to(device) optimizer.zero_grad() # Reset gradients outputs = model(images) # Forward pass loss = criterion(outputs, labels) # Compute loss loss.backward() # Backpropagation optimizer.step() # Update weights running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_progress.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Accuracy': f'{100 * correct / total:.2f}%' }) train_loss = running_loss / len(train_loader) train_accuracy = 100 * correct / total # ---------------------- # Validation Phase # ---------------------- model.eval() # Set model to evaluation mode val_loss = 0.0 correct_val = 0 total_val = 0 val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False) with torch.no_grad(): for images, labels in val_progress: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) total_val += labels.size(0) correct_val += (predicted == labels).sum().item() val_progress.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Accuracy': f'{100 * correct_val / total_val:.2f}%' }) val_loss /= len(val_loader) val_accuracy = 100 * correct_val / total_val # ---------------------- # Testing Phase # ---------------------- test_loss = 0.0 correct_test = 0 total_test = 0 test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False) with torch.no_grad(): for images, labels in test_progress: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) test_loss += loss.item() _, predicted = torch.max(outputs, 1) total_test += labels.size(0) correct_test += (predicted == labels).sum().item() test_progress.set_postfix({ 'Loss': f'{loss.item():.4f}', 'Accuracy': f'{100 * correct_test / total_test:.2f}%' }) test_loss /= len(test_loader) test_accuracy = 100 * correct_test / total_test # Store metrics for plotting train_losses.append(train_loss) val_losses.append(val_loss) test_losses.append(test_loss) train_accuracies.append(train_accuracy) val_accuracies.append(val_accuracy) test_accuracies.append(test_accuracy) # Log the metrics for this epoch print( f"Epoch [{epoch+1}/{epochs}]: " f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | " f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | " f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%" ) # Save the best model based on validation accuracy if val_accuracy > best_val_accuracy: best_val_accuracy = val_accuracy torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth")) # Save the last model torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth")) print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy)) # ---------------------- # Plotting Metrics with Matplotlib # ---------------------- epochs_range = range(1, epochs + 1) # Plot Losses plt.figure() plt.plot(epochs_range, train_losses, label='Train Loss') plt.plot(epochs_range, val_losses, label='Validation Loss') plt.plot(epochs_range, test_losses, label='Test Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Loss per Epoch') plt.legend() loss_plot_path = os.path.join(output_folder, 'loss_plot.png') plt.savefig(loss_plot_path) plt.close() print(f"Loss plot saved to {loss_plot_path}") # Plot Accuracies plt.figure() plt.plot(epochs_range, train_accuracies, label='Train Accuracy') plt.plot(epochs_range, val_accuracies, label='Validation Accuracy') plt.plot(epochs_range, test_accuracies, label='Test Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.title('Accuracy per Epoch') plt.legend() acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png') plt.savefig(acc_plot_path) plt.close() print(f"Accuracy plot saved to {acc_plot_path}")