Spaces:
Build error
Build error
| # Imports | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import streamlit as st | |
| from app_utils import * | |
| # The functions (except main) are taken straight from Keras Example | |
| def compute_loss(feature_extractor, input_image, filter_index): | |
| activation = feature_extractor(input_image) | |
| # We avoid border artifacts by only involving non-border pixels in the loss. | |
| filter_activation = activation[:, 2:-2, 2:-2, filter_index] | |
| return tf.reduce_mean(filter_activation) | |
| def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate): | |
| with tf.GradientTape() as tape: | |
| tape.watch(img) | |
| loss = compute_loss(feature_extractor, img, filter_index) | |
| # Compute gradients. | |
| grads = tape.gradient(loss, img) | |
| # Normalize gradients. | |
| grads = tf.math.l2_normalize(grads) | |
| img += learning_rate * grads | |
| return loss, img | |
| def initialize_image(): | |
| # We start from a gray image with some random noise | |
| img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3)) | |
| # ResNet50V2 expects inputs in the range [-1, +1]. | |
| # Here we scale our random inputs to [-0.125, +0.125] | |
| return (img - 0.5) * 0.25 | |
| def visualize_filter(feature_extractor, filter_index): | |
| # We run gradient ascent for 20 steps | |
| img = initialize_image() | |
| for _ in range(ITERATIONS): | |
| loss, img = gradient_ascent_step( | |
| feature_extractor, img, filter_index, LEARNING_RATE | |
| ) | |
| # Decode the resulting input image | |
| img = deprocess_image(img[0].numpy()) | |
| return loss, img | |
| def deprocess_image(img): | |
| # Normalize array: center on 0., ensure variance is 0.15 | |
| img -= img.mean() | |
| img /= img.std() + 1e-5 | |
| img *= 0.15 | |
| # Center crop | |
| img = img[25:-25, 25:-25, :] | |
| # Clip to [0, 1] | |
| img += 0.5 | |
| img = np.clip(img, 0, 1) | |
| # Convert to RGB array | |
| img *= 255 | |
| img = np.clip(img, 0, 255).astype("uint8") | |
| return img | |
| # The visualization function | |
| def main(): | |
| # Initialize states | |
| initialize_states() | |
| # Model selector | |
| mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS) | |
| # Check to not load the model for ever layer change | |
| if mn_option != st.session_state.model_name: | |
| model = getattr(keras.applications, mn_option)( | |
| weights="imagenet", include_top=False | |
| ) | |
| st.session_state.layer_list = ["<select layer>"] + [ | |
| layer.name for layer in model.layers | |
| ] | |
| st.session_state.model = model | |
| st.session_state.model_name = mn_option | |
| # Layer selector, saves the feature selector in case 64 filters are to be seen | |
| if st.session_state.model_name: | |
| ln_option = st.selectbox( | |
| "Select the target layer (best to pick somewhere in the middle of the model) -", | |
| st.session_state.layer_list, | |
| ) | |
| if ln_option != "<select layer>": | |
| if ln_option != st.session_state.layer_name: | |
| layer = st.session_state.model.get_layer(name=ln_option) | |
| st.session_state.feat_extract = keras.Model( | |
| inputs=st.session_state.model.inputs, outputs=layer.output | |
| ) | |
| st.session_state.layer_name = ln_option | |
| # Filter index selector | |
| if st.session_state.layer_name: | |
| warn_ph = st.empty() | |
| layer_ph = st.empty() | |
| filter_select = st.selectbox("Visualize -", VIS_OPTION.keys()) | |
| if VIS_OPTION[filter_select] == 0: | |
| loss, img = visualize_filter(st.session_state.feat_extract, 0) | |
| st.image(img) | |
| else: | |
| layer = st.session_state.model.get_layer(name=st.session_state.layer_name) | |
| num_filters = layer.get_output_at(0).get_shape().as_list()[-1] | |
| warn_ph.warning( | |
| ":exclamation: Calculating the gradients can take a while.." | |
| ) | |
| if num_filters < 64: | |
| layer_ph.info( | |
| f"{st.session_state.layer_name} has only {num_filters} filters, visualizing only those filters.." | |
| ) | |
| prog_bar = st.progress(0) | |
| fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14)) | |
| for filter_index, ax in enumerate(axis.ravel()[: min(num_filters, 64)]): | |
| prog_bar.progress((filter_index + 1) / min(num_filters, 64)) | |
| loss, img = visualize_filter( | |
| st.session_state.feat_extract, filter_index | |
| ) | |
| ax.imshow(img) | |
| ax.set_title(filter_index + 1) | |
| ax.set_axis_off() | |
| else: | |
| for ax in axis.ravel()[num_filters:]: | |
| ax.set_axis_off() | |
| st.write(fig) | |
| warn_ph.empty() | |
| if __name__ == "__main__": | |
| with open("model_names.txt", "r") as op: | |
| AVAILABLE_MODELS = [i.strip() for i in op.readlines()] | |
| st.set_page_config(layout="wide") | |
| st.title(title) | |
| st.write(info_text) | |
| st.info(f"{credits}\n\n{replicate}\n\n{vit_info}") | |
| st.write(self_credit) | |
| main() | |