| |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import cv2 |
| from typing import Tuple |
| from PIL import Image |
|
|
| class GradCAM: |
| """ |
| Gradient-weighted Class Activation Mapping |
| Shows which regions of image are important for prediction |
| """ |
| |
| def __init__(self, model: torch.nn.Module, target_layer: str = None): |
| """ |
| Args: |
| model: The neural network |
| target_layer: Layer name to compute CAM on (usually last conv layer) |
| """ |
| self.model = model |
| self.gradients = None |
| self.activations = None |
| |
| |
| if target_layer is None: |
| |
| self.target_layer = model.convnext.stages[-1] |
| else: |
| self.target_layer = dict(model.named_modules())[target_layer] |
| |
| |
| self.target_layer.register_forward_hook(self._save_activation) |
| self.target_layer.register_full_backward_hook(self._save_gradient) |
| |
| def _save_activation(self, module, input, output): |
| """Save forward activations""" |
| self.activations = output.detach() |
| |
| def _save_gradient(self, module, grad_input, grad_output): |
| """Save backward gradients""" |
| self.gradients = grad_output[0].detach() |
| |
| def generate_cam( |
| self, |
| image: torch.Tensor, |
| target_class: int = None |
| ) -> np.ndarray: |
| """ |
| Generate Class Activation Map |
| |
| Args: |
| image: Input image [1, 3, H, W] |
| target_class: Class to generate CAM for (None = predicted class) |
| |
| Returns: |
| cam: Activation map [H, W] normalized to 0-1 |
| """ |
| self.model.eval() |
| |
| |
| output = self.model(image) |
| |
| |
| if target_class is None: |
| target_class = output.argmax(dim=1).item() |
| |
| |
| self.model.zero_grad() |
| |
| |
| output[0, target_class].backward() |
| |
| |
| gradients = self.gradients[0] |
| activations = self.activations[0] |
| |
| |
| weights = gradients.mean(dim=(1, 2)) |
| |
| |
| cam = torch.zeros(activations.shape[1:], dtype=torch.float32) |
| for i, w in enumerate(weights): |
| cam += w * activations[i] |
| |
| |
| cam = F.relu(cam) |
| |
| |
| cam = cam.cpu().numpy() |
| cam = cam - cam.min() |
| if cam.max() > 0: |
| cam = cam / cam.max() |
| |
| return cam |
| |
| def overlay_cam_on_image( |
| self, |
| image: np.ndarray, |
| cam: np.ndarray, |
| alpha: float = 0.5, |
| colormap: int = cv2.COLORMAP_JET |
| ) -> np.ndarray: |
| """ |
| Overlay CAM heatmap on original image |
| |
| Returns: |
| overlay: [H, W, 3] RGB image with heatmap |
| """ |
| H, W = image.shape[:2] |
| |
| |
| cam_resized = cv2.resize(cam, (W, H)) |
| |
| |
| heatmap = cv2.applyColorMap( |
| np.uint8(255 * cam_resized), |
| colormap |
| ) |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| |
| |
| overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8) |
| |
| return overlay |
|
|
| class AttentionVisualizer: |
| """Visualize MedSigLIP attention maps""" |
| |
| def __init__(self, model): |
| self.model = model |
| |
| def get_attention_maps(self, image: torch.Tensor) -> np.ndarray: |
| """ |
| Extract attention maps from MedSigLIP |
| |
| Returns: |
| attention: [num_heads, H, W] attention weights |
| """ |
| |
| with torch.no_grad(): |
| _ = self.model(image) |
| |
| |
| |
| attention = self.model.medsiglip_features |
| |
| |
| |
| |
| |
| |
| return np.random.rand(14, 14) |
| |
| def overlay_attention( |
| self, |
| image: np.ndarray, |
| attention: np.ndarray, |
| alpha: float = 0.6 |
| ) -> np.ndarray: |
| """Overlay attention map on image""" |
| H, W = image.shape[:2] |
| |
| |
| attention_resized = cv2.resize(attention, (W, H)) |
| |
| |
| attention_resized = (attention_resized - attention_resized.min()) |
| if attention_resized.max() > 0: |
| attention_resized = attention_resized / attention_resized.max() |
| |
| |
| heatmap = cv2.applyColorMap( |
| np.uint8(255 * attention_resized), |
| cv2.COLORMAP_VIRIDIS |
| ) |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| |
| |
| overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8) |
| |
| return overlay |