Sefat33 commited on
Commit
42534cf
·
verified ·
1 Parent(s): 024990b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -87
app.py CHANGED
@@ -7,7 +7,6 @@ import matplotlib.pyplot as plt
7
  import matplotlib.cm as cm
8
  from lime import lime_image
9
  from skimage.segmentation import mark_boundaries
10
- from keras.models import load_model as keras_load_model
11
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
12
 
13
  # --- Fix deserialization issues ---
@@ -28,25 +27,18 @@ DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
28
  # --- Constants ---
29
  IMG_SIZE = (224, 224)
30
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
31
- gradcam_enabled = False # Default to False
32
 
33
- # --- Load model auto-detect ---
34
  @st.cache_resource
35
- def load_model_auto(path="Model"):
36
- global gradcam_enabled
37
- if not os.path.exists(path):
38
- st.error(f"❌ Model path '{path}' not found!")
39
  st.stop()
40
-
41
  try:
42
- if os.path.isdir(path):
43
- model = tf.keras.Sequential([TFSMLayer(path, call_endpoint="serving_default")])
44
- gradcam_enabled = False
45
- elif path.endswith(('.keras', '.h5')):
46
- model = keras_load_model(path)
47
- gradcam_enabled = True
48
- else:
49
- raise ValueError("Unsupported model format. Use .keras, .h5 or SavedModel folder.")
50
  return model
51
  except Exception as e:
52
  st.error(f"❌ Error loading model: {str(e)}")
@@ -74,15 +66,16 @@ def preprocess_and_show_steps(img):
74
 
75
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
76
 
 
77
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
78
  axs[0].imshow(img)
79
- axs[0].set_title("Original")
80
  axs[1].imshow(circ)
81
- axs[1].set_title("Circular Crop")
82
  axs[2].imshow(clahe_img)
83
- axs[2].set_title("CLAHE")
84
  axs[3].imshow(resized)
85
- axs[3].set_title("Sharpened + Resize")
86
  for ax in axs:
87
  ax.axis("off")
88
  st.pyplot(fig)
@@ -90,26 +83,6 @@ def preprocess_and_show_steps(img):
90
 
91
  return resized
92
 
93
- # --- Grad-CAM ---
94
- def find_last_conv_layer(model):
95
- for layer in reversed(model.layers):
96
- if isinstance(layer, tf.keras.layers.Conv2D):
97
- return layer.name
98
- return None
99
-
100
- def generate_gradcam(model, img_array, class_index, layer_name):
101
- grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
102
- with tf.GradientTape() as tape:
103
- conv_outputs, predictions = grad_model(img_array)
104
- loss = predictions[:, class_index]
105
- grads = tape.gradient(loss, conv_outputs)
106
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
107
- conv_outputs = conv_outputs[0]
108
- heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
109
- heatmap = tf.squeeze(heatmap)
110
- heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
111
- return heatmap.numpy()
112
-
113
  # --- LIME Explainer ---
114
  explainer = lime_image.LimeImageExplainer()
115
  def predict_fn(images):
@@ -131,77 +104,55 @@ explanation_text = {
131
  'Others': "Non-specific features detected, marked as Others."
132
  }
133
 
134
- # --- Display explanations ---
135
- def display_explanations(img, input_array, pred_label, pred_idx):
136
- overlayed = None
137
- if gradcam_enabled:
138
- try:
139
- conv_layer = find_last_conv_layer(model)
140
- if conv_layer:
141
- heatmap = generate_gradcam(model, input_array, pred_idx, conv_layer)
142
- heatmap = cv2.resize(heatmap, IMG_SIZE)
143
- heatmap = cv2.GaussianBlur(np.uint8(255 * heatmap), (7, 7), 0)
144
- heatmap_rgb = np.uint8(cm.jet(heatmap / 255.0)[..., :3] * 255)
145
- overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
146
- except Exception as e:
147
- st.warning(f"⚠️ Grad-CAM failed: {e}")
148
- else:
149
- st.info("ℹ️ Grad-CAM is disabled for this model.")
150
-
151
- explanation = explainer.explain_instance(img, classifier_fn=predict_fn, top_labels=1, hide_color=0, num_samples=1000)
152
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
153
 
154
- fig, axs = plt.subplots(1, 3 if overlayed is not None else 2, figsize=(15, 5))
155
  axs[0].imshow(img)
156
- axs[0].set_title("Original")
157
  axs[1].imshow(mark_boundaries(temp, mask))
158
- axs[1].set_title("LIME Explanation")
159
- if overlayed is not None:
160
- axs[2].imshow(overlayed)
161
- axs[2].set_title("Grad-CAM")
162
  for ax in axs:
163
  ax.axis('off')
164
 
165
  summary = explanation_text.get(pred_label, "Model detected features matching this class.")
166
  plt.figtext(0.5, 0.01, summary, wrap=True, ha='center', fontsize=10)
167
- st.pyplot(fig)
168
- plt.close()
169
-
170
- # --- Probability chart ---
171
- def plot_probabilities(probs, class_names):
172
- fig, ax = plt.subplots(figsize=(8, 4))
173
- bars = ax.barh(class_names, probs * 100, color='skyblue')
174
- ax.set_xlim(0, 100)
175
- ax.set_xlabel("Confidence (%)")
176
- ax.set_title("Prediction Probabilities")
177
- for bar, prob in zip(bars, probs):
178
- ax.text(prob * 100 + 1, bar.get_y() + bar.get_height()/2, f"{prob*100:.1f}%", va='center')
179
  st.pyplot(fig)
