File size: 6,711 Bytes
a6bbc26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c00307f
a6bbc26
c00307f
a6bbc26
 
 
 
 
c00307f
a6bbc26
c00307f
a6bbc26
 
c00307f
a6bbc26
 
 
 
 
 
 
 
 
 
 
 
 
c00307f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6bbc26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
os.environ["KERAS_BACKEND"] = "jax"

import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import keras
import keras_hub
import numpy as np
import jax
from keras import ops
from PIL import Image

# Global variables for models
model = None
last_conv_layer_model = None
classifier_model = None

def initialize_models():
    """Initialize the models once when the app starts."""
    global model, last_conv_layer_model, classifier_model
    
    # Load the pretrained Xception model
    model = keras_hub.models.ImageClassifier.from_preset(
        "xception_41_imagenet",
        activation="softmax",
    )
    
    # Create a model that maps the input image to the activations of the last convolutional layer
    last_conv_layer_name = "block14_sepconv2_act"
    last_conv_layer = model.backbone.get_layer(last_conv_layer_name)
    last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output)
    
    # Create a model that maps the activations of the last convolutional layer to the final class predictions
    classifier_input = last_conv_layer.output
    x = classifier_input
    for layer_name in ["pooler", "predictions"]:
        x = model.get_layer(layer_name)(x)
    classifier_model = keras.Model(classifier_input, x)

def loss_fn(last_conv_layer_output):
    """Defines a separate loss function for gradient computation."""
    preds = classifier_model(last_conv_layer_output)
    top_pred_index = ops.argmax(preds[0])
    top_class_channel = preds[:, top_pred_index]
    return top_class_channel[0]

# Create gradient function
grad_fn = jax.grad(loss_fn)

def get_top_class_gradients(img_array):
    """Get gradients of the top predicted class with respect to last conv layer."""
    last_conv_layer_output = last_conv_layer_model(img_array)
    grads = grad_fn(last_conv_layer_output)
    return grads, last_conv_layer_output

def generate_heatmap(image):
    """
    Generate class activation heatmap for an uploaded image.
    
    Args:
        image: PIL Image or numpy array
        
    Returns:
        tuple: (superimposed_img, prediction_text)
    """
    if image is None:
        return None, "Please upload an image."
    
    # Convert PIL image to numpy array if needed
    if isinstance(image, Image.Image):
        img = np.array(image)
    else:
        img = image
    
    # Prepare image for model (add batch dimension)
    img_array = np.expand_dims(img, axis=0)
    
    # Get predictions
    preds = model.predict(img_array, verbose=0)
    
    # Decode predictions
    decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds)
    
    # Format prediction text
    prediction_text = "Top 5 Predictions:\n\n"
    for i, (description, score) in enumerate(decoded_preds[0][:5], 1):
        prediction_text += f"{i}. {description}: {score:.2%}\n"
    
    # Preprocess image
    img_array = model.preprocessor(img_array)
    
    # Get gradients and last conv layer output
    grads, last_conv_layer_output = get_top_class_gradients(img_array)
    grads = ops.convert_to_numpy(grads)
    last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output)
    
    # Compute importance of each channel
    pooled_grads = np.mean(grads, axis=(0, 1, 2))
    last_conv_layer_output = last_conv_layer_output[0].copy()
    
    # Weight each channel by its importance
    for i in range(pooled_grads.shape[-1]):
        last_conv_layer_output[:, :, i] *= pooled_grads[i]
    
    # Create heatmap
    heatmap = np.mean(last_conv_layer_output, axis=-1)
    
    # Normalize heatmap
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
    
    # Rescale heatmap to 0-255
    heatmap = np.uint8(255 * heatmap)
    
    # Apply jet colormap
    jet = cm.get_cmap("jet")
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]
    
    # Convert to image and resize to match original
    jet_heatmap = keras.utils.array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = keras.utils.img_to_array(jet_heatmap)
    
    # Superimpose heatmap on original image
    superimposed_img = jet_heatmap * 0.4 + img
    superimposed_img = keras.utils.array_to_img(superimposed_img)
    
    return superimposed_img, prediction_text

# Initialize models when the script loads
print("Initializing models... this may take a moment.")
initialize_models()
print("Models initialized!")

# Create Gradio interface
with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo:
    gr.Markdown(
        """
        # Class Activation Heatmap Visualizer
        
        Upload an image or choose one of the examples to see what parts of the image the neural network focuses on when making predictions.
        The heatmap shows which regions of the image are most important for the top predicted class.

        Code adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation
        
        **Model:** Xception trained on ImageNet (1,000 classes)
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Upload Image",
                type="pil",
                height=400
            )
            submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg")

            # Example images
            gr.Examples(
                examples=[
                    ["images/elephant.jpg"],
                    ["images/dog.jpg"],
                    ["images/F1_car.jpg"],
                    ["images/multiple_animals.jpg"],
                    ["images/osprey.jpeg"]
                ],
                inputs=input_image,
                label="Try an example:"
            )

            gr.Markdown(
                """
                ### How to interpret the heatmap:
                - **Red/Yellow regions**: Areas the model focuses on most for its prediction
                - **Blue/Purple regions**: Areas the model considers less important
                """
            )
            
        with gr.Column():
            output_image = gr.Image(
                label="Heatmap Visualization",
                type="pil",
                height=400
            )
            prediction_text = gr.Textbox(
                label="Predictions",
                lines=7,
                interactive=False
            )
    
    # Connect the button to the function
    submit_btn.click(
        fn=generate_heatmap,
        inputs=input_image,
        outputs=[output_image, prediction_text]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(share=False)