Spaces:
Build error
Build error
Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Imports
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
from tensorflow import keras
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
from app_utils import *
|
| 11 |
+
|
| 12 |
+
# The functions (except main) are taken straight from Keras Example
|
| 13 |
+
def compute_loss(feature_extractor, input_image, filter_index):
|
| 14 |
+
activation = feature_extractor(input_image)
|
| 15 |
+
# We avoid border artifacts by only involving non-border pixels in the loss.
|
| 16 |
+
filter_activation = activation[:, 2:-2, 2:-2, filter_index]
|
| 17 |
+
return tf.reduce_mean(filter_activation)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@tf.function
|
| 21 |
+
def gradient_ascent_step(feature_extractor, img, filter_index, learning_rate):
|
| 22 |
+
with tf.GradientTape() as tape:
|
| 23 |
+
tape.watch(img)
|
| 24 |
+
loss = compute_loss(feature_extractor, img, filter_index)
|
| 25 |
+
# Compute gradients.
|
| 26 |
+
grads = tape.gradient(loss, img)
|
| 27 |
+
# Normalize gradients.
|
| 28 |
+
grads = tf.math.l2_normalize(grads)
|
| 29 |
+
img += learning_rate * grads
|
| 30 |
+
return loss, img
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def initialize_image():
|
| 34 |
+
# We start from a gray image with some random noise
|
| 35 |
+
img = tf.random.uniform((1, IMG_WIDTH, IMG_HEIGHT, 3))
|
| 36 |
+
# ResNet50V2 expects inputs in the range [-1, +1].
|
| 37 |
+
# Here we scale our random inputs to [-0.125, +0.125]
|
| 38 |
+
return (img - 0.5) * 0.25
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def visualize_filter(feature_extractor, filter_index):
|
| 42 |
+
# We run gradient ascent for 20 steps
|
| 43 |
+
img = initialize_image()
|
| 44 |
+
for _ in range(ITERATIONS):
|
| 45 |
+
loss, img = gradient_ascent_step(
|
| 46 |
+
feature_extractor, img, filter_index, LEARNING_RATE
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Decode the resulting input image
|
| 50 |
+
img = deprocess_image(img[0].numpy())
|
| 51 |
+
return loss, img
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def deprocess_image(img):
|
| 55 |
+
# Normalize array: center on 0., ensure variance is 0.15
|
| 56 |
+
img -= img.mean()
|
| 57 |
+
img /= img.std() + 1e-5
|
| 58 |
+
img *= 0.15
|
| 59 |
+
|
| 60 |
+
# Center crop
|
| 61 |
+
img = img[25:-25, 25:-25, :]
|
| 62 |
+
|
| 63 |
+
# Clip to [0, 1]
|
| 64 |
+
img += 0.5
|
| 65 |
+
img = np.clip(img, 0, 1)
|
| 66 |
+
|
| 67 |
+
# Convert to RGB array
|
| 68 |
+
img *= 255
|
| 69 |
+
img = np.clip(img, 0, 255).astype("uint8")
|
| 70 |
+
return img
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# The visualization function
|
| 74 |
+
def main():
|
| 75 |
+
# Model selector
|
| 76 |
+
mn_option = st.selectbox("Select the model for visualization -", AVAILABLE_MODELS)
|
| 77 |
+
|
| 78 |
+
# Check to not load the model for ever layer change
|
| 79 |
+
if mn_option != st.session_state.model_name:
|
| 80 |
+
model = getattr(keras.applications, mn_option)(
|
| 81 |
+
weights="imagenet", include_top=False
|
| 82 |
+
)
|
| 83 |
+
st.session_state.layer_list = ["<select layer>"] + [
|
| 84 |
+
layer.name for layer in model.layers
|
| 85 |
+
]
|
| 86 |
+
st.session_state.model = model
|
| 87 |
+
st.session_state.model_name = mn_option
|
| 88 |
+
|
| 89 |
+
# Layer selector, saves the feature selector in case 64 filters are to be seen
|
| 90 |
+
if st.session_state.model_name:
|
| 91 |
+
ln_option = st.selectbox(
|
| 92 |
+
"Select the target layer (best to pick somewhere in the middle of the model) -",
|
| 93 |
+
st.session_state.layer_list,
|
| 94 |
+
)
|
| 95 |
+
if ln_option != "<select layer>":
|
| 96 |
+
if ln_option != st.session_state.layer_name:
|
| 97 |
+
layer = st.session_state.model.get_layer(name=ln_option)
|
| 98 |
+
st.session_state.feat_extract = keras.Model(
|
| 99 |
+
inputs=st.session_state.model.inputs, outputs=layer.output
|
| 100 |
+
)
|
| 101 |
+
st.session_state.layer_name = ln_option
|
| 102 |
+
|
| 103 |
+
# Filter index selector
|
| 104 |
+
if st.session_state.layer_name:
|
| 105 |
+
filter_select = st.selectbox("Visualize -", VIS_OPTION.keys())
|
| 106 |
+
|
| 107 |
+
if VIS_OPTION[filter_select] == 0:
|
| 108 |
+
loss, img = visualize_filter(st.session_state.feat_extract, 0)
|
| 109 |
+
st.image(img)
|
| 110 |
+
else:
|
| 111 |
+
st.warning(":exclamation: Calculating the gradients can take a while..")
|
| 112 |
+
prog_bar = st.progress(0)
|
| 113 |
+
fig, axis = plt.subplots(nrows=8, ncols=8, figsize=(14, 14))
|
| 114 |
+
for filter_index, ax in enumerate(axis.ravel()):
|
| 115 |
+
prog_bar.progress((filter_index + 1) / 64)
|
| 116 |
+
if filter_index < 65:
|
| 117 |
+
loss, img = visualize_filter(
|
| 118 |
+
st.session_state.feat_extract, filter_index
|
| 119 |
+
)
|
| 120 |
+
ax.imshow(img)
|
| 121 |
+
ax.set_title(filter_index + 1)
|
| 122 |
+
ax.set_axis_off()
|
| 123 |
+
else:
|
| 124 |
+
ax.set_axis_off()
|
| 125 |
+
|
| 126 |
+
st.write(fig)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
|
| 131 |
+
with open("model_names.txt", "r") as op:
|
| 132 |
+
AVAILABLE_MODELS = [i.strip() for i in op.readlines()]
|
| 133 |
+
|
| 134 |
+
st.set_page_config(layout="wide")
|
| 135 |
+
|
| 136 |
+
st.title(title)
|
| 137 |
+
st.write(info_text)
|
| 138 |
+
st.info(f"{credits}\n\n{replicate}\n\n{vit_info}")
|
| 139 |
+
|
| 140 |
+
main()
|