| | import math |
| | import os |
| | import sys |
| | import tempfile |
| |
|
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import streamlit as st |
| | from PIL import Image |
| | from tensorflow.keras import layers, models |
| |
|
| | |
| | fig_size = plt.rcParams['figure.figsize'] |
| |
|
| | |
| | plt.rcParams['image.cmap'] = 'gray' |
| |
|
| | |
| | st.title("Filters and Feature Maps Visualization") |
| |
|
| | |
| | model = st.file_uploader(label="Upload model", type=["h5"]) |
| |
|
| | if model: |
| | |
| | with tempfile.TemporaryDirectory() as tempdir: |
| | with open(os.path.join(tempdir, "temp.h5"), mode='wb') as f: |
| | |
| | f.write(model.getvalue()) |
| |
|
| | |
| | model = models.load_model(os.path.join(tempdir, "temp.h5")) |
| |
|
| | |
| | viz_option = st.selectbox("What would you like to visualize?", |
| | options=["Filters", "Feature Maps"]) |
| |
|
| | |
| | conv_indices = [i for i in range(len(model.layers)) if isinstance( |
| | model.layers[i], layers.Conv2D)] |
| |
|
| | if viz_option.lower() == "filters": |
| | |
| |
|
| | |
| | layer_index = st.selectbox( |
| | "Select a layer to see its filters", options=conv_indices) |
| |
|
| | weights = model.layers[layer_index].get_weights()[0] |
| | num_filters = weights.shape[-1] |
| | num_channels = weights.shape[-2] |
| |
|
| | st.write( |
| | f"This layer has {num_filters} filters and {num_channels} channels per filter.") |
| |
|
| | channel_index = st.selectbox( |
| | "Which channel would you like to view?", options=range(1, num_channels + 1)) |
| |
|
| | |
| | nrows = math.ceil(math.sqrt(num_filters)) |
| | ncols = math.ceil(math.sqrt(num_filters)) |
| | fig, ax = plt.subplots(nrows, ncols, figsize=( |
| | fig_size[0] * ncols, fig_size[1] * nrows)) |
| |
|
| | |
| | for i in range(num_filters - (nrows * ncols), 0): |
| | ax.flatten()[i].remove() |
| |
|
| | |
| | for i in range(num_filters): |
| | ax.flatten()[i].imshow(weights[:, :, channel_index - 1, i]) |
| | ax.flatten()[i].set(xticklabels=[], |
| | yticklabels=[], title=f"Filter {i + 1}") |
| |
|
| | fig.tight_layout() |
| | |
| | st.pyplot(fig) |
| | else: |
| | |
| |
|
| | |
| | img = st.file_uploader(label="Upload image", type=['jpg', 'png']) |
| |
|
| | if img: |
| | |
| | img = np.asarray(Image.open(img)) |
| | st.image(img) |
| |
|
| | |
| | img = np.expand_dims(np.expand_dims(img, axis=-1), axis=0) |
| |
|
| | |
| | layer_index = st.selectbox( |
| | "Feature Map at which layer?", options=conv_indices) |
| |
|
| | |
| | temp_model = models.Model( |
| | inputs=model.inputs, outputs=model.layers[layer_index].output) |
| | output = np.squeeze(temp_model.predict(img)) |
| |
|
| | num_channels = output.shape[-1] |
| |
|
| | nrows = math.ceil(math.sqrt(num_channels)) |
| | ncols = math.ceil(math.sqrt(num_channels)) |
| | fig, ax = plt.subplots(nrows, ncols, figsize=( |
| | fig_size[0] * ncols, fig_size[1] * nrows)) |
| |
|
| | |
| | for i in range(num_channels - (nrows * ncols), 0): |
| | ax.flatten()[i].remove() |
| |
|
| | |
| | for i in range(num_channels): |
| | ax.flatten()[i].imshow(output[:, :, i]) |
| | ax.flatten()[i].set(xticklabels=[], |
| | yticklabels=[], title=f"Channel {i + 1}") |
| |
|
| | fig.tight_layout() |
| | |
| | st.pyplot(fig) |
| |
|