Sefat33 commited on
Commit
8530127
Β·
verified Β·
1 Parent(s): c2dae3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -10,10 +10,16 @@ from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
10
 
11
  # --- Fix deserialization issues ---
12
  original_bn = BatchNormalization.from_config
13
- BatchNormalization.from_config = classmethod(lambda cls, config, *a, **k: original_bn(config if not isinstance(config.get("axis"), list) else {**config, "axis": config["axis"][0]}, *a, **k))
 
 
 
 
14
 
15
  original_dw = DepthwiseConv2D.from_config
16
- DepthwiseConv2D.from_config = classmethod(lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k))
 
 
17
 
18
  # --- Constants ---
19
  IMG_SIZE = (224, 224)
@@ -52,8 +58,11 @@ def preprocess_with_steps(img):
52
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
53
 
54
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
55
- for ax, image, title in zip(axs, [img, circ, clahe_img, resized],
56
- ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]):
 
 
 
57
  ax.imshow(image)
58
  ax.set_title(title)
59
  ax.axis("off")
@@ -74,9 +83,11 @@ def show_lime(img, model, pred_idx, pred_label):
74
  classifier_fn=lambda imgs: predict(imgs, model),
75
  top_labels=1,
76
  hide_color=0,
77
- num_samples=100
 
 
 
78
  )
79
- temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
80
 
81
  fig, ax = plt.subplots(1, 1, figsize=(6, 5))
82
  ax.imshow(mark_boundaries(temp, mask))
@@ -91,24 +102,22 @@ st.title("🧠 Retina Disease Classifier with LIME Explanation")
91
  model = load_model()
92
 
93
  with st.sidebar:
94
- uploaded_files = st.file_uploader("πŸ“‚ Upload retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
 
 
95
  selected_filename = None
96
  if uploaded_files:
97
  filenames = [f.name for f in uploaded_files]
98
- selected_filename = st.selectbox("🎯 Select an image to preprocess and predict", filenames)
99
 
100
- # --- Process selected image ---
101
  if uploaded_files and selected_filename:
102
  file = next(f for f in uploaded_files if f.name == selected_filename)
103
-
104
- # Read bytes once and reset pointer for later use
105
- file_bytes = file.read()
106
  file.seek(0)
107
-
108
- bgr = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
109
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
110
 
111
- st.subheader(f"πŸ” Preprocessing & Prediction for: {selected_filename}")
112
  preprocessed = preprocess_with_steps(rgb)
113
  input_tensor = np.expand_dims(preprocessed, axis=0)
114
 
@@ -118,19 +127,15 @@ if uploaded_files and selected_filename:
118
  confidence = np.max(preds) * 100
119
 
120
  st.success(f"βœ… Prediction: **{pred_label}** ({confidence:.2f}%)")
121
-
122
  show_lime(preprocessed, model, pred_idx, pred_label)
123
 
124
- # --- Show LIME for all images ---
125
  if uploaded_files:
126
  st.markdown("## πŸ§ͺ LIME Explanations for All Images")
127
  cols = st.columns(min(4, len(uploaded_files)))
128
  for i, file in enumerate(uploaded_files):
129
- # Read bytes once and reset pointer
130
- file_bytes = file.read()
131
  file.seek(0)
132
-
133
- bgr = cv2.imdecode(np.frombuffer(file_bytes, np.uint8), cv2.IMREAD_COLOR)
134
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
135
  img = cv2.resize(rgb, IMG_SIZE) / 255.0
136
  input_tensor = np.expand_dims(img, axis=0)
@@ -141,12 +146,19 @@ if uploaded_files:
141
 
142
  with cols[i % len(cols)]:
143
  st.markdown(f"**{file.name}**<br>🧠 *{pred_label}*", unsafe_allow_html=True)
 
144
  explanation = LIME_EXPLAINER.explain_instance(
145
  image=img,
146
  classifier_fn=lambda imgs: predict(imgs, model),
147
  top_labels=1,
148
  hide_color=0,
149
- num_samples=1000
 
 
 
150
  )
151
- temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
152
- st.image(mark_boundaries(temp, mask), use_column_width=True)
 
 
 
 
10
 
11
  # --- Fix deserialization issues ---
12
  original_bn = BatchNormalization.from_config
13
+ BatchNormalization.from_config = classmethod(
14
+ lambda cls, config, *a, **k: original_bn(
15
+ config if not isinstance(config.get("axis"), list) else {**config, "axis": config["axis"][0]}, *a, **k
16
+ )
17
+ )
18
 
19
  original_dw = DepthwiseConv2D.from_config
20
+ DepthwiseConv2D.from_config = classmethod(
21
+ lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k)
22
+ )
23
 
