Spaces:
Sleeping
Sleeping
| # src/explainer.py | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import captum | |
| from captum.attr import LayerGradCam, GradientShap | |
| from captum.attr import visualization as viz | |
| import torch.nn.functional as F | |
| class ViTWrapper(torch.nn.Module): | |
| """ | |
| Wrapper class to make Hugging Face ViT compatible with Captum. | |
| This returns raw tensors instead of Hugging Face output objects. | |
| """ | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| def forward(self, x): | |
| # Hugging Face models expect pixel_values key | |
| outputs = self.model(pixel_values=x) | |
| return outputs.logits | |
| class AttentionHook: | |
| """Hook to capture attention weights from ViT model""" | |
| def __init__(self): | |
| self.attention_weights = None | |
| def __call__(self, module, input, output): | |
| # For ViT, attention weights are usually the second output | |
| if len(output) >= 2: | |
| self.attention_weights = output[1] # attention weights | |
| else: | |
| self.attention_weights = None | |
| def explain_attention(model, processor, image, layer_index=6, head_index=0): | |
| """ | |
| Extract and visualize attention weights using hooks. | |
| """ | |
| try: | |
| device = next(model.parameters()).device | |
| # Preprocess image | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Register hook to capture attention | |
| hook = AttentionHook() | |
| # Try different layer access patterns | |
| try: | |
| # For standard ViT structure | |
| target_layer = model.vit.encoder.layer[layer_index].attention.attention | |
| handle = target_layer.register_forward_hook(hook) | |
| except: | |
| try: | |
| # Alternative structure | |
| target_layer = model.vit.encoder.layers[layer_index].attention.attention | |
| handle = target_layer.register_forward_hook(hook) | |
| except: | |
| raise ValueError(f"Could not access layer {layer_index} for attention hook") | |
| # Forward pass to capture attention | |
| with torch.no_grad(): | |
| _ = model(**inputs) | |
| # Remove hook | |
| handle.remove() | |
| if hook.attention_weights is None: | |
| raise ValueError("No attention weights captured by hook") | |
| # Get attention weights | |
| attention_weights = hook.attention_weights # Shape: (batch, heads, seq_len, seq_len) | |
| attention_map = attention_weights[0, head_index] # Shape: (seq_len, seq_len) | |
| # Remove CLS token attention to other tokens | |
| patch_attention = attention_map[1:, 1:] # Remove CLS token rows and columns | |
| # Create visualization | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| # Display attention matrix | |
| im = ax.imshow(patch_attention.cpu().numpy(), cmap='viridis', aspect='auto') | |
| ax.set_title(f'Attention Map - Layer {layer_index}, Head {head_index}', fontsize=14, fontweight='bold') | |
| ax.set_xlabel('Key Patches') | |
| ax.set_ylabel('Query Patches') | |
| # Add colorbar | |
| plt.colorbar(im, ax=ax) | |
| plt.tight_layout() | |
| return fig | |
| except Exception as e: | |
| print(f"Error in attention visualization: {str(e)}") | |
| # Return a simple error plot | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"Attention visualization failed:\n{str(e)}", | |
| ha='center', va='center', transform=ax.transAxes, fontsize=10) | |
| ax.set_title('Attention Visualization Error') | |
| return fig | |
| def explain_gradcam(model, processor, image, target_layer_index=-2): | |
| """ | |
| Generate GradCAM heatmap for the predicted class. | |
| """ | |
| try: | |
| device = next(model.parameters()).device | |
| # Preprocess image | |
| inputs = processor(images=image, return_tensors="pt") | |
| input_tensor = inputs['pixel_values'].to(device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| predicted_class = outputs.logits.argmax(dim=1).item() | |
| # Get the target layer | |
| try: | |
| target_layer = model.vit.encoder.layer[target_layer_index].attention.attention | |
| except: | |
| target_layer = model.vit.encoder.layers[target_layer_index].attention.attention | |
| # Create wrapped model for Captum compatibility | |
| wrapped_model = ViTWrapper(model) | |
| # Initialize GradCAM with wrapped model | |
| gradcam = LayerGradCam(wrapped_model, target_layer) | |
| # Generate attribution - handle tuple output | |
| attribution = gradcam.attribute(input_tensor, target=predicted_class) | |
| # FIX: Handle tuple output by taking the first element | |
| if isinstance(attribution, tuple): | |
| attribution = attribution[0] | |
| # Convert attribution to heatmap | |
| attribution = attribution.squeeze().cpu().detach().numpy() | |
| # Normalize attribution | |
| if attribution.max() > attribution.min(): | |
| attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min()) | |
| else: | |
| attribution = np.zeros_like(attribution) | |
| # Resize heatmap to match original image | |
| original_size = image.size | |
| heatmap = Image.fromarray((attribution * 255).astype(np.uint8)) | |
| heatmap = heatmap.resize(original_size, Image.Resampling.LANCZOS) | |
| heatmap = np.array(heatmap) | |
| # Create visualization figure | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Original image | |
| ax1.imshow(image) | |
| ax1.set_title('Original Image') | |
| ax1.axis('off') | |
| # Heatmap | |
| ax2.imshow(heatmap, cmap='hot') | |
| ax2.set_title('GradCAM Heatmap') | |
| ax2.axis('off') | |
| # Overlay | |
| ax3.imshow(image) | |
| ax3.imshow(heatmap, cmap='hot', alpha=0.5) | |
| ax3.set_title('Overlay') | |
| ax3.axis('off') | |
| plt.tight_layout() | |
| # Create overlay image for dashboard | |
| heatmap_rgb = (plt.cm.hot(heatmap / 255.0)[:, :, :3] * 255).astype(np.uint8) | |
| overlay_img = Image.fromarray(heatmap_rgb) | |
| overlay_img = overlay_img.resize(original_size, Image.Resampling.LANCZOS) | |
| # Blend with original | |
| original_rgba = image.convert('RGBA') | |
| overlay_rgba = overlay_img.convert('RGBA') | |
| blended = Image.blend(original_rgba, overlay_rgba, alpha=0.5) | |
| return fig, blended.convert('RGB') | |
| except Exception as e: | |
| print(f"Error in GradCAM: {str(e)}") | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"GradCAM failed:\n{str(e)}", | |
| ha='center', va='center', transform=ax.transAxes, fontsize=10) | |
| ax.set_title('GradCAM Error') | |
| return fig, image | |
| def explain_gradient_shap(model, processor, image, n_samples=5): | |
| """ | |
| Generate GradientSHAP explanations. | |
| """ | |
| try: | |
| device = next(model.parameters()).device | |
| # Preprocess image | |
| inputs = processor(images=image, return_tensors="pt") | |
| input_tensor = inputs['pixel_values'].to(device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| predicted_class = outputs.logits.argmax(dim=1).item() | |
| # Create baseline (black image) | |
| baseline = torch.zeros_like(input_tensor) | |
| # Create wrapped model for Captum compatibility | |
| wrapped_model = ViTWrapper(model) | |
| # Initialize GradientSHAP with wrapped model | |
| gradient_shap = GradientShap(wrapped_model) | |
| # Generate attribution | |
| attribution = gradient_shap.attribute( | |
| input_tensor, | |
| baselines=baseline, | |
| n_samples=n_samples, | |
| target=predicted_class | |
| ) | |
| # Summarize attribution across channels | |
| attribution = attribution.squeeze().sum(dim=0).cpu().detach().numpy() | |
| # Normalize | |
| if attribution.max() > attribution.min(): | |
| attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min()) | |
| else: | |
| attribution = np.zeros_like(attribution) | |
| # Create visualization | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Original image | |
| ax1.imshow(image) | |
| ax1.set_title('Original Image') | |
| ax1.axis('off') | |
| # SHAP attribution | |
| im = ax2.imshow(attribution, cmap='coolwarm') | |
| ax2.set_title('GradientSHAP Attribution') | |
| ax2.axis('off') | |
| plt.colorbar(im, ax=ax2) | |
| # Overlay | |
| ax3.imshow(image, alpha=0.7) | |
| im_overlay = ax3.imshow(attribution, cmap='coolwarm', alpha=0.5) | |
| ax3.set_title('Attribution Overlay') | |
| ax3.axis('off') | |
| plt.colorbar(im_overlay, ax=ax3) | |
| plt.tight_layout() | |
| return fig | |
| except Exception as e: | |
| print(f"Error in GradientSHAP: {str(e)}") | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.text(0.5, 0.5, f"GradientSHAP failed:\n{str(e)}", | |
| ha='center', va='center', transform=ax.transAxes, fontsize=10) | |
| ax.set_title('GradientSHAP Error') | |
| return fig |