File size: 6,711 Bytes
a6bbc26 c00307f a6bbc26 c00307f a6bbc26 c00307f a6bbc26 c00307f a6bbc26 c00307f a6bbc26 c00307f a6bbc26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
os.environ["KERAS_BACKEND"] = "jax"
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import keras
import keras_hub
import numpy as np
import jax
from keras import ops
from PIL import Image
# Global variables for models
model = None
last_conv_layer_model = None
classifier_model = None
def initialize_models():
"""Initialize the models once when the app starts."""
global model, last_conv_layer_model, classifier_model
# Load the pretrained Xception model
model = keras_hub.models.ImageClassifier.from_preset(
"xception_41_imagenet",
activation="softmax",
)
# Create a model that maps the input image to the activations of the last convolutional layer
last_conv_layer_name = "block14_sepconv2_act"
last_conv_layer = model.backbone.get_layer(last_conv_layer_name)
last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output)
# Create a model that maps the activations of the last convolutional layer to the final class predictions
classifier_input = last_conv_layer.output
x = classifier_input
for layer_name in ["pooler", "predictions"]:
x = model.get_layer(layer_name)(x)
classifier_model = keras.Model(classifier_input, x)
def loss_fn(last_conv_layer_output):
"""Defines a separate loss function for gradient computation."""
preds = classifier_model(last_conv_layer_output)
top_pred_index = ops.argmax(preds[0])
top_class_channel = preds[:, top_pred_index]
return top_class_channel[0]
# Create gradient function
grad_fn = jax.grad(loss_fn)
def get_top_class_gradients(img_array):
"""Get gradients of the top predicted class with respect to last conv layer."""
last_conv_layer_output = last_conv_layer_model(img_array)
grads = grad_fn(last_conv_layer_output)
return grads, last_conv_layer_output
def generate_heatmap(image):
"""
Generate class activation heatmap for an uploaded image.
Args:
image: PIL Image or numpy array
Returns:
tuple: (superimposed_img, prediction_text)
"""
if image is None:
return None, "Please upload an image."
# Convert PIL image to numpy array if needed
if isinstance(image, Image.Image):
img = np.array(image)
else:
img = image
# Prepare image for model (add batch dimension)
img_array = np.expand_dims(img, axis=0)
# Get predictions
preds = model.predict(img_array, verbose=0)
# Decode predictions
decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds)
# Format prediction text
prediction_text = "Top 5 Predictions:\n\n"
for i, (description, score) in enumerate(decoded_preds[0][:5], 1):
prediction_text += f"{i}. {description}: {score:.2%}\n"
# Preprocess image
img_array = model.preprocessor(img_array)
# Get gradients and last conv layer output
grads, last_conv_layer_output = get_top_class_gradients(img_array)
grads = ops.convert_to_numpy(grads)
last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output)
# Compute importance of each channel
pooled_grads = np.mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0].copy()
# Weight each channel by its importance
for i in range(pooled_grads.shape[-1]):
last_conv_layer_output[:, :, i] *= pooled_grads[i]
# Create heatmap
heatmap = np.mean(last_conv_layer_output, axis=-1)
# Normalize heatmap
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# Rescale heatmap to 0-255
heatmap = np.uint8(255 * heatmap)
# Apply jet colormap
jet = cm.get_cmap("jet")
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap]
# Convert to image and resize to match original
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
# Superimpose heatmap on original image
superimposed_img = jet_heatmap * 0.4 + img
superimposed_img = keras.utils.array_to_img(superimposed_img)
return superimposed_img, prediction_text
# Initialize models when the script loads
print("Initializing models... this may take a moment.")
initialize_models()
print("Models initialized!")
# Create Gradio interface
with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo:
gr.Markdown(
"""
# Class Activation Heatmap Visualizer
Upload an image or choose one of the examples to see what parts of the image the neural network focuses on when making predictions.
The heatmap shows which regions of the image are most important for the top predicted class.
Code adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation
**Model:** Xception trained on ImageNet (1,000 classes)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg")
# Example images
gr.Examples(
examples=[
["images/elephant.jpg"],
["images/dog.jpg"],
["images/F1_car.jpg"],
["images/multiple_animals.jpg"],
["images/osprey.jpeg"]
],
inputs=input_image,
label="Try an example:"
)
gr.Markdown(
"""
### How to interpret the heatmap:
- **Red/Yellow regions**: Areas the model focuses on most for its prediction
- **Blue/Purple regions**: Areas the model considers less important
"""
)
with gr.Column():
output_image = gr.Image(
label="Heatmap Visualization",
type="pil",
height=400
)
prediction_text = gr.Textbox(
label="Predictions",
lines=7,
interactive=False
)
# Connect the button to the function
submit_btn.click(
fn=generate_heatmap,
inputs=input_image,
outputs=[output_image, prediction_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)
|