Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any | |
| import numpy as np | |
| from sklearn.metrics import roc_auc_score, average_precision_score | |
| import matplotlib.pyplot as plt | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class CalibrationEvaluator: | |
| def __init__(self): | |
| pass | |
| def expected_calibration_error(self, predictions: List[float], | |
| labels: List[int], n_bins: int = 10) -> float: | |
| """Calculate Expected Calibration Error (ECE)""" | |
| if not predictions or not labels: | |
| return 0.0 | |
| predictions = np.array(predictions) | |
| labels = np.array(labels) | |
| # Create bins | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_lowers = bin_boundaries[:-1] | |
| bin_uppers = bin_boundaries[1:] | |
| ece = 0 | |
| for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): | |
| # Find predictions in this bin | |
| in_bin = (predictions > bin_lower) & (predictions <= bin_upper) | |
| prop_in_bin = in_bin.mean() | |
| if prop_in_bin > 0: | |
| # Calculate accuracy in this bin | |
| accuracy_in_bin = labels[in_bin].mean() | |
| avg_confidence_in_bin = predictions[in_bin].mean() | |
| # Add to ECE | |
| ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin | |
| return ece | |
| def maximum_calibration_error(self, predictions: List[float], | |
| labels: List[int], n_bins: int = 10) -> float: | |
| """Calculate Maximum Calibration Error (MCE)""" | |
| if not predictions or not labels: | |
| return 0.0 | |
| predictions = np.array(predictions) | |
| labels = np.array(labels) | |
| # Create bins | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_lowers = bin_boundaries[:-1] | |
| bin_uppers = bin_boundaries[1:] | |
| mce = 0 | |
| for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): | |
| # Find predictions in this bin | |
| in_bin = (predictions > bin_lower) & (predictions <= bin_upper) | |
| if in_bin.sum() > 0: | |
| # Calculate accuracy in this bin | |
| accuracy_in_bin = labels[in_bin].mean() | |
| avg_confidence_in_bin = predictions[in_bin].mean() | |
| # Update MCE | |
| mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin)) | |
| return mce | |
| def reliability_diagram(self, predictions: List[float], labels: List[int], | |
| n_bins: int = 10, save_path: str = None) -> Dict[str, Any]: | |
| """Create reliability diagram""" | |
| if not predictions or not labels: | |
| return {} | |
| predictions = np.array(predictions) | |
| labels = np.array(labels) | |
| # Create bins | |
| bin_boundaries = np.linspace(0, 1, n_bins + 1) | |
| bin_lowers = bin_boundaries[:-1] | |
| bin_uppers = bin_boundaries[1:] | |
| bin_centers = [] | |
| accuracies = [] | |
| confidences = [] | |
| counts = [] | |
| for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): | |
| # Find predictions in this bin | |
| in_bin = (predictions > bin_lower) & (predictions <= bin_upper) | |
| count = in_bin.sum() | |
| if count > 0: | |
| bin_center = (bin_lower + bin_upper) / 2 | |
| accuracy = labels[in_bin].mean() | |
| confidence = predictions[in_bin].mean() | |
| bin_centers.append(bin_center) | |
| accuracies.append(accuracy) | |
| confidences.append(confidence) | |
| counts.append(count) | |
| # Create plot | |
| plt.figure(figsize=(8, 6)) | |
| plt.bar(bin_centers, accuracies, width=0.1, alpha=0.7, label='Accuracy') | |
| plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration') | |
| plt.xlabel('Confidence') | |
| plt.ylabel('Accuracy') | |
| plt.title('Reliability Diagram') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| return { | |
| 'bin_centers': bin_centers, | |
| 'accuracies': accuracies, | |
| 'confidences': confidences, | |
| 'counts': counts | |
| } | |
| def auroc(self, predictions: List[float], labels: List[int]) -> float: | |
| """Calculate Area Under ROC Curve""" | |
| if not predictions or not labels: | |
| return 0.0 | |
| try: | |
| return roc_auc_score(labels, predictions) | |
| except: | |
| return 0.0 | |
| def auprc(self, predictions: List[float], labels: List[int]) -> float: | |
| """Calculate Area Under Precision-Recall Curve""" | |
| if not predictions or not labels: | |
| return 0.0 | |
| try: | |
| return average_precision_score(labels, predictions) | |
| except: | |
| return 0.0 | |
| def risk_coverage_curve(self, predictions: List[float], labels: List[int], | |
| risk_thresholds: List[float] = None) -> Dict[str, Any]: | |
| """Calculate risk-coverage curve""" | |
| if not predictions or not labels: | |
| return {'thresholds': [], 'coverage': [], 'accuracy': []} | |
| predictions = np.array(predictions) | |
| labels = np.array(labels) | |
| if risk_thresholds is None: | |
| risk_thresholds = np.linspace(0, 1, 21) | |
| coverages = [] | |
| accuracies = [] | |
| for threshold in risk_thresholds: | |
| # Select predictions with risk <= threshold | |
| selected = predictions <= threshold | |
| if selected.sum() > 0: | |
| coverage = selected.mean() | |
| accuracy = labels[selected].mean() | |
| else: | |
| coverage = 0.0 | |
| accuracy = 0.0 | |
| coverages.append(coverage) | |
| accuracies.append(accuracy) | |
| return { | |
| 'thresholds': risk_thresholds.tolist(), | |
| 'coverage': coverages, | |
| 'accuracy': accuracies | |
| } | |
| def evaluate_calibration(self, predictions: List[float], labels: List[int]) -> Dict[str, float]: | |
| """Comprehensive calibration evaluation""" | |
| if not predictions or not labels: | |
| return { | |
| 'ece': 0.0, | |
| 'mce': 0.0, | |
| 'auroc': 0.0, | |
| 'auprc': 0.0 | |
| } | |
| metrics = { | |
| 'ece': self.expected_calibration_error(predictions, labels), | |
| 'mce': self.maximum_calibration_error(predictions, labels), | |
| 'auroc': self.auroc(predictions, labels), | |
| 'auprc': self.auprc(predictions, labels) | |
| } | |
| # Risk-coverage analysis | |
| risk_coverage = self.risk_coverage_curve(predictions, labels) | |
| metrics['risk_coverage'] = risk_coverage | |
| return metrics | |
| def plot_calibration_curves(self, predictions: List[float], labels: List[int], | |
| save_path: str = None) -> None: | |
| """Plot calibration curves""" | |
| if not predictions or not labels: | |
| return | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
| # Reliability diagram | |
| reliability_data = self.reliability_diagram(predictions, labels) | |
| if reliability_data: | |
| axes[0, 0].bar(reliability_data['bin_centers'], reliability_data['accuracies'], | |
| width=0.1, alpha=0.7) | |
| axes[0, 0].plot([0, 1], [0, 1], 'r--') | |
| axes[0, 0].set_xlabel('Confidence') | |
| axes[0, 0].set_ylabel('Accuracy') | |
| axes[0, 0].set_title('Reliability Diagram') | |
| axes[0, 0].grid(True, alpha=0.3) | |
| # Risk-coverage curve | |
| risk_coverage = self.risk_coverage_curve(predictions, labels) | |
| if risk_coverage['thresholds']: | |
| axes[0, 1].plot(risk_coverage['coverage'], risk_coverage['accuracy'], 'b-') | |
| axes[0, 1].set_xlabel('Coverage') | |
| axes[0, 1].set_ylabel('Accuracy') | |
| axes[0, 1].set_title('Risk-Coverage Curve') | |
| axes[0, 1].grid(True, alpha=0.3) | |
| # Confidence distribution | |
| axes[1, 0].hist(predictions, bins=20, alpha=0.7, edgecolor='black') | |
| axes[1, 0].set_xlabel('Confidence') | |
| axes[1, 0].set_ylabel('Count') | |
| axes[1, 0].set_title('Confidence Distribution') | |
| axes[1, 0].grid(True, alpha=0.3) | |
| # Accuracy vs Confidence | |
| bin_centers = np.linspace(0, 1, 11) | |
| accuracies = [] | |
| for i in range(len(bin_centers) - 1): | |
| mask = (np.array(predictions) >= bin_centers[i]) & (np.array(predictions) < bin_centers[i + 1]) | |
| if mask.sum() > 0: | |
| accuracies.append(np.array(labels)[mask].mean()) | |
| else: | |
| accuracies.append(0) | |
| axes[1, 1].plot(bin_centers[:-1], accuracies, 'bo-') | |
| axes[1, 1].plot([0, 1], [0, 1], 'r--') | |
| axes[1, 1].set_xlabel('Confidence') | |
| axes[1, 1].set_ylabel('Accuracy') | |
| axes[1, 1].set_title('Accuracy vs Confidence') | |
| axes[1, 1].grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |