from typing import List, Dict, Any import numpy as np from sklearn.metrics import roc_auc_score, average_precision_score import matplotlib.pyplot as plt import logging logger = logging.getLogger(__name__) class CalibrationEvaluator: def __init__(self): pass def expected_calibration_error(self, predictions: List[float], labels: List[int], n_bins: int = 10) -> float: """Calculate Expected Calibration Error (ECE)""" if not predictions or not labels: return 0.0 predictions = np.array(predictions) labels = np.array(labels) # Create bins bin_boundaries = np.linspace(0, 1, n_bins + 1) bin_lowers = bin_boundaries[:-1] bin_uppers = bin_boundaries[1:] ece = 0 for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): # Find predictions in this bin in_bin = (predictions > bin_lower) & (predictions <= bin_upper) prop_in_bin = in_bin.mean() if prop_in_bin > 0: # Calculate accuracy in this bin accuracy_in_bin = labels[in_bin].mean() avg_confidence_in_bin = predictions[in_bin].mean() # Add to ECE ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin return ece def maximum_calibration_error(self, predictions: List[float], labels: List[int], n_bins: int = 10) -> float: """Calculate Maximum Calibration Error (MCE)""" if not predictions or not labels: return 0.0 predictions = np.array(predictions) labels = np.array(labels) # Create bins bin_boundaries = np.linspace(0, 1, n_bins + 1) bin_lowers = bin_boundaries[:-1] bin_uppers = bin_boundaries[1:] mce = 0 for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): # Find predictions in this bin in_bin = (predictions > bin_lower) & (predictions <= bin_upper) if in_bin.sum() > 0: # Calculate accuracy in this bin accuracy_in_bin = labels[in_bin].mean() avg_confidence_in_bin = predictions[in_bin].mean() # Update MCE mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin)) return mce def reliability_diagram(self, predictions: List[float], labels: List[int], n_bins: int = 10, save_path: str = None) -> Dict[str, Any]: """Create reliability diagram""" if not predictions or not labels: return {} predictions = np.array(predictions) labels = np.array(labels) # Create bins bin_boundaries = np.linspace(0, 1, n_bins + 1) bin_lowers = bin_boundaries[:-1] bin_uppers = bin_boundaries[1:] bin_centers = [] accuracies = [] confidences = [] counts = [] for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): # Find predictions in this bin in_bin = (predictions > bin_lower) & (predictions <= bin_upper) count = in_bin.sum() if count > 0: bin_center = (bin_lower + bin_upper) / 2 accuracy = labels[in_bin].mean() confidence = predictions[in_bin].mean() bin_centers.append(bin_center) accuracies.append(accuracy) confidences.append(confidence) counts.append(count) # Create plot plt.figure(figsize=(8, 6)) plt.bar(bin_centers, accuracies, width=0.1, alpha=0.7, label='Accuracy') plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration') plt.xlabel('Confidence') plt.ylabel('Accuracy') plt.title('Reliability Diagram') plt.legend() plt.grid(True, alpha=0.3) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() return { 'bin_centers': bin_centers, 'accuracies': accuracies, 'confidences': confidences, 'counts': counts } def auroc(self, predictions: List[float], labels: List[int]) -> float: """Calculate Area Under ROC Curve""" if not predictions or not labels: return 0.0 try: return roc_auc_score(labels, predictions) except: return 0.0 def auprc(self, predictions: List[float], labels: List[int]) -> float: """Calculate Area Under Precision-Recall Curve""" if not predictions or not labels: return 0.0 try: return average_precision_score(labels, predictions) except: return 0.0 def risk_coverage_curve(self, predictions: List[float], labels: List[int], risk_thresholds: List[float] = None) -> Dict[str, Any]: """Calculate risk-coverage curve""" if not predictions or not labels: return {'thresholds': [], 'coverage': [], 'accuracy': []} predictions = np.array(predictions) labels = np.array(labels) if risk_thresholds is None: risk_thresholds = np.linspace(0, 1, 21) coverages = [] accuracies = [] for threshold in risk_thresholds: # Select predictions with risk <= threshold selected = predictions <= threshold if selected.sum() > 0: coverage = selected.mean() accuracy = labels[selected].mean() else: coverage = 0.0 accuracy = 0.0 coverages.append(coverage) accuracies.append(accuracy) return { 'thresholds': risk_thresholds.tolist(), 'coverage': coverages, 'accuracy': accuracies } def evaluate_calibration(self, predictions: List[float], labels: List[int]) -> Dict[str, float]: """Comprehensive calibration evaluation""" if not predictions or not labels: return { 'ece': 0.0, 'mce': 0.0, 'auroc': 0.0, 'auprc': 0.0 } metrics = { 'ece': self.expected_calibration_error(predictions, labels), 'mce': self.maximum_calibration_error(predictions, labels), 'auroc': self.auroc(predictions, labels), 'auprc': self.auprc(predictions, labels) } # Risk-coverage analysis risk_coverage = self.risk_coverage_curve(predictions, labels) metrics['risk_coverage'] = risk_coverage return metrics def plot_calibration_curves(self, predictions: List[float], labels: List[int], save_path: str = None) -> None: """Plot calibration curves""" if not predictions or not labels: return fig, axes = plt.subplots(2, 2, figsize=(12, 10)) # Reliability diagram reliability_data = self.reliability_diagram(predictions, labels) if reliability_data: axes[0, 0].bar(reliability_data['bin_centers'], reliability_data['accuracies'], width=0.1, alpha=0.7) axes[0, 0].plot([0, 1], [0, 1], 'r--') axes[0, 0].set_xlabel('Confidence') axes[0, 0].set_ylabel('Accuracy') axes[0, 0].set_title('Reliability Diagram') axes[0, 0].grid(True, alpha=0.3) # Risk-coverage curve risk_coverage = self.risk_coverage_curve(predictions, labels) if risk_coverage['thresholds']: axes[0, 1].plot(risk_coverage['coverage'], risk_coverage['accuracy'], 'b-') axes[0, 1].set_xlabel('Coverage') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].set_title('Risk-Coverage Curve') axes[0, 1].grid(True, alpha=0.3) # Confidence distribution axes[1, 0].hist(predictions, bins=20, alpha=0.7, edgecolor='black') axes[1, 0].set_xlabel('Confidence') axes[1, 0].set_ylabel('Count') axes[1, 0].set_title('Confidence Distribution') axes[1, 0].grid(True, alpha=0.3) # Accuracy vs Confidence bin_centers = np.linspace(0, 1, 11) accuracies = [] for i in range(len(bin_centers) - 1): mask = (np.array(predictions) >= bin_centers[i]) & (np.array(predictions) < bin_centers[i + 1]) if mask.sum() > 0: accuracies.append(np.array(labels)[mask].mean()) else: accuracies.append(0) axes[1, 1].plot(bin_centers[:-1], accuracies, 'bo-') axes[1, 1].plot([0, 1], [0, 1], 'r--') axes[1, 1].set_xlabel('Confidence') axes[1, 1].set_ylabel('Accuracy') axes[1, 1].set_title('Accuracy vs Confidence') axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close()