File size: 5,576 Bytes
86f402d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# models/explainability.py

import torch
import torch.nn.functional as F
import numpy as np
import cv2
from typing import Tuple
from PIL import Image

class GradCAM:
    """
    Gradient-weighted Class Activation Mapping
    Shows which regions of image are important for prediction
    """
    
    def __init__(self, model: torch.nn.Module, target_layer: str = None):
        """
        Args:
            model: The neural network
            target_layer: Layer name to compute CAM on (usually last conv layer)
        """
        self.model = model
        self.gradients = None
        self.activations = None
        
        # Auto-detect target layer if not specified
        if target_layer is None:
            # Use last ConvNeXt stage
            self.target_layer = model.convnext.stages[-1]
        else:
            self.target_layer = dict(model.named_modules())[target_layer]
        
        # Register hooks
        self.target_layer.register_forward_hook(self._save_activation)
        self.target_layer.register_full_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output):
        """Save forward activations"""
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        """Save backward gradients"""
        self.gradients = grad_output[0].detach()
    
    def generate_cam(
        self,
        image: torch.Tensor,
        target_class: int = None
    ) -> np.ndarray:
        """
        Generate Class Activation Map
        
        Args:
            image: Input image [1, 3, H, W]
            target_class: Class to generate CAM for (None = predicted class)
            
        Returns:
            cam: Activation map [H, W] normalized to 0-1
        """
        self.model.eval()
        
        # Forward pass
        output = self.model(image)
        
        # Use predicted class if not specified
        if target_class is None:
            target_class = output.argmax(dim=1).item()
        
        # Zero gradients
        self.model.zero_grad()
        
        # Backward pass for target class
        output[0, target_class].backward()
        
        # Get gradients and activations
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=(1, 2))  # [C]
        
        # Weighted sum of activations
        cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        # ReLU
        cam = F.relu(cam)
        
        # Normalize to 0-1
        cam = cam.cpu().numpy()
        cam = cam - cam.min()
        if cam.max() > 0:
            cam = cam / cam.max()
        
        return cam
    
    def overlay_cam_on_image(
        self,
        image: np.ndarray,  # [H, W, 3] RGB
        cam: np.ndarray,    # [h, w]
        alpha: float = 0.5,
        colormap: int = cv2.COLORMAP_JET
    ) -> np.ndarray:
        """
        Overlay CAM heatmap on original image
        
        Returns:
            overlay: [H, W, 3] RGB image with heatmap
        """
        H, W = image.shape[:2]
        
        # Resize CAM to image size
        cam_resized = cv2.resize(cam, (W, H))
        
        # Convert to heatmap
        heatmap = cv2.applyColorMap(
            np.uint8(255 * cam_resized),
            colormap
        )
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Blend with original image
        overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
        
        return overlay

class AttentionVisualizer:
    """Visualize MedSigLIP attention maps"""
    
    def __init__(self, model):
        self.model = model
    
    def get_attention_maps(self, image: torch.Tensor) -> np.ndarray:
        """
        Extract attention maps from MedSigLIP
        
        Returns:
            attention: [num_heads, H, W] attention weights
        """
        # Forward pass
        with torch.no_grad():
            _ = self.model(image)
        
        # Get last layer attention from MedSigLIP
        # Shape: [batch, num_heads, seq_len, seq_len]
        attention = self.model.medsiglip_features
        
        # Average across heads and extract spatial attention
        # This is model-dependent - adjust based on MedSigLIP architecture
        
        # Placeholder implementation
        # You'll need to adapt this to your specific MedSigLIP implementation
        return np.random.rand(14, 14)  # Placeholder
    
    def overlay_attention(
        self,
        image: np.ndarray,
        attention: np.ndarray,
        alpha: float = 0.6
    ) -> np.ndarray:
        """Overlay attention map on image"""
        H, W = image.shape[:2]
        
        # Resize attention to image size
        attention_resized = cv2.resize(attention, (W, H))
        
        # Normalize
        attention_resized = (attention_resized - attention_resized.min())
        if attention_resized.max() > 0:
            attention_resized = attention_resized / attention_resized.max()
        
        # Create colored overlay
        heatmap = cv2.applyColorMap(
            np.uint8(255 * attention_resized),
            cv2.COLORMAP_VIRIDIS
        )
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Blend
        overlay = (alpha * heatmap + (1 - alpha) * image).astype(np.uint8)
        
        return overlay