| | import os |
| | os.environ["KERAS_BACKEND"] = "jax" |
| |
|
| | import gradio as gr |
| | import matplotlib.pyplot as plt |
| | import matplotlib.cm as cm |
| | import keras |
| | import keras_hub |
| | import numpy as np |
| | import jax |
| | from keras import ops |
| | from PIL import Image |
| |
|
| | |
| | model = None |
| | last_conv_layer_model = None |
| | classifier_model = None |
| |
|
| | def initialize_models(): |
| | """Initialize the models once when the app starts.""" |
| | global model, last_conv_layer_model, classifier_model |
| | |
| | |
| | model = keras_hub.models.ImageClassifier.from_preset( |
| | "xception_41_imagenet", |
| | activation="softmax", |
| | ) |
| | |
| | |
| | last_conv_layer_name = "block14_sepconv2_act" |
| | last_conv_layer = model.backbone.get_layer(last_conv_layer_name) |
| | last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output) |
| | |
| | |
| | classifier_input = last_conv_layer.output |
| | x = classifier_input |
| | for layer_name in ["pooler", "predictions"]: |
| | x = model.get_layer(layer_name)(x) |
| | classifier_model = keras.Model(classifier_input, x) |
| |
|
| | def loss_fn(last_conv_layer_output): |
| | """Defines a separate loss function for gradient computation.""" |
| | preds = classifier_model(last_conv_layer_output) |
| | top_pred_index = ops.argmax(preds[0]) |
| | top_class_channel = preds[:, top_pred_index] |
| | return top_class_channel[0] |
| |
|
| | |
| | grad_fn = jax.grad(loss_fn) |
| |
|
| | def get_top_class_gradients(img_array): |
| | """Get gradients of the top predicted class with respect to last conv layer.""" |
| | last_conv_layer_output = last_conv_layer_model(img_array) |
| | grads = grad_fn(last_conv_layer_output) |
| | return grads, last_conv_layer_output |
| |
|
| | def generate_heatmap(image): |
| | """ |
| | Generate class activation heatmap for an uploaded image. |
| | |
| | Args: |
| | image: PIL Image or numpy array |
| | |
| | Returns: |
| | tuple: (superimposed_img, prediction_text) |
| | """ |
| | if image is None: |
| | return None, "Please upload an image." |
| | |
| | |
| | if isinstance(image, Image.Image): |
| | img = np.array(image) |
| | else: |
| | img = image |
| | |
| | |
| | img_array = np.expand_dims(img, axis=0) |
| | |
| | |
| | preds = model.predict(img_array, verbose=0) |
| | |
| | |
| | decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds) |
| | |
| | |
| | prediction_text = "Top 5 Predictions:\n\n" |
| | for i, (description, score) in enumerate(decoded_preds[0][:5], 1): |
| | prediction_text += f"{i}. {description}: {score:.2%}\n" |
| | |
| | |
| | img_array = model.preprocessor(img_array) |
| | |
| | |
| | grads, last_conv_layer_output = get_top_class_gradients(img_array) |
| | grads = ops.convert_to_numpy(grads) |
| | last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output) |
| | |
| | |
| | pooled_grads = np.mean(grads, axis=(0, 1, 2)) |
| | last_conv_layer_output = last_conv_layer_output[0].copy() |
| | |
| | |
| | for i in range(pooled_grads.shape[-1]): |
| | last_conv_layer_output[:, :, i] *= pooled_grads[i] |
| | |
| | |
| | heatmap = np.mean(last_conv_layer_output, axis=-1) |
| | |
| | |
| | heatmap = np.maximum(heatmap, 0) |
| | heatmap /= np.max(heatmap) |
| | |
| | |
| | heatmap = np.uint8(255 * heatmap) |
| | |
| | |
| | jet = cm.get_cmap("jet") |
| | jet_colors = jet(np.arange(256))[:, :3] |
| | jet_heatmap = jet_colors[heatmap] |
| | |
| | |
| | jet_heatmap = keras.utils.array_to_img(jet_heatmap) |
| | jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0])) |
| | jet_heatmap = keras.utils.img_to_array(jet_heatmap) |
| | |
| | |
| | superimposed_img = jet_heatmap * 0.4 + img |
| | superimposed_img = keras.utils.array_to_img(superimposed_img) |
| | |
| | return superimposed_img, prediction_text |
| |
|
| | |
| | print("Initializing models... this may take a moment.") |
| | initialize_models() |
| | print("Models initialized!") |
| |
|
| | |
| | with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo: |
| | gr.Markdown( |
| | """ |
| | # Class Activation Heatmap Visualizer |
| | |
| | Upload an image or choose one of the examples to see what parts of the image the neural network focuses on when making predictions. |
| | The heatmap shows which regions of the image are most important for the top predicted class. |
| | |
| | Code adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation |
| | |
| | **Model:** Xception trained on ImageNet (1,000 classes) |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image( |
| | label="Upload Image", |
| | type="pil", |
| | height=400 |
| | ) |
| | submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg") |
| |
|
| | |
| | gr.Examples( |
| | examples=[ |
| | ["images/elephant.jpg"], |
| | ["images/dog.jpg"], |
| | ["images/F1_car.jpg"], |
| | ["images/multiple_animals.jpg"], |
| | ["images/osprey.jpeg"] |
| | ], |
| | inputs=input_image, |
| | label="Try an example:" |
| | ) |
| |
|
| | gr.Markdown( |
| | """ |
| | ### How to interpret the heatmap: |
| | - **Red/Yellow regions**: Areas the model focuses on most for its prediction |
| | - **Blue/Purple regions**: Areas the model considers less important |
| | """ |
| | ) |
| | |
| | with gr.Column(): |
| | output_image = gr.Image( |
| | label="Heatmap Visualization", |
| | type="pil", |
| | height=400 |
| | ) |
| | prediction_text = gr.Textbox( |
| | label="Predictions", |
| | lines=7, |
| | interactive=False |
| | ) |
| | |
| | |
| | submit_btn.click( |
| | fn=generate_heatmap, |
| | inputs=input_image, |
| | outputs=[output_image, prediction_text] |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | demo.launch(share=False) |
| |
|