stroke-classification / src /gradcam_utils.py
bakhili's picture
Create gradcam_utils.py
f26049e verified
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
"""
Generate Grad-CAM heatmap showing what the model actually focuses on.
Args:
img_array: Preprocessed image array (1, 224, 224, 3)
model: Trained model
last_conv_layer_name: Name of the last convolutional layer
pred_index: Index of the class to generate heatmap for (None = predicted class)
Returns:
heatmap: 2D array showing model attention
"""
try:
# Create a model that maps the input image to the activations of the last conv layer
# as well as the output predictions
grad_model = Model(
inputs=[model.inputs],
outputs=[model.get_layer(last_conv_layer_name).output, model.output]
)
# Compute the gradient of the top predicted class for our input image
# with respect to the activations of the last conv layer
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(img_array)
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
# This is the gradient of the output neuron (top predicted or chosen)
# with regard to the output feature map of the last conv layer
grads = tape.gradient(class_channel, last_conv_layer_output)
# This is a vector where each entry is the mean intensity of the gradient
# over a specific feature map channel
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
# We multiply each channel in the feature map array
# by "how important this channel is" with regard to the top predicted class
# then sum all the channels to obtain the heatmap class activation
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
# For visualization purpose, we will also normalize the heatmap between 0 & 1
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy()
except Exception as e:
print(f"Grad-CAM error: {e}")
return None
def find_last_conv_layer(model):
"""
Automatically find the last convolutional layer in the model.
"""
conv_layers = []
for layer in model.layers:
if 'conv' in layer.name.lower():
conv_layers.append(layer.name)
if conv_layers:
return conv_layers[-1]
else:
# Fallback: look for common layer names
common_names = ['block5_conv3', 'conv5_block3_3_conv', 'top_conv', 'conv_7b']
for name in common_names:
try:
model.get_layer(name)
return name
except:
continue
return None
def create_real_attention_heatmap(img, model, predictions):
"""
Create a real attention heatmap using Grad-CAM.
"""
try:
# Preprocess image for Grad-CAM
img_resized = img.resize((224, 224))
img_array = np.array(img_resized, dtype=np.float32)
# Handle grayscale
if len(img_array.shape) == 2:
img_array = np.stack([img_array] * 3, axis=-1)
# Normalize and add batch dimension
img_array = np.expand_dims(img_array, axis=0) / 255.0
# Find the last convolutional layer
last_conv_layer_name = find_last_conv_layer(model)
if last_conv_layer_name is None:
print("Could not find convolutional layer for Grad-CAM")
return None
print(f"Using layer: {last_conv_layer_name}")
# Generate Grad-CAM heatmap
heatmap = make_gradcam_heatmap(
img_array,
model,
last_conv_layer_name,
pred_index=np.argmax(predictions)
)
if heatmap is not None:
# Resize heatmap to match input image size
heatmap_resized = tf.image.resize(
heatmap[..., tf.newaxis],
(224, 224)
).numpy()[:, :, 0]
return heatmap_resized
else:
return None
except Exception as e:
print(f"Real attention heatmap error: {e}")
return None