ViT-Auditing-Toolkit / src /explainer.py
Dyuti Dasmahapatra
complete Phase 1 - core ViT auditing toolkit implementation
a01dc02
raw
history blame
9.62 kB
# src/explainer.py
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import captum
from captum.attr import LayerGradCam, GradientShap
from captum.attr import visualization as viz
import torch.nn.functional as F
class ViTWrapper(torch.nn.Module):
"""
Wrapper class to make Hugging Face ViT compatible with Captum.
This returns raw tensors instead of Hugging Face output objects.
"""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
# Hugging Face models expect pixel_values key
outputs = self.model(pixel_values=x)
return outputs.logits
class AttentionHook:
"""Hook to capture attention weights from ViT model"""
def __init__(self):
self.attention_weights = None
def __call__(self, module, input, output):
# For ViT, attention weights are usually the second output
if len(output) >= 2:
self.attention_weights = output[1] # attention weights
else:
self.attention_weights = None
def explain_attention(model, processor, image, layer_index=6, head_index=0):
"""
Extract and visualize attention weights using hooks.
"""
try:
device = next(model.parameters()).device
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Register hook to capture attention
hook = AttentionHook()
# Try different layer access patterns
try:
# For standard ViT structure
target_layer = model.vit.encoder.layer[layer_index].attention.attention
handle = target_layer.register_forward_hook(hook)
except:
try:
# Alternative structure
target_layer = model.vit.encoder.layers[layer_index].attention.attention
handle = target_layer.register_forward_hook(hook)
except:
raise ValueError(f"Could not access layer {layer_index} for attention hook")
# Forward pass to capture attention
with torch.no_grad():
_ = model(**inputs)
# Remove hook
handle.remove()
if hook.attention_weights is None:
raise ValueError("No attention weights captured by hook")
# Get attention weights
attention_weights = hook.attention_weights # Shape: (batch, heads, seq_len, seq_len)
attention_map = attention_weights[0, head_index] # Shape: (seq_len, seq_len)
# Remove CLS token attention to other tokens
patch_attention = attention_map[1:, 1:] # Remove CLS token rows and columns
# Create visualization
fig, ax = plt.subplots(figsize=(8, 6))
# Display attention matrix
im = ax.imshow(patch_attention.cpu().numpy(), cmap='viridis', aspect='auto')
ax.set_title(f'Attention Map - Layer {layer_index}, Head {head_index}', fontsize=14, fontweight='bold')
ax.set_xlabel('Key Patches')
ax.set_ylabel('Query Patches')
# Add colorbar
plt.colorbar(im, ax=ax)
plt.tight_layout()
return fig
except Exception as e:
print(f"Error in attention visualization: {str(e)}")
# Return a simple error plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Attention visualization failed:\n{str(e)}",
ha='center', va='center', transform=ax.transAxes, fontsize=10)
ax.set_title('Attention Visualization Error')
return fig
def explain_gradcam(model, processor, image, target_layer_index=-2):
"""
Generate GradCAM heatmap for the predicted class.
"""
try:
device = next(model.parameters()).device
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
input_tensor = inputs['pixel_values'].to(device)
# Get prediction
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = outputs.logits.argmax(dim=1).item()
# Get the target layer
try:
target_layer = model.vit.encoder.layer[target_layer_index].attention.attention
except:
target_layer = model.vit.encoder.layers[target_layer_index].attention.attention
# Create wrapped model for Captum compatibility
wrapped_model = ViTWrapper(model)
# Initialize GradCAM with wrapped model
gradcam = LayerGradCam(wrapped_model, target_layer)
# Generate attribution - handle tuple output
attribution = gradcam.attribute(input_tensor, target=predicted_class)
# FIX: Handle tuple output by taking the first element
if isinstance(attribution, tuple):
attribution = attribution[0]
# Convert attribution to heatmap
attribution = attribution.squeeze().cpu().detach().numpy()
# Normalize attribution
if attribution.max() > attribution.min():
attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min())
else:
attribution = np.zeros_like(attribution)
# Resize heatmap to match original image
original_size = image.size
heatmap = Image.fromarray((attribution * 255).astype(np.uint8))
heatmap = heatmap.resize(original_size, Image.Resampling.LANCZOS)
heatmap = np.array(heatmap)
# Create visualization figure
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
# Original image
ax1.imshow(image)
ax1.set_title('Original Image')
ax1.axis('off')
# Heatmap
ax2.imshow(heatmap, cmap='hot')
ax2.set_title('GradCAM Heatmap')
ax2.axis('off')
# Overlay
ax3.imshow(image)
ax3.imshow(heatmap, cmap='hot', alpha=0.5)
ax3.set_title('Overlay')
ax3.axis('off')
plt.tight_layout()
# Create overlay image for dashboard
heatmap_rgb = (plt.cm.hot(heatmap / 255.0)[:, :, :3] * 255).astype(np.uint8)
overlay_img = Image.fromarray(heatmap_rgb)
overlay_img = overlay_img.resize(original_size, Image.Resampling.LANCZOS)
# Blend with original
original_rgba = image.convert('RGBA')
overlay_rgba = overlay_img.convert('RGBA')
blended = Image.blend(original_rgba, overlay_rgba, alpha=0.5)
return fig, blended.convert('RGB')
except Exception as e:
print(f"Error in GradCAM: {str(e)}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"GradCAM failed:\n{str(e)}",
ha='center', va='center', transform=ax.transAxes, fontsize=10)
ax.set_title('GradCAM Error')
return fig, image
def explain_gradient_shap(model, processor, image, n_samples=5):
"""
Generate GradientSHAP explanations.
"""
try:
device = next(model.parameters()).device
# Preprocess image
inputs = processor(images=image, return_tensors="pt")
input_tensor = inputs['pixel_values'].to(device)
# Get prediction
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = outputs.logits.argmax(dim=1).item()
# Create baseline (black image)
baseline = torch.zeros_like(input_tensor)
# Create wrapped model for Captum compatibility
wrapped_model = ViTWrapper(model)
# Initialize GradientSHAP with wrapped model
gradient_shap = GradientShap(wrapped_model)
# Generate attribution
attribution = gradient_shap.attribute(
input_tensor,
baselines=baseline,
n_samples=n_samples,
target=predicted_class
)
# Summarize attribution across channels
attribution = attribution.squeeze().sum(dim=0).cpu().detach().numpy()
# Normalize
if attribution.max() > attribution.min():
attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min())
else:
attribution = np.zeros_like(attribution)
# Create visualization
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
# Original image
ax1.imshow(image)
ax1.set_title('Original Image')
ax1.axis('off')
# SHAP attribution
im = ax2.imshow(attribution, cmap='coolwarm')
ax2.set_title('GradientSHAP Attribution')
ax2.axis('off')
plt.colorbar(im, ax=ax2)
# Overlay
ax3.imshow(image, alpha=0.7)
im_overlay = ax3.imshow(attribution, cmap='coolwarm', alpha=0.5)
ax3.set_title('Attribution Overlay')
ax3.axis('off')
plt.colorbar(im_overlay, ax=ax3)
plt.tight_layout()
return fig
except Exception as e:
print(f"Error in GradientSHAP: {str(e)}")
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"GradientSHAP failed:\n{str(e)}",
ha='center', va='center', transform=ax.transAxes, fontsize=10)
ax.set_title('GradientSHAP Error')
return fig