File size: 6,470 Bytes
a01dc02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# test_phase1_complete.py

import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

from model_loader import load_model_and_processor, SUPPORTED_MODELS
from predictor import predict_image, create_prediction_plot
from explainer import explain_attention, explain_gradcam, explain_gradient_shap
from utils import preprocess_image, create_comparison_figure, get_top_predictions_dict
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def test_phase1_complete():
    """
    Complete Phase 1 Test - Tests all components together.
    """
    print("πŸ§ͺ ViT Auditing Toolkit - Phase 1 Complete Test")
    print("=" * 50)
    
    try:
        # Test 1: Model Loading
        print("1. Testing Model Loading...")
        model, processor = load_model_and_processor()
        print(f"   βœ… Loaded: {SUPPORTED_MODELS['ViT-Base']}")
        
        # Test 2: Create test image using utils
        print("2. Testing Image Preprocessing...")
        # Create a more realistic test image
        test_image = Image.new('RGB', (300, 200), color=(150, 75, 75))
        # Add different colored regions
        for x in range(50, 150):
            for y in range(50, 150):
                test_image.putpixel((x, y), (75, 150, 75))  # Green rectangle
        for x in range(180, 280):
            for y in range(30, 100):
                test_image.putpixel((x, y), (75, 75, 150))  # Blue rectangle
        
        # Preprocess using utils
        processed_image = preprocess_image(test_image, target_size=224)
        print(f"   βœ… Original size: {test_image.size}, Processed: {processed_image.size}")
        
        # Test 3: Prediction Pipeline
        print("3. Testing Prediction Pipeline...")
        probs, indices, labels = predict_image(processed_image, model, processor, top_k=5)
        pred_fig = create_prediction_plot(probs, labels)
        
        # Test utils function
        pred_dict = get_top_predictions_dict(probs, labels)
        print(f"   βœ… Top prediction: {labels[0]} ({probs[0]:.2%})")
        
        # Test 4: Attention Explanation
        print("4. Testing Attention Visualization...")
        attention_fig = explain_attention(model, processor, processed_image, layer_index=6, head_index=0)
        print("   βœ… Attention visualization generated")
        
        # Test 5: GradCAM Explanation
        print("5. Testing GradCAM...")
        gradcam_fig, gradcam_overlay = explain_gradcam(model, processor, processed_image)
        print("   βœ… GradCAM visualization generated")
        
        # Test 6: GradientSHAP Explanation
        print("6. Testing GradientSHAP...")
        shap_fig = explain_gradient_shap(model, processor, processed_image, n_samples=3)
        print("   βœ… GradientSHAP visualization generated")
        
        # Test 7: Utils - Comparison Figure
        print("7. Testing Utils - Comparison Figure...")
        comparison_fig = create_comparison_figure(
            processed_image,
            [gradcam_overlay],
            ['GradCAM Overlay']
        )
        print("   βœ… Comparison figure generated")
        
        # Display Results
        print("\nπŸ“Š DISPLAYING RESULTS:")
        print("=" * 30)
        
        # Show prediction results
        plt.figure(pred_fig.number)
        plt.suptitle("1. Model Predictions", fontweight='bold', y=1.02)
        plt.show()
        
        # Show attention results
        plt.figure(attention_fig.number)
        plt.suptitle("2. Attention Visualization", fontweight='bold', y=1.02)
        plt.show()
        
        # Show GradCAM results
        plt.figure(gradcam_fig.number)
        plt.suptitle("3. GradCAM Explanation", fontweight='bold', y=1.02)
        plt.show()
        
        # Show SHAP results
        plt.figure(shap_fig.number)
        plt.suptitle("4. GradientSHAP Explanation", fontweight='bold', y=1.02)
        plt.show()
        
        # Show comparison
        plt.figure(comparison_fig.number)
        plt.suptitle("5. Comparison View", fontweight='bold', y=1.02)
        plt.show()
        
        # Summary
        print("\nπŸŽ‰ PHASE 1 COMPLETE SUMMARY:")
        print("=" * 35)
        print("βœ… Model Loading & Preprocessing")
        print("βœ… Prediction Pipeline with Visualization") 
        print("βœ… Attention Visualization")
        print("βœ… GradCAM Explanations")
        print("βœ… GradientSHAP Explanations")
        print("βœ… Utility Functions")
        print(f"βœ… All components integrated successfully!")
        print("\nπŸš€ Ready for Phase 2: Dashboard Integration!")
        
        return True
        
    except Exception as e:
        print(f"\n❌ Phase 1 Test Failed: {e}")
        import traceback
        traceback.print_exc()
        return False

def test_individual_components():
    """
    Test individual components for debugging.
    """
    print("\nπŸ”§ Individual Component Tests:")
    print("-" * 30)
    
    try:
        # Test model loading
        model, processor = load_model_and_processor()
        print("βœ… Model loading: PASS")
        
        # Test image creation
        test_img = Image.new('RGB', (224, 224), color='red')
        print("βœ… Image creation: PASS")
        
        # Test prediction
        probs, indices, labels = predict_image(test_img, model, processor)
        print("βœ… Prediction: PASS")
        
        # Test attention
        attn_fig = explain_attention(model, processor, test_img)
        print("βœ… Attention: PASS")
        
        # Test GradCAM
        gc_fig, gc_img = explain_gradcam(model, processor, test_img)
        print("βœ… GradCAM: PASS")
        
        # Test SHAP
        shap_fig = explain_gradient_shap(model, processor, test_img, n_samples=2)
        print("βœ… GradientSHAP: PASS")
        
        # Test utils
        from utils import normalize_heatmap
        test_heatmap = np.random.rand(10, 10)
        normalized = normalize_heatmap(test_heatmap)
        print("βœ… Utils: PASS")
        
        print("\nπŸŽ‰ All individual components working!")
        
    except Exception as e:
        print(f"❌ Component test failed: {e}")

if __name__ == "__main__":
    # Run complete test
    success = test_phase1_complete()
    
    if success:
        # Run quick individual tests
        test_individual_components()
    else:
        print("\n⚠️  Running individual component tests for debugging...")
        test_individual_components()