|
|
""" |
|
|
Visualization utilities for model analysis |
|
|
""" |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import torch |
|
|
import seaborn as sns |
|
|
from pathlib import Path |
|
|
from sklearn.metrics import confusion_matrix |
|
|
|
|
|
def setup_plotting(): |
|
|
"""Setup plotting style""" |
|
|
plt.style.use('seaborn-v0_8-darkgrid') |
|
|
sns.set_palette("husl") |
|
|
|
|
|
|
|
|
plt.rcParams['figure.figsize'] = (10, 6) |
|
|
plt.rcParams['font.size'] = 12 |
|
|
plt.rcParams['axes.titlesize'] = 14 |
|
|
plt.rcParams['axes.labelsize'] = 12 |
|
|
|
|
|
def plot_training_history(metrics_file: str, save_path: str = None): |
|
|
"""Plot training and validation metrics""" |
|
|
|
|
|
import json |
|
|
|
|
|
with open(metrics_file, 'r') as f: |
|
|
metrics = json.load(f) |
|
|
|
|
|
epochs = [m['epoch'] for m in metrics] |
|
|
train_loss = [m['train']['loss'] for m in metrics] |
|
|
val_loss = [m['validation']['loss'] for m in metrics] |
|
|
train_acc = [m['train']['accuracy'] for m in metrics] |
|
|
val_acc = [m['validation']['accuracy'] for m in metrics] |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 5)) |
|
|
|
|
|
|
|
|
axes[0].plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2) |
|
|
axes[0].plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2) |
|
|
axes[0].set_xlabel('Epoch') |
|
|
axes[0].set_ylabel('Loss') |
|
|
axes[0].set_title('Training and Validation Loss') |
|
|
axes[0].legend() |
|
|
axes[0].grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
axes[1].plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2) |
|
|
axes[1].plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2) |
|
|
axes[1].set_xlabel('Epoch') |
|
|
axes[1].set_ylabel('Accuracy (%)') |
|
|
axes[1].set_title('Training and Validation Accuracy') |
|
|
axes[1].legend() |
|
|
axes[1].grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
|
|
|
return fig |
|
|
|
|
|
def plot_confusion_matrix(model, dataloader, device='cpu', save_path: str = None): |
|
|
"""Plot confusion matrix""" |
|
|
|
|
|
model.eval() |
|
|
all_preds = [] |
|
|
all_targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for data, target in dataloader: |
|
|
data, target = data.to(device), target.to(device) |
|
|
output = model(data) |
|
|
pred = output.argmax(dim=1) |
|
|
|
|
|
all_preds.extend(pred.cpu().numpy()) |
|
|
all_targets.extend(target.cpu().numpy()) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_targets, all_preds) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 8)) |
|
|
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) |
|
|
ax.figure.colorbar(im, ax=ax) |
|
|
|
|
|
|
|
|
ax.set_xlabel('Predicted Label') |
|
|
ax.set_ylabel('True Label') |
|
|
ax.set_title('Confusion Matrix') |
|
|
|
|
|
|
|
|
thresh = cm.max() / 2. |
|
|
for i in range(cm.shape[0]): |
|
|
for j in range(cm.shape[1]): |
|
|
ax.text(j, i, format(cm[i, j], 'd'), |
|
|
ha="center", va="center", |
|
|
color="white" if cm[i, j] > thresh else "black") |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
|
|
|
return fig |
|
|
|
|
|
def visualize_attacks(original, adversarial, predictions, save_path: str = None): |
|
|
"""Visualize original vs adversarial examples""" |
|
|
|
|
|
n_samples = min(10, len(original)) |
|
|
|
|
|
fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 2, 4)) |
|
|
|
|
|
for i in range(n_samples): |
|
|
|
|
|
ax = axes[0, i] |
|
|
ax.imshow(original[i].squeeze(), cmap='gray') |
|
|
ax.set_title(f"Orig: {predictions['original'][i]}") |
|
|
ax.axis('off') |
|
|
|
|
|
|
|
|
ax = axes[1, i] |
|
|
ax.imshow(adversarial[i].squeeze(), cmap='gray') |
|
|
ax.set_title(f"Adv: {predictions['adversarial'][i]}") |
|
|
ax.axis('off') |
|
|
|
|
|
plt.suptitle('Original vs Adversarial Examples') |
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=150, bbox_inches='tight') |
|
|
|
|
|
return fig |
|
|
|