| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import cv2 |
| | from typing import Tuple |
| | from PIL import Image |
| |
|
| | class GradCAM: |
| | """ |
| | Gradient-weighted Class Activation Mapping |
| | Shows which regions of image are important for prediction |
| | """ |
| | |
| | def __init__(self, model: torch.nn.Module, target_layer: str = None): |
| | """ |
| | Args: |
| | model: The neural network |
| | target_layer: Layer name to compute CAM on (usually last conv layer) |
| | """ |
| | self.model = model |
| | self.gradients = None |
| | self.activations = None |
| | |
| | |
| | if target_layer is None: |
| | |
| | self.target_layer = model.convnext.stages[-1] |
| | else: |
| | self.target_layer = dict(model.named_modules())[target_layer] |
| | |
| | |
| | self.target_layer.register_forward_hook(self._save_activation) |
| | self.target_layer.register_full_backward_hook(self._save_gradient) |
| | |
| | def _save_activation(self, module, input, output): |
| | """Save forward activations""" |
| | self.activations = output.detach() |
| | |
| | def _save_gradient(self, module, grad_input, grad_output): |
| | """Save backward gradients""" |
| | self.gradients = grad_output[0].detach() |
| | |
| | def generate_cam( |
| | self, |
| | image: torch.Tensor, |
| | target_class: int = None |
| | ) -> np.ndarray: |
| | """ |
| | Generate Class Activation Map |
| | |
| | Args: |
| | image: Input image [1, 3, H, W] |
| | target_class: Class to generate CAM for (None = predicted class) |
| | |
| | Returns: |
| | cam: Activation map [H, W] normalized to 0-1 |
| | """ |
| | self.model.eval() |
| | |
| | |
| | output = self.model(image) |
| | |
| | |
| | if target_class is None: |
| | target_class = output.argmax(dim=1).item() |
| | |
| | |
| | self.model.zero_grad() |
| | |
| | |
| | output[0, target_class].backward() |
| | |
| | |
| | gradients = self.gradients[0] |
| | activations = self.activations[0] |
| | |
| | |
| | weights = gradients.mean(dim=(1, 2)) |
| | |
| | |
| | cam = torch.zeros(activations.shape[1:], dtype=torch.float32) |
| | for i, w in enumerate(weights): |
| | cam += w * activations[i] |
| | |
| | |
| | cam = F.relu(cam) |
| | |
| | |
| | cam = cam.cpu().numpy() |
| | cam = cam - cam.min() |
| | if cam.max() > 0: |
| | cam = cam / cam.max() |
| | |
| | return cam |
| | |
| | def overlay_cam_on_image( |
| | self, |
| | image: np.ndarray, |
| | cam: np.ndarray, |
| | alpha: float = 0.5, |
| | colormap: int = cv2.COLORMAP_JET |
| | ) -> np.ndarray: |
| | """ |
| | Overlay CAM heatmap on original image |
| | |
| | Returns: |
| | overlay: [H, W, 3] RGB image with heatmap |
| | """ |
| | H, W = image.shape[:2] |
| | |
| | |
| | cam_resized = cv2.resize(cam, (W, H)) |
| | |
| | |
| | heatmap = cv2.applyColorMap( |
| | np.uint8(255 * cam_resized), |
| | colormap |
| | ) |
| | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| | |
| | |
| | overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8) |
| | |
| | return overlay |
| |
|
| | class AttentionVisualizer: |
| | """Visualize MedSigLIP attention maps""" |
| | |
| | def __init__(self, model): |
| | self.model = model |
| | |
| | def get_attention_maps(self, image: torch.Tensor) -> np.ndarray: |
| | """ |
| | Extract attention maps from MedSigLIP |
| | |
| | Returns: |
| | attention: [num_heads, H, W] attention weights |
| | """ |
| | |
| | with torch.no_grad(): |
| | _ = self.model(image) |
| | |
| | |
| | |
| | attention = self.model.medsiglip_features |
| | |
| | |
| | |
| | |
| | |
| | |
| | return np.random.rand(14, 14) |
| | |
| | def overlay_attention( |
| | self, |
| | image: np.ndarray, |
| | attention: np.ndarray, |
| | alpha: float = 0.6 |
| | ) -> np.ndarray: |
| | """Overlay attention map on image""" |
| | H, W = image.shape[:2] |
| | |
| | |
| | attention_resized = cv2.resize(attention, (W, H)) |
| | |
| | |
| | attention_resized = (attention_resized - attention_resized.min()) |
| | if attention_resized.max() > 0: |
| | attention_resized = attention_resized / attention_resized.max() |
| | |
| | |
| | heatmap = cv2.applyColorMap( |
| | np.uint8(255 * attention_resized), |
| | cv2.COLORMAP_VIRIDIS |
| | ) |
| | heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
| | |
| | |
| | overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8) |
| | |
| | return overlay |