CNN_Explainer / app.py
shubham680's picture
Update app.py
659e17f verified
raw
history blame
2.71 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
import keras
from sklearn.preprocessing import MinMaxScaler
from keras.models import Model
from keras.layers import Conv2D
import cv2
import tensorflow as tf
model = keras.models.load_model("model.keras")
uploaded_img = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
options = ['1st Convolution', '2nd Convolution', '3rd Convolution']
selected_option = st.selectbox('Choose an option:', options)
conv_layers = [layer for layer in model.layers if isinstance(layer, Conv2D)]
fig = plt.figure(figsize=(12, 4))
layer_ind = options.index(selected_option)
# selected_layer = conv_layers[layer_ind]
scaler = MinMaxScaler()
# for i in range(3):
for j in range(6):
layer=conv_layers[layer_ind]
weights=layer.get_weights()[0][:,:,0,j]
norm_weights = scaler.fit_transform(weights)
plt.subplot(2,3,j+1)
plt.imshow(norm_weights,cmap='gray')
plt.title(f"Filters {j+1}")
plt.axis('off')
plt.tight_layout()
st.pyplot(fig)
if uploaded_img is not None:
#st.image(uploaded_img, caption="Uploaded Image", use_column_width=True)
file_bytes = np.frombuffer(uploaded_img.read(), np.uint8)
img = cv2.imdecode(file_bytes, cv2.IMREAD_GRAYSCALE)
img_resized = cv2.resize(img,(28,28),interpolation=cv2.INTER_AREA)
#img_norm = img_resized.astype('float32') / 255.0
input_img = img_resized.reshape(1,28,28,1)
st.image(img_norm, caption="Uploaded Image (Resized to 28x28)", use_container_width =True, channels="GRAY")
#layer_ind = options.index(selected_option)
selected_layer = conv_layers[layer_ind]
#func_model = Model(inputs = model.layers[0].input, outputs = model.selected_layer.output)
func_model = Model(inputs = model.layers[0].input, outputs = selected_layer.output)
fm = func_model.predict(input_img)
fm = fm[0]
if layer_ind == 0:
fig1 = plt.figure(figsize=(12, 4))
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.imshow(fm[:, :, i], cmap='gray')
plt.title(f"Feature Map {i+1}")
plt.axis('off')
elif layer_ind == 1:
fig1 = plt.figure(figsize=(25, 15))
for i in range(16):
plt.subplot(2, 8, i + 1)
plt.imshow(fm[:, :, i], cmap='gray')
plt.title(f"Feature Map {i+1}")
plt.axis('off')
elif layer_ind == 2:
fig1 = plt.figure(figsize=(100, 50))
for i in range(120):
plt.subplot(30, 4, i + 1)
plt.imshow(fm[:, :, i],cmap="gray")
plt.title(f"Feature Map {i+1}")
plt.axis('off')
plt.tight_layout()
st.pyplot(fig1)