import torch import matplotlib import matplotlib.pyplot as plt import os matplotlib.style.use('ggplot') class SaveBestModel: """ Class to save the best model while training. If the current epoch's validation loss is less than the previous least less, then save the model state. """ def __init__( self, best_valid_loss=float('inf') ): self.best_valid_loss = best_valid_loss def __call__( self, current_valid_loss, epoch, model, out_dir, name ): if current_valid_loss < self.best_valid_loss: self.best_valid_loss = current_valid_loss print(f"\nBest validation loss: {self.best_valid_loss}") print(f"\nSaving best model for epoch: {epoch+1}\n") torch.save({ 'epoch': epoch+1, 'model_state_dict': model.state_dict(), }, os.path.join(out_dir, 'best_'+name+'.pth')) def save_model(epochs, model, optimizer, criterion, out_dir, name): """ Function to save the trained model to disk. """ torch.save({ 'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': criterion, }, os.path.join(out_dir, name+'.pth')) def save_plots(train_acc, valid_acc, train_loss, valid_loss, out_dir): """ Function to save the loss and accuracy plots to disk. """ # Accuracy plots. plt.figure(figsize=(10, 7)) plt.plot( train_acc, color='tab:blue', linestyle='-', label='train accuracy' ) plt.plot( valid_acc, color='tab:red', linestyle='-', label='validataion accuracy' ) plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() plt.savefig(os.path.join(out_dir, 'accuracy.png')) # Loss plots. plt.figure(figsize=(10, 7)) plt.plot( train_loss, color='tab:blue', linestyle='-', label='train loss' ) plt.plot( valid_loss, color='tab:red', linestyle='-', label='validataion loss' ) plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() plt.savefig(os.path.join(out_dir, 'loss.png'))