Spaces:
Sleeping
Sleeping
File size: 5,811 Bytes
c65e61c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import matplotlib.pyplot as plt
import numpy as np
import os
def plot_training_curves(train_acc, val_acc, train_loss, val_loss, model_name, save_dir='plots'):
"""
Plot and save training/validation curves.
Args:
train_acc: List of training accuracies
val_acc: List of validation accuracies
train_loss: List of training losses
val_loss: List of validation losses
model_name: Name of the model for plot titles and filenames
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
epochs = range(1, len(train_acc) + 1)
# Set up the plotting style
plt.style.use('seaborn-v0_8')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Plot accuracy curves
ax1.plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2)
ax1.plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2)
ax1.set_title(f'{model_name} - Accuracy Curves', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 100])
# Add best accuracy annotations
best_train_acc = max(train_acc)
best_val_acc = max(val_acc)
best_train_epoch = train_acc.index(best_train_acc) + 1
best_val_epoch = val_acc.index(best_val_acc) + 1
ax1.annotate(f'Best: {best_val_acc:.2f}%',
xy=(best_val_epoch, best_val_acc),
xytext=(10, 10), textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
# Plot loss curves
ax2.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
ax2.plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2)
ax2.set_title(f'{model_name} - Loss Curves', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
# Add minimum loss annotation
min_val_loss = min(val_loss)
min_loss_epoch = val_loss.index(min_val_loss) + 1
ax2.annotate(f'Min: {min_val_loss:.4f}',
xy=(min_loss_epoch, min_val_loss),
xytext=(10, 10), textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
plt.tight_layout()
# Save plot
save_path = os.path.join(save_dir, f'{model_name.lower()}_training_curves.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"๐ Training curves saved to: {save_path}")
return fig
def plot_model_comparison(models_history, save_dir='plots'):
"""
Compare training curves of multiple models.
Args:
models_history: Dict with model names as keys and history dicts as values
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
plt.style.use('seaborn-v0_8')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
colors = ['blue', 'red', 'green', 'orange', 'purple']
for i, (model_name, history) in enumerate(models_history.items()):
epochs = range(1, len(history['val_acc']) + 1)
color = colors[i % len(colors)]
# Plot validation accuracy
ax1.plot(epochs, history['val_acc'], color=color,
label=f'{model_name}', linewidth=2, marker='o', markersize=3)
# Plot validation loss
ax2.plot(epochs, history['val_loss'], color=color,
label=f'{model_name}', linewidth=2, marker='o', markersize=3)
# Configure accuracy plot
ax1.set_title('Model Comparison - Validation Accuracy', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('Validation Accuracy (%)', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 100])
# Configure loss plot
ax2.set_title('Model Comparison - Validation Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs', fontsize=12)
ax2.set_ylabel('Validation Loss', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
# Save comparison plot
save_path = os.path.join(save_dir, 'model_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"๐ Model comparison saved to: {save_path}")
return fig
def plot_learning_rate_schedule(lr_history, model_name, save_dir='plots'):
"""
Plot learning rate schedule over epochs.
Args:
lr_history: List of learning rates per epoch
model_name: Name of the model
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
epochs = range(1, len(lr_history) + 1)
plt.figure(figsize=(10, 6))
plt.plot(epochs, lr_history, 'g-', linewidth=2)
plt.title(f'{model_name} - Learning Rate Schedule', fontsize=14, fontweight='bold')
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.grid(True, alpha=0.3)
plt.yscale('log')
# Save plot
save_path = os.path.join(save_dir, f'{model_name.lower()}_lr_schedule.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"๐ LR schedule saved to: {save_path}")
return plt.gcf() |