CNN_Explainer / app.py
Nikhithapotnuru's picture
Update app.py
8d5b066 verified
import streamlit as st
from keras.datasets import mnist
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential,Model
import matplotlib.pyplot as plt
m = MinMaxScaler()
from tensorflow import keras
model = keras.models.load_model('cnn_lenet.keras')
from keras.layers import Conv2D,MaxPooling2D,AveragePooling2D,InputLayer,Dense,Flatten
option = st.sidebar.selectbox("Datasets",["Select dataset","Hand Writen Digit Dataset"])
if option == "Hand Writen Digit Dataset":
(x_train,y_train),(x_test,y_test) = mnist.load_data()
st.write("Successfully Load the Dataset")
if st.button("Train"):
fig, axs = plt.subplots(6, 1, figsize=(8, 6))
col1,col2,col3,col4,col5,col6 = st.columns(6)
with col1:
st.write("Filters")
for i in range(6):
axs[i].imshow(
m.fit_transform(model.layers[0].weights[0][:, :, :, i][:, :, 0]),
cmap="gray"
)
axs[i].axis("off")
st.pyplot(fig)
with col2:
st.write("Conv2D Layer-1")
sub = Model(inputs=model.inputs[0],outputs=model.layers[0].output)
fig,ax = plt.subplots(6,1,figsize=(8,6))
for i in range(6):
ax[i].imshow(sub.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
ax[i].axis("off")
st.pyplot(fig)
with col3:
st.write("Average Pooling Layer-1")
max1 = Model(inputs=model.inputs[0],outputs=model.layers[1].output)
fig,ax1 = plt.subplots(6,1,figsize=(8,6))
for i in range(6):
ax1[i].imshow(max1.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
ax1[i].axis("off")
st.pyplot(fig)
with col4:
st.write("Conv2D Layer-2")
sub1 = Model(inputs=model.inputs[0],outputs=model.layers[2].output)
fig,ax2 = plt.subplots(16,1,figsize=(8,6))
for i in range(16):
ax2[i].imshow(sub1.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
ax2[i].axis("off")
st.pyplot(fig)
with col5:
st.write("Average Pooling Layer-2")
max2 = Model(inputs=model.inputs[0],outputs=model.layers[3].output)
fig,ax3 = plt.subplots(16,1,figsize=(8,6))
for i in range(16):
ax3[i].imshow(max2.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
ax3[i].axis("off")
st.pyplot(fig)
with col6:
st.write("Conv2D Layer-3")
sub3 = Model(inputs=model.inputs[0],outputs=model.layers[4].output)
fig,ax4 = plt.subplots(120,1,figsize=(8,6))
for i in range(120):
ax4[i].imshow(sub3.predict(x_train[0:1,:].reshape(1,28,28,1))[0,:,:,i],cmap="gray")
ax4[i].axis("off")
st.pyplot(fig)