24
  # --- Constants ---
25
  IMG_SIZE = (224, 224)
 
58
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
59
 
60
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
61
+ for ax, image, title in zip(
62
+ axs,
63
+ [img, circ, clahe_img, resized],
64
+ ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"],
65
+ ):
66
  ax.imshow(image)
67
  ax.set_title(title)
68
  ax.axis("off")
 
83
  classifier_fn=lambda imgs: predict(imgs, model),
84
  top_labels=1,
85
  hide_color=0,
86
+ num_samples=50, # reduced samples for speed
87
+ )
88
+ temp, mask = explanation.get_image_and_mask(
89
+ label=pred_idx, positive_only=True, num_features=10, hide_rest=False
90
  )
 
91
 
92
  fig, ax = plt.subplots(1, 1, figsize=(6, 5))
93
  ax.imshow(mark_boundaries(temp, mask))
 
102
  model = load_model()
103
 
104
  with st.sidebar:
105
+ uploaded_files = st.file_uploader(
106
+ "πŸ“‚ Upload retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True
107
+ )
108
  selected_filename = None
109
  if uploaded_files:
110
  filenames = [f.name for f in uploaded_files]
111
+ selected_filename = st.selectbox("🎯 Select an image to explain", filenames)
112
 
113
+ # -- Predict & Display for Selected Image --
114
  if uploaded_files and selected_filename:
115
  file = next(f for f in uploaded_files if f.name == selected_filename)
 
 
 
116
  file.seek(0)
117
+ bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
 
118
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
119
 
120
+ st.subheader("πŸ” Preprocessing Steps")
121
  preprocessed = preprocess_with_steps(rgb)
122
  input_tensor = np.expand_dims(preprocessed, axis=0)
123
 
 
127
  confidence = np.max(preds) * 100
128
 
129
  st.success(f"βœ… Prediction: **{pred_label}** ({confidence:.2f}%)")
 
130
  show_lime(preprocessed, model, pred_idx, pred_label)
131
 
132
+ # -- Show LIME for all images with reduced size side-by-side --
133
  if uploaded_files:
134
  st.markdown("## πŸ§ͺ LIME Explanations for All Images")
135
  cols = st.columns(min(4, len(uploaded_files)))
136
  for i, file in enumerate(uploaded_files):
 
 
137
  file.seek(0)
138
+ bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
 
139
  rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
140
  img = cv2.resize(rgb, IMG_SIZE) / 255.0
141
  input_tensor = np.expand_dims(img, axis=0)
 
146
 
147
  with cols[i % len(cols)]:
148
  st.markdown(f"**{file.name}**<br>🧠 *{pred_label}*", unsafe_allow_html=True)
149
+
150
  explanation = LIME_EXPLAINER.explain_instance(
151
  image=img,
152
  classifier_fn=lambda imgs: predict(imgs, model),
153
  top_labels=1,
154
  hide_color=0,
155
+ num_samples=50,
156
+ )
157
+ temp, mask = explanation.get_image_and_mask(
158
+ label=pred_idx, positive_only=True, num_features=10, hide_rest=False
159
  )
160
+ fig, ax = plt.subplots(figsize=(4, 3))
161
+ ax.imshow(mark_boundaries(temp, mask))
162
+ ax.axis("off")
163
+ st.pyplot(fig, use_container_width=False)
164
+ plt.close(fig)