File size: 5,640 Bytes
dd5a03c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# test_advanced_features.py

import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

from model_loader import load_model_and_processor
from auditor import create_auditors, CounterfactualAnalyzer, ConfidenceCalibrationAnalyzer, BiasDetector
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def create_test_subsets():
    """Create dummy test subsets for bias detection demo."""
    # Create different colored images to simulate subgroups
    subsets = []
    subset_names = ['Red Dominant', 'Green Dominant', 'Blue Dominant', 'Mixed Colors']
    
    for i, name in enumerate(subset_names):
        subset = []
        for j in range(10):  # 10 images per subset
            if name == 'Red Dominant':
                img = Image.new('RGB', (224, 224), color=(200, 50, 50))
            elif name == 'Green Dominant':
                img = Image.new('RGB', (224, 224), color=(50, 200, 50))
            elif name == 'Blue Dominant':
                img = Image.new('RGB', (224, 224), color=(50, 50, 200))
            else:  # Mixed
                color = (50 + j*20, 100 + j*10, 150 - j*15)
                img = Image.new('RGB', (224, 224), color=color)
            subset.append(img)
        subsets.append(subset)
    
    return subsets, subset_names

def test_advanced_features():
    """
    Test the advanced auditing features.
    """
    print("πŸ”¬ Testing Advanced Auditing Features")
    print("=" * 50)
    
    try:
        # Load model
        model, processor = load_model_and_processor()
        
        # Create auditors
        auditors = create_auditors(model, processor)
        print("βœ… Auditors created: Counterfactual, Calibration, Bias Detection")
        
        # Create test image
        test_image = Image.new('RGB', (224, 224), color=(150, 100, 100))
        for x in range(50, 150):
            for y in range(50, 150):
                test_image.putpixel((x, y), (100, 200, 100))
        
        print("\n1. Testing Counterfactual Analysis...")
        counterfactual_results = auditors['counterfactual'].patch_perturbation_analysis(
            test_image, patch_size=32, perturbation_type='blur'
        )
        print("   βœ… Counterfactual analysis completed")
        print(f"   πŸ“Š Avg confidence change: {counterfactual_results['avg_confidence_change']:.4f}")
        print(f"   πŸ”€ Prediction flip rate: {counterfactual_results['prediction_flip_rate']:.2%}")
        
        print("\n2. Testing Confidence Calibration...")
        # Create dummy test set
        test_images = [test_image] * 5  # Simple test with same image
        calibration_results = auditors['calibration'].analyze_calibration(test_images)
        print("   βœ… Calibration analysis completed")
        print(f"   πŸ“ˆ Mean confidence: {calibration_results['metrics']['mean_confidence']:.3f}")
        print(f"   🎯 Overconfident rate: {calibration_results['metrics']['overconfident_rate']:.2%}")
        
        print("\n3. Testing Bias Detection...")
        test_subsets, subset_names = create_test_subsets()
        bias_results = auditors['bias'].analyze_subgroup_performance(test_subsets, subset_names)
        print("   βœ… Bias detection analysis completed")
        print(f"   πŸ“Š Analyzed {len(subset_names)} subgroups")
        
        # Display results
        print("\nπŸ“Š DISPLAYING ADVANCED ANALYSIS RESULTS:")
        print("=" * 40)
        
        # Counterfactual results
        plt.figure(counterfactual_results['figure'].number)
        plt.suptitle("1. Counterfactual Analysis - Patch Sensitivity", fontweight='bold', y=0.98)
        plt.show()
        
        # Calibration results
        plt.figure(calibration_results['figure'].number)
        plt.suptitle("2. Confidence Calibration Analysis", fontweight='bold', y=0.98)
        plt.show()
        
        # Bias detection results
        plt.figure(bias_results['figure'].number)
        plt.suptitle("3. Bias Detection - Subgroup Analysis", fontweight='bold', y=0.98)
        plt.show()
        
        # Print detailed metrics
        print("\nπŸ“ˆ DETAILED METRICS:")
        print("-" * 20)
        
        print("\n🎯 Counterfactual Analysis:")
        for key, value in counterfactual_results.items():
            if key != 'figure':
                print(f"   {key}: {value}")
        
        print("\nπŸ“Š Calibration Analysis:")
        for key, value in calibration_results['metrics'].items():
            print(f"   {key}: {value}")
        
        print("\nβš–οΈ Bias Detection:")
        print("   Subgroup Metrics:")
        for subgroup, metrics in bias_results['subgroup_metrics'].items():
            print(f"     {subgroup}:")
            for metric, value in metrics.items():
                print(f"       {metric}: {value}")
        
        print("\nπŸŽ‰ ADVANCED FEATURES SUMMARY:")
        print("=" * 35)
        print("βœ… Counterfactual Analysis - Patch Sensitivity")
        print("βœ… Confidence Calibration - Reliability Analysis") 
        print("βœ… Bias Detection - Subgroup Performance")
        print("βœ… All advanced auditing features working!")
        
        return True
        
    except Exception as e:
        print(f"❌ Advanced features test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

if __name__ == "__main__":
    success = test_advanced_features()
    
    if success:
        print("\nπŸš€ All Phase 1 + Advanced Features Complete!")
        print("   Ready for Phase 2: Dashboard Integration!")
    else:
        print("\n⚠️ Some advanced features need debugging")