# src/explainer.py import captum import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from captum.attr import GradientShap, LayerGradCam from captum.attr import visualization as viz from PIL import Image class ViTWrapper(torch.nn.Module): """ Wrapper class to make Hugging Face ViT compatible with Captum. This returns raw tensors instead of Hugging Face output objects. """ def __init__(self, model): super().__init__() self.model = model def forward(self, x): # Hugging Face models expect pixel_values key outputs = self.model(pixel_values=x) return outputs.logits class AttentionHook: """Hook to capture attention weights from ViT model""" def __init__(self): self.attention_weights = None def __call__(self, module, input, output): # For ViT, attention weights are usually the second output if len(output) >= 2: self.attention_weights = output[1] # attention weights else: self.attention_weights = None def explain_attention(model, processor, image, layer_index=6, head_index=0): """ Extract and visualize attention weights using hooks. """ try: device = next(model.parameters()).device # Preprocess image inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Register hook to capture attention (supported for ViT/DeiT only) hook = AttentionHook() # Only support attention visualization for ViT-like architectures if hasattr(model, "vit"): try: # For standard ViT structure target_layer = model.vit.encoder.layer[layer_index].attention.attention handle = target_layer.register_forward_hook(hook) except Exception: try: # Alternative structure target_layer = model.vit.encoder.layers[layer_index].attention.attention handle = target_layer.register_forward_hook(hook) except Exception: raise ValueError( f"Could not access layer {layer_index} for attention hook" ) else: raise ValueError( "Attention visualization currently supports ViT/DeiT models only. " "Please select a ViT model or use GradCAM/GradientSHAP." ) # Forward pass to capture attention with torch.no_grad(): _ = model(**inputs) # Remove hook handle.remove() if hook.attention_weights is None: raise ValueError("No attention weights captured by hook") # Get attention weights attention_weights = hook.attention_weights # Shape: (batch, heads, seq_len, seq_len) attention_map = attention_weights[0, head_index] # Shape: (seq_len, seq_len) # Remove CLS token attention to other tokens patch_attention = attention_map[1:, 1:] # Remove CLS token rows and columns # Create visualization fig, ax = plt.subplots(figsize=(8, 6)) # Display attention matrix im = ax.imshow(patch_attention.cpu().numpy(), cmap="viridis", aspect="auto") ax.set_title( f"Attention Map - Layer {layer_index}, Head {head_index}", fontsize=14, fontweight="bold", ) ax.set_xlabel("Key Patches") ax.set_ylabel("Query Patches") # Add colorbar plt.colorbar(im, ax=ax) plt.tight_layout() return fig except Exception as e: print(f"Error in attention visualization: {str(e)}") # Return a simple error plot fig, ax = plt.subplots(figsize=(8, 6)) ax.text( 0.5, 0.5, f"Attention visualization failed:\n{str(e)}", ha="center", va="center", transform=ax.transAxes, fontsize=10, ) ax.set_title("Attention Visualization Error") return fig def explain_gradcam(model, processor, image, target_layer_index=-2): """ Generate GradCAM heatmap for the predicted class. """ try: device = next(model.parameters()).device # Preprocess image inputs = processor(images=image, return_tensors="pt") input_tensor = inputs["pixel_values"].to(device) # Get prediction with torch.no_grad(): outputs = model(input_tensor) predicted_class = outputs.logits.argmax(dim=1).item() # Get the target layer adaptively across architectures target_layer = _select_gradcam_target_layer(model, target_layer_index) # Create wrapped model for Captum compatibility wrapped_model = ViTWrapper(model) # Initialize GradCAM with wrapped model gradcam = LayerGradCam(wrapped_model, target_layer) # Generate attribution - handle tuple output attribution = gradcam.attribute(input_tensor, target=predicted_class) # Handle tuple output by taking the first element if isinstance(attribution, tuple): attribution = attribution[0] # If attribution has channel dimension, aggregate over channels if isinstance(attribution, torch.Tensor): att = attribution.detach().cpu() if att.dim() == 4: # (B, C, H, W) att = att.sum(dim=1) # (B, H, W) att = att.squeeze(0) # (H, W) attribution = att.numpy() else: attribution = np.array(attribution) # Normalize attribution if attribution.max() > attribution.min(): attribution = (attribution - attribution.min()) / ( attribution.max() - attribution.min() ) else: attribution = np.zeros_like(attribution) # Resize heatmap to match original image original_size = image.size heatmap = Image.fromarray((np.clip(attribution, 0, 1) * 255).astype(np.uint8)) heatmap = heatmap.resize(original_size, Image.Resampling.LANCZOS) heatmap = np.array(heatmap) # Create visualization figure fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) # Original image ax1.imshow(image) ax1.set_title("Original Image") ax1.axis("off") # Heatmap ax2.imshow(heatmap, cmap="hot") ax2.set_title("GradCAM Heatmap") ax2.axis("off") # Overlay ax3.imshow(image) ax3.imshow(heatmap, cmap="hot", alpha=0.5) ax3.set_title("Overlay") ax3.axis("off") plt.tight_layout() # Create overlay image for dashboard heatmap_rgb = (plt.cm.hot(heatmap / 255.0)[:, :, :3] * 255).astype(np.uint8) overlay_img = Image.fromarray(heatmap_rgb) overlay_img = overlay_img.resize(original_size, Image.Resampling.LANCZOS) # Blend with original original_rgba = image.convert("RGBA") overlay_rgba = overlay_img.convert("RGBA") blended = Image.blend(original_rgba, overlay_rgba, alpha=0.5) return fig, blended.convert("RGB") except Exception as e: print(f"Error in GradCAM: {str(e)}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text( 0.5, 0.5, f"GradCAM failed:\n{str(e)}", ha="center", va="center", transform=ax.transAxes, fontsize=10, ) ax.set_title("GradCAM Error") return fig, image def _select_gradcam_target_layer(model, target_layer_index): """Best-effort selection of a target layer for GradCAM across architectures.""" # 1) ViT / DeiT: use attention layer as before if hasattr(model, "vit"): try: return model.vit.encoder.layer[target_layer_index].attention.attention except Exception: return model.vit.encoder.layers[target_layer_index].attention.attention # 2) ResNet (HF): try final bottleneck/conv in layer4 if hasattr(model, "resnet"): res = model.resnet try: blk = res.layer4[-1] # Prefer the last conv if exists for attr in ["conv3", "conv2", "conv1"]: if hasattr(blk, attr): return getattr(blk, attr) return blk except Exception: pass # 3) Swin (HF): try last attention block; fallback to patch embedding conv if hasattr(model, "swin"): try: # Common pattern: encoder.layers[-1].blocks[-1].attention return model.swin.encoder.layers[-1].blocks[-1].attention except Exception: try: return model.swin.embeddings.patch_embeddings.projection except Exception: pass # 4) Generic fallback: last Conv2d found in the model last_conv = None for m in model.modules(): if isinstance(m, torch.nn.Conv2d): last_conv = m if last_conv is not None: return last_conv # As a final fallback, just return the model (may not work with GradCAM, but avoids attribute errors) return model def explain_gradient_shap(model, processor, image, n_samples=5): """ Generate GradientSHAP explanations. """ try: device = next(model.parameters()).device # Preprocess image inputs = processor(images=image, return_tensors="pt") input_tensor = inputs["pixel_values"].to(device) # Get prediction with torch.no_grad(): outputs = model(input_tensor) predicted_class = outputs.logits.argmax(dim=1).item() # Create baseline (black image) baseline = torch.zeros_like(input_tensor) # Create wrapped model for Captum compatibility wrapped_model = ViTWrapper(model) # Initialize GradientSHAP with wrapped model gradient_shap = GradientShap(wrapped_model) # Generate attribution attribution = gradient_shap.attribute( input_tensor, baselines=baseline, n_samples=n_samples, target=predicted_class ) # Summarize attribution across channels attribution = attribution.squeeze().sum(dim=0).cpu().detach().numpy() # Normalize if attribution.max() > attribution.min(): attribution = (attribution - attribution.min()) / ( attribution.max() - attribution.min() ) else: attribution = np.zeros_like(attribution) # Create visualization fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5)) # Original image ax1.imshow(image) ax1.set_title("Original Image") ax1.axis("off") # SHAP attribution im = ax2.imshow(attribution, cmap="coolwarm") ax2.set_title("GradientSHAP Attribution") ax2.axis("off") plt.colorbar(im, ax=ax2) # Overlay ax3.imshow(image, alpha=0.7) im_overlay = ax3.imshow(attribution, cmap="coolwarm", alpha=0.5) ax3.set_title("Attribution Overlay") ax3.axis("off") plt.colorbar(im_overlay, ax=ax3) plt.tight_layout() return fig except Exception as e: print(f"Error in GradientSHAP: {str(e)}") fig, ax = plt.subplots(figsize=(8, 6)) ax.text( 0.5, 0.5, f"GradientSHAP failed:\n{str(e)}", ha="center", va="center", transform=ax.transAxes, fontsize=10, ) ax.set_title("GradientSHAP Error") return fig