safe_rag / eval /eval_calib.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger(__name__)
class CalibrationEvaluator:
def __init__(self):
pass
def expected_calibration_error(self, predictions: List[float],
labels: List[int], n_bins: int = 10) -> float:
"""Calculate Expected Calibration Error (ECE)"""
if not predictions or not labels:
return 0.0
predictions = np.array(predictions)
labels = np.array(labels)
# Create bins
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
ece = 0
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
# Find predictions in this bin
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
prop_in_bin = in_bin.mean()
if prop_in_bin > 0:
# Calculate accuracy in this bin
accuracy_in_bin = labels[in_bin].mean()
avg_confidence_in_bin = predictions[in_bin].mean()
# Add to ECE
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
return ece
def maximum_calibration_error(self, predictions: List[float],
labels: List[int], n_bins: int = 10) -> float:
"""Calculate Maximum Calibration Error (MCE)"""
if not predictions or not labels:
return 0.0
predictions = np.array(predictions)
labels = np.array(labels)
# Create bins
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
mce = 0
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
# Find predictions in this bin
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
if in_bin.sum() > 0:
# Calculate accuracy in this bin
accuracy_in_bin = labels[in_bin].mean()
avg_confidence_in_bin = predictions[in_bin].mean()
# Update MCE
mce = max(mce, np.abs(avg_confidence_in_bin - accuracy_in_bin))
return mce
def reliability_diagram(self, predictions: List[float], labels: List[int],
n_bins: int = 10, save_path: str = None) -> Dict[str, Any]:
"""Create reliability diagram"""
if not predictions or not labels:
return {}
predictions = np.array(predictions)
labels = np.array(labels)
# Create bins
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
bin_centers = []
accuracies = []
confidences = []
counts = []
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
# Find predictions in this bin
in_bin = (predictions > bin_lower) & (predictions <= bin_upper)
count = in_bin.sum()
if count > 0:
bin_center = (bin_lower + bin_upper) / 2
accuracy = labels[in_bin].mean()
confidence = predictions[in_bin].mean()
bin_centers.append(bin_center)
accuracies.append(accuracy)
confidences.append(confidence)
counts.append(count)
# Create plot
plt.figure(figsize=(8, 6))
plt.bar(bin_centers, accuracies, width=0.1, alpha=0.7, label='Accuracy')
plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')
plt.xlabel('Confidence')
plt.ylabel('Accuracy')
plt.title('Reliability Diagram')
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
return {
'bin_centers': bin_centers,
'accuracies': accuracies,
'confidences': confidences,
'counts': counts
}
def auroc(self, predictions: List[float], labels: List[int]) -> float:
"""Calculate Area Under ROC Curve"""
if not predictions or not labels:
return 0.0
try:
return roc_auc_score(labels, predictions)
except:
return 0.0
def auprc(self, predictions: List[float], labels: List[int]) -> float:
"""Calculate Area Under Precision-Recall Curve"""
if not predictions or not labels:
return 0.0
try:
return average_precision_score(labels, predictions)
except:
return 0.0
def risk_coverage_curve(self, predictions: List[float], labels: List[int],
risk_thresholds: List[float] = None) -> Dict[str, Any]:
"""Calculate risk-coverage curve"""
if not predictions or not labels:
return {'thresholds': [], 'coverage': [], 'accuracy': []}
predictions = np.array(predictions)
labels = np.array(labels)
if risk_thresholds is None:
risk_thresholds = np.linspace(0, 1, 21)
coverages = []
accuracies = []
for threshold in risk_thresholds:
# Select predictions with risk <= threshold
selected = predictions <= threshold
if selected.sum() > 0:
coverage = selected.mean()
accuracy = labels[selected].mean()
else:
coverage = 0.0
accuracy = 0.0
coverages.append(coverage)
accuracies.append(accuracy)
return {
'thresholds': risk_thresholds.tolist(),
'coverage': coverages,
'accuracy': accuracies
}
def evaluate_calibration(self, predictions: List[float], labels: List[int]) -> Dict[str, float]:
"""Comprehensive calibration evaluation"""
if not predictions or not labels:
return {
'ece': 0.0,
'mce': 0.0,
'auroc': 0.0,
'auprc': 0.0
}
metrics = {
'ece': self.expected_calibration_error(predictions, labels),
'mce': self.maximum_calibration_error(predictions, labels),
'auroc': self.auroc(predictions, labels),
'auprc': self.auprc(predictions, labels)
}
# Risk-coverage analysis
risk_coverage = self.risk_coverage_curve(predictions, labels)
metrics['risk_coverage'] = risk_coverage
return metrics
def plot_calibration_curves(self, predictions: List[float], labels: List[int],
save_path: str = None) -> None:
"""Plot calibration curves"""
if not predictions or not labels:
return
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Reliability diagram
reliability_data = self.reliability_diagram(predictions, labels)
if reliability_data:
axes[0, 0].bar(reliability_data['bin_centers'], reliability_data['accuracies'],
width=0.1, alpha=0.7)
axes[0, 0].plot([0, 1], [0, 1], 'r--')
axes[0, 0].set_xlabel('Confidence')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].set_title('Reliability Diagram')
axes[0, 0].grid(True, alpha=0.3)
# Risk-coverage curve
risk_coverage = self.risk_coverage_curve(predictions, labels)
if risk_coverage['thresholds']:
axes[0, 1].plot(risk_coverage['coverage'], risk_coverage['accuracy'], 'b-')
axes[0, 1].set_xlabel('Coverage')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Risk-Coverage Curve')
axes[0, 1].grid(True, alpha=0.3)
# Confidence distribution
axes[1, 0].hist(predictions, bins=20, alpha=0.7, edgecolor='black')
axes[1, 0].set_xlabel('Confidence')
axes[1, 0].set_ylabel('Count')
axes[1, 0].set_title('Confidence Distribution')
axes[1, 0].grid(True, alpha=0.3)
# Accuracy vs Confidence
bin_centers = np.linspace(0, 1, 11)
accuracies = []
for i in range(len(bin_centers) - 1):
mask = (np.array(predictions) >= bin_centers[i]) & (np.array(predictions) < bin_centers[i + 1])
if mask.sum() > 0:
accuracies.append(np.array(labels)[mask].mean())
else:
accuracies.append(0)
axes[1, 1].plot(bin_centers[:-1], accuracies, 'bo-')
axes[1, 1].plot([0, 1], [0, 1], 'r--')
axes[1, 1].set_xlabel('Confidence')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_title('Accuracy vs Confidence')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()