File size: 4,705 Bytes
0ad0d30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import keras
from keras import layers
import gradio as gr
import numpy as np
from PIL import Image
import io
import contextlib
# Load model
model = keras.models.load_model("dogs_and_cats_CNN.keras")
def get_model_summary(model):
"""Return the model summary as a string."""
stream = io.StringIO()
with contextlib.redirect_stdout(stream):
model.summary()
summary_str = stream.getvalue()
return summary_str
def get_img_array(image, target_size):
"""Resize the image and return it as an array."""
image = image.resize(target_size)
array = keras.utils.img_to_array(image)
array = np.expand_dims(array, axis=0)
return array
def predict(image):
img_tensor = get_img_array(image, target_size=(180, 180))
# predict class
predictions = model.predict(img_tensor)
if predictions[0][0] > 0.5:
predicted_class = "Dog"
confidence = predictions[0][0]
else:
predicted_class = "Cat"
confidence = 1 - predictions[0][0]
prediction_text = f"## **Prediction:** {predicted_class} **Confidence:** {confidence:.2%}"
# Collect convolution and pooling layers
layer_outputs = []
layer_names = []
for layer in model.layers:
if isinstance(layer, (layers.Conv2D, layers.MaxPooling2D)):
# If a layer of a convolution or max pooling layers, append it's outputs to the visualization
layer_outputs.append(layer.output)
layer_names.append(layer.name)
activation_model = keras.Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(img_tensor)
# Build visualization grids for each layer
images = []
images_per_row = 16
for layer_name, layer_activation in zip(layer_names, activations):
n_features = layer_activation.shape[-1]
size = layer_activation.shape[1]
n_cols = max(1, n_features // images_per_row)
display_grid = np.zeros(
((size + 1) * n_cols - 1, images_per_row * (size + 1) - 1)
)
for col in range(n_cols):
for row in range(images_per_row):
channel_index = col * images_per_row + row
if channel_index >= n_features:
break
channel_image = layer_activation[0, :, :, channel_index].copy()
if channel_image.std() > 1e-6:
channel_image -= channel_image.mean()
channel_image /= channel_image.std()
channel_image *= 64
channel_image += 128
channel_image = np.clip(channel_image, 0, 255).astype("uint8")
display_grid[
col * (size + 1):(col + 1) * size + col,
row * (size + 1):(row + 1) * size + row,
] = channel_image
display_grid = display_grid / 255.0
images.append((display_grid, layer_name))
summary_text = get_model_summary(model)
return images, summary_text, prediction_text
# Gradio interface with examples
with gr.Blocks() as demo:
gr.Markdown("# CNN Intermediate Activations Visualizer")
gr.Markdown("Visualizes activations of all convolutional and pooling layers and displays the model summary.")
gr.Markdown("Model is trained on a subset of kaggle's dogs vs cats dataset: https://www.kaggle.com/c/dogs-vs-cats/data")
gr.Markdown("Adapted from: https://deeplearningwithpython.io/chapters/chapter10_interpreting-what-convnets-learn/#visualizing-intermediate-activations")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload an image")
submit_btn = gr.Button("Analyze")
# Example images
gr.Examples(
examples=[
["images/cat_1.jpg"],
["images/dog.jpg"],
["images/cat_2.jpg"],
["images/cat_and_dog.jpg"]
],
inputs=input_image,
label="Try an example:"
)
with gr.Column():
output_gallery = gr.Gallery(label="Layer Activations", show_label=True, columns=1)
output_prediction = gr.Markdown(label="Prediction")
gr.Markdown("As you go deeper through the neural network, the activations become more abstract and relate more to the class prediction")
output_summary = gr.Textbox(label="Model Summary", lines=20)
submit_btn.click(
fn=predict,
inputs=input_image,
outputs=[output_gallery, output_summary, output_prediction]
)
demo.launch()
|