Spaces:
Sleeping
Sleeping
| # src/auditor.py | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image, ImageDraw, ImageFilter | |
| import torch.nn.functional as F | |
| from scipy import stats | |
| from sklearn.calibration import calibration_curve | |
| from sklearn.metrics import brier_score_loss | |
| import pandas as pd | |
| class CounterfactualAnalyzer: | |
| """Analyze how predictions change with image perturbations.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def patch_perturbation_analysis(self, image, patch_size=16, perturbation_type='blur'): | |
| """ | |
| Analyze how predictions change when different patches are perturbed. | |
| Args: | |
| image: PIL Image | |
| patch_size: Size of patches to perturb | |
| perturbation_type: Type of perturbation ('blur', 'noise', 'blackout', 'gray') | |
| Returns: | |
| dict: Analysis results with visualizations | |
| """ | |
| original_probs, _, original_labels = self._predict_image(image) | |
| original_top_label = original_labels[0] | |
| original_confidence = original_probs[0] | |
| # Get image dimensions | |
| width, height = image.size | |
| # Create grid of patches | |
| patches_x = width // patch_size | |
| patches_y = height // patch_size | |
| # Store results | |
| confidence_changes = [] | |
| prediction_changes = [] | |
| patch_heatmap = np.zeros((patches_y, patches_x)) | |
| for i in range(patches_y): | |
| for j in range(patches_x): | |
| # Create perturbed image | |
| perturbed_img = self._perturb_patch( | |
| image.copy(), j, i, patch_size, perturbation_type | |
| ) | |
| # Get prediction on perturbed image | |
| perturbed_probs, _, perturbed_labels = self._predict_image(perturbed_img) | |
| perturbed_confidence = perturbed_probs[0] | |
| perturbed_label = perturbed_labels[0] | |
| # Calculate changes | |
| confidence_change = perturbed_confidence - original_confidence | |
| prediction_change = 1 if perturbed_label != original_top_label else 0 | |
| confidence_changes.append(confidence_change) | |
| prediction_changes.append(prediction_change) | |
| patch_heatmap[i, j] = confidence_change | |
| # Create visualization | |
| fig = self._create_counterfactual_visualization( | |
| image, patch_heatmap, patch_size, original_top_label, | |
| original_confidence, confidence_changes, prediction_changes | |
| ) | |
| return { | |
| 'figure': fig, | |
| 'patch_heatmap': patch_heatmap, | |
| 'avg_confidence_change': np.mean(confidence_changes), | |
| 'prediction_flip_rate': np.mean(prediction_changes), | |
| 'most_sensitive_patch': np.unravel_index(np.argmin(patch_heatmap), patch_heatmap.shape) | |
| } | |
| def _perturb_patch(self, image, patch_x, patch_y, patch_size, perturbation_type): | |
| """Apply perturbation to a specific patch.""" | |
| left = patch_x * patch_size | |
| upper = patch_y * patch_size | |
| right = left + patch_size | |
| lower = upper + patch_size | |
| patch_box = (left, upper, right, lower) | |
| if perturbation_type == 'blur': | |
| # Extract patch, blur it, and paste back | |
| patch = image.crop(patch_box) | |
| blurred_patch = patch.filter(ImageFilter.GaussianBlur(5)) | |
| image.paste(blurred_patch, patch_box) | |
| elif perturbation_type == 'blackout': | |
| # Black out the patch | |
| draw = ImageDraw.Draw(image) | |
| draw.rectangle(patch_box, fill='black') | |
| elif perturbation_type == 'gray': | |
| # Convert patch to grayscale | |
| patch = image.crop(patch_box) | |
| gray_patch = patch.convert('L').convert('RGB') | |
| image.paste(gray_patch, patch_box) | |
| elif perturbation_type == 'noise': | |
| # Add noise to patch | |
| patch = np.array(image.crop(patch_box)) | |
| noise = np.random.normal(0, 50, patch.shape).astype(np.uint8) | |
| noisy_patch = np.clip(patch + noise, 0, 255).astype(np.uint8) | |
| image.paste(Image.fromarray(noisy_patch), patch_box) | |
| return image | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_counterfactual_visualization(self, image, patch_heatmap, patch_size, | |
| original_label, original_confidence, | |
| confidence_changes, prediction_changes): | |
| """Create visualization for counterfactual analysis.""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) | |
| # Original image | |
| ax1.imshow(image) | |
| ax1.set_title(f'Original Image\nPrediction: {original_label} ({original_confidence:.2%})', | |
| fontweight='bold') | |
| ax1.axis('off') | |
| # Patch sensitivity heatmap | |
| im = ax2.imshow(patch_heatmap, cmap='RdBu_r', vmin=-0.5, vmax=0.5) | |
| ax2.set_title('Patch Sensitivity Heatmap\n(Confidence Change When Perturbed)', | |
| fontweight='bold') | |
| ax2.set_xlabel('Patch X') | |
| ax2.set_ylabel('Patch Y') | |
| plt.colorbar(im, ax=ax2, label='Confidence Change') | |
| # Add patch grid to original image | |
| width, height = image.size | |
| for i in range(patch_heatmap.shape[0]): | |
| for j in range(patch_heatmap.shape[1]): | |
| rect = plt.Rectangle((j * patch_size, i * patch_size), | |
| patch_size, patch_size, | |
| linewidth=1, edgecolor='red', | |
| facecolor='none', alpha=0.3) | |
| ax1.add_patch(rect) | |
| # Confidence change distribution | |
| ax3.hist(confidence_changes, bins=20, alpha=0.7, color='skyblue') | |
| ax3.axvline(0, color='red', linestyle='--', label='No Change') | |
| ax3.set_xlabel('Confidence Change') | |
| ax3.set_ylabel('Frequency') | |
| ax3.set_title('Distribution of Confidence Changes', fontweight='bold') | |
| ax3.legend() | |
| ax3.grid(alpha=0.3) | |
| # Prediction flip analysis | |
| flip_rate = np.mean(prediction_changes) | |
| ax4.bar(['No Flip', 'Flip'], [1 - flip_rate, flip_rate], color=['green', 'red']) | |
| ax4.set_ylabel('Proportion') | |
| ax4.set_title(f'Prediction Flip Rate: {flip_rate:.2%}', fontweight='bold') | |
| ax4.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| class ConfidenceCalibrationAnalyzer: | |
| """Analyze model calibration and confidence metrics.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def analyze_calibration(self, test_images, test_labels=None, n_bins=10): | |
| """ | |
| Analyze model calibration using confidence scores. | |
| Args: | |
| test_images: List of PIL Images for testing | |
| test_labels: Optional true labels for accuracy calculation | |
| n_bins: Number of bins for calibration curve | |
| Returns: | |
| dict: Calibration analysis results | |
| """ | |
| confidences = [] | |
| predictions = [] | |
| max_confidences = [] | |
| # Get predictions and confidences | |
| for img in test_images: | |
| probs, indices, labels = self._predict_image(img) | |
| max_confidences.append(probs[0]) | |
| predictions.append(labels[0]) | |
| confidences.append(probs) | |
| max_confidences = np.array(max_confidences) | |
| # Create calibration analysis | |
| fig = self._create_calibration_visualization( | |
| max_confidences, test_labels, predictions, n_bins | |
| ) | |
| # Calculate calibration metrics | |
| calibration_metrics = self._calculate_calibration_metrics( | |
| max_confidences, test_labels, predictions | |
| ) | |
| return { | |
| 'figure': fig, | |
| 'metrics': calibration_metrics, | |
| 'confidence_distribution': max_confidences | |
| } | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_calibration_visualization(self, confidences, true_labels, predictions, n_bins): | |
| """Create calibration visualization.""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) | |
| # Confidence distribution | |
| ax1.hist(confidences, bins=20, alpha=0.7, color='lightblue', edgecolor='black') | |
| ax1.set_xlabel('Confidence Score') | |
| ax1.set_ylabel('Frequency') | |
| ax1.set_title('Distribution of Confidence Scores', fontweight='bold') | |
| ax1.axvline(np.mean(confidences), color='red', linestyle='--', | |
| label=f'Mean: {np.mean(confidences):.3f}') | |
| ax1.legend() | |
| ax1.grid(alpha=0.3) | |
| # Reliability diagram (if true labels available) | |
| if true_labels is not None: | |
| # Convert to binary correctness | |
| correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) | |
| fraction_of_positives, mean_predicted_prob = calibration_curve( | |
| correct, confidences, n_bins=n_bins, strategy='uniform' | |
| ) | |
| ax2.plot(mean_predicted_prob, fraction_of_positives, "s-", label='Model') | |
| ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") | |
| ax2.set_xlabel('Mean Predicted Probability') | |
| ax2.set_ylabel('Fraction of Positives') | |
| ax2.set_title('Reliability Diagram', fontweight='bold') | |
| ax2.legend() | |
| ax2.grid(alpha=0.3) | |
| # Calculate ECE | |
| bin_edges = np.linspace(0, 1, n_bins + 1) | |
| bin_indices = np.digitize(confidences, bin_edges) - 1 | |
| bin_indices = np.clip(bin_indices, 0, n_bins - 1) | |
| ece = 0 | |
| for bin_idx in range(n_bins): | |
| mask = bin_indices == bin_idx | |
| if np.sum(mask) > 0: | |
| bin_conf = np.mean(confidences[mask]) | |
| bin_acc = np.mean(correct[mask]) | |
| ece += (np.sum(mask) / len(confidences)) * np.abs(bin_acc - bin_conf) | |
| ax2.text(0.1, 0.9, f'ECE: {ece:.3f}', transform=ax2.transAxes, | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7)) | |
| # Confidence vs accuracy (if true labels available) | |
| if true_labels is not None: | |
| confidence_bins = np.linspace(0, 1, n_bins + 1) | |
| bin_accuracies = [] | |
| bin_confidences = [] | |
| for i in range(n_bins): | |
| mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i+1]) | |
| if np.sum(mask) > 0: | |
| bin_acc = np.mean(correct[mask]) | |
| bin_conf = np.mean(confidences[mask]) | |
| bin_accuracies.append(bin_acc) | |
| bin_confidences.append(bin_conf) | |
| ax3.plot(bin_confidences, bin_accuracies, 'o-', label='Model') | |
| ax3.plot([0, 1], [0, 1], 'k--', label='Ideal') | |
| ax3.set_xlabel('Average Confidence') | |
| ax3.set_ylabel('Average Accuracy') | |
| ax3.set_title('Confidence vs Accuracy', fontweight='bold') | |
| ax3.legend() | |
| ax3.grid(alpha=0.3) | |
| # Top-1 vs Top-5 confidence gap | |
| if len(confidences) > 0 and isinstance(confidences[0], np.ndarray): | |
| top1_conf = [c[0] for c in confidences] | |
| top5_conf = [np.sum(c[:5]) for c in confidences] | |
| confidence_gap = [t1 - (t5 - t1)/4 for t1, t5 in zip(top1_conf, top5_conf)] | |
| ax4.hist(confidence_gap, bins=20, alpha=0.7, color='lightgreen', edgecolor='black') | |
| ax4.set_xlabel('Confidence Gap (Top-1 vs Rest)') | |
| ax4.set_ylabel('Frequency') | |
| ax4.set_title('Distribution of Confidence Gaps', fontweight='bold') | |
| ax4.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def _calculate_calibration_metrics(self, confidences, true_labels, predictions): | |
| """Calculate calibration metrics.""" | |
| metrics = { | |
| 'mean_confidence': float(np.mean(confidences)), | |
| 'confidence_std': float(np.std(confidences)), | |
| 'overconfident_rate': float(np.mean(confidences > 0.8)), | |
| 'underconfident_rate': float(np.mean(confidences < 0.2)), | |
| } | |
| if true_labels is not None: | |
| correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) | |
| accuracy = np.mean(correct) | |
| avg_confidence = np.mean(confidences) | |
| metrics.update({ | |
| 'accuracy': float(accuracy), | |
| 'confidence_gap': float(avg_confidence - accuracy), | |
| 'brier_score': float(brier_score_loss(correct, confidences)) | |
| }) | |
| return metrics | |
| class BiasDetector: | |
| """Detect potential biases in model performance across subgroups.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def analyze_subgroup_performance(self, image_subsets, subset_names, true_labels_subsets=None): | |
| """ | |
| Analyze performance across different subgroups. | |
| Args: | |
| image_subsets: List of image subsets for each subgroup | |
| subset_names: Names for each subgroup | |
| true_labels_subsets: Optional true labels for each subset | |
| Returns: | |
| dict: Bias analysis results | |
| """ | |
| subgroup_metrics = {} | |
| for i, (subset, name) in enumerate(zip(image_subsets, subset_names)): | |
| confidences = [] | |
| predictions = [] | |
| for img in subset: | |
| probs, indices, labels = self._predict_image(img) | |
| confidences.append(probs[0]) | |
| predictions.append(labels[0]) | |
| metrics = { | |
| 'mean_confidence': np.mean(confidences), | |
| 'confidence_std': np.std(confidences), | |
| 'sample_size': len(subset) | |
| } | |
| # Calculate accuracy if true labels provided | |
| if true_labels_subsets is not None and i < len(true_labels_subsets): | |
| true_labels = true_labels_subsets[i] | |
| correct = [pred == true for pred, true in zip(predictions, true_labels)] | |
| metrics['accuracy'] = np.mean(correct) | |
| metrics['error_rate'] = 1 - metrics['accuracy'] | |
| subgroup_metrics[name] = metrics | |
| # Create bias analysis visualization | |
| fig = self._create_bias_visualization(subgroup_metrics, true_labels_subsets is not None) | |
| # Calculate fairness metrics | |
| fairness_metrics = self._calculate_fairness_metrics(subgroup_metrics) | |
| return { | |
| 'figure': fig, | |
| 'subgroup_metrics': subgroup_metrics, | |
| 'fairness_metrics': fairness_metrics | |
| } | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_bias_visualization(self, subgroup_metrics, has_accuracy): | |
| """Create visualization for bias analysis.""" | |
| if has_accuracy: | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) | |
| else: | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
| subgroups = list(subgroup_metrics.keys()) | |
| # Confidence by subgroup | |
| confidences = [metrics['mean_confidence'] for metrics in subgroup_metrics.values()] | |
| ax1.bar(subgroups, confidences, color='lightblue', alpha=0.7) | |
| ax1.set_ylabel('Mean Confidence') | |
| ax1.set_title('Mean Confidence by Subgroup', fontweight='bold') | |
| ax1.tick_params(axis='x', rotation=45) | |
| ax1.grid(axis='y', alpha=0.3) | |
| # Add confidence values on bars | |
| for i, v in enumerate(confidences): | |
| ax1.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom') | |
| # Sample sizes | |
| sample_sizes = [metrics['sample_size'] for metrics in subgroup_metrics.values()] | |
| ax2.bar(subgroups, sample_sizes, color='lightgreen', alpha=0.7) | |
| ax2.set_ylabel('Sample Size') | |
| ax2.set_title('Sample Size by Subgroup', fontweight='bold') | |
| ax2.tick_params(axis='x', rotation=45) | |
| ax2.grid(axis='y', alpha=0.3) | |
| # Add sample size values on bars | |
| for i, v in enumerate(sample_sizes): | |
| ax2.text(i, v + max(sample_sizes)*0.01, f'{v}', ha='center', va='bottom') | |
| # Accuracy by subgroup (if available) | |
| if has_accuracy: | |
| accuracies = [metrics.get('accuracy', 0) for metrics in subgroup_metrics.values()] | |
| ax3.bar(subgroups, accuracies, color='lightcoral', alpha=0.7) | |
| ax3.set_ylabel('Accuracy') | |
| ax3.set_title('Accuracy by Subgroup', fontweight='bold') | |
| ax3.tick_params(axis='x', rotation=45) | |
| ax3.grid(axis='y', alpha=0.3) | |
| # Add accuracy values on bars | |
| for i, v in enumerate(accuracies): | |
| ax3.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom') | |
| plt.tight_layout() | |
| return fig | |
| def _calculate_fairness_metrics(self, subgroup_metrics): | |
| """Calculate fairness metrics.""" | |
| fairness_metrics = {} | |
| # Check if we have accuracy metrics | |
| has_accuracy = all('accuracy' in metrics for metrics in subgroup_metrics.values()) | |
| if has_accuracy and len(subgroup_metrics) >= 2: | |
| accuracies = [metrics['accuracy'] for metrics in subgroup_metrics.values()] | |
| confidences = [metrics['mean_confidence'] for metrics in subgroup_metrics.values()] | |
| fairness_metrics = { | |
| 'accuracy_range': float(max(accuracies) - min(accuracies)), | |
| 'accuracy_std': float(np.std(accuracies)), | |
| 'confidence_range': float(max(confidences) - min(confidences)), | |
| 'max_accuracy_disparity': float(max(accuracies) / min(accuracies) if min(accuracies) > 0 else float('inf')), | |
| } | |
| return fairness_metrics | |
| # Convenience function to create all auditors | |
| def create_auditors(model, processor): | |
| """Create all auditing analyzers.""" | |
| return { | |
| 'counterfactual': CounterfactualAnalyzer(model, processor), | |
| 'calibration': ConfidenceCalibrationAnalyzer(model, processor), | |
| 'bias': BiasDetector(model, processor) | |
| } |