Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| def plot_training_curves(train_acc, val_acc, train_loss, val_loss, model_name, save_dir='plots'): | |
| """ | |
| Plot and save training/validation curves. | |
| Args: | |
| train_acc: List of training accuracies | |
| val_acc: List of validation accuracies | |
| train_loss: List of training losses | |
| val_loss: List of validation losses | |
| model_name: Name of the model for plot titles and filenames | |
| save_dir: Directory to save plots | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| epochs = range(1, len(train_acc) + 1) | |
| # Set up the plotting style | |
| plt.style.use('seaborn-v0_8') | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) | |
| # Plot accuracy curves | |
| ax1.plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2) | |
| ax1.plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2) | |
| ax1.set_title(f'{model_name} - Accuracy Curves', fontsize=14, fontweight='bold') | |
| ax1.set_xlabel('Epochs', fontsize=12) | |
| ax1.set_ylabel('Accuracy (%)', fontsize=12) | |
| ax1.legend(fontsize=11) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_ylim([0, 100]) | |
| # Add best accuracy annotations | |
| best_train_acc = max(train_acc) | |
| best_val_acc = max(val_acc) | |
| best_train_epoch = train_acc.index(best_train_acc) + 1 | |
| best_val_epoch = val_acc.index(best_val_acc) + 1 | |
| ax1.annotate(f'Best: {best_val_acc:.2f}%', | |
| xy=(best_val_epoch, best_val_acc), | |
| xytext=(10, 10), textcoords='offset points', | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7), | |
| arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) | |
| # Plot loss curves | |
| ax2.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2) | |
| ax2.plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2) | |
| ax2.set_title(f'{model_name} - Loss Curves', fontsize=14, fontweight='bold') | |
| ax2.set_xlabel('Epochs', fontsize=12) | |
| ax2.set_ylabel('Loss', fontsize=12) | |
| ax2.legend(fontsize=11) | |
| ax2.grid(True, alpha=0.3) | |
| # Add minimum loss annotation | |
| min_val_loss = min(val_loss) | |
| min_loss_epoch = val_loss.index(min_val_loss) + 1 | |
| ax2.annotate(f'Min: {min_val_loss:.4f}', | |
| xy=(min_loss_epoch, min_val_loss), | |
| xytext=(10, 10), textcoords='offset points', | |
| bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7), | |
| arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) | |
| plt.tight_layout() | |
| # Save plot | |
| save_path = os.path.join(save_dir, f'{model_name.lower()}_training_curves.png') | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| print(f"π Training curves saved to: {save_path}") | |
| return fig | |
| def plot_model_comparison(models_history, save_dir='plots'): | |
| """ | |
| Compare training curves of multiple models. | |
| Args: | |
| models_history: Dict with model names as keys and history dicts as values | |
| save_dir: Directory to save plots | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| plt.style.use('seaborn-v0_8') | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) | |
| colors = ['blue', 'red', 'green', 'orange', 'purple'] | |
| for i, (model_name, history) in enumerate(models_history.items()): | |
| epochs = range(1, len(history['val_acc']) + 1) | |
| color = colors[i % len(colors)] | |
| # Plot validation accuracy | |
| ax1.plot(epochs, history['val_acc'], color=color, | |
| label=f'{model_name}', linewidth=2, marker='o', markersize=3) | |
| # Plot validation loss | |
| ax2.plot(epochs, history['val_loss'], color=color, | |
| label=f'{model_name}', linewidth=2, marker='o', markersize=3) | |
| # Configure accuracy plot | |
| ax1.set_title('Model Comparison - Validation Accuracy', fontsize=14, fontweight='bold') | |
| ax1.set_xlabel('Epochs', fontsize=12) | |
| ax1.set_ylabel('Validation Accuracy (%)', fontsize=12) | |
| ax1.legend(fontsize=11) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_ylim([0, 100]) | |
| # Configure loss plot | |
| ax2.set_title('Model Comparison - Validation Loss', fontsize=14, fontweight='bold') | |
| ax2.set_xlabel('Epochs', fontsize=12) | |
| ax2.set_ylabel('Validation Loss', fontsize=12) | |
| ax2.legend(fontsize=11) | |
| ax2.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| # Save comparison plot | |
| save_path = os.path.join(save_dir, 'model_comparison.png') | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| print(f"π Model comparison saved to: {save_path}") | |
| return fig | |
| def plot_learning_rate_schedule(lr_history, model_name, save_dir='plots'): | |
| """ | |
| Plot learning rate schedule over epochs. | |
| Args: | |
| lr_history: List of learning rates per epoch | |
| model_name: Name of the model | |
| save_dir: Directory to save plots | |
| """ | |
| os.makedirs(save_dir, exist_ok=True) | |
| epochs = range(1, len(lr_history) + 1) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(epochs, lr_history, 'g-', linewidth=2) | |
| plt.title(f'{model_name} - Learning Rate Schedule', fontsize=14, fontweight='bold') | |
| plt.xlabel('Epochs', fontsize=12) | |
| plt.ylabel('Learning Rate', fontsize=12) | |
| plt.grid(True, alpha=0.3) | |
| plt.yscale('log') | |
| # Save plot | |
| save_path = os.path.join(save_dir, f'{model_name.lower()}_lr_schedule.png') | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| print(f"π LR schedule saved to: {save_path}") | |
| return plt.gcf() |