Dyuti Dasmahapatra commited on
Commit
a01dc02
Β·
1 Parent(s): 0561f5e

complete Phase 1 - core ViT auditing toolkit implementation

Browse files
requirements.txt CHANGED
@@ -1,16 +1,10 @@
1
- # Core ML & Data
2
- torch>=2.0.0
3
- torchvision>=0.15.0
4
- transformers>=4.30.0
5
- numpy
6
-
7
- # Explainable AI
8
- captum>=0.6.0
9
-
10
- # Dashboard & Visualization
11
- gradio>=3.40.0
12
- Pillow>=9.0.0
13
  matplotlib>=3.7.0
14
-
15
- # Utilities (for potential image handling)
16
- requests>=2.25.0
 
1
+ # requirements.txt - UPDATED FOR CURRENT PYTORCH VERSIONS
2
+ torch>=2.2.0,<2.9.0
3
+ torchvision>=0.17.0,<0.19.0
4
+ transformers>=4.35.0
5
+ captum>=0.7.0
6
+ gradio>=4.19.0
7
+ Pillow>=10.0.0
 
 
 
 
 
8
  matplotlib>=3.7.0
9
+ numpy>=1.24.0
10
+ requests>=2.28.0
 
src/explainer.py CHANGED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/explainer.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import captum
8
+ from captum.attr import LayerGradCam, GradientShap
9
+ from captum.attr import visualization as viz
10
+ import torch.nn.functional as F
11
+
12
+ class ViTWrapper(torch.nn.Module):
13
+ """
14
+ Wrapper class to make Hugging Face ViT compatible with Captum.
15
+ This returns raw tensors instead of Hugging Face output objects.
16
+ """
17
+ def __init__(self, model):
18
+ super().__init__()
19
+ self.model = model
20
+
21
+ def forward(self, x):
22
+ # Hugging Face models expect pixel_values key
23
+ outputs = self.model(pixel_values=x)
24
+ return outputs.logits
25
+
26
+ class AttentionHook:
27
+ """Hook to capture attention weights from ViT model"""
28
+ def __init__(self):
29
+ self.attention_weights = None
30
+
31
+ def __call__(self, module, input, output):
32
+ # For ViT, attention weights are usually the second output
33
+ if len(output) >= 2:
34
+ self.attention_weights = output[1] # attention weights
35
+ else:
36
+ self.attention_weights = None
37
+
38
+ def explain_attention(model, processor, image, layer_index=6, head_index=0):
39
+ """
40
+ Extract and visualize attention weights using hooks.
41
+ """
42
+ try:
43
+ device = next(model.parameters()).device
44
+
45
+ # Preprocess image
46
+ inputs = processor(images=image, return_tensors="pt")
47
+ inputs = {k: v.to(device) for k, v in inputs.items()}
48
+
49
+ # Register hook to capture attention
50
+ hook = AttentionHook()
51
+
52
+ # Try different layer access patterns
53
+ try:
54
+ # For standard ViT structure
55
+ target_layer = model.vit.encoder.layer[layer_index].attention.attention
56
+ handle = target_layer.register_forward_hook(hook)
57
+ except:
58
+ try:
59
+ # Alternative structure
60
+ target_layer = model.vit.encoder.layers[layer_index].attention.attention
61
+ handle = target_layer.register_forward_hook(hook)
62
+ except:
63
+ raise ValueError(f"Could not access layer {layer_index} for attention hook")
64
+
65
+ # Forward pass to capture attention
66
+ with torch.no_grad():
67
+ _ = model(**inputs)
68
+
69
+ # Remove hook
70
+ handle.remove()
71
+
72
+ if hook.attention_weights is None:
73
+ raise ValueError("No attention weights captured by hook")
74
+
75
+ # Get attention weights
76
+ attention_weights = hook.attention_weights # Shape: (batch, heads, seq_len, seq_len)
77
+ attention_map = attention_weights[0, head_index] # Shape: (seq_len, seq_len)
78
+
79
+ # Remove CLS token attention to other tokens
80
+ patch_attention = attention_map[1:, 1:] # Remove CLS token rows and columns
81
+
82
+ # Create visualization
83
+ fig, ax = plt.subplots(figsize=(8, 6))
84
+
85
+ # Display attention matrix
86
+ im = ax.imshow(patch_attention.cpu().numpy(), cmap='viridis', aspect='auto')
87
+
88
+ ax.set_title(f'Attention Map - Layer {layer_index}, Head {head_index}', fontsize=14, fontweight='bold')
89
+ ax.set_xlabel('Key Patches')
90
+ ax.set_ylabel('Query Patches')
91
+
92
+ # Add colorbar
93
+ plt.colorbar(im, ax=ax)
94
+
95
+ plt.tight_layout()
96
+ return fig
97
+
98
+ except Exception as e:
99
+ print(f"Error in attention visualization: {str(e)}")
100
+ # Return a simple error plot
101
+ fig, ax = plt.subplots(figsize=(8, 6))
102
+ ax.text(0.5, 0.5, f"Attention visualization failed:\n{str(e)}",
103
+ ha='center', va='center', transform=ax.transAxes, fontsize=10)
104
+ ax.set_title('Attention Visualization Error')
105
+ return fig
106
+
107
+ def explain_gradcam(model, processor, image, target_layer_index=-2):
108
+ """
109
+ Generate GradCAM heatmap for the predicted class.
110
+ """
111
+ try:
112
+ device = next(model.parameters()).device
113
+
114
+ # Preprocess image
115
+ inputs = processor(images=image, return_tensors="pt")
116
+ input_tensor = inputs['pixel_values'].to(device)
117
+
118
+ # Get prediction
119
+ with torch.no_grad():
120
+ outputs = model(input_tensor)
121
+ predicted_class = outputs.logits.argmax(dim=1).item()
122
+
123
+ # Get the target layer
124
+ try:
125
+ target_layer = model.vit.encoder.layer[target_layer_index].attention.attention
126
+ except:
127
+ target_layer = model.vit.encoder.layers[target_layer_index].attention.attention
128
+
129
+ # Create wrapped model for Captum compatibility
130
+ wrapped_model = ViTWrapper(model)
131
+
132
+ # Initialize GradCAM with wrapped model
133
+ gradcam = LayerGradCam(wrapped_model, target_layer)
134
+
135
+ # Generate attribution - handle tuple output
136
+ attribution = gradcam.attribute(input_tensor, target=predicted_class)
137
+
138
+ # FIX: Handle tuple output by taking the first element
139
+ if isinstance(attribution, tuple):
140
+ attribution = attribution[0]
141
+
142
+ # Convert attribution to heatmap
143
+ attribution = attribution.squeeze().cpu().detach().numpy()
144
+
145
+ # Normalize attribution
146
+ if attribution.max() > attribution.min():
147
+ attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min())
148
+ else:
149
+ attribution = np.zeros_like(attribution)
150
+
151
+ # Resize heatmap to match original image
152
+ original_size = image.size
153
+ heatmap = Image.fromarray((attribution * 255).astype(np.uint8))
154
+ heatmap = heatmap.resize(original_size, Image.Resampling.LANCZOS)
155
+ heatmap = np.array(heatmap)
156
+
157
+ # Create visualization figure
158
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
159
+
160
+ # Original image
161
+ ax1.imshow(image)
162
+ ax1.set_title('Original Image')
163
+ ax1.axis('off')
164
+
165
+ # Heatmap
166
+ ax2.imshow(heatmap, cmap='hot')
167
+ ax2.set_title('GradCAM Heatmap')
168
+ ax2.axis('off')
169
+
170
+ # Overlay
171
+ ax3.imshow(image)
172
+ ax3.imshow(heatmap, cmap='hot', alpha=0.5)
173
+ ax3.set_title('Overlay')
174
+ ax3.axis('off')
175
+
176
+ plt.tight_layout()
177
+
178
+ # Create overlay image for dashboard
179
+ heatmap_rgb = (plt.cm.hot(heatmap / 255.0)[:, :, :3] * 255).astype(np.uint8)
180
+ overlay_img = Image.fromarray(heatmap_rgb)
181
+ overlay_img = overlay_img.resize(original_size, Image.Resampling.LANCZOS)
182
+
183
+ # Blend with original
184
+ original_rgba = image.convert('RGBA')
185
+ overlay_rgba = overlay_img.convert('RGBA')
186
+ blended = Image.blend(original_rgba, overlay_rgba, alpha=0.5)
187
+
188
+ return fig, blended.convert('RGB')
189
+
190
+ except Exception as e:
191
+ print(f"Error in GradCAM: {str(e)}")
192
+ fig, ax = plt.subplots(figsize=(8, 6))
193
+ ax.text(0.5, 0.5, f"GradCAM failed:\n{str(e)}",
194
+ ha='center', va='center', transform=ax.transAxes, fontsize=10)
195
+ ax.set_title('GradCAM Error')
196
+ return fig, image
197
+
198
+ def explain_gradient_shap(model, processor, image, n_samples=5):
199
+ """
200
+ Generate GradientSHAP explanations.
201
+ """
202
+ try:
203
+ device = next(model.parameters()).device
204
+
205
+ # Preprocess image
206
+ inputs = processor(images=image, return_tensors="pt")
207
+ input_tensor = inputs['pixel_values'].to(device)
208
+
209
+ # Get prediction
210
+ with torch.no_grad():
211
+ outputs = model(input_tensor)
212
+ predicted_class = outputs.logits.argmax(dim=1).item()
213
+
214
+ # Create baseline (black image)
215
+ baseline = torch.zeros_like(input_tensor)
216
+
217
+ # Create wrapped model for Captum compatibility
218
+ wrapped_model = ViTWrapper(model)
219
+
220
+ # Initialize GradientSHAP with wrapped model
221
+ gradient_shap = GradientShap(wrapped_model)
222
+
223
+ # Generate attribution
224
+ attribution = gradient_shap.attribute(
225
+ input_tensor,
226
+ baselines=baseline,
227
+ n_samples=n_samples,
228
+ target=predicted_class
229
+ )
230
+
231
+ # Summarize attribution across channels
232
+ attribution = attribution.squeeze().sum(dim=0).cpu().detach().numpy()
233
+
234
+ # Normalize
235
+ if attribution.max() > attribution.min():
236
+ attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min())
237
+ else:
238
+ attribution = np.zeros_like(attribution)
239
+
240
+ # Create visualization
241
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
242
+
243
+ # Original image
244
+ ax1.imshow(image)
245
+ ax1.set_title('Original Image')
246
+ ax1.axis('off')
247
+
248
+ # SHAP attribution
249
+ im = ax2.imshow(attribution, cmap='coolwarm')
250
+ ax2.set_title('GradientSHAP Attribution')
251
+ ax2.axis('off')
252
+ plt.colorbar(im, ax=ax2)
253
+
254
+ # Overlay
255
+ ax3.imshow(image, alpha=0.7)
256
+ im_overlay = ax3.imshow(attribution, cmap='coolwarm', alpha=0.5)
257
+ ax3.set_title('Attribution Overlay')
258
+ ax3.axis('off')
259
+ plt.colorbar(im_overlay, ax=ax3)
260
+
261
+ plt.tight_layout()
262
+ return fig
263
+
264
+ except Exception as e:
265
+ print(f"Error in GradientSHAP: {str(e)}")
266
+ fig, ax = plt.subplots(figsize=(8, 6))
267
+ ax.text(0.5, 0.5, f"GradientSHAP failed:\n{str(e)}",
268
+ ha='center', va='center', transform=ax.transAxes, fontsize=10)
269
+ ax.set_title('GradientSHAP Error')
270
+ return fig
src/model_loader.py CHANGED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/model_loader.py
2
+
3
+ from transformers import ViTImageProcessor, ViTForImageClassification
4
+ import torch
5
+
6
+ def load_model_and_processor(model_name="google/vit-base-patch16-224"):
7
+ """
8
+ Load a Vision Transformer model and its corresponding processor from Hugging Face.
9
+ """
10
+ try:
11
+ print(f"Loading model {model_name}...")
12
+
13
+ # Load processor and model with eager attention implementation
14
+ processor = ViTImageProcessor.from_pretrained(model_name)
15
+
16
+ # Force eager attention implementation to get attention weights
17
+ model = ViTForImageClassification.from_pretrained(
18
+ model_name,
19
+ attn_implementation="eager" # This enables attention output
20
+ )
21
+
22
+ # Now we can safely set output_attentions
23
+ model.config.output_attentions = True
24
+
25
+ # Set device
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = model.to(device)
28
+
29
+ # Set model to evaluation mode
30
+ model.eval()
31
+
32
+ print(f"βœ… Model and processor loaded successfully on {device}!")
33
+ print(f" Using attention implementation: {model.config._attn_implementation}")
34
+ return model, processor
35
+
36
+ except Exception as e:
37
+ print(f"Error loading model {model_name}: {str(e)}")
38
+ raise
39
+
40
+ # Supported models
41
+ SUPPORTED_MODELS = {
42
+ "ViT-Base": "google/vit-base-patch16-224",
43
+ "ViT-Large": "google/vit-large-patch16-224",
44
+ }
src/predictor.py CHANGED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/predictor.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ def predict_image(image, model, processor, top_k=5):
10
+ """
11
+ Perform inference on an image and return top-k predictions.
12
+
13
+ Args:
14
+ image (PIL.Image): Input image to classify.
15
+ model: Loaded ViT model.
16
+ processor: Loaded ViT processor.
17
+ top_k (int): Number of top predictions to return.
18
+
19
+ Returns:
20
+ tuple: (top_probs, top_indices, top_labels) - Probabilities, class indices, and label names.
21
+ """
22
+ try:
23
+ # Get the device from the model
24
+ device = next(model.parameters()).device
25
+
26
+ # Preprocess the image - note: current processors return pixel_values
27
+ inputs = processor(images=image, return_tensors="pt")
28
+ inputs = {k: v.to(device) for k, v in inputs.items()}
29
+
30
+ # Perform inference
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+ logits = outputs.logits
34
+
35
+ # Apply softmax to get probabilities
36
+ probabilities = F.softmax(logits, dim=-1)[0]
37
+
38
+ # Get top-k predictions
39
+ top_probs, top_indices = torch.topk(probabilities, top_k)
40
+
41
+ # Convert to Python lists and numpy arrays
42
+ top_probs = top_probs.cpu().numpy()
43
+ top_indices = top_indices.cpu().numpy()
44
+
45
+ # Get human-readable labels
46
+ top_labels = [model.config.id2label[idx] for idx in top_indices]
47
+
48
+ return top_probs, top_indices, top_labels
49
+
50
+ except Exception as e:
51
+ print(f"Error during prediction: {str(e)}")
52
+ raise
53
+
54
+ def create_prediction_plot(probs, labels):
55
+ """
56
+ Create a clean, professional bar chart for top predictions.
57
+
58
+ Args:
59
+ probs (np.array): Array of probabilities.
60
+ labels (list): List of label names.
61
+
62
+ Returns:
63
+ matplotlib.figure.Figure: The generated plot figure.
64
+ """
65
+ fig, ax = plt.subplots(figsize=(8, 4))
66
+
67
+ # Create horizontal bar chart
68
+ y_pos = np.arange(len(labels))
69
+ bars = ax.barh(y_pos, probs, color='skyblue', alpha=0.8)
70
+ ax.set_yticks(y_pos)
71
+ ax.set_yticklabels(labels, fontsize=10)
72
+ ax.set_xlabel('Confidence', fontsize=12)
73
+ ax.set_title('Top Predictions', fontsize=14, fontweight='bold')
74
+
75
+ # Add probability text on bars
76
+ for i, (bar, prob) in enumerate(zip(bars, probs)):
77
+ width = bar.get_width()
78
+ ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
79
+ f'{prob:.2%}', va='center', fontsize=9)
80
+
81
+ # Set x-axis limit and style
82
+ ax.set_xlim(0, max(probs) * 1.15) # Add some padding for text
83
+ ax.grid(axis='x', alpha=0.3, linestyle='--')
84
+
85
+ plt.tight_layout()
86
+ return fig
src/utils.py CHANGED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils.py
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ import torch
7
+
8
+ def preprocess_image(image, target_size=224):
9
+ """
10
+ Preprocess image for ViT model.
11
+
12
+ Args:
13
+ image: PIL Image or file path
14
+ target_size: Target size for resizing
15
+
16
+ Returns:
17
+ PIL.Image: Preprocessed image
18
+ """
19
+ if isinstance(image, str):
20
+ # If it's a file path, load the image
21
+ image = Image.open(image)
22
+
23
+ # Convert to RGB if necessary
24
+ if image.mode != 'RGB':
25
+ image = image.convert('RGB')
26
+
27
+ # Resize image
28
+ image = image.resize((target_size, target_size))
29
+
30
+ return image
31
+
32
+ def normalize_heatmap(heatmap):
33
+ """
34
+ Normalize heatmap to [0, 1] range.
35
+
36
+ Args:
37
+ heatmap: numpy array of heatmap values
38
+
39
+ Returns:
40
+ numpy.array: Normalized heatmap
41
+ """
42
+ if heatmap.max() > heatmap.min():
43
+ return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
44
+ else:
45
+ return np.zeros_like(heatmap)
46
+
47
+ def overlay_heatmap(image, heatmap, alpha=0.5, colormap='hot'):
48
+ """
49
+ Overlay heatmap on original image.
50
+
51
+ Args:
52
+ image: PIL Image
53
+ heatmap: numpy array of heatmap values
54
+ alpha: Transparency for heatmap overlay
55
+ colormap: Matplotlib colormap name
56
+
57
+ Returns:
58
+ PIL.Image: Image with heatmap overlay
59
+ """
60
+ # Normalize heatmap
61
+ heatmap = normalize_heatmap(heatmap)
62
+
63
+ # Convert heatmap to RGB using colormap
64
+ cmap = plt.get_cmap(colormap)
65
+ heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)
66
+
67
+ # Resize heatmap to match image size
68
+ heatmap_img = Image.fromarray(heatmap_rgb)
69
+ heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)
70
+
71
+ # Blend images
72
+ original_rgba = image.convert('RGBA')
73
+ heatmap_rgba = heatmap_img.convert('RGBA')
74
+ blended = Image.blend(original_rgba, heatmap_rgba, alpha)
75
+
76
+ return blended.convert('RGB')
77
+
78
+ def create_comparison_figure(original_image, explanation_images, explanation_titles):
79
+ """
80
+ Create a comparison figure showing original image and multiple explanations.
81
+
82
+ Args:
83
+ original_image: PIL Image
84
+ explanation_images: List of explanation images
85
+ explanation_titles: List of titles for each explanation
86
+
87
+ Returns:
88
+ matplotlib.figure.Figure: Comparison figure
89
+ """
90
+ num_explanations = len(explanation_images)
91
+ fig, axes = plt.subplots(1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4))
92
+
93
+ # Plot original image
94
+ axes[0].imshow(original_image)
95
+ axes[0].set_title('Original Image', fontweight='bold')
96
+ axes[0].axis('off')
97
+
98
+ # Plot explanations
99
+ for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
100
+ axes[i + 1].imshow(exp_img)
101
+ axes[i + 1].set_title(title, fontweight='bold')
102
+ axes[i + 1].axis('off')
103
+
104
+ plt.tight_layout()
105
+ return fig
106
+
107
+ def tensor_to_image(tensor):
108
+ """
109
+ Convert PyTorch tensor to PIL Image.
110
+
111
+ Args:
112
+ tensor: PyTorch tensor of shape (C, H, W) or (B, C, H, W)
113
+
114
+ Returns:
115
+ PIL.Image: Converted image
116
+ """
117
+ if tensor.dim() == 4:
118
+ tensor = tensor.squeeze(0)
119
+
120
+ # Denormalize if needed and convert to numpy
121
+ tensor = tensor.cpu().detach()
122
+ if tensor.min() < 0 or tensor.max() > 1:
123
+ # Assume it's normalized, denormalize to [0, 1]
124
+ tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
125
+
126
+ numpy_image = tensor.permute(1, 2, 0).numpy()
127
+ numpy_image = (numpy_image * 255).astype(np.uint8)
128
+
129
+ return Image.fromarray(numpy_image)
130
+
131
+ def get_top_predictions_dict(probs, labels, top_k=5):
132
+ """
133
+ Convert top predictions to dictionary for Gradio Label component.
134
+
135
+ Args:
136
+ probs: Array of probabilities
137
+ labels: List of label names
138
+ top_k: Number of top predictions to include
139
+
140
+ Returns:
141
+ dict: Dictionary of {label: probability} for top-k predictions
142
+ """
143
+ return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}
tests/test_phase1_complete.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # test_phase1_complete.py
2
+
3
+ import sys
4
+ import os
5
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
6
+
7
+ from model_loader import load_model_and_processor, SUPPORTED_MODELS
8
+ from predictor import predict_image, create_prediction_plot
9
+ from explainer import explain_attention, explain_gradcam, explain_gradient_shap
10
+ from utils import preprocess_image, create_comparison_figure, get_top_predictions_dict
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+
15
+ def test_phase1_complete():
16
+ """
17
+ Complete Phase 1 Test - Tests all components together.
18
+ """
19
+ print("πŸ§ͺ ViT Auditing Toolkit - Phase 1 Complete Test")
20
+ print("=" * 50)
21
+
22
+ try:
23
+ # Test 1: Model Loading
24
+ print("1. Testing Model Loading...")
25
+ model, processor = load_model_and_processor()
26
+ print(f" βœ… Loaded: {SUPPORTED_MODELS['ViT-Base']}")
27
+
28
+ # Test 2: Create test image using utils
29
+ print("2. Testing Image Preprocessing...")
30
+ # Create a more realistic test image
31
+ test_image = Image.new('RGB', (300, 200), color=(150, 75, 75))
32
+ # Add different colored regions
33
+ for x in range(50, 150):
34
+ for y in range(50, 150):
35
+ test_image.putpixel((x, y), (75, 150, 75)) # Green rectangle
36
+ for x in range(180, 280):
37
+ for y in range(30, 100):
38
+ test_image.putpixel((x, y), (75, 75, 150)) # Blue rectangle
39
+
40
+ # Preprocess using utils
41
+ processed_image = preprocess_image(test_image, target_size=224)
42
+ print(f" βœ… Original size: {test_image.size}, Processed: {processed_image.size}")
43
+
44
+ # Test 3: Prediction Pipeline
45
+ print("3. Testing Prediction Pipeline...")
46
+ probs, indices, labels = predict_image(processed_image, model, processor, top_k=5)
47
+ pred_fig = create_prediction_plot(probs, labels)
48
+
49
+ # Test utils function
50
+ pred_dict = get_top_predictions_dict(probs, labels)
51
+ print(f" βœ… Top prediction: {labels[0]} ({probs[0]:.2%})")
52
+
53
+ # Test 4: Attention Explanation
54
+ print("4. Testing Attention Visualization...")
55
+ attention_fig = explain_attention(model, processor, processed_image, layer_index=6, head_index=0)
56
+ print(" βœ… Attention visualization generated")
57
+
58
+ # Test 5: GradCAM Explanation
59
+ print("5. Testing GradCAM...")
60
+ gradcam_fig, gradcam_overlay = explain_gradcam(model, processor, processed_image)
61
+ print(" βœ… GradCAM visualization generated")
62
+
63
+ # Test 6: GradientSHAP Explanation
64
+ print("6. Testing GradientSHAP...")
65
+ shap_fig = explain_gradient_shap(model, processor, processed_image, n_samples=3)
66
+ print(" βœ… GradientSHAP visualization generated")
67
+
68
+ # Test 7: Utils - Comparison Figure
69
+ print("7. Testing Utils - Comparison Figure...")
70
+ comparison_fig = create_comparison_figure(
71
+ processed_image,
72
+ [gradcam_overlay],
73
+ ['GradCAM Overlay']
74
+ )
75
+ print(" βœ… Comparison figure generated")
76
+
77
+ # Display Results
78
+ print("\nπŸ“Š DISPLAYING RESULTS:")
79
+ print("=" * 30)
80
+
81
+ # Show prediction results
82
+ plt.figure(pred_fig.number)
83
+ plt.suptitle("1. Model Predictions", fontweight='bold', y=1.02)
84
+ plt.show()
85
+
86
+ # Show attention results
87
+ plt.figure(attention_fig.number)
88
+ plt.suptitle("2. Attention Visualization", fontweight='bold', y=1.02)
89
+ plt.show()
90
+
91
+ # Show GradCAM results
92
+ plt.figure(gradcam_fig.number)
93
+ plt.suptitle("3. GradCAM Explanation", fontweight='bold', y=1.02)
94
+ plt.show()
95
+
96
+ # Show SHAP results
97
+ plt.figure(shap_fig.number)
98
+ plt.suptitle("4. GradientSHAP Explanation", fontweight='bold', y=1.02)
99
+ plt.show()
100
+
101
+ # Show comparison
102
+ plt.figure(comparison_fig.number)
103
+ plt.suptitle("5. Comparison View", fontweight='bold', y=1.02)
104
+ plt.show()
105
+
106
+ # Summary
107
+ print("\nπŸŽ‰ PHASE 1 COMPLETE SUMMARY:")
108
+ print("=" * 35)
109
+ print("βœ… Model Loading & Preprocessing")
110
+ print("βœ… Prediction Pipeline with Visualization")
111
+ print("βœ… Attention Visualization")
112
+ print("βœ… GradCAM Explanations")
113
+ print("βœ… GradientSHAP Explanations")
114
+ print("βœ… Utility Functions")
115
+ print(f"βœ… All components integrated successfully!")
116
+ print("\nπŸš€ Ready for Phase 2: Dashboard Integration!")
117
+
118
+ return True
119
+
120
+ except Exception as e:
121
+ print(f"\n❌ Phase 1 Test Failed: {e}")
122
+ import traceback
123
+ traceback.print_exc()
124
+ return False
125
+
126
+ def test_individual_components():
127
+ """
128
+ Test individual components for debugging.
129
+ """
130
+ print("\nπŸ”§ Individual Component Tests:")
131
+ print("-" * 30)
132
+
133
+ try:
134
+ # Test model loading
135
+ model, processor = load_model_and_processor()
136
+ print("βœ… Model loading: PASS")
137
+
138
+ # Test image creation
139
+ test_img = Image.new('RGB', (224, 224), color='red')
140
+ print("βœ… Image creation: PASS")
141
+
142
+ # Test prediction
143
+ probs, indices, labels = predict_image(test_img, model, processor)
144
+ print("βœ… Prediction: PASS")
145
+
146
+ # Test attention
147
+ attn_fig = explain_attention(model, processor, test_img)
148
+ print("βœ… Attention: PASS")
149
+
150
+ # Test GradCAM
151
+ gc_fig, gc_img = explain_gradcam(model, processor, test_img)
152
+ print("βœ… GradCAM: PASS")
153
+
154
+ # Test SHAP
155
+ shap_fig = explain_gradient_shap(model, processor, test_img, n_samples=2)
156
+ print("βœ… GradientSHAP: PASS")
157
+
158
+ # Test utils
159
+ from utils import normalize_heatmap
160
+ test_heatmap = np.random.rand(10, 10)
161
+ normalized = normalize_heatmap(test_heatmap)
162
+ print("βœ… Utils: PASS")
163
+
164
+ print("\nπŸŽ‰ All individual components working!")
165
+
166
+ except Exception as e:
167
+ print(f"❌ Component test failed: {e}")
168
+
169
+ if __name__ == "__main__":
170
+ # Run complete test
171
+ success = test_phase1_complete()
172
+
173
+ if success:
174
+ # Run quick individual tests
175
+ test_individual_components()
176
+ else:
177
+ print("\n⚠️ Running individual component tests for debugging...")
178
+ test_individual_components()