cameron-d commited on
Commit
a6bbc26
·
1 Parent(s): 008ed91

Initial commit

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +204 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["KERAS_BACKEND"] = "jax"
3
+
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.cm as cm
7
+ import keras
8
+ import keras_hub
9
+ import numpy as np
10
+ import jax
11
+ from keras import ops
12
+ from PIL import Image
13
+
14
+ # Global variables for models
15
+ model = None
16
+ last_conv_layer_model = None
17
+ classifier_model = None
18
+
19
+ def initialize_models():
20
+ """Initialize the models once when the app starts."""
21
+ global model, last_conv_layer_model, classifier_model
22
+
23
+ # Load the pretrained Xception model
24
+ model = keras_hub.models.ImageClassifier.from_preset(
25
+ "xception_41_imagenet",
26
+ activation="softmax",
27
+ )
28
+
29
+ # Create a model that maps the input image to the activations of the last convolutional layer
30
+ last_conv_layer_name = "block14_sepconv2_act"
31
+ last_conv_layer = model.backbone.get_layer(last_conv_layer_name)
32
+ last_conv_layer_model = keras.Model(model.inputs, last_conv_layer.output)
33
+
34
+ # Create a model that maps the activations of the last convolutional layer to the final class predictions
35
+ classifier_input = last_conv_layer.output
36
+ x = classifier_input
37
+ for layer_name in ["pooler", "predictions"]:
38
+ x = model.get_layer(layer_name)(x)
39
+ classifier_model = keras.Model(classifier_input, x)
40
+
41
+ def loss_fn(last_conv_layer_output):
42
+ """Defines a separate loss function for gradient computation."""
43
+ preds = classifier_model(last_conv_layer_output)
44
+ top_pred_index = ops.argmax(preds[0])
45
+ top_class_channel = preds[:, top_pred_index]
46
+ return top_class_channel[0]
47
+
48
+ # Create gradient function
49
+ grad_fn = jax.grad(loss_fn)
50
+
51
+ def get_top_class_gradients(img_array):
52
+ """Get gradients of the top predicted class with respect to last conv layer."""
53
+ last_conv_layer_output = last_conv_layer_model(img_array)
54
+ grads = grad_fn(last_conv_layer_output)
55
+ return grads, last_conv_layer_output
56
+
57
+ def generate_heatmap(image):
58
+ """
59
+ Generate class activation heatmap for an uploaded image.
60
+
61
+ Args:
62
+ image: PIL Image or numpy array
63
+
64
+ Returns:
65
+ tuple: (superimposed_img, prediction_text)
66
+ """
67
+ if image is None:
68
+ return None, "Please upload an image."
69
+
70
+ # Convert PIL image to numpy array if needed
71
+ if isinstance(image, Image.Image):
72
+ img = np.array(image)
73
+ else:
74
+ img = image
75
+
76
+ # Prepare image for model (add batch dimension)
77
+ img_array = np.expand_dims(img, axis=0)
78
+
79
+ # Get predictions
80
+ preds = model.predict(img_array, verbose=0)
81
+
82
+ # Decode predictions
83
+ decoded_preds = keras_hub.utils.decode_imagenet_predictions(preds)
84
+
85
+ # Format prediction text
86
+ prediction_text = "Top 5 Predictions:\n\n"
87
+ for i, (description, score) in enumerate(decoded_preds[0][:5], 1):
88
+ prediction_text += f"{i}. {description}: {score:.2%}\n"
89
+
90
+ # Preprocess image
91
+ img_array = model.preprocessor(img_array)
92
+
93
+ # Get gradients and last conv layer output
94
+ grads, last_conv_layer_output = get_top_class_gradients(img_array)
95
+ grads = ops.convert_to_numpy(grads)
96
+ last_conv_layer_output = ops.convert_to_numpy(last_conv_layer_output)
97
+
98
+ # Compute importance of each channel
99
+ pooled_grads = np.mean(grads, axis=(0, 1, 2))
100
+ last_conv_layer_output = last_conv_layer_output[0].copy()
101
+
102
+ # Weight each channel by its importance
103
+ for i in range(pooled_grads.shape[-1]):
104
+ last_conv_layer_output[:, :, i] *= pooled_grads[i]
105
+
106
+ # Create heatmap
107
+ heatmap = np.mean(last_conv_layer_output, axis=-1)
108
+
109
+ # Normalize heatmap
110
+ heatmap = np.maximum(heatmap, 0)
111
+ heatmap /= np.max(heatmap)
112
+
113
+ # Rescale heatmap to 0-255
114
+ heatmap = np.uint8(255 * heatmap)
115
+
116
+ # Apply jet colormap
117
+ jet = cm.get_cmap("jet")
118
+ jet_colors = jet(np.arange(256))[:, :3]
119
+ jet_heatmap = jet_colors[heatmap]
120
+
121
+ # Convert to image and resize to match original
122
+ jet_heatmap = keras.utils.array_to_img(jet_heatmap)
123
+ jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
124
+ jet_heatmap = keras.utils.img_to_array(jet_heatmap)
125
+
126
+ # Superimpose heatmap on original image
127
+ superimposed_img = jet_heatmap * 0.4 + img
128
+ superimposed_img = keras.utils.array_to_img(superimposed_img)
129
+
130
+ return superimposed_img, prediction_text
131
+
132
+ # Initialize models when the script loads
133
+ print("Initializing models... This may take a moment.")
134
+ initialize_models()
135
+ print("Models initialized successfully!")
136
+
137
+ # Create Gradio interface
138
+ with gr.Blocks(title="Class Activation Heatmap Visualizer") as demo:
139
+ gr.Markdown(
140
+ """
141
+ # 🔥 Class Activation Heatmap Visualizer
142
+
143
+ Upload an image to see what parts of the image the neural network focuses on when making predictions.
144
+ The heatmap shows which regions of the image are most important for the top predicted class.
145
+
146
+ Adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-heatmaps-of-class-activation
147
+
148
+ **Model:** Xception trained on ImageNet (1,000 classes)
149
+ """
150
+ )
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ input_image = gr.Image(
155
+ label="Upload Image",
156
+ type="pil",
157
+ height=400
158
+ )
159
+ submit_btn = gr.Button("Generate Heatmap", variant="primary", size="lg")
160
+
161
+ with gr.Column():
162
+ output_image = gr.Image(
163
+ label="Heatmap Visualization",
164
+ type="pil",
165
+ height=400
166
+ )
167
+ prediction_text = gr.Textbox(
168
+ label="Predictions",
169
+ lines=7,
170
+ interactive=False
171
+ )
172
+
173
+ gr.Markdown(
174
+ """
175
+ ### How to interpret the heatmap:
176
+ - **Red/Yellow regions**: Areas the model focuses on most for its prediction
177
+ - **Blue/Purple regions**: Areas the model considers less important
178
+ - The heatmap is overlaid at 40% opacity on your original image
179
+ """
180
+ )
181
+
182
+ # Example images
183
+ gr.Examples(
184
+ examples=[
185
+ ["elephant.jpg"],
186
+ ["dog.jpg"],
187
+ ["F1_car.jpg"],
188
+ ["multiple_animals.jpg"],
189
+ ["osprey.jpeg"],
190
+ ],
191
+ inputs=input_image,
192
+ label="Try an example:"
193
+ )
194
+
195
+ # Connect the button to the function
196
+ submit_btn.click(
197
+ fn=generate_heatmap,
198
+ inputs=input_image,
199
+ outputs=[output_image, prediction_text]
200
+ )
201
+
202
+ # Launch the app
203
+ if __name__ == "__main__":
204
+ demo.launch(share=False)