File size: 4,061 Bytes
f4bee9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
Visualization utilities for model analysis
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
from pathlib import Path
from sklearn.metrics import confusion_matrix
def setup_plotting():
"""Setup plotting style"""
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
# Set figure defaults
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
def plot_training_history(metrics_file: str, save_path: str = None):
"""Plot training and validation metrics"""
import json
with open(metrics_file, 'r') as f:
metrics = json.load(f)
epochs = [m['epoch'] for m in metrics]
train_loss = [m['train']['loss'] for m in metrics]
val_loss = [m['validation']['loss'] for m in metrics]
train_acc = [m['train']['accuracy'] for m in metrics]
val_acc = [m['validation']['accuracy'] for m in metrics]
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Loss plot
axes[0].plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
axes[0].plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy plot
axes[1].plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2)
axes[1].plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
def plot_confusion_matrix(model, dataloader, device='cpu', save_path: str = None):
"""Plot confusion matrix"""
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
all_preds.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
# Compute confusion matrix
cm = confusion_matrix(all_targets, all_preds)
# Plot
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# Labels
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_title('Confusion Matrix')
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
def visualize_attacks(original, adversarial, predictions, save_path: str = None):
"""Visualize original vs adversarial examples"""
n_samples = min(10, len(original))
fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 2, 4))
for i in range(n_samples):
# Original image
ax = axes[0, i]
ax.imshow(original[i].squeeze(), cmap='gray')
ax.set_title(f"Orig: {predictions['original'][i]}")
ax.axis('off')
# Adversarial image
ax = axes[1, i]
ax.imshow(adversarial[i].squeeze(), cmap='gray')
ax.set_title(f"Adv: {predictions['adversarial'][i]}")
ax.axis('off')
plt.suptitle('Original vs Adversarial Examples')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
|