File size: 5,576 Bytes
86f402d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | # models/explainability.py
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from typing import Tuple
from PIL import Image
class GradCAM:
"""
Gradient-weighted Class Activation Mapping
Shows which regions of image are important for prediction
"""
def __init__(self, model: torch.nn.Module, target_layer: str = None):
"""
Args:
model: The neural network
target_layer: Layer name to compute CAM on (usually last conv layer)
"""
self.model = model
self.gradients = None
self.activations = None
# Auto-detect target layer if not specified
if target_layer is None:
# Use last ConvNeXt stage
self.target_layer = model.convnext.stages[-1]
else:
self.target_layer = dict(model.named_modules())[target_layer]
# Register hooks
self.target_layer.register_forward_hook(self._save_activation)
self.target_layer.register_full_backward_hook(self._save_gradient)
def _save_activation(self, module, input, output):
"""Save forward activations"""
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output):
"""Save backward gradients"""
self.gradients = grad_output[0].detach()
def generate_cam(
self,
image: torch.Tensor,
target_class: int = None
) -> np.ndarray:
"""
Generate Class Activation Map
Args:
image: Input image [1, 3, H, W]
target_class: Class to generate CAM for (None = predicted class)
Returns:
cam: Activation map [H, W] normalized to 0-1
"""
self.model.eval()
# Forward pass
output = self.model(image)
# Use predicted class if not specified
if target_class is None:
target_class = output.argmax(dim=1).item()
# Zero gradients
self.model.zero_grad()
# Backward pass for target class
output[0, target_class].backward()
# Get gradients and activations
gradients = self.gradients[0] # [C, H, W]
activations = self.activations[0] # [C, H, W]
# Global average pooling of gradients
weights = gradients.mean(dim=(1, 2)) # [C]
# Weighted sum of activations
cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
for i, w in enumerate(weights):
cam += w * activations[i]
# ReLU
cam = F.relu(cam)
# Normalize to 0-1
cam = cam.cpu().numpy()
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam
def overlay_cam_on_image(
self,
image: np.ndarray, # [H, W, 3] RGB
cam: np.ndarray, # [h, w]
alpha: float = 0.5,
colormap: int = cv2.COLORMAP_JET
) -> np.ndarray:
"""
Overlay CAM heatmap on original image
Returns:
overlay: [H, W, 3] RGB image with heatmap
"""
H, W = image.shape[:2]
# Resize CAM to image size
cam_resized = cv2.resize(cam, (W, H))
# Convert to heatmap
heatmap = cv2.applyColorMap(
np.uint8(255 * cam_resized),
colormap
)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Blend with original image
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
return overlay
class AttentionVisualizer:
"""Visualize MedSigLIP attention maps"""
def __init__(self, model):
self.model = model
def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
"""
Extract attention maps from MedSigLIP
Returns:
attention: [num_heads, H, W] attention weights
"""
# Forward pass
with torch.no_grad():
_ = self.model(image)
# Get last layer attention from MedSigLIP
# Shape: [batch, num_heads, seq_len, seq_len]
attention = self.model.medsiglip_features
# Average across heads and extract spatial attention
# This is model-dependent - adjust based on MedSigLIP architecture
# Placeholder implementation
# You'll need to adapt this to your specific MedSigLIP implementation
return np.random.rand(14, 14) # Placeholder
def overlay_attention(
self,
image: np.ndarray,
attention: np.ndarray,
alpha: float = 0.6
) -> np.ndarray:
"""Overlay attention map on image"""
H, W = image.shape[:2]
# Resize attention to image size
attention_resized = cv2.resize(attention, (W, H))
# Normalize
attention_resized = (attention_resized - attention_resized.min())
if attention_resized.max() > 0:
attention_resized = attention_resized / attention_resized.max()
# Create colored overlay
heatmap = cv2.applyColorMap(
np.uint8(255 * attention_resized),
cv2.COLORMAP_VIRIDIS
)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
# Blend
overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
return overlay |