cameron-d's picture
Update app.py
c00307f verified
import os
os.environ["KERAS_BACKEND"] = "jax"
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import keras
import keras_hub
import numpy as np
import jax
from keras import ops
from PIL import Image
# Global variables for models
model = None
last_conv_layer_model = None
classifier_model = None
def initialize_models():
"""Initialize the models once when the app starts."""
global model, last_conv_layer_model, classifier_model
# Load the pretrained Xception model
model = keras_hub.models.ImageClassifier.from_preset(
"xception_41_imagenet",
activation="softmax",
)
# Create a model that maps the input image to the activations of the last convolutional layer
last_conv_layer_name = "block14_sepconv2_act"
last_conv_layer = model.backbone.get_layer(last_conv_layer_name)
last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output)
# Create a model that maps the activations of the last convolutional layer to the final class predictions
classifier_input = last_conv_layer.output
x = classifier_input
for layer_name in ["pooler", "predictions"]:
x = model.get_layer(layer_name)(x)
classifier_model = keras.Model(classifier_input, x)
def loss_fn(last_conv_layer_output):
"""Defines a separate loss function for gradient computation."""
preds = classifier_model(last_conv_layer_output)
top_pred_index = ops.argmax(preds[0])
top_class_channel = preds[:, top_pred_index]
return top_class_channel[0]
# Create gradient function
grad_fn = jax.grad(loss_fn)
def get_top_class_gradients(img_array):
"""Get gradients of the top predicted class with respect to last conv layer."""
last_conv_layer_output = last_conv_layer_model(img_array)
grads = grad_fn(last_conv_layer_output)
return grads, last_conv_layer_output
def generate_heatmap(image):
"""
Generate class activation heatmap for an uploaded image.
Args:
image: PIL Image or numpy array
Returns:
tuple: (superimposed_img, prediction_text)
"""
if image is None:
return None, "Please upload an image."
# Convert PIL image to numpy array if needed
if isinstance(image, Image.Image):
img = np.array(image)
else:
img = image
# Prepare image for model (add batch dimension)
img_array = np.expand_dims(img, axis=0)
# Get predictions
preds = model.predict(img_array, verbose=0)
# Decode predictions
decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds)
# Format prediction text
prediction_text = "Top 5 Predictions:\n\n"
for i, (description, score) in enumerate(decoded_preds[0][:5], 1):
prediction_text += f"{i}. {description}: {score:.2%}\n"
# Preprocess image
img_array = model.preprocessor(img_array)
# Get gradients and last conv layer output
grads, last_conv_layer_output = get_top_class_gradients(img_array)
grads = ops.convert_to_numpy(grads)
last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output)
# Compute importance of each channel
pooled_grads = np.mean(grads, axis=(0, 1, 2))
last_conv_layer_output = last_conv_layer_output[0].copy()
# Weight each channel by its importance
for i in range(pooled_grads.shape[-1]):
last_conv_layer_output[:, :, i] *= pooled_grads[i]
# Create heatmap
heatmap = np.mean(last_conv_layer_output, axis=-1)
# Normalize heatmap
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
# Rescale heatmap to 0-255
heatmap = np.uint8(255 * heatmap)
# Apply jet colormap
jet = cm.get_cmap("jet")
jet_colors = jet(np.arange(256))[:, :3]
jet_heatmap = jet_colors[heatmap]
# Convert to image and resize to match original
jet_heatmap = keras.utils.array_to_img(jet_heatmap)
jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
jet_heatmap = keras.utils.img_to_array(jet_heatmap)
# Superimpose heatmap on original image
superimposed_img = jet_heatmap * 0.4 + img
superimposed_img = keras.utils.array_to_img(superimposed_img)
return superimposed_img, prediction_text
# Initialize models when the script loads
print("Initializing models... this may take a moment.")
initialize_models()
print("Models initialized!")
# Create Gradio interface
with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo:
gr.Markdown(
"""
# Class Activation Heatmap Visualizer
Upload an image or choose one of the examples to see what parts of the image the neural network focuses on when making predictions.
The heatmap shows which regions of the image are most important for the top predicted class.
Code adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation
**Model:** Xception trained on ImageNet (1,000 classes)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
height=400
)
submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg")
# Example images
gr.Examples(
examples=[
["images/elephant.jpg"],
["images/dog.jpg"],
["images/F1_car.jpg"],
["images/multiple_animals.jpg"],
["images/osprey.jpeg"]
],
inputs=input_image,
label="Try an example:"
)
gr.Markdown(
"""
### How to interpret the heatmap:
- **Red/Yellow regions**: Areas the model focuses on most for its prediction
- **Blue/Purple regions**: Areas the model considers less important
"""
)
with gr.Column():
output_image = gr.Image(
label="Heatmap Visualization",
type="pil",
height=400
)
prediction_text = gr.Textbox(
label="Predictions",
lines=7,
interactive=False
)
# Connect the button to the function
submit_btn.click(
fn=generate_heatmap,
inputs=input_image,
outputs=[output_image, prediction_text]
)
# Launch the app
if __name__ == "__main__":
demo.launch(share=False)