""" Visualization Utilities Tools for visualizing model predictions, uncertainty, and interpretability. """ import torch import numpy as np import matplotlib.pyplot as plt import seaborn as sns from typing import Optional, List from pathlib import Path def plot_predictions_vs_targets( predictions: np.ndarray, targets: np.ndarray, uncertainties: Optional[np.ndarray] = None, save_path: Optional[str] = None, title: str = "Predictions vs Targets", ): """ Plot predicted vs actual values with optional uncertainty. Args: predictions: Predicted values targets: Target values uncertainties: Optional uncertainties (std) save_path: Path to save figure title: Plot title """ fig, ax = plt.subplots(figsize=(8, 8)) # Scatter plot if uncertainties is not None: scatter = ax.scatter(targets, predictions, c=uncertainties, cmap='viridis', alpha=0.6, s=20) plt.colorbar(scatter, ax=ax, label='Uncertainty (std)') else: ax.scatter(targets, predictions, alpha=0.6, s=20) # Perfect prediction line min_val = min(targets.min(), predictions.min()) max_val = max(targets.max(), predictions.max()) ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect prediction') # Labels and title ax.set_xlabel('True Values', fontsize=12) ax.set_ylabel('Predicted Values', fontsize=12) ax.set_title(title, fontsize=14) ax.legend() ax.grid(alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') else: plt.show() plt.close() def plot_uncertainty_calibration( predictions: np.ndarray, targets: np.ndarray, uncertainties: np.ndarray, num_bins: int = 10, save_path: Optional[str] = None, ): """ Plot uncertainty calibration curve. Args: predictions: Predicted values targets: Target values uncertainties: Predicted uncertainties num_bins: Number of bins for calibration save_path: Path to save figure """ errors = np.abs(predictions - targets) # Bin by uncertainty bin_edges = np.percentile(uncertainties, np.linspace(0, 100, num_bins + 1)) bin_centers = [] observed_errors = [] for i in range(num_bins): if i == num_bins - 1: mask = (uncertainties >= bin_edges[i]) & (uncertainties <= bin_edges[i + 1]) else: mask = (uncertainties >= bin_edges[i]) & (uncertainties < bin_edges[i + 1]) if mask.sum() > 0: bin_centers.append(uncertainties[mask].mean()) observed_errors.append(errors[mask].mean()) # Plot fig, ax = plt.subplots(figsize=(8, 6)) ax.scatter(bin_centers, observed_errors, s=100, alpha=0.7) ax.plot(bin_centers, bin_centers, 'r--', lw=2, label='Perfect calibration') ax.set_xlabel('Predicted Uncertainty', fontsize=12) ax.set_ylabel('Observed Error', fontsize=12) ax.set_title('Uncertainty Calibration', fontsize=14) ax.legend() ax.grid(alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') else: plt.show() plt.close() def plot_training_curves( train_losses: List[float], val_losses: List[float], save_path: Optional[str] = None, ): """ Plot training and validation loss curves. Args: train_losses: Training losses per epoch val_losses: Validation losses per epoch save_path: Path to save figure """ fig, ax = plt.subplots(figsize=(10, 6)) epochs = range(1, len(train_losses) + 1) ax.plot(epochs, train_losses, label='Train Loss', linewidth=2) ax.plot(epochs, val_losses, label='Val Loss', linewidth=2) ax.set_xlabel('Epoch', fontsize=12) ax.set_ylabel('Loss', fontsize=12) ax.set_title('Training Curves', fontsize=14) ax.legend() ax.grid(alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') else: plt.show() plt.close() def plot_error_distribution( predictions: np.ndarray, targets: np.ndarray, save_path: Optional[str] = None, ): """ Plot distribution of prediction errors. Args: predictions: Predicted values targets: Target values save_path: Path to save figure """ errors = predictions - targets fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) # Histogram ax1.hist(errors, bins=50, alpha=0.7, edgecolor='black') ax1.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero error') ax1.set_xlabel('Prediction Error', fontsize=12) ax1.set_ylabel('Frequency', fontsize=12) ax1.set_title('Error Distribution', fontsize=14) ax1.legend() ax1.grid(alpha=0.3) # Q-Q plot from scipy import stats stats.probplot(errors, dist="norm", plot=ax2) ax2.set_title('Q-Q Plot', fontsize=14) ax2.grid(alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') else: plt.show() plt.close() def create_results_summary( results: dict, save_dir: str = "results/figures", ): """ Create comprehensive visualization summary. Args: results: Dictionary with predictions, targets, uncertainties save_dir: Directory to save figures """ save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) predictions = results['predictions'] targets = results['targets'] uncertainties = results.get('uncertainties') # 1. Predictions vs Targets plot_predictions_vs_targets( predictions, targets, uncertainties, save_path=save_dir / "predictions_vs_targets.png" ) # 2. Uncertainty Calibration if uncertainties is not None: plot_uncertainty_calibration( predictions, targets, uncertainties, save_path=save_dir / "uncertainty_calibration.png" ) # 3. Error Distribution plot_error_distribution( predictions, targets, save_path=save_dir / "error_distribution.png" ) print(f"Visualizations saved to {save_dir}")