ddi / src /calibration /reliability_analysis.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""
Reliability analysis and visualization for calibration.
Generates:
- Reliability diagrams
- Confidence histograms
- Confidence vs accuracy plots
- Calibration summary statistics
"""
from pathlib import Path
from typing import Dict, Tuple, Optional
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
def generate_reliability_diagram(
probs: np.ndarray,
labels: np.ndarray,
output_path: Optional[Path] = None,
title: str = "Calibration Reliability Diagram",
n_bins: int = 10,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate reliability diagram (calibration curve).
Shows relationship between predicted confidence and actual accuracy.
Perfect calibration is a diagonal line.
Args:
probs: [N, C] probability matrix
labels: [N] ground truth labels
output_path: Where to save the plot
title: Plot title
n_bins: Number of confidence bins
Returns:
Tuple of (bin_confidences, bin_accuracies, bin_sizes)
"""
# Get predictions and confidence
predictions = np.argmax(probs, axis=1)
confidence = np.max(probs, axis=1)
accuracy = (predictions == labels).astype(float)
# Bin by confidence
bin_boundaries = np.linspace(0, 1 + 1e-6, n_bins + 1)
bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
bin_accs = []
bin_confs = []
bin_sizes = []
for i in range(n_bins):
mask = (confidence >= bin_boundaries[i]) & (confidence < bin_boundaries[i + 1])
if i == n_bins - 1:
mask = (confidence >= bin_boundaries[i]) & (confidence <= 1.0 + 1e-6)
if np.sum(mask) > 0:
bin_acc = accuracy[mask].mean()
bin_conf = confidence[mask].mean()
bin_size = np.sum(mask)
bin_accs.append(bin_acc)
bin_confs.append(bin_conf)
bin_sizes.append(bin_size)
bin_accs = np.array(bin_accs)
bin_confs = np.array(bin_confs)
bin_sizes = np.array(bin_sizes)
# Create plot
fig, ax = plt.subplots(figsize=(8, 8))
# Perfect calibration line
ax.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Perfect calibration')
# Calibration curve
ax.plot(bin_confs, bin_accs, 'b-o', linewidth=2, markersize=8, label='Model calibration')
# Fill area between curve and diagonal
ax.fill_between(bin_confs, bin_accs, bin_confs, alpha=0.3)
ax.set_xlabel('Predicted Confidence', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(alpha=0.3)
ax.legend()
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(fig)
return bin_confs, bin_accs, bin_sizes
def generate_confidence_histogram(
probs: np.ndarray,
labels: np.ndarray,
output_path: Optional[Path] = None,
title: str = "Confidence Distribution",
) -> Dict[str, np.ndarray]:
"""
Generate confidence histogram by correctness.
Args:
probs: [N, C] probability matrix
labels: [N] ground truth labels
output_path: Where to save the plot
title: Plot title
Returns:
Dict with correct and incorrect confidences
"""
predictions = np.argmax(probs, axis=1)
confidence = np.max(probs, axis=1)
correct = (predictions == labels).astype(bool)
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(confidence[correct], bins=30, alpha=0.6, label='Correct', color='green', edgecolor='black')
ax.hist(confidence[~correct], bins=30, alpha=0.6, label='Incorrect', color='red', edgecolor='black')
ax.set_xlabel('Predicted Confidence', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(fig)
return {
'correct_confidences': confidence[correct],
'incorrect_confidences': confidence[~correct],
}
def generate_confidence_vs_accuracy(
probs: np.ndarray,
labels: np.ndarray,
output_path: Optional[Path] = None,
title: str = "Confidence vs Accuracy",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Generate scatter plot of confidence vs accuracy per-sample.
Args:
probs: [N, C] probability matrix
labels: [N] ground truth labels
output_path: Where to save the plot
title: Plot title
Returns:
Tuple of (confidence, accuracy, per-sample metrics)
"""
predictions = np.argmax(probs, axis=1)
confidence = np.max(probs, axis=1)
accuracy = (predictions == labels).astype(float)
fig, ax = plt.subplots(figsize=(10, 8))
# Color by accuracy
colors = np.where(accuracy == 1, 'green', 'red')
sizes = np.where(accuracy == 1, 50, 100)
ax.scatter(confidence, accuracy + np.random.normal(0, 0.02, len(accuracy)),
c=colors, s=sizes, alpha=0.5, edgecolors='black', linewidth=0.5)
# Trend line using bins
bins = np.linspace(0, 1, 11)
bin_centers = (bins[:-1] + bins[1:]) / 2
bin_accs = []
for i in range(len(bins) - 1):
mask = (confidence >= bins[i]) & (confidence < bins[i + 1])
if np.sum(mask) > 0:
bin_accs.append(accuracy[mask].mean())
else:
bin_accs.append(0)
ax.plot(bin_centers, bin_accs, 'b-', linewidth=2, label='Binned accuracy')
ax.set_xlabel('Predicted Confidence', fontsize=12)
ax.set_ylabel('Correctness (jittered)', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
green_patch = mpatches.Patch(color='green', label='Correct')
red_patch = mpatches.Patch(color='red', label='Incorrect')
ax.legend(handles=[green_patch, red_patch, ax.lines[0]], fontsize=10)
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(fig)
return confidence, accuracy, accuracy
def generate_calibration_summary(
probs_before: np.ndarray,
probs_after: np.ndarray,
labels: np.ndarray,
output_path: Optional[Path] = None,
) -> Dict:
"""
Generate side-by-side calibration comparison.
Args:
probs_before: Pre-calibration probabilities [N, C]
probs_after: Post-calibration probabilities [N, C]
labels: Ground truth labels [N]
output_path: Where to save the plot
Returns:
Dict with comparison metrics
"""
from .calibration_metrics import compute_ece, compute_brier_score, compute_log_loss
# Compute metrics
ece_before, _, _, _ = compute_ece(probs_before, labels)
ece_after, _, _, _ = compute_ece(probs_after, labels)
brier_before = compute_brier_score(probs_before, labels)
brier_after = compute_brier_score(probs_after, labels)
ll_before = compute_log_loss(probs_before, labels)
ll_after = compute_log_loss(probs_after, labels)
predictions_before = np.argmax(probs_before, axis=1)
predictions_after = np.argmax(probs_after, axis=1)
acc_before = np.mean(predictions_before == labels)
acc_after = np.mean(predictions_after == labels)
# Create comparison plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# ECE comparison
ax = axes[0]
metrics = ['ECE']
values_before = [ece_before]
values_after = [ece_after]
x = np.arange(len(metrics))
width = 0.35
ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8)
ax.bar(x + width/2, values_after, width, label='After', alpha=0.8)
ax.set_ylabel('ECE')
ax.set_title('Expected Calibration Error')
ax.legend()
ax.set_xticks(x)
ax.set_xticklabels(metrics)
# Brier score comparison
ax = axes[1]
metrics = ['Brier']
values_before = [brier_before]
values_after = [brier_after]
x = np.arange(len(metrics))
ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8)
ax.bar(x + width/2, values_after, width, label='After', alpha=0.8)
ax.set_ylabel('Brier Score')
ax.set_title('Brier Score (lower is better)')
ax.legend()
ax.set_xticks(x)
ax.set_xticklabels(metrics)
# Log loss comparison
ax = axes[2]
metrics = ['Log Loss']
values_before = [ll_before]
values_after = [ll_after]
x = np.arange(len(metrics))
ax.bar(x - width/2, values_before, width, label='Before', alpha=0.8)
ax.bar(x + width/2, values_after, width, label='After', alpha=0.8)
ax.set_ylabel('Log Loss')
ax.set_title('Log Loss (lower is better)')
ax.legend()
ax.set_xticks(x)
ax.set_xticklabels(metrics)
fig.suptitle('Calibration Improvement', fontsize=14, fontweight='bold', y=1.02)
if output_path:
fig.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(fig)
return {
'ece_before': ece_before,
'ece_after': ece_after,
'ece_improvement': ece_before - ece_after,
'brier_before': brier_before,
'brier_after': brier_after,
'brier_improvement': brier_before - brier_after,
'log_loss_before': ll_before,
'log_loss_after': ll_after,
'log_loss_improvement': ll_before - ll_after,
'accuracy_before': acc_before,
'accuracy_after': acc_after,
}