180
  plt.close()
181
 
182
  # --- Streamlit UI ---
183
- st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
184
- st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
185
 
186
- # Load model
187
- model = load_model_auto("Model") # Folder or .keras/.h5 file
188
 
189
- # Upload image
190
- uploaded_file = st.file_uploader("📤 Upload a retinal image", type=["jpg", "jpeg", "png"])
191
  if uploaded_file:
192
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
193
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
194
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
195
 
196
  processed_img = preprocess_and_show_steps(rgb_img)
197
- input_tensor = np.expand_dims(processed_img, axis=0)
198
 
199
- preds = predict_fn(input_tensor)
 
 
 
200
  pred_idx = np.argmax(preds)
201
  pred_label = CLASS_NAMES[pred_idx]
202
  confidence = np.max(preds) * 100
203
 
204
- st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
205
- plot_probabilities(preds[0], CLASS_NAMES)
206
 
207
- display_explanations(processed_img, input_tensor, pred_label, pred_idx)
 
7
  import matplotlib.cm as cm
8
  from lime import lime_image
9
  from skimage.segmentation import mark_boundaries
 
10
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
11
 
12
  # --- Fix deserialization issues ---
 
27
  # --- Constants ---
28
  IMG_SIZE = (224, 224)
29
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
 
30
 
31
+ # --- Load model using TFSMLayer ---
32
  @st.cache_resource
33
+ def load_model():
34
+ model_path = "Model" # Folder path to TF SavedModel
35
+ if not os.path.exists(model_path):
36
+ st.error(f"❌ Model folder '{model_path}' not found!")
37
  st.stop()
 
38
  try:
39
+ model = tf.keras.Sequential([
40
+ TFSMLayer(model_path, call_endpoint="serving_default")
41
+ ])
 
 
 
 
 
42
  return model
43
  except Exception as e:
44
  st.error(f"❌ Error loading model: {str(e)}")
 
66
 
67
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
68
 
69
+ # Show preprocessing stages
70
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
71
  axs[0].imshow(img)
72
+ axs[0].set_title("Original Image")
73
  axs[1].imshow(circ)
74
+ axs[1].set_title("After Circular Crop")
75
  axs[2].imshow(clahe_img)
76
+ axs[2].set_title("After CLAHE")
77
  axs[3].imshow(resized)
78
+ axs[3].set_title("Sharpen + Resize")
79
  for ax in axs:
80
  ax.axis("off")
81
  st.pyplot(fig)
 
83
 
84
  return resized
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # --- LIME Explainer ---
87
  explainer = lime_image.LimeImageExplainer()
88
  def predict_fn(images):
 
104
  'Others': "Non-specific features detected, marked as Others."
105
  }
106
 
107
+ # --- Display LIME only (Grad-CAM not possible with TFSMLayer) ---
108
+ def display_lime_visualization(img, true_label, pred_label, pred_idx):
109
+ st.info("⚠️ Grad-CAM is disabled because the model is loaded as a TFSMLayer (inference-only).")
110
+
111
+ explanation = explainer.explain_instance(
112
+ image=img,
113
+ classifier_fn=predict_fn,
114
+ top_labels=1,
115
+ hide_color=0,
116
+ num_samples=1000
117
+ )
 
 
 
 
 
 
 
118
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
119
 
120
+ fig, axs = plt.subplots(1, 2, figsize=(12, 5))
121
  axs[0].imshow(img)
122
+ axs[0].set_title(f"Original\nTrue: {true_label}")
123
  axs[1].imshow(mark_boundaries(temp, mask))
124
+ axs[1].set_title(f"LIME Explanation\nPred: {pred_label}")
 
 
 
125
  for ax in axs:
126
  ax.axis('off')
127
 
128
  summary = explanation_text.get(pred_label, "Model detected features matching this class.")
129
  plt.figtext(0.5, 0.01, summary, wrap=True, ha='center', fontsize=10)
130
+ plt.tight_layout(rect=[0, 0.03, 1, 1])
 
 
 
 
 
 
 
 
 
 
 
131
  st.pyplot(fig)
132
  plt.close()
133
 
134
  # --- Streamlit UI ---
135
+ st.set_page_config(page_title="🧠 Retina Disease Classifier with LIME", layout="centered")
136
+ st.title("🧠 Retina Disease Classifier with LIME Explanation")
137
 
138
+ model = load_model()
 
139
 
140
+ uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "jpeg", "png"])
 
141
  if uploaded_file:
142
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
143
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
144
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
145
 
146
  processed_img = preprocess_and_show_steps(rgb_img)
 
147
 
148
+ input_tensor = np.expand_dims(processed_img, axis=0)
149
+ preds = model.predict(input_tensor)
150
+ if isinstance(preds, dict):
151
+ preds = list(preds.values())[0]
152
  pred_idx = np.argmax(preds)
153
  pred_label = CLASS_NAMES[pred_idx]
154
  confidence = np.max(preds) * 100
155
 
156
+ st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
 
157
 
158
+ display_lime_visualization(processed_img, "Uploaded Image", pred_label, pred_idx)