Ariyan-Pro's picture
Enterprise Adversarial ML Governance Engine v5.0 LTS
f4bee9e
"""
Visualization utilities for model analysis
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
import seaborn as sns
from pathlib import Path
from sklearn.metrics import confusion_matrix
def setup_plotting():
"""Setup plotting style"""
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
# Set figure defaults
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
def plot_training_history(metrics_file: str, save_path: str = None):
"""Plot training and validation metrics"""
import json
with open(metrics_file, 'r') as f:
metrics = json.load(f)
epochs = [m['epoch'] for m in metrics]
train_loss = [m['train']['loss'] for m in metrics]
val_loss = [m['validation']['loss'] for m in metrics]
train_acc = [m['train']['accuracy'] for m in metrics]
val_acc = [m['validation']['accuracy'] for m in metrics]
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
# Loss plot
axes[0].plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
axes[0].plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Accuracy plot
axes[1].plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2)
axes[1].plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
def plot_confusion_matrix(model, dataloader, device='cpu', save_path: str = None):
"""Plot confusion matrix"""
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
all_preds.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
# Compute confusion matrix
cm = confusion_matrix(all_targets, all_preds)
# Plot
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
# Labels
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_title('Confusion Matrix')
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig
def visualize_attacks(original, adversarial, predictions, save_path: str = None):
"""Visualize original vs adversarial examples"""
n_samples = min(10, len(original))
fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 2, 4))
for i in range(n_samples):
# Original image
ax = axes[0, i]
ax.imshow(original[i].squeeze(), cmap='gray')
ax.set_title(f"Orig: {predictions['original'][i]}")
ax.axis('off')
# Adversarial image
ax = axes[1, i]
ax.imshow(adversarial[i].squeeze(), cmap='gray')
ax.set_title(f"Adv: {predictions['adversarial'][i]}")
ax.axis('off')
plt.suptitle('Original vs Adversarial Examples')
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig