Spaces:
Configuration error
Configuration error
| def train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs, output_folder, device="cpu"): | |
| """ | |
| Train a neural network model with specified training, validation, and testing datasets. | |
| Additionally, plots accuracy and loss per epoch using matplotlib and saves them as images. | |
| This function performs a complete training loop, including: | |
| - Creating DataLoaders for training, validation, and testing datasets | |
| - Moving the model to the specified device (CPU/GPU) | |
| - Training the model for a specified number of epochs | |
| - Tracking and logging training, validation, and testing metrics | |
| - Saving the best (based on validation performance) and last model weights | |
| - Plotting and saving accuracy and loss graphs per epoch | |
| Parameters: | |
| ----------- | |
| model : torch.nn.Module | |
| The neural network model to be trained | |
| train_loader : torch.utils.data.DataLoader | |
| Dataset used for training the model | |
| val_loader : torch.utils.data.DataLoader | |
| Dataset used for validating the model during training | |
| test_loader : torch.utils.data.DataLoader | |
| Dataset used for evaluating the model's performance after training | |
| optimizer : torch.optim.Optimizer | |
| Optimization algorithm for updating model weights | |
| criterion : torch.nn.Module | |
| Loss function used to compute the model's performance | |
| epochs : int | |
| Number of complete passes through the entire training dataset | |
| output_folder : str | |
| Folder path where the model weights and plots will be saved | |
| device : str, optional | |
| Computing device to use for training (default is "cpu") | |
| Can be "cpu" or "cuda" for GPU training | |
| Returns: | |
| -------- | |
| None | |
| Side Effects: | |
| ------------- | |
| - Prints training, validation, and testing metrics for each epoch | |
| - Saves the best performing model (based on validation accuracy) to "weights/best_model.pth" | |
| - Saves the final model to "weights/last_model.pth" | |
| - Saves the loss plot as "loss_plot.png" and accuracy plot as "accuracy_plot.png" in the output folder | |
| Example: | |
| -------- | |
| >>> model = MyModel() | |
| >>> optimizer = torch.optim.Adam(model.parameters()) | |
| >>> criterion = nn.CrossEntropyLoss() | |
| >>> train_model(model, train_loader, val_loader, test_loader, optimizer, criterion, epochs=10, batch_size=32, output_folder="weights") | |
| """ | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| # Ensure weights folder exists | |
| os.makedirs(output_folder, exist_ok=True) | |
| print(f"Device Found: {device}, Starting Training 🚀") | |
| # Move model to the specified device | |
| model = model.to(device) | |
| best_val_accuracy = 0.0 # Initialize best validation accuracy tracker | |
| # Lists to store metrics per epoch for plotting | |
| train_losses, val_losses, test_losses = [], [], [] | |
| train_accuracies, val_accuracies, test_accuracies = [], [], [] | |
| for epoch in range(epochs): | |
| # ---------------------- | |
| # Training Phase | |
| # ---------------------- | |
| model.train() # Set model to training mode | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| train_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Training)", leave=False) | |
| for images, labels in train_progress: | |
| # Move tensors to the specified device | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() # Reset gradients | |
| outputs = model(images) # Forward pass | |
| loss = criterion(outputs, labels) # Compute loss | |
| loss.backward() # Backpropagation | |
| optimizer.step() # Update weights | |
| running_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| train_progress.set_postfix({ | |
| 'Loss': f'{loss.item():.4f}', | |
| 'Accuracy': f'{100 * correct / total:.2f}%' | |
| }) | |
| train_loss = running_loss / len(train_loader) | |
| train_accuracy = 100 * correct / total | |
| # ---------------------- | |
| # Validation Phase | |
| # ---------------------- | |
| model.eval() # Set model to evaluation mode | |
| val_loss = 0.0 | |
| correct_val = 0 | |
| total_val = 0 | |
| val_progress = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Validation)", leave=False) | |
| with torch.no_grad(): | |
| for images, labels in val_progress: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total_val += labels.size(0) | |
| correct_val += (predicted == labels).sum().item() | |
| val_progress.set_postfix({ | |
| 'Loss': f'{loss.item():.4f}', | |
| 'Accuracy': f'{100 * correct_val / total_val:.2f}%' | |
| }) | |
| val_loss /= len(val_loader) | |
| val_accuracy = 100 * correct_val / total_val | |
| # ---------------------- | |
| # Testing Phase | |
| # ---------------------- | |
| test_loss = 0.0 | |
| correct_test = 0 | |
| total_test = 0 | |
| test_progress = tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} (Testing)", leave=False) | |
| with torch.no_grad(): | |
| for images, labels in test_progress: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| test_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total_test += labels.size(0) | |
| correct_test += (predicted == labels).sum().item() | |
| test_progress.set_postfix({ | |
| 'Loss': f'{loss.item():.4f}', | |
| 'Accuracy': f'{100 * correct_test / total_test:.2f}%' | |
| }) | |
| test_loss /= len(test_loader) | |
| test_accuracy = 100 * correct_test / total_test | |
| # Store metrics for plotting | |
| train_losses.append(train_loss) | |
| val_losses.append(val_loss) | |
| test_losses.append(test_loss) | |
| train_accuracies.append(train_accuracy) | |
| val_accuracies.append(val_accuracy) | |
| test_accuracies.append(test_accuracy) | |
| # Log the metrics for this epoch | |
| print( | |
| f"Epoch [{epoch+1}/{epochs}]: " | |
| f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% | " | |
| f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% | " | |
| f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%" | |
| ) | |
| # Save the best model based on validation accuracy | |
| if val_accuracy > best_val_accuracy: | |
| best_val_accuracy = val_accuracy | |
| torch.save(model.state_dict(), os.path.join(output_folder, "best_model.pth")) | |
| # Save the last model | |
| torch.save(model.state_dict(), os.path.join(output_folder, "last_model.pth")) | |
| print("Training completed. Best validation accuracy: {:.2f}%".format(best_val_accuracy)) | |
| # ---------------------- | |
| # Plotting Metrics with Matplotlib | |
| # ---------------------- | |
| epochs_range = range(1, epochs + 1) | |
| # Plot Losses | |
| plt.figure() | |
| plt.plot(epochs_range, train_losses, label='Train Loss') | |
| plt.plot(epochs_range, val_losses, label='Validation Loss') | |
| plt.plot(epochs_range, test_losses, label='Test Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Loss per Epoch') | |
| plt.legend() | |
| loss_plot_path = os.path.join(output_folder, 'loss_plot.png') | |
| plt.savefig(loss_plot_path) | |
| plt.close() | |
| print(f"Loss plot saved to {loss_plot_path}") | |
| # Plot Accuracies | |
| plt.figure() | |
| plt.plot(epochs_range, train_accuracies, label='Train Accuracy') | |
| plt.plot(epochs_range, val_accuracies, label='Validation Accuracy') | |
| plt.plot(epochs_range, test_accuracies, label='Test Accuracy') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy (%)') | |
| plt.title('Accuracy per Epoch') | |
| plt.legend() | |
| acc_plot_path = os.path.join(output_folder, 'accuracy_plot.png') | |
| plt.savefig(acc_plot_path) | |
| plt.close() | |
| print(f"Accuracy plot saved to {acc_plot_path}") | |