Spaces:
Sleeping
Sleeping
| """ | |
| Visualization Utilities | |
| Tools for visualizing model predictions, uncertainty, and interpretability. | |
| """ | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from typing import Optional, List | |
| from pathlib import Path | |
| def plot_predictions_vs_targets( | |
| predictions: np.ndarray, | |
| targets: np.ndarray, | |
| uncertainties: Optional[np.ndarray] = None, | |
| save_path: Optional[str] = None, | |
| title: str = "Predictions vs Targets", | |
| ): | |
| """ | |
| Plot predicted vs actual values with optional uncertainty. | |
| Args: | |
| predictions: Predicted values | |
| targets: Target values | |
| uncertainties: Optional uncertainties (std) | |
| save_path: Path to save figure | |
| title: Plot title | |
| """ | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| # Scatter plot | |
| if uncertainties is not None: | |
| scatter = ax.scatter(targets, predictions, c=uncertainties, | |
| cmap='viridis', alpha=0.6, s=20) | |
| plt.colorbar(scatter, ax=ax, label='Uncertainty (std)') | |
| else: | |
| ax.scatter(targets, predictions, alpha=0.6, s=20) | |
| # Perfect prediction line | |
| min_val = min(targets.min(), predictions.min()) | |
| max_val = max(targets.max(), predictions.max()) | |
| ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect prediction') | |
| # Labels and title | |
| ax.set_xlabel('True Values', fontsize=12) | |
| ax.set_ylabel('Predicted Values', fontsize=12) | |
| ax.set_title(title, fontsize=14) | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_uncertainty_calibration( | |
| predictions: np.ndarray, | |
| targets: np.ndarray, | |
| uncertainties: np.ndarray, | |
| num_bins: int = 10, | |
| save_path: Optional[str] = None, | |
| ): | |
| """ | |
| Plot uncertainty calibration curve. | |
| Args: | |
| predictions: Predicted values | |
| targets: Target values | |
| uncertainties: Predicted uncertainties | |
| num_bins: Number of bins for calibration | |
| save_path: Path to save figure | |
| """ | |
| errors = np.abs(predictions - targets) | |
| # Bin by uncertainty | |
| bin_edges = np.percentile(uncertainties, np.linspace(0, 100, num_bins + 1)) | |
| bin_centers = [] | |
| observed_errors = [] | |
| for i in range(num_bins): | |
| if i == num_bins - 1: | |
| mask = (uncertainties >= bin_edges[i]) & (uncertainties <= bin_edges[i + 1]) | |
| else: | |
| mask = (uncertainties >= bin_edges[i]) & (uncertainties < bin_edges[i + 1]) | |
| if mask.sum() > 0: | |
| bin_centers.append(uncertainties[mask].mean()) | |
| observed_errors.append(errors[mask].mean()) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.scatter(bin_centers, observed_errors, s=100, alpha=0.7) | |
| ax.plot(bin_centers, bin_centers, 'r--', lw=2, label='Perfect calibration') | |
| ax.set_xlabel('Predicted Uncertainty', fontsize=12) | |
| ax.set_ylabel('Observed Error', fontsize=12) | |
| ax.set_title('Uncertainty Calibration', fontsize=14) | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_training_curves( | |
| train_losses: List[float], | |
| val_losses: List[float], | |
| save_path: Optional[str] = None, | |
| ): | |
| """ | |
| Plot training and validation loss curves. | |
| Args: | |
| train_losses: Training losses per epoch | |
| val_losses: Validation losses per epoch | |
| save_path: Path to save figure | |
| """ | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| epochs = range(1, len(train_losses) + 1) | |
| ax.plot(epochs, train_losses, label='Train Loss', linewidth=2) | |
| ax.plot(epochs, val_losses, label='Val Loss', linewidth=2) | |
| ax.set_xlabel('Epoch', fontsize=12) | |
| ax.set_ylabel('Loss', fontsize=12) | |
| ax.set_title('Training Curves', fontsize=14) | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def plot_error_distribution( | |
| predictions: np.ndarray, | |
| targets: np.ndarray, | |
| save_path: Optional[str] = None, | |
| ): | |
| """ | |
| Plot distribution of prediction errors. | |
| Args: | |
| predictions: Predicted values | |
| targets: Target values | |
| save_path: Path to save figure | |
| """ | |
| errors = predictions - targets | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Histogram | |
| ax1.hist(errors, bins=50, alpha=0.7, edgecolor='black') | |
| ax1.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero error') | |
| ax1.set_xlabel('Prediction Error', fontsize=12) | |
| ax1.set_ylabel('Frequency', fontsize=12) | |
| ax1.set_title('Error Distribution', fontsize=14) | |
| ax1.legend() | |
| ax1.grid(alpha=0.3) | |
| # Q-Q plot | |
| from scipy import stats | |
| stats.probplot(errors, dist="norm", plot=ax2) | |
| ax2.set_title('Q-Q Plot', fontsize=14) | |
| ax2.grid(alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| else: | |
| plt.show() | |
| plt.close() | |
| def create_results_summary( | |
| results: dict, | |
| save_dir: str = "results/figures", | |
| ): | |
| """ | |
| Create comprehensive visualization summary. | |
| Args: | |
| results: Dictionary with predictions, targets, uncertainties | |
| save_dir: Directory to save figures | |
| """ | |
| save_dir = Path(save_dir) | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| predictions = results['predictions'] | |
| targets = results['targets'] | |
| uncertainties = results.get('uncertainties') | |
| # 1. Predictions vs Targets | |
| plot_predictions_vs_targets( | |
| predictions, targets, uncertainties, | |
| save_path=save_dir / "predictions_vs_targets.png" | |
| ) | |
| # 2. Uncertainty Calibration | |
| if uncertainties is not None: | |
| plot_uncertainty_calibration( | |
| predictions, targets, uncertainties, | |
| save_path=save_dir / "uncertainty_calibration.png" | |
| ) | |
| # 3. Error Distribution | |
| plot_error_distribution( | |
| predictions, targets, | |
| save_path=save_dir / "error_distribution.png" | |
| ) | |
| print(f"Visualizations saved to {save_dir}") | |