Nikhithapotnuru commited on
Commit
3744477
·
verified ·
1 Parent(s): bc14c70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -7
app.py CHANGED
@@ -1,7 +1,63 @@
1
- import streamlit as st
2
- import keras
3
- import matplotlib.pyplot as plt
4
- from keras.datasets import mnist
5
- from sklearn.preprocessing import MinMaxScaler
6
- from keras.models import Sequential,Model
7
- from keras.layers import Dense,ConV2D,MaxPooling2D,AveragePooling2D,InputLayer,Flatten
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from tensorflow.keras.models import Model, load_model
5
+ from tensorflow.keras.layers import Conv2D
6
+ from PIL import Image
7
+ import cv2
8
+
9
+ st.title("CNN Layer Visualization (Upload Model & Image)")
10
+
11
+ col1, col2 = st.columns([1, 2])
12
+
13
+ with col1:
14
+ st.header("Upload Files")
15
+
16
+ # Upload model file
17
+ model_file = st.file_uploader("Upload a Keras model (.h5 or .keras)", type=["h5", "keras"])
18
+
19
+ # Upload image file
20
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
21
+
22
+ if model_file is not None and uploaded_file is not None:
23
+ # Load model
24
+ model = load_model(model_file)
25
+
26
+ # Show uploaded image
27
+ image = Image.open(uploaded_file).convert("RGB")
28
+ st.image(image, caption="Uploaded Image", use_column_width=True)
29
+
30
+ # Preprocess image
31
+ img_array = np.array(image)
32
+ img_resized = cv2.resize(img_array, (128, 128)) / 255.0
33
+ img_input = np.expand_dims(img_resized, axis=0)
34
+
35
+ # Collect conv layer outputs
36
+ conv_layers = [layer for layer in model.layers if isinstance(layer, Conv2D)]
37
+ outputs = [layer.output for layer in conv_layers]
38
+ feature_model = Model(inputs=model.input, outputs=outputs)
39
+ feature_maps = feature_model.predict(img_input)
40
+
41
+ with col2:
42
+ st.header("Layer Output Visualizer")
43
+ if model_file is not None and uploaded_file is not None:
44
+ conv_layer_names = [layer.name for layer in model.layers if isinstance(layer, Conv2D)]
45
+
46
+ # Loop through each conv layer
47
+ for fmap, lname in zip(feature_maps, conv_layer_names):
48
+ num_filters = fmap.shape[-1]
49
+ cols = 8
50
+ rows = num_filters // cols if num_filters % cols == 0 else num_filters // cols + 1
51
+
52
+ fig, axes = plt.subplots(rows, cols, figsize=(15, 15))
53
+ fig.suptitle(f"Layer: {lname}", fontsize=16)
54
+
55
+ for i in range(rows * cols):
56
+ if i < num_filters:
57
+ ax = axes[i // cols, i % cols]
58
+ ax.imshow(fmap[0, :, :, i], cmap="viridis")
59
+ ax.axis("off")
60
+ else:
61
+ axes[i // cols, i % cols].axis("off")
62
+
63
+ st.pyplot(fig)