CNN_Benchmark / utils /visualizations.py
Shreshth2002's picture
Upload folder using huggingface_hub
c65e61c verified
import matplotlib.pyplot as plt
import numpy as np
import os
def plot_training_curves(train_acc, val_acc, train_loss, val_loss, model_name, save_dir='plots'):
"""
Plot and save training/validation curves.
Args:
train_acc: List of training accuracies
val_acc: List of validation accuracies
train_loss: List of training losses
val_loss: List of validation losses
model_name: Name of the model for plot titles and filenames
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
epochs = range(1, len(train_acc) + 1)
# Set up the plotting style
plt.style.use('seaborn-v0_8')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Plot accuracy curves
ax1.plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2)
ax1.plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2)
ax1.set_title(f'{model_name} - Accuracy Curves', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 100])
# Add best accuracy annotations
best_train_acc = max(train_acc)
best_val_acc = max(val_acc)
best_train_epoch = train_acc.index(best_train_acc) + 1
best_val_epoch = val_acc.index(best_val_acc) + 1
ax1.annotate(f'Best: {best_val_acc:.2f}%',
xy=(best_val_epoch, best_val_acc),
xytext=(10, 10), textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
# Plot loss curves
ax2.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
ax2.plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2)
ax2.set_title(f'{model_name} - Loss Curves', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
# Add minimum loss annotation
min_val_loss = min(val_loss)
min_loss_epoch = val_loss.index(min_val_loss) + 1
ax2.annotate(f'Min: {min_val_loss:.4f}',
xy=(min_loss_epoch, min_val_loss),
xytext=(10, 10), textcoords='offset points',
bbox=dict(boxstyle='round,pad=0.3', facecolor='red', alpha=0.7),
arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
plt.tight_layout()
# Save plot
save_path = os.path.join(save_dir, f'{model_name.lower()}_training_curves.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"πŸ“Š Training curves saved to: {save_path}")
return fig
def plot_model_comparison(models_history, save_dir='plots'):
"""
Compare training curves of multiple models.
Args:
models_history: Dict with model names as keys and history dicts as values
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
plt.style.use('seaborn-v0_8')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
colors = ['blue', 'red', 'green', 'orange', 'purple']
for i, (model_name, history) in enumerate(models_history.items()):
epochs = range(1, len(history['val_acc']) + 1)
color = colors[i % len(colors)]
# Plot validation accuracy
ax1.plot(epochs, history['val_acc'], color=color,
label=f'{model_name}', linewidth=2, marker='o', markersize=3)
# Plot validation loss
ax2.plot(epochs, history['val_loss'], color=color,
label=f'{model_name}', linewidth=2, marker='o', markersize=3)
# Configure accuracy plot
ax1.set_title('Model Comparison - Validation Accuracy', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epochs', fontsize=12)
ax1.set_ylabel('Validation Accuracy (%)', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 100])
# Configure loss plot
ax2.set_title('Model Comparison - Validation Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epochs', fontsize=12)
ax2.set_ylabel('Validation Loss', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
# Save comparison plot
save_path = os.path.join(save_dir, 'model_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"πŸ“Š Model comparison saved to: {save_path}")
return fig
def plot_learning_rate_schedule(lr_history, model_name, save_dir='plots'):
"""
Plot learning rate schedule over epochs.
Args:
lr_history: List of learning rates per epoch
model_name: Name of the model
save_dir: Directory to save plots
"""
os.makedirs(save_dir, exist_ok=True)
epochs = range(1, len(lr_history) + 1)
plt.figure(figsize=(10, 6))
plt.plot(epochs, lr_history, 'g-', linewidth=2)
plt.title(f'{model_name} - Learning Rate Schedule', fontsize=14, fontweight='bold')
plt.xlabel('Epochs', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.grid(True, alpha=0.3)
plt.yscale('log')
# Save plot
save_path = os.path.join(save_dir, f'{model_name.lower()}_lr_schedule.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"πŸ“Š LR schedule saved to: {save_path}")
return plt.gcf()