Spaces:
Running
Running
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model | |
| def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None): | |
| """ | |
| Generate Grad-CAM heatmap showing what the model actually focuses on. | |
| Args: | |
| img_array: Preprocessed image array (1, 224, 224, 3) | |
| model: Trained model | |
| last_conv_layer_name: Name of the last convolutional layer | |
| pred_index: Index of the class to generate heatmap for (None = predicted class) | |
| Returns: | |
| heatmap: 2D array showing model attention | |
| """ | |
| try: | |
| # Create a model that maps the input image to the activations of the last conv layer | |
| # as well as the output predictions | |
| grad_model = Model( | |
| inputs=[model.inputs], | |
| outputs=[model.get_layer(last_conv_layer_name).output, model.output] | |
| ) | |
| # Compute the gradient of the top predicted class for our input image | |
| # with respect to the activations of the last conv layer | |
| with tf.GradientTape() as tape: | |
| last_conv_layer_output, preds = grad_model(img_array) | |
| if pred_index is None: | |
| pred_index = tf.argmax(preds[0]) | |
| class_channel = preds[:, pred_index] | |
| # This is the gradient of the output neuron (top predicted or chosen) | |
| # with regard to the output feature map of the last conv layer | |
| grads = tape.gradient(class_channel, last_conv_layer_output) | |
| # This is a vector where each entry is the mean intensity of the gradient | |
| # over a specific feature map channel | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| # We multiply each channel in the feature map array | |
| # by "how important this channel is" with regard to the top predicted class | |
| # then sum all the channels to obtain the heatmap class activation | |
| last_conv_layer_output = last_conv_layer_output[0] | |
| heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] | |
| heatmap = tf.squeeze(heatmap) | |
| # For visualization purpose, we will also normalize the heatmap between 0 & 1 | |
| heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) | |
| return heatmap.numpy() | |
| except Exception as e: | |
| print(f"Grad-CAM error: {e}") | |
| return None | |
| def find_last_conv_layer(model): | |
| """ | |
| Automatically find the last convolutional layer in the model. | |
| """ | |
| conv_layers = [] | |
| for layer in model.layers: | |
| if 'conv' in layer.name.lower(): | |
| conv_layers.append(layer.name) | |
| if conv_layers: | |
| return conv_layers[-1] | |
| else: | |
| # Fallback: look for common layer names | |
| common_names = ['block5_conv3', 'conv5_block3_3_conv', 'top_conv', 'conv_7b'] | |
| for name in common_names: | |
| try: | |
| model.get_layer(name) | |
| return name | |
| except: | |
| continue | |
| return None | |
| def create_real_attention_heatmap(img, model, predictions): | |
| """ | |
| Create a real attention heatmap using Grad-CAM. | |
| """ | |
| try: | |
| # Preprocess image for Grad-CAM | |
| img_resized = img.resize((224, 224)) | |
| img_array = np.array(img_resized, dtype=np.float32) | |
| # Handle grayscale | |
| if len(img_array.shape) == 2: | |
| img_array = np.stack([img_array] * 3, axis=-1) | |
| # Normalize and add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) / 255.0 | |
| # Find the last convolutional layer | |
| last_conv_layer_name = find_last_conv_layer(model) | |
| if last_conv_layer_name is None: | |
| print("Could not find convolutional layer for Grad-CAM") | |
| return None | |
| print(f"Using layer: {last_conv_layer_name}") | |
| # Generate Grad-CAM heatmap | |
| heatmap = make_gradcam_heatmap( | |
| img_array, | |
| model, | |
| last_conv_layer_name, | |
| pred_index=np.argmax(predictions) | |
| ) | |
| if heatmap is not None: | |
| # Resize heatmap to match input image size | |
| heatmap_resized = tf.image.resize( | |
| heatmap[..., tf.newaxis], | |
| (224, 224) | |
| ).numpy()[:, :, 0] | |
| return heatmap_resized | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Real attention heatmap error: {e}") | |
| return None | |