emiraran commited on
Commit
43b6e40
·
verified ·
1 Parent(s): 1eb8b57

Update gradcam_utils.py

Browse files
Files changed (1) hide show
  1. gradcam_utils.py +255 -187
gradcam_utils.py CHANGED
@@ -1,187 +1,255 @@
1
- """
2
- Grad-CAM Implementation for Chest X-Ray Classification
3
- ========================================================
4
-
5
- Visualizes which regions of the X-ray the model focuses on when making predictions.
6
-
7
- Reference: Selvaraju et al. (2017) - Grad-CAM: Visual Explanations from Deep Networks
8
- """
9
-
10
- import tensorflow as tf
11
- import numpy as np
12
- import cv2
13
- from PIL import Image
14
-
15
-
16
- def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
17
- """
18
- Generate Grad-CAM heatmap for a given image and prediction.
19
-
20
- Args:
21
- img_array: Preprocessed image (batch_size, height, width, channels)
22
- model: Trained Keras model
23
- last_conv_layer_name: Name of last convolutional layer
24
- pred_index: Target class index (if None, uses predicted class)
25
-
26
- Returns:
27
- heatmap: Normalized heatmap (0-1 range)
28
- """
29
- # Create a model that maps the input image to the activations of the last conv layer
30
- # as well as the output predictions
31
- grad_model = tf.keras.models.Model(
32
- [model.inputs],
33
- [model.get_layer(last_conv_layer_name).output, model.output]
34
- )
35
-
36
- # Compute the gradient of the top predicted class for our input image
37
- # with respect to the activations of the last conv layer
38
- with tf.GradientTape() as tape:
39
- last_conv_layer_output, preds = grad_model(img_array)
40
- if pred_index is None:
41
- pred_index = tf.argmax(preds[0])
42
- class_channel = preds[:, pred_index]
43
-
44
- # Gradient of the output neuron with regard to the output feature map of the last conv layer
45
- grads = tape.gradient(class_channel, last_conv_layer_output)
46
-
47
- # Vector where each entry is the mean intensity of the gradient over a specific feature map channel
48
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
49
-
50
- # Multiply each channel in the feature map array by "how important this channel is"
51
- last_conv_layer_output = last_conv_layer_output[0]
52
- heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
53
- heatmap = tf.squeeze(heatmap)
54
-
55
- # Normalize the heatmap between 0 & 1 for visualization
56
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
57
- return heatmap.numpy()
58
-
59
-
60
- def overlay_heatmap_on_image(img, heatmap, alpha=0.4, colormap=cv2.COLORMAP_JET):
61
- """
62
- Overlay Grad-CAM heatmap on original image.
63
-
64
- Args:
65
- img: Original PIL Image or numpy array
66
- heatmap: Grad-CAM heatmap (0-1 range)
67
- alpha: Transparency of heatmap overlay (0-1)
68
- colormap: OpenCV colormap (default: JET - red=hot, blue=cold)
69
-
70
- Returns:
71
- superimposed_img: PIL Image with heatmap overlay
72
- """
73
- # Convert PIL to numpy if needed
74
- if isinstance(img, Image.Image):
75
- img = np.array(img)
76
-
77
- # Resize heatmap to match image size
78
- heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
79
-
80
- # Convert heatmap to RGB
81
- heatmap_colored = np.uint8(255 * heatmap_resized)
82
- heatmap_colored = cv2.applyColorMap(heatmap_colored, colormap)
83
- heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
84
-
85
- # Superimpose the heatmap on original image
86
- superimposed_img = heatmap_colored * alpha + img * (1 - alpha)
87
- superimposed_img = np.uint8(superimposed_img)
88
-
89
- return Image.fromarray(superimposed_img)
90
-
91
-
92
- def generate_gradcam_for_disease(image, model, disease_name, label_encoder,
93
- last_conv_layer_name='top_conv', img_size=224):
94
- """
95
- Generate Grad-CAM visualization for a specific disease prediction.
96
-
97
- Args:
98
- image: PIL Image
99
- model: Trained model
100
- disease_name: Name of disease to visualize
101
- label_encoder: Disease name -> index mapping
102
- last_conv_layer_name: Name of last conv layer in EfficientNetB0
103
- img_size: Input image size
104
-
105
- Returns:
106
- overlaid_image: PIL Image with Grad-CAM overlay
107
- heatmap: Raw heatmap array
108
- """
109
- # Preprocess image
110
- img_resized = image.convert('RGB').resize((img_size, img_size))
111
- img_array = np.array(img_resized) / 255.0
112
- img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
113
-
114
- # Get disease index
115
- disease_idx = label_encoder[disease_name]
116
-
117
- # Generate heatmap
118
- heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, disease_idx)
119
-
120
- # Overlay on original image
121
- overlaid_image = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.4)
122
-
123
- return overlaid_image, heatmap
124
-
125
-
126
- def generate_gradcam_for_top_predictions(image, model, predictions, label_encoder,
127
- top_k=3, last_conv_layer_name='top_conv'):
128
- """
129
- Generate Grad-CAM for top K predicted diseases.
130
-
131
- Args:
132
- image: PIL Image
133
- model: Trained model
134
- predictions: List of prediction dicts from main app
135
- label_encoder: Disease name -> index mapping
136
- top_k: Number of top predictions to visualize
137
- last_conv_layer_name: Name of last conv layer
138
-
139
- Returns:
140
- gradcam_images: List of (disease_name, overlaid_image, probability) tuples
141
- """
142
- gradcam_images = []
143
-
144
- # Sort predictions by probability
145
- sorted_preds = sorted(predictions, key=lambda x: x['probability'], reverse=True)[:top_k]
146
-
147
- for pred in sorted_preds:
148
- disease_name = pred['disease']
149
- probability = pred['probability']
150
-
151
- # Generate Grad-CAM
152
- overlaid_img, _ = generate_gradcam_for_disease(
153
- image, model, disease_name, label_encoder, last_conv_layer_name
154
- )
155
-
156
- gradcam_images.append((disease_name, overlaid_img, probability))
157
-
158
- return gradcam_images
159
-
160
-
161
- def get_last_conv_layer_name(model):
162
- """
163
- Automatically find the last convolutional layer in the model.
164
-
165
- For EfficientNetB0, it's typically 'top_conv' or the last Conv2D layer.
166
-
167
- Args:
168
- model: Keras model
169
-
170
- Returns:
171
- layer_name: Name of last conv layer
172
- """
173
- # Try common names first
174
- common_names = ['top_conv', 'block7a_project_conv', 'conv_head']
175
- for name in common_names:
176
- try:
177
- model.get_layer(name)
178
- return name
179
- except:
180
- pass
181
-
182
- # Search backwards for Conv2D layer
183
- for layer in reversed(model.layers):
184
- if isinstance(layer, tf.keras.layers.Conv2D):
185
- return layer.name
186
-
187
- raise ValueError("No convolutional layer found in model!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Improved Grad-CAM Implementation for Medical Images
3
+ ====================================================
4
+ Fixed version with better visualization and noise reduction
5
+ """
6
+
7
+ import tensorflow as tf
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image
11
+
12
+
13
+ def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
14
+ """
15
+ Generate improved Grad-CAM heatmap with noise reduction.
16
+
17
+ Args:
18
+ img_array: Preprocessed image (batch_size, height, width, channels)
19
+ model: Trained Keras model
20
+ last_conv_layer_name: Name of last convolutional layer
21
+ pred_index: Target class index (if None, uses predicted class)
22
+
23
+ Returns:
24
+ heatmap: Normalized heatmap (0-1 range)
25
+ """
26
+ # Create gradient model
27
+ grad_model = tf.keras.models.Model(
28
+ [model.inputs],
29
+ [model.get_layer(last_conv_layer_name).output, model.output]
30
+ )
31
+
32
+ with tf.GradientTape() as tape:
33
+ conv_outputs, predictions = grad_model(img_array)
34
+
35
+ if pred_index is None:
36
+ pred_index = tf.argmax(predictions[0])
37
+
38
+ # Get the score for target class
39
+ class_channel = predictions[:, pred_index]
40
+
41
+ # Compute gradients
42
+ grads = tape.gradient(class_channel, conv_outputs)
43
+
44
+ # Global average pooling of gradients
45
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
46
+
47
+ # Weight the channels by importance
48
+ conv_outputs = conv_outputs[0]
49
+ pooled_grads = pooled_grads.numpy()
50
+ conv_outputs = conv_outputs.numpy()
51
+
52
+ # Multiply each channel by its importance
53
+ for i in range(pooled_grads.shape[-1]):
54
+ conv_outputs[:, :, i] *= pooled_grads[i]
55
+
56
+ # Average over all channels to get the heatmap
57
+ heatmap = np.mean(conv_outputs, axis=-1)
58
+
59
+ # Apply ReLU to heatmap (only positive influence)
60
+ heatmap = np.maximum(heatmap, 0)
61
+
62
+ # Normalize between 0 and 1
63
+ if heatmap.max() > 0:
64
+ heatmap = heatmap / heatmap.max()
65
+
66
+ # Apply slight gaussian blur to reduce noise
67
+ heatmap = cv2.GaussianBlur(heatmap, (3, 3), 0)
68
+
69
+ return heatmap
70
+
71
+
72
+ def overlay_heatmap_on_image(img, heatmap, alpha=0.5, colormap=cv2.COLORMAP_JET):
73
+ """
74
+ Overlay Grad-CAM heatmap on original image with better contrast.
75
+
76
+ Args:
77
+ img: Original PIL Image or numpy array
78
+ heatmap: Grad-CAM heatmap (0-1 range)
79
+ alpha: Transparency of heatmap overlay (default: 0.5 for better visibility)
80
+ colormap: OpenCV colormap (JET: red=important, blue=not important)
81
+
82
+ Returns:
83
+ superimposed_img: PIL Image with heatmap overlay
84
+ """
85
+ # Convert PIL to numpy if needed
86
+ if isinstance(img, Image.Image):
87
+ img = np.array(img)
88
+
89
+ # Ensure image is RGB
90
+ if len(img.shape) == 2:
91
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
92
+
93
+ # Resize heatmap to match image size
94
+ heatmap_resized = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
95
+
96
+ # Apply threshold to remove very weak activations (noise reduction)
97
+ threshold = 0.2 # Only show activations above 20%
98
+ heatmap_resized[heatmap_resized < threshold] = 0
99
+
100
+ # Convert heatmap to RGB colormap
101
+ heatmap_colored = np.uint8(255 * heatmap_resized)
102
+ heatmap_colored = cv2.applyColorMap(heatmap_colored, colormap)
103
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
104
+
105
+ # Normalize original image to 0-255 range
106
+ if img.max() <= 1.0:
107
+ img = np.uint8(255 * img)
108
+
109
+ # Create mask for non-zero heatmap areas
110
+ mask = heatmap_resized > 0
111
+
112
+ # Create output image
113
+ superimposed_img = img.copy().astype(float)
114
+
115
+ # Only apply heatmap where mask is True
116
+ superimposed_img[mask] = (
117
+ heatmap_colored[mask] * alpha + img[mask] * (1 - alpha)
118
+ )
119
+
120
+ superimposed_img = np.uint8(np.clip(superimposed_img, 0, 255))
121
+
122
+ return Image.fromarray(superimposed_img)
123
+
124
+
125
+ def get_last_conv_layer_name(model):
126
+ """
127
+ Find the last convolutional layer in EfficientNetB0.
128
+
129
+ Args:
130
+ model: Keras model
131
+
132
+ Returns:
133
+ layer_name: Name of last conv layer
134
+ """
135
+ # EfficientNetB0 specific layer names (in order of preference)
136
+ efficientnet_layers = [
137
+ 'top_conv',
138
+ 'block7a_project_conv',
139
+ 'block6d_project_conv',
140
+ 'block6c_project_conv',
141
+ 'conv_head'
142
+ ]
143
+
144
+ # Try EfficientNet specific layers first
145
+ for layer_name in efficientnet_layers:
146
+ try:
147
+ layer = model.get_layer(layer_name)
148
+ print(f"✅ Found Grad-CAM layer: {layer_name}")
149
+ return layer_name
150
+ except:
151
+ continue
152
+
153
+ # Fallback: search for last Conv2D layer
154
+ for layer in reversed(model.layers):
155
+ if isinstance(layer, tf.keras.layers.Conv2D):
156
+ print(f"✅ Using fallback Conv2D layer: {layer.name}")
157
+ return layer.name
158
+
159
+ # Last resort: search in nested models
160
+ for layer in reversed(model.layers):
161
+ if hasattr(layer, 'layers'):
162
+ for sublayer in reversed(layer.layers):
163
+ if isinstance(sublayer, tf.keras.layers.Conv2D):
164
+ print(f"✅ Using nested Conv2D layer: {sublayer.name}")
165
+ return sublayer.name
166
+
167
+ raise ValueError("❌ No convolutional layer found in model!")
168
+
169
+
170
+ def create_gradcam_comparison(original_img, heatmap, predictions, disease_name):
171
+ """
172
+ Create a side-by-side comparison with original, heatmap, and overlay.
173
+
174
+ Args:
175
+ original_img: Original PIL Image
176
+ heatmap: Grad-CAM heatmap
177
+ predictions: Model predictions
178
+ disease_name: Name of disease being visualized
179
+
180
+ Returns:
181
+ comparison_img: PIL Image with 3-panel comparison
182
+ """
183
+ # Convert original to numpy
184
+ if isinstance(original_img, Image.Image):
185
+ original_np = np.array(original_img)
186
+ else:
187
+ original_np = original_img
188
+
189
+ # Resize heatmap
190
+ heatmap_resized = cv2.resize(heatmap, (original_np.shape[1], original_np.shape[0]))
191
+
192
+ # Create colored heatmap
193
+ heatmap_colored = np.uint8(255 * heatmap_resized)
194
+ heatmap_colored = cv2.applyColorMap(heatmap_colored, cv2.COLORMAP_JET)
195
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
196
+
197
+ # Create overlay
198
+ overlay = overlay_heatmap_on_image(original_img, heatmap, alpha=0.5)
199
+ overlay_np = np.array(overlay)
200
+
201
+ # Ensure all images are same size and RGB
202
+ if len(original_np.shape) == 2:
203
+ original_np = cv2.cvtColor(original_np, cv2.COLOR_GRAY2RGB)
204
+
205
+ # Stack horizontally
206
+ comparison = np.hstack([original_np, heatmap_colored, overlay_np])
207
+
208
+ return Image.fromarray(comparison)
209
+
210
+
211
+ def generate_multi_disease_gradcam(image, model, predictions, all_diseases,
212
+ last_conv_layer_name, top_k=3, img_size=224):
213
+ """
214
+ Generate Grad-CAM visualizations for multiple diseases.
215
+
216
+ Args:
217
+ image: Input PIL Image or numpy array
218
+ model: Trained model
219
+ predictions: Prediction probabilities for all diseases
220
+ all_diseases: List of disease names
221
+ last_conv_layer_name: Name of last conv layer
222
+ top_k: Number of top predictions to visualize
223
+ img_size: Image size for model input
224
+
225
+ Returns:
226
+ gradcam_results: List of (disease_name, probability, gradcam_image) tuples
227
+ """
228
+ # Preprocess image
229
+ if isinstance(image, np.ndarray):
230
+ img_pil = Image.fromarray(image.astype('uint8'))
231
+ else:
232
+ img_pil = image
233
+
234
+ img_resized = img_pil.convert('RGB').resize((img_size, img_size))
235
+ img_array = np.array(img_resized) / 255.0
236
+ img_array = np.expand_dims(img_array, axis=0).astype(np.float32)
237
+
238
+ # Get top K diseases
239
+ top_indices = np.argsort(predictions)[::-1][:top_k]
240
+
241
+ results = []
242
+
243
+ for idx in top_indices:
244
+ disease_name = all_diseases[idx]
245
+ probability = float(predictions[idx])
246
+
247
+ # Generate heatmap
248
+ heatmap = make_gradcam_heatmap(img_array, model, last_conv_layer_name, idx)
249
+
250
+ # Create overlay
251
+ gradcam_img = overlay_heatmap_on_image(img_resized, heatmap, alpha=0.5)
252
+
253
+ results.append((disease_name, probability, gradcam_img))
254
+
255
+ return results