Nikhithapotnuru commited on
Commit
0ec8502
·
verified ·
1 Parent(s): 3c86acf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -5,29 +5,29 @@ from keras.models import Model
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  from tensorflow import keras
8
- from keras.layers import Conv2D, AveragePooling2D
9
  from PIL import Image
10
 
11
  m = MinMaxScaler()
12
- model = keras.models.load_model('cnn_lenet.keras')
13
-
14
- # Sidebar
15
- option = st.sidebar.selectbox("Datasets", ["Select dataset", "Hand Written Digit Dataset"])
16
 
 
17
  uploaded_file = st.sidebar.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
18
 
19
- if option == "Hand Written Digit Dataset":
20
- (x_train, y_train), (x_test, y_test) = mnist.load_data()
21
- st.write("✅ Successfully Loaded the MNIST Dataset")
22
 
23
- # Preprocess uploaded image or use MNIST sample
24
  if uploaded_file is not None:
25
  image = Image.open(uploaded_file).convert("L") # grayscale
26
  image = image.resize((28, 28)) # resize to model input
27
  img_array = np.array(image) / 255.0
28
  img_array = img_array.reshape(1, 28, 28, 1) # reshape for CNN
 
 
29
  else:
30
- img_array = x_train[0:1].reshape(1, 28, 28, 1) # fallback to MNIST sample
 
 
31
 
32
  # Button to visualize
33
  if st.button("Visualize Layers"):
@@ -37,7 +37,10 @@ if st.button("Visualize Layers"):
37
  st.write("Filters")
38
  fig, axs = plt.subplots(6, 1, figsize=(8, 6))
39
  for i in range(6):
40
- axs[i].imshow(m.fit_transform(model.layers[0].weights[0][:, :, :, i][:, :, 0]), cmap="gray")
 
 
 
41
  axs[i].axis("off")
42
  st.pyplot(fig)
43
 
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  from tensorflow import keras
 
8
  from PIL import Image
9
 
10
  m = MinMaxScaler()
11
+ model = keras.models.load_model("cnn_lenet.keras")
 
 
 
12
 
13
+ # Sidebar upload
14
  uploaded_file = st.sidebar.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
15
 
16
+ # Load MNIST for fallback
17
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
 
18
 
19
+ # Preprocess uploaded image or fallback
20
  if uploaded_file is not None:
21
  image = Image.open(uploaded_file).convert("L") # grayscale
22
  image = image.resize((28, 28)) # resize to model input
23
  img_array = np.array(image) / 255.0
24
  img_array = img_array.reshape(1, 28, 28, 1) # reshape for CNN
25
+ st.success("✅ Successfully uploaded image")
26
+ st.image(image, caption="Uploaded Image (28x28)", use_column_width=True)
27
  else:
28
+ img_array = x_train[0:1].reshape(1, 28, 28, 1)
29
+ st.info("ℹ️ No image uploaded, using MNIST sample instead")
30
+ st.image(x_train[0], caption="MNIST Sample", use_column_width=True)
31
 
32
  # Button to visualize
33
  if st.button("Visualize Layers"):
 
37
  st.write("Filters")
38
  fig, axs = plt.subplots(6, 1, figsize=(8, 6))
39
  for i in range(6):
40
+ axs[i].imshow(
41
+ m.fit_transform(model.layers[0].weights[0][:, :, :, i][:, :, 0]),
42
+ cmap="gray"
43
+ )
44
  axs[i].axis("off")
45
  st.pyplot(fig)
46