Spaces:
Running
Running
| """ | |
| Reliability analysis and visualization for calibration. | |
| Generates: | |
| - Reliability diagrams | |
| - Confidence histograms | |
| - Confidence vs accuracy plots | |
| - Calibration summary statistics | |
| """ | |
| from pathlib import Path | |
| from typing import Dict, Tuple, Optional | |
| import json | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| def generate_reliability_diagram( | |
| probs: np.ndarray, | |
| labels: np.ndarray, | |
| output_path: Optional[Path] = None, | |
| title: str = "Calibration Reliability Diagram", | |
| n_bins: int = 10, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Generate reliability diagram (calibration curve). | |
| Shows relationship between predicted confidence and actual accuracy. | |
| Perfect calibration is a diagonal line. | |
| Args: | |
| probs: [N, C] probability matrix | |
| labels: [N] ground truth labels | |
| output_path: Where to save the plot | |
| title: Plot title | |
| n_bins: Number of confidence bins | |
| Returns: | |
| Tuple of (bin_confidences, bin_accuracies, bin_sizes) | |
| """ | |
| # Get predictions and confidence | |
| predictions = np.argmax(probs, axis=1) | |
| confidence = np.max(probs, axis=1) | |
| accuracy = (predictions == labels).astype(float) | |
| # Bin by confidence | |
| bin_boundaries = np.linspace(0, 1 + 1e-6, n_bins + 1) | |
| bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2 | |
| bin_accs = [] | |
| bin_confs = [] | |
| bin_sizes = [] | |
| for i in range(n_bins): | |
| mask = (confidence >= bin_boundaries[i]) & (confidence < bin_boundaries[i + 1]) | |
| if i == n_bins - 1: | |
| mask = (confidence >= bin_boundaries[i]) & (confidence <= 1.0 + 1e-6) | |
| if np.sum(mask) > 0: | |
| bin_acc = accuracy[mask].mean() | |
| bin_conf = confidence[mask].mean() | |
| bin_size = np.sum(mask) | |
| bin_accs.append(bin_acc) | |
| bin_confs.append(bin_conf) | |
| bin_sizes.append(bin_size) | |
| bin_accs = np.array(bin_accs) | |
| bin_confs = np.array(bin_confs) | |
| bin_sizes = np.array(bin_sizes) | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| # Perfect calibration line | |
| ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Perfect calibration') | |
| # Calibration curve | |
| ax.plot(bin_confs, bin_accs, 'b-o', linewidth=2, markersize=8, label='Model calibration') | |
| # Fill area between curve and diagonal | |
| ax.fill_between(bin_confs, bin_accs, bin_confs, alpha=0.3) | |
| ax.set_xlabel('Predicted Confidence', fontsize=12) | |
| ax.set_ylabel('Accuracy', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| ax.set_xlim([0, 1]) | |
| ax.set_ylim([0, 1]) | |
| ax.grid(alpha=0.3) | |
| ax.legend() | |
| if output_path: | |
| fig.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close(fig) | |
| return bin_confs, bin_accs, bin_sizes | |
| def generate_confidence_histogram( | |
| probs: np.ndarray, | |
| labels: np.ndarray, | |
| output_path: Optional[Path] = None, | |
| title: str = "Confidence Distribution", | |
| ) -> Dict[str, np.ndarray]: | |
| """ | |
| Generate confidence histogram by correctness. | |
| Args: | |
| probs: [N, C] probability matrix | |
| labels: [N] ground truth labels | |
| output_path: Where to save the plot | |
| title: Plot title | |
| Returns: | |
| Dict with correct and incorrect confidences | |
| """ | |
| predictions = np.argmax(probs, axis=1) | |
| confidence = np.max(probs, axis=1) | |
| correct = (predictions == labels).astype(bool) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.hist(confidence[correct], bins=30, alpha=0.6, label='Correct', color='green', edgecolor='black') | |
| ax.hist(confidence[~correct], bins=30, alpha=0.6, label='Incorrect', color='red', edgecolor='black') | |
| ax.set_xlabel('Predicted Confidence', fontsize=12) | |
| ax.set_ylabel('Frequency', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| if output_path: | |
| fig.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close(fig) | |
| return { | |
| 'correct_confidences': confidence[correct], | |
| 'incorrect_confidences': confidence[~correct], | |
| } | |
| def generate_confidence_vs_accuracy( | |
| probs: np.ndarray, | |
| labels: np.ndarray, | |
| output_path: Optional[Path] = None, | |
| title: str = "Confidence vs Accuracy", | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Generate scatter plot of confidence vs accuracy per-sample. | |
| Args: | |
| probs: [N, C] probability matrix | |
| labels: [N] ground truth labels | |
| output_path: Where to save the plot | |
| title: Plot title | |
| Returns: | |
| Tuple of (confidence, accuracy, per-sample metrics) | |
| """ | |
| predictions = np.argmax(probs, axis=1) | |
| confidence = np.max(probs, axis=1) | |
| accuracy = (predictions == labels).astype(float) | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| # Color by accuracy | |
| colors = np.where(accuracy == 1, 'green', 'red') | |
| sizes = np.where(accuracy == 1, 50, 100) | |
| ax.scatter(confidence, accuracy + np.random.normal(0, 0.02, len(accuracy)), | |
| c=colors, s=sizes, alpha=0.5, edgecolors='black', linewidth=0.5) | |
| # Trend line using bins | |
| bins = np.linspace(0, 1, 11) | |
| bin_centers = (bins[:-1] + bins[1:]) / 2 | |
| bin_accs = [] | |
| for i in range(len(bins) - 1): | |
| mask = (confidence >= bins[i]) & (confidence < bins[i + 1]) | |
| if np.sum(mask) > 0: | |
| bin_accs.append(accuracy[mask].mean()) | |
| else: | |
| bin_accs.append(0) | |
| ax.plot(bin_centers, bin_accs, 'b-', linewidth=2, label='Binned accuracy') | |
| ax.set_xlabel('Predicted Confidence', fontsize=12) | |
| ax.set_ylabel('Correctness (jittered)', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| green_patch = mpatches.Patch(color='green', label='Correct') | |
| red_patch = mpatches.Patch(color='red', label='Incorrect') | |
| ax.legend(handles=[green_patch, red_patch, ax.lines[0]], fontsize=10) | |
| if output_path: | |
| fig.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close(fig) | |
| return confidence, accuracy, accuracy | |
| def generate_calibration_summary( | |
| probs_before: np.ndarray, | |
| probs_after: np.ndarray, | |
| labels: np.ndarray, | |
| output_path: Optional[Path] = None, | |
| ) -> Dict: | |
| """ | |
| Generate side-by-side calibration comparison. | |
| Args: | |
| probs_before: Pre-calibration probabilities [N, C] | |
| probs_after: Post-calibration probabilities [N, C] | |
| labels: Ground truth labels [N] | |
| output_path: Where to save the plot | |
| Returns: | |
| Dict with comparison metrics | |
| """ | |
| from .calibration_metrics import compute_ece, compute_brier_score, compute_log_loss | |
| # Compute metrics | |
| ece_before, _, _, _ = compute_ece(probs_before, labels) | |
| ece_after, _, _, _ = compute_ece(probs_after, labels) | |
| brier_before = compute_brier_score(probs_before, labels) | |
| brier_after = compute_brier_score(probs_after, labels) | |
| ll_before = compute_log_loss(probs_before, labels) | |
| ll_after = compute_log_loss(probs_after, labels) | |
| predictions_before = np.argmax(probs_before, axis=1) | |
| predictions_after = np.argmax(probs_after, axis=1) | |
| acc_before = np.mean(predictions_before == labels) | |
| acc_after = np.mean(predictions_after == labels) | |
| # Create comparison plot | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) | |
| # ECE comparison | |
| ax = axes[0] | |
| metrics = ['ECE'] | |
| values_before = [ece_before] | |
| values_after = [ece_after] | |
| x = np.arange(len(metrics)) | |
| width = 0.35 | |
| ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) | |
| ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) | |
| ax.set_ylabel('ECE') | |
| ax.set_title('Expected Calibration Error') | |
| ax.legend() | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(metrics) | |
| # Brier score comparison | |
| ax = axes[1] | |
| metrics = ['Brier'] | |
| values_before = [brier_before] | |
| values_after = [brier_after] | |
| x = np.arange(len(metrics)) | |
| ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) | |
| ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) | |
| ax.set_ylabel('Brier Score') | |
| ax.set_title('Brier Score (lower is better)') | |
| ax.legend() | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(metrics) | |
| # Log loss comparison | |
| ax = axes[2] | |
| metrics = ['Log Loss'] | |
| values_before = [ll_before] | |
| values_after = [ll_after] | |
| x = np.arange(len(metrics)) | |
| ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8) | |
| ax.bar(x + width/2, values_after, width, label='After', alpha=0.8) | |
| ax.set_ylabel('Log Loss') | |
| ax.set_title('Log Loss (lower is better)') | |
| ax.legend() | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(metrics) | |
| fig.suptitle('Calibration Improvement', fontsize=14, fontweight='bold', y=1.02) | |
| if output_path: | |
| fig.savefig(output_path, dpi=150, bbox_inches='tight') | |
| plt.close(fig) | |
| return { | |
| 'ece_before': ece_before, | |
| 'ece_after': ece_after, | |
| 'ece_improvement': ece_before - ece_after, | |
| 'brier_before': brier_before, | |
| 'brier_after': brier_after, | |
| 'brier_improvement': brier_before - brier_after, | |
| 'log_loss_before': ll_before, | |
| 'log_loss_after': ll_after, | |
| 'log_loss_improvement': ll_before - ll_after, | |
| 'accuracy_before': acc_before, | |
| 'accuracy_after': acc_after, | |
| } | |