Sefat33 commited on
Commit
3b6e0e2
Β·
verified Β·
1 Parent(s): aa68fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -130
app.py CHANGED
@@ -9,35 +9,8 @@ from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
9
  import os
10
  from io import BytesIO
11
  import base64
12
- st.markdown("""
13
- <style>
14
- .equal-cols {
15
- display: flex;
16
- gap: 1.5rem;
17
- }
18
- .equal-cols > div {
19
- flex: 1;
20
- display: flex;
21
- flex-direction: column;
22
- }
23
- .lime-image {
24
- width: 100%;
25
- border-radius: 10px;
26
- }
27
- .overlay-box {
28
- background-color: rgba(255, 255, 255, 0.85);
29
- padding: 1rem;
30
- border-radius: 10px;
31
- overflow-y: auto;
32
- color: #333;
33
- font-size: 16px;
34
- line-height: 1.5;
35
- height: 100%;
36
- }
37
- </style>
38
- """, unsafe_allow_html=True)
39
 
40
- # --- Fix deserialization issues for BatchNorm and DepthwiseConv2D ---
41
  original_bn = BatchNormalization.from_config
42
  BatchNormalization.from_config = classmethod(
43
  lambda cls, config, *a, **k: original_bn(
@@ -49,7 +22,7 @@ DepthwiseConv2D.from_config = classmethod(
49
  lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k)
50
  )
51
 
52
- # --- Background Setup ---
53
  def set_background(image_path):
54
  with open(image_path, "rb") as f:
55
  encoded = base64.b64encode(f.read()).decode()
@@ -81,14 +54,8 @@ set_background("5858.jpg")
81
  # --- Constants ---
82
  IMG_SIZE = (224, 224)
83
  CLASS_NAMES = [
84
- 'Normal',
85
- 'Diabetic Retinopathy',
86
- 'Glaucoma',
87
- 'Cataract',
88
- 'Age-related Macular Degeneration (AMD)',
89
- 'Hypertension',
90
- 'Myopia',
91
- 'Others'
92
  ]
93
  LIME_EXPLAINER = lime_image.LimeImageExplainer()
94
 
@@ -100,14 +67,25 @@ def load_model():
100
  st.error(f"Model folder '{model_path}' not found.")
101
  st.stop()
102
  try:
103
- # Load TensorFlow SavedModel as an inference-only Keras layer
104
  model = tf.keras.Sequential([TFSMLayer(model_path, call_endpoint="serving_default")])
105
  return model
106
  except Exception as e:
107
  st.error(f"Error loading model: {e}")
108
  st.stop()
109
 
110
- # --- Preprocessing ---
 
 
 
 
 
 
 
 
 
 
 
 
111
  def preprocess_with_steps(img):
112
  h, w = img.shape[:2]
113
  center, radius = (w // 2, h // 2), min(w, h) // 2
@@ -115,22 +93,18 @@ def preprocess_with_steps(img):
115
  dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
116
  mask = (dist <= radius).astype(np.uint8)
117
 
118
- # Apply mask and replace black background with white
119
  circ = img.copy()
120
- white_background = np.ones_like(circ, dtype=np.uint8) * 255 # White background
121
- circ = np.where(mask[:, :, np.newaxis] == 1, circ, white_background)
122
 
123
- # CLAHE
124
  lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
125
  cl = cv2.createCLAHE(clipLimit=2.0).apply(lab[:, :, 0])
126
  merged = cv2.merge((cl, lab[:, :, 1], lab[:, :, 2]))
127
  clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
128
 
129
- # Sharpen + Resize
130
  sharp = cv2.addWeighted(clahe_img, 4, cv2.GaussianBlur(clahe_img, (0, 0), 10), -4, 128)
131
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
132
 
133
- # Visualization
134
  fig, axs = plt.subplots(1, 4, figsize=(16, 4))
135
  for ax, image, title in zip(
136
  axs, [img, circ, clahe_img, resized],
@@ -139,76 +113,22 @@ def preprocess_with_steps(img):
139
  ax.imshow(image)
140
  ax.set_title(title)
141
  ax.axis("off")
142
-
143
- plt.tight_layout()
144
  st.pyplot(fig)
145
  plt.close(fig)
146
  return resized
147
 
148
-
149
  # --- Reasoning Text ---
150
  explanation_text = {
151
- 'Normal': """βœ… **Normal** <br>
152
- - 🟒 Clear retina, no lesions <br>
153
- - 🩺 Blood vessels normal <br>
154
- - πŸ‘ Healthy optic disc & macula <br>
155
- βœ”οΈ No signs of retinal disease.""",
156
-
157
- 'Diabetic Retinopathy': """πŸ’‰ **Diabetic Retinopathy** <br>
158
- - πŸ”Ά Red spots / hemorrhages <br>
159
- - 🩸 Leaking or swollen vessels <br>
160
- - πŸ‘ Macula possibly thickened <br>
161
- ⚠️ Indicative of diabetes-related damage.""",
162
-
163
- 'Glaucoma': """πŸ‘ **Glaucoma** <br>
164
- - πŸ”΄ Thinned nerve fiber layer <br>
165
- - πŸ’‰ Cupping in optic disc <br>
166
- - πŸ‘ Risk of peripheral vision loss <br>
167
- πŸ”΄ May need long-term eye pressure control.""",
168
-
169
- 'Cataract': """🌫 **Cataract** <br>
170
- - 🌫 Cloudy or hazy image <br>
171
- - πŸ‘ Disc/macula not clearly visible <br>
172
- - πŸ” Overall low contrast <br>
173
- ⚠️ Likely due to lens opacity.""",
174
-
175
- 'Age-related Macular Degeneration (AMD)': """πŸ§“ **AMD** <br>
176
- - πŸ”΄ Yellow drusen near macula <br>
177
- - πŸ‘ Center vision affected <br>
178
- - 🩺 Degenerative macula changes <br>
179
- ⚠️ Early to moderate AMD signs.""",
180
-
181
- 'Hypertension': """⚠️ **Hypertension** <br>
182
- - πŸ”Ά Bright lesions or hemorrhages <br>
183
- - 🩸 Twisted/narrowed vessels <br>
184
- - πŸ‘ Star or flame-like patterns <br>
185
- ⚠️ Vascular damage from high BP.""",
186
-
187
- 'Myopia': """πŸ‘“ **Myopia** <br>
188
- - πŸ”΅ Elongated eyeball signs <br>
189
- - 🩺 Slight disc tilting <br>
190
- - πŸ‘ Possible peripapillary atrophy <br>
191
- ℹ️ Common in nearsighted eyes.""",
192
-
193
- 'Others': """πŸ”Ž **Others** <br>
194
- - βšͺ Unusual or unclassified patterns <br>
195
- - 🩸 Irregular vascular changes <br>
196
- - πŸ‘ Disc or macula abnormalities <br>
197
- ❓ Possibly rare or overlapping conditions."""
198
  }
199
- # --- Prediction ---
200
- def predict(images, model):
201
- images = np.array(images)
202
- preds = model.predict(images, verbose=0)
203
- if isinstance(preds, dict): # Handle dict output (SavedModel case)
204
- for v in preds.values():
205
- if isinstance(v, (np.ndarray, list)):
206
- return np.array(v)
207
- return np.array(list(preds.values())[0])
208
- else:
209
- return preds
210
 
211
- # --- LIME Display ---
212
  # --- LIME Display ---
213
  def show_lime(img, model, pred_idx, pred_label, all_probs):
214
  with st.spinner("🟑 Generating LIME explanation..."):
@@ -224,44 +144,28 @@ def show_lime(img, model, pred_idx, pred_label, all_probs):
224
  )
225
  lime_img = mark_boundaries(temp, mask)
226
 
227
- # Save to buffer once
228
  buf = BytesIO()
229
  plt.imsave(buf, lime_img, format="png")
230
  buf.seek(0)
231
  lime_data = buf.getvalue()
232
 
233
- # Custom equal-height container
234
- st.markdown('<div class="equal-cols">', unsafe_allow_html=True)
235
-
236
- # Left column with image
237
- with st.container():
238
  st.markdown("### πŸ“ LIME Explanation")
239
- st.image(lime_data, use_container_width=True, output_format="PNG")
240
  st.download_button(
241
  "πŸ“₯ Download LIME Image",
242
  lime_data,
243
  file_name=f"{pred_label}_LIME.png",
244
  mime="image/png"
245
  )
246
-
247
- # Right column with explanation
248
- with st.container():
249
  st.markdown(
250
- f"<div class='overlay-box'>"
251
- f"<h3>🧠 Model's Reasoning</h3>"
252
- f"{explanation_text.get(pred_label, 'No explanation available.')}"
253
- "</div>",
254
  unsafe_allow_html=True
255
  )
256
 
257
- st.markdown('</div>', unsafe_allow_html=True)
258
-
259
-
260
-
261
-
262
-
263
-
264
- # --- Main App UI ---
265
  st.set_page_config(page_title="πŸ‘ Retina Classifier with LIME", layout="wide")
266
  st.title("πŸ‘ Retina Disease Classifier with LIME Explanation")
267
 
@@ -294,4 +198,4 @@ if uploaded_files and selected_filename:
294
  st.success(f"βœ… Prediction: **{pred_label}** ({confidence:.2f}%)")
295
  show_lime(preprocessed, model, pred_idx, pred_label, preds)
296
  else:
297
- st.info("Upload retinal images from the sidebar to get started.")
 
9
  import os
10
  from io import BytesIO
11
  import base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # --- Fix deserialization issues ---
14
  original_bn = BatchNormalization.from_config
15
  BatchNormalization.from_config = classmethod(
16
  lambda cls, config, *a, **k: original_bn(
 
22
  lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k)
23
  )
24
 
25
+ # --- Set Background ---
26
  def set_background(image_path):
27
  with open(image_path, "rb") as f:
28
  encoded = base64.b64encode(f.read()).decode()
 
54
  # --- Constants ---
55
  IMG_SIZE = (224, 224)
56
  CLASS_NAMES = [
57
+ 'Normal', 'Diabetic Retinopathy', 'Glaucoma', 'Cataract',
58
+ 'Age-related Macular Degeneration (AMD)', 'Hypertension', 'Myopia', 'Others'
 
 
 
 
 
 
59
  ]
60
  LIME_EXPLAINER = lime_image.LimeImageExplainer()
61
 
 
67
  st.error(f"Model folder '{model_path}' not found.")
68
  st.stop()
69
  try:
 
70
  model = tf.keras.Sequential([TFSMLayer(model_path, call_endpoint="serving_default")])
71
  return model
72
  except Exception as e:
73
  st.error(f"Error loading model: {e}")
74
  st.stop()
75
 
76
+ # --- Prediction ---
77
+ def predict(images, model):
78
+ images = np.array(images)
79
+ preds = model.predict(images, verbose=0)
80
+ if isinstance(preds, dict):
81
+ for v in preds.values():
82
+ if isinstance(v, (np.ndarray, list)):
83
+ return np.array(v)
84
+ return np.array(list(preds.values())[0])
85
+ else:
86
+ return preds
87
+
88
+ # --- Preprocessing Steps ---
89
  def preprocess_with_steps(img):
90
  h, w = img.shape[:2]
91
  center, radius = (w // 2, h // 2), min(w, h) // 2
 
93
  dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
94
  mask = (dist <= radius).astype(np.uint8)
95
 
 
96
  circ = img.copy()
97
+ white_bg = np.ones_like(circ, dtype=np.uint8) * 255
98
+ circ = np.where(mask[:, :, np.newaxis] == 1, circ, white_bg)
99
 
 
100
  lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
101
  cl = cv2.createCLAHE(clipLimit=2.0).apply(lab[:, :, 0])
102
  merged = cv2.merge((cl, lab[:, :, 1], lab[:, :, 2]))
103
  clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
104
 
 
105
  sharp = cv2.addWeighted(clahe_img, 4, cv2.GaussianBlur(clahe_img, (0, 0), 10), -4, 128)
106
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
107
 
 
108
  fig, axs = plt.subplots(1, 4, figsize=(16, 4))
109
  for ax, image, title in zip(
110
  axs, [img, circ, clahe_img, resized],
 
113
  ax.imshow(image)
114
  ax.set_title(title)
115
  ax.axis("off")
 
 
116
  st.pyplot(fig)
117
  plt.close(fig)
118
  return resized
119
 
 
120
  # --- Reasoning Text ---
121
  explanation_text = {
122
+ 'Normal': """βœ… **Normal**<br>- 🟒 Clear retina, no lesions<br>- 🩺 Blood vessels normal<br>- πŸ‘ Healthy optic disc & macula<br>βœ”οΈ No signs of retinal disease.""",
123
+ 'Diabetic Retinopathy': """πŸ’‰ **Diabetic Retinopathy**<br>- πŸ”Ά Red spots / hemorrhages<br>- 🩸 Leaking or swollen vessels<br>- πŸ‘ Macula possibly thickened<br>⚠️ Diabetes-related damage.""",
124
+ 'Glaucoma': """πŸ‘ **Glaucoma**<br>- πŸ”΄ Thinned nerve fiber layer<br>- πŸ’‰ Cupping in optic disc<br>- πŸ‘ Risk of peripheral vision loss<br>πŸ”΄ Long-term eye pressure control needed.""",
125
+ 'Cataract': """🌫 **Cataract**<br>- 🌫 Cloudy or hazy image<br>- πŸ‘ Disc/macula not clearly visible<br>- πŸ” Overall low contrast<br>⚠️ Likely due to lens opacity.""",
126
+ 'Age-related Macular Degeneration (AMD)': """πŸ§“ **AMD**<br>- πŸ”΄ Yellow drusen near macula<br>- πŸ‘ Center vision affected<br>- 🩺 Degenerative macula changes<br>⚠️ Early/moderate AMD signs.""",
127
+ 'Hypertension': """⚠️ **Hypertension**<br>- πŸ”Ά Bright lesions or hemorrhages<br>- 🩸 Twisted/narrowed vessels<br>- πŸ‘ Star or flame-like patterns<br>⚠️ Vascular damage from high BP.""",
128
+ 'Myopia': """πŸ‘“ **Myopia**<br>- πŸ”΅ Elongated eyeball signs<br>- 🩺 Slight disc tilting<br>- πŸ‘ Possible peripapillary atrophy<br>ℹ️ Common in nearsighted eyes.""",
129
+ 'Others': """πŸ”Ž **Others**<br>- βšͺ Unusual or unclassified patterns<br>- 🩸 Irregular vascular changes<br>- πŸ‘ Disc or macula abnormalities<br>❓ Possibly rare or overlapping conditions."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  }
 
 
 
 
 
 
 
 
 
 
 
131
 
 
132
  # --- LIME Display ---
133
  def show_lime(img, model, pred_idx, pred_label, all_probs):
134
  with st.spinner("🟑 Generating LIME explanation..."):
 
144
  )
145
  lime_img = mark_boundaries(temp, mask)
146
 
 
147
  buf = BytesIO()
148
  plt.imsave(buf, lime_img, format="png")
149
  buf.seek(0)
150
  lime_data = buf.getvalue()
151
 
152
+ col1, col2 = st.columns(2)
153
+ with col1:
 
 
 
154
  st.markdown("### πŸ“ LIME Explanation")
155
+ st.image(lime_data, width=224, output_format="PNG") # πŸ‘ˆ Small LIME image
156
  st.download_button(
157
  "πŸ“₯ Download LIME Image",
158
  lime_data,
159
  file_name=f"{pred_label}_LIME.png",
160
  mime="image/png"
161
  )
162
+ with col2:
 
 
163
  st.markdown(
164
+ f"<div class='overlay'>{explanation_text.get(pred_label, 'No explanation available.')}</div>",
 
 
 
165
  unsafe_allow_html=True
166
  )
167
 
168
+ # --- Streamlit App UI ---
 
 
 
 
 
 
 
169
  st.set_page_config(page_title="πŸ‘ Retina Classifier with LIME", layout="wide")
170
  st.title("πŸ‘ Retina Disease Classifier with LIME Explanation")
171
 
 
198
  st.success(f"βœ… Prediction: **{pred_label}** ({confidence:.2f}%)")
199
  show_lime(preprocessed, model, pred_idx, pred_label, preds)
200
  else:
201
+ st.info("πŸ“€ Upload a retinal image from the sidebar to get started.")