""" Reliability analysis and visualization for calibration. Generates: - Reliability diagrams - Confidence histograms - Confidence vs accuracy plots - Calibration summary statistics """ from pathlib import Path from typing import Dict, Tuple, Optional import json import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches def generate_reliability_diagram( probs: np.ndarray, labels: np.ndarray, output_path: Optional[Path] = None, title: str = "Calibration Reliability Diagram", n_bins: int = 10, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Generate reliability diagram (calibration curve). Shows relationship between predicted confidence and actual accuracy. Perfect calibration is a diagonal line. Args: probs: [N, C] probability matrix labels: [N] ground truth labels output_path: Where to save the plot title: Plot title n_bins: Number of confidence bins Returns: Tuple of (bin_confidences, bin_accuracies, bin_sizes) """ # Get predictions and confidence predictions = np.argmax(probs, axis=1) confidence = np.max(probs, axis=1) accuracy = (predictions == labels).astype(float) # Bin by confidence bin_boundaries = np.linspace(0, 1 + 1e-6, n_bins + 1) bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2 bin_accs = [] bin_confs = [] bin_sizes = [] for i in range(n_bins): mask = (confidence >= bin_boundaries[i]) & (confidence < bin_boundaries[i + 1]) if i == n_bins - 1: mask = (confidence >= bin_boundaries[i]) & (confidence <= 1.0 + 1e-6) if np.sum(mask) > 0: bin_acc = accuracy[mask].mean() bin_conf = confidence[mask].mean() bin_size = np.sum(mask) bin_accs.append(bin_acc) bin_confs.append(bin_conf) bin_sizes.append(bin_size) bin_accs = np.array(bin_accs) bin_confs = np.array(bin_confs) bin_sizes = np.array(bin_sizes) # Create plot fig, ax = plt.subplots(figsize=(8, 8)) # Perfect calibration line ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Perfect calibration') # Calibration curve ax.plot(bin_confs, bin_accs, 'b-o', linewidth=2, markersize=8, label='Model calibration') # Fill area between curve and diagonal ax.fill_between(bin_confs, bin_accs, bin_confs, alpha=0.3) ax.set_xlabel('Predicted Confidence', fontsize=12) ax.set_ylabel('Accuracy', fontsize=12) ax.set_title(title, fontsize=14, fontweight='bold') ax.set_xlim([0, 1]) ax.set_ylim([0, 1]) ax.grid(alpha=0.3) ax.legend() if output_path: fig.savefig(output_path, dpi=150, bbox_inches='tight') plt.close(fig) return bin_confs, bin_accs, bin_sizes def generate_confidence_histogram( probs: np.ndarray, labels: np.ndarray, output_path: Optional[Path] = None, title: str = "Confidence Distribution", ) -> Dict[str, np.ndarray]: """ Generate confidence histogram by correctness. Args: probs: [N, C] probability matrix labels: [N] ground truth labels output_path: Where to save the plot title: Plot title Returns: Dict with correct and incorrect confidences """ predictions = np.argmax(probs, axis=1) confidence = np.max(probs, axis=1) correct = (predictions == labels).astype(bool) fig, ax = plt.subplots(figsize=(10, 6)) ax.hist(confidence[correct], bins=30, alpha=0.6, label='Correct', color='green', edgecolor='black') ax.hist(confidence[~correct], bins=30, alpha=0.6, label='Incorrect', color='red', edgecolor='black') ax.set_xlabel('Predicted Confidence', fontsize=12) ax.set_ylabel('Frequency', fontsize=12) ax.set_title(title, fontsize=14, fontweight='bold') ax.legend() ax.grid(alpha=0.3) if output_path: fig.savefig(output_path, dpi=150, bbox_inches='tight') plt.close(fig) return { 'correct_confidences': confidence[correct], 'incorrect_confidences': confidence[~correct], } def generate_confidence_vs_accuracy( probs: np.ndarray, labels: np.ndarray, output_path: Optional[Path] = None, title: str = "Confidence vs Accuracy", ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Generate scatter plot of confidence vs accuracy per-sample. Args: probs: [N, C] probability matrix labels: [N] ground truth labels output_path: Where to save the plot title: Plot title Returns: Tuple of (confidence, accuracy, per-sample metrics) """ predictions = np.argmax(probs, axis=1) confidence = np.max(probs, axis=1) accuracy = (predictions == labels).astype(float) fig, ax = plt.subplots(figsize=(10, 8)) # Color by accuracy colors = np.where(accuracy == 1, 'green', 'red') sizes = np.where(accuracy == 1, 50, 100) ax.scatter(confidence, accuracy + np.random.normal(0, 0.02, len(accuracy)), c=colors, s=sizes, alpha=0.5, edgecolors='black', linewidth=0.5) # Trend line using bins bins = np.linspace(0, 1, 11) bin_centers = (bins[:-1] + bins[1:]) / 2 bin_accs = [] for i in range(len(bins) - 1): mask = (confidence >= bins[i]) & (confidence < bins[i + 1]) if np.sum(mask) > 0: bin_accs.append(accuracy[mask].mean()) else: bin_accs.append(0) ax.plot(bin_centers, bin_accs, 'b-', linewidth=2, label='Binned accuracy') ax.set_xlabel('Predicted Confidence', fontsize=12) ax.set_ylabel('Correctness (jittered)', fontsize=12) ax.set_title(title, fontsize=14, fontweight='bold') ax.legend() ax.grid(alpha=0.3) green_patch = mpatches.Patch(color='green', label='Correct') red_patch = mpatches.Patch(color='red', label='Incorrect') ax.legend(handles=[green_patch, red_patch, ax.lines[0]], fontsize=10) if output_path: fig.savefig(output_path, dpi=150, bbox_inches='tight') plt.close(fig) return confidence, accuracy, accuracy def generate_calibration_summary( probs_before: np.ndarray, probs_after: np.ndarray, labels: np.ndarray, output_path: Optional[Path] = None, ) -> Dict: """ Generate side-by-side calibration comparison. Args: probs_before: Pre-calibration probabilities [N, C] probs_after: Post-calibration probabilities [N, C] labels: Ground truth labels [N] output_path: Where to save the plot Returns: Dict with comparison metrics """ from .calibration_metrics import compute_ece, compute_brier_score, compute_log_loss # Compute metrics ece_before, _, _, _ = compute_ece(probs_before, labels) ece_after, _, _, _ = compute_ece(probs_after, labels) brier_before = compute_brier_score(probs_before, labels) brier_after = compute_brier_score(probs_after, labels) ll_before = compute_log_loss(probs_before, labels) ll_after = compute_log_loss(probs_after, labels) predictions_before = np.argmax(probs_before, axis=1) predictions_after = np.argmax(probs_after, axis=1) acc_before = np.mean(predictions_before == labels) acc_after = np.mean(predictions_after == labels) # Create comparison plot fig, axes = plt.subplots(1, 3, figsize=(15, 4)) # ECE comparison ax = axes[0] metrics = ['ECE'] values_before = [ece_before] values_after = [ece_after] x = np.arange(len(metrics)) width = 0.35 ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) ax.set_ylabel('ECE') ax.set_title('Expected Calibration Error') ax.legend() ax.set_xticks(x) ax.set_xticklabels(metrics) # Brier score comparison ax = axes[1] metrics = ['Brier'] values_before = [brier_before] values_after = [brier_after] x = np.arange(len(metrics)) ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) ax.set_ylabel('Brier Score') ax.set_title('Brier Score (lower is better)') ax.legend() ax.set_xticks(x) ax.set_xticklabels(metrics) # Log loss comparison ax = axes[2] metrics = ['Log Loss'] values_before = [ll_before] values_after = [ll_after] x = np.arange(len(metrics)) ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) ax.set_ylabel('Log Loss') ax.set_title('Log Loss (lower is better)') ax.legend() ax.set_xticks(x) ax.set_xticklabels(metrics) fig.suptitle('Calibration Improvement', fontsize=14, fontweight='bold', y=1.02) if output_path: fig.savefig(output_path, dpi=150, bbox_inches='tight') plt.close(fig) return { 'ece_before': ece_before, 'ece_after': ece_after, 'ece_improvement': ece_before - ece_after, 'brier_before': brier_before, 'brier_after': brier_after, 'brier_improvement': brier_before - brier_after, 'log_loss_before': ll_before, 'log_loss_after': ll_after, 'log_loss_improvement': ll_before - ll_after, 'accuracy_before': acc_before, 'accuracy_after': acc_after, }