| import torch | |
| import matplotlib.pyplot as plt | |
| from torchvision.utils import make_grid | |
| def save_checkpoint(model, optimizer, epoch, loss, path): | |
| torch.save({ | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'loss': loss, | |
| }, path) | |
| print(f"Checkpoint saved at epoch {epoch}") | |
| def load_checkpoint(model, optimizer, path): | |
| checkpoint = torch.load(path, weights_only=True) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| epoch = checkpoint['epoch'] | |
| loss = checkpoint['loss'] | |
| print(f"Checkpoint loaded, resuming from epoch {epoch}") | |
| return model, optimizer, epoch, loss | |
| def plot_training_curves(epochs, train_acc1, test_acc1, train_acc5, test_acc5, train_losses, test_losses, learning_rates): | |
| plt.figure(figsize=(12, 8)) | |
| plt.subplot(2, 2, 1) | |
| plt.plot(epochs, train_acc1, label='Train Top-1 Acc') | |
| plt.plot(epochs, test_acc1, label='Test Top-1 Acc') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| plt.title('Top-1 Accuracy') | |
| plt.subplot(2, 2, 2) | |
| plt.plot(epochs, train_acc5, label='Train Top-5 Acc') | |
| plt.plot(epochs, test_acc5, label='Test Top-5 Acc') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Accuracy') | |
| plt.legend() | |
| plt.title('Top-5 Accuracy') | |
| plt.subplot(2, 2, 3) | |
| plt.plot(epochs, train_losses, label='Train Loss') | |
| plt.plot(epochs, test_losses, label='Test Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.legend() | |
| plt.title('Loss') | |
| plt.subplot(2, 2, 4) | |
| plt.plot(epochs, learning_rates, label='Learning Rate') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Learning Rate') | |
| plt.legend() | |
| plt.title('Learning Rate') | |
| plt.tight_layout() | |
| plt.show() | |
| def plot_misclassified_samples(misclassified_images, misclassified_labels, misclassified_preds, classes): | |
| if misclassified_images: | |
| print("\nDisplaying some misclassified samples:") | |
| misclassified_grid = make_grid(misclassified_images[:16], nrow=4, normalize=True, scale_each=True) | |
| plt.figure(figsize=(8, 8)) | |
| plt.imshow(misclassified_grid.permute(1, 2, 0)) | |
| plt.title("Misclassified Samples") | |
| plt.axis('off') | |
| plt.show() | |