Sefat33 commited on
Commit
7e3aa57
Β·
verified Β·
1 Parent(s): 5d50d92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -77
app.py CHANGED
@@ -8,6 +8,7 @@ from keras.models import Model
8
  from keras.saving import register_keras_serializable
9
  from keras.layers import TFSMLayer
10
  import matplotlib.pyplot as plt
 
11
  from lime import lime_image
12
  from skimage.segmentation import mark_boundaries
13
 
@@ -40,28 +41,28 @@ class CustomTFOpLambda(tf.keras.layers.Layer):
40
  config.update({"function": self.function})
41
  return config
42
 
43
- # --- Constants ---
44
  IMG_SIZE = (224, 224)
45
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
46
 
47
- # --- Load SavedModel as TFSMLayer wrapped model ---
48
  @st.cache_resource
49
  def load_model():
50
- model_path = "Model" # Your SavedModel directory path
51
  if not os.path.exists(model_path):
52
  st.error(f"❌ Model directory '{model_path}' not found!")
53
  st.stop()
54
  try:
55
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
56
- inputs = Input(shape=(224, 224, 3))
57
  outputs = tfsm_layer(inputs)
58
  model = Model(inputs=inputs, outputs=outputs)
59
  return model
60
  except Exception as e:
61
- st.error(f"❌ Error loading model with TFSMLayer: {str(e)}")
62
  st.stop()
63
 
64
- # --- Preprocessing functions ---
65
  def crop_circle(img):
66
  h, w = img.shape[:2]
67
  center = (w // 2, h // 2)
@@ -92,93 +93,113 @@ def preprocess_image(img):
92
  clahe = apply_clahe(circ)
93
  sharp = sharpen_image(clahe)
94
  resized = resize_normalize(sharp)
95
- return circ, clahe, sharp, resized
96
 
97
- # --- Grad-CAM ---
98
- def show_gradcam(model, img, class_idx):
99
- last_conv_layer = None
100
  for layer in reversed(model.layers):
101
- if isinstance(layer, tf.keras.layers.Conv2D):
102
- last_conv_layer = layer.name
103
- break
104
- if last_conv_layer is None:
105
- st.warning("⚠️ No Conv2D layer found for Grad-CAM.")
106
- return
107
-
108
- grad_model = Model(model.inputs, [model.get_layer(last_conv_layer).output, model.output])
109
- img_tensor = tf.convert_to_tensor(img[np.newaxis, ...])
110
-
111
  with tf.GradientTape() as tape:
112
- conv_outputs, predictions = grad_model(img_tensor)
113
- loss = predictions[:, class_idx]
114
- grads = tape.gradient(loss, conv_outputs)[0]
115
- cam = tf.reduce_mean(grads, axis=-1).numpy()
116
-
117
- cam = np.maximum(cam, 0)
118
- cam = cv2.resize(cam, IMG_SIZE)
119
- cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-10)
120
-
121
- heatmap = np.uint8(255 * cam)
122
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
123
- overlay = cv2.addWeighted(np.uint8(img * 255), 0.6, heatmap, 0.4, 0)
124
-
125
- st.subheader("πŸ”₯ Grad-CAM")
126
- st.image(overlay, use_container_width=True)
127
-
128
- # --- LIME ---
129
- def show_lime(model, img, class_idx):
130
- explainer = lime_image.LimeImageExplainer()
131
-
132
- def predict_fn(images):
133
- images = np.array(images)
134
- preds = model.predict(images)
135
- if isinstance(preds, dict):
136
- preds = list(preds.values())[0]
137
- return preds
138
-
139
- explanation = explainer.explain_instance(np.uint8(img * 255), predict_fn, top_labels=1, hide_color=0, num_samples=1000)
140
- lime_img, mask = explanation.get_image_and_mask(class_idx, positive_only=True, hide_rest=False)
141
-
142
- st.subheader("🟒 LIME Explanation")
143
- st.image(mark_boundaries(lime_img, mask), use_container_width=True)
144
-
145
- # --- Streamlit UI ---
146
- st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
147
- st.title("🧠 Retina Disease Classifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  model = load_model()
 
150
 
151
- uploaded_file = st.file_uploader("πŸ“€ Upload a retinal image", type=["jpg", "jpeg", "png"])
152
  if uploaded_file:
153
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
154
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
155
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
156
 
157
- circ, clahe, sharp, final = preprocess_image(rgb_img)
158
-
159
- st.subheader("πŸ§ͺ Preprocessing Pipeline (Left ➝ Right)")
160
- steps = [
161
- ("πŸ“· Original", rgb_img),
162
- ("πŸ”΅ Circular Crop", circ),
163
- ("βšͺ CLAHE", clahe),
164
- ("🟣 Sharpened", sharp),
165
- ("πŸ“ Resized", (final * 255).astype(np.uint8))
166
- ]
167
- cols = st.columns(len(steps))
168
- for col, (label, img) in zip(cols, steps):
169
- col.image(img, caption=label, use_container_width=True)
170
-
171
- input_tensor = np.expand_dims(final, axis=0)
172
  preds = model.predict(input_tensor)
173
  if isinstance(preds, dict):
174
  preds = list(preds.values())[0]
175
-
176
  pred_idx = np.argmax(preds)
177
  pred_label = CLASS_NAMES[pred_idx]
178
  confidence = np.max(preds) * 100
179
 
180
- st.success(f"βœ… Prediction: **{pred_label}**")
181
- st.info(f"πŸ” Confidence: {confidence:.2f}%")
182
 
183
- show_gradcam(model, final, pred_idx)
184
- show_lime(model, final, pred_idx)
 
8
  from keras.saving import register_keras_serializable
9
  from keras.layers import TFSMLayer
10
  import matplotlib.pyplot as plt
11
+ import matplotlib.cm as cm
12
  from lime import lime_image
13
  from skimage.segmentation import mark_boundaries
14
 
 
41
  config.update({"function": self.function})
42
  return config
43
 
44
+ # Constants
45
  IMG_SIZE = (224, 224)
46
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
47
 
48
+ # Load model using TFSMLayer
49
  @st.cache_resource
50
  def load_model():
51
+ model_path = "Model" # Adjust path to your SavedModel directory
52
  if not os.path.exists(model_path):
53
  st.error(f"❌ Model directory '{model_path}' not found!")
54
  st.stop()
55
  try:
56
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
57
+ inputs = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
58
  outputs = tfsm_layer(inputs)
59
  model = Model(inputs=inputs, outputs=outputs)
60
  return model
61
  except Exception as e:
62
+ st.error(f"❌ Error loading model: {str(e)}")
63
  st.stop()
64
 
65
+ # Preprocessing functions
66
  def crop_circle(img):
67
  h, w = img.shape[:2]
68
  center = (w // 2, h // 2)
 
93
  clahe = apply_clahe(circ)
94
  sharp = sharpen_image(clahe)
95
  resized = resize_normalize(sharp)
96
+ return resized
97
 
98
+ # --- Find Last Conv Layer for Grad-CAM ---
99
+ def find_last_conv_layer(model):
 
100
  for layer in reversed(model.layers):
101
+ # Check if Conv2D or custom attention output layer name pattern
102
+ if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
103
+ return layer.name
104
+ raise ValueError("No suitable conv layer found.")
105
+
106
+ # Grad-CAM generation
107
+ def generate_gradcam(model, img_array, class_index, layer_name):
108
+ grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
 
 
109
  with tf.GradientTape() as tape:
110
+ conv_outputs, predictions = grad_model(img_array)
111
+ loss = predictions[:, class_index]
112
+ grads = tape.gradient(loss, conv_outputs)
113
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
114
+ heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
115
+ heatmap = tf.squeeze(heatmap)
116
+ heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
117
+ return heatmap.numpy()
118
+
119
+ # LIME explainer
120
+ explainer = lime_image.LimeImageExplainer()
121
+ def predict_fn(images):
122
+ images = np.array(images)
123
+ preds = model.predict(images, verbose=0)
124
+ # If output is dict (due to TFSMLayer), extract predictions properly
125
+ if isinstance(preds, dict):
126
+ preds = list(preds.values())[0]
127
+ return preds
128
+
129
+ # Explanation text per class
130
+ explanation_text = {
131
+ 'Normal': "Model predicted Normal based on healthy optic disc and macula.",
132
+ 'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
133
+ 'Glaucoma': "Detected increased cupping in the optic disc indicating Glaucoma.",
134
+ 'Cataract': "Image blur indicated potential Cataract.",
135
+ 'AMD': "Degeneration signs in macula indicate AMD.",
136
+ 'Hypertension': "Blood vessel narrowing/hemorrhages indicate Hypertension.",
137
+ 'Myopia': "Tilted disc and fundus shape suggest Myopia.",
138
+ 'Others': "Non-specific features detected, marked as Others."
139
+ }
140
+
141
+ # Visualization in Streamlit
142
+ def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
143
+ input_array = np.expand_dims(img, axis=0)
144
+
145
+ # Grad-CAM heatmap
146
+ heatmap = generate_gradcam(model, input_array, pred_idx, layer_name)
147
+ heatmap = cv2.resize(heatmap, IMG_SIZE)
148
+ heatmap = np.uint8(255 * heatmap)
149
+ heatmap = cv2.GaussianBlur(heatmap, (7, 7), 0)
150
+ heatmap_rgb = cm.jet(heatmap / 255.0)[..., :3]
151
+ heatmap_rgb = np.uint8(heatmap_rgb * 255)
152
+ overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
153
+
154
+ # LIME explanation
155
+ explanation = explainer.explain_instance(
156
+ image=img, classifier_fn=predict_fn,
157
+ top_labels=1, hide_color=0, num_samples=1000
158
+ )
159
+ temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
160
+
161
+ # Plot side by side
162
+ fig, axs = plt.subplots(1, 3, figsize=(15, 5))
163
+ axs[0].imshow(img)
164
+ axs[0].set_title(f"Original\nTrue: {true_label}", fontsize=11)
165
+ axs[1].imshow(overlayed)
166
+ axs[1].set_title(f"Grad-CAM\nPred: {pred_label}", fontsize=11)
167
+ axs[2].imshow(mark_boundaries(temp, mask))
168
+ axs[2].set_title(f"LIME\nPred: {pred_label}", fontsize=11)
169
+ for ax in axs:
170
+ ax.axis('off')
171
+ summary = explanation_text.get(pred_label, "Model detected features matching this class.")
172
+ plt.figtext(0.5, 0.01, summary, wrap=True, ha='center', fontsize=10)
173
+ plt.tight_layout(rect=[0, 0.03, 1, 1])
174
+ st.pyplot(fig)
175
+ plt.close()
176
+
177
+ # Streamlit app UI
178
+ st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
179
+ st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
180
 
181
  model = load_model()
182
+ last_conv_layer_name = find_last_conv_layer(model)
183
 
184
+ uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "jpeg", "png"])
185
  if uploaded_file:
186
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
187
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
188
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
189
 
190
+ # Preprocess image
191
+ processed_img = preprocess_image(rgb_img)
192
+
193
+ # Predict
194
+ input_tensor = np.expand_dims(processed_img, axis=0)
 
 
 
 
 
 
 
 
 
 
195
  preds = model.predict(input_tensor)
196
  if isinstance(preds, dict):
197
  preds = list(preds.values())[0]
 
198
  pred_idx = np.argmax(preds)
199
  pred_label = CLASS_NAMES[pred_idx]
200
  confidence = np.max(preds) * 100
201
 
202
+ st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
 
203
 
204
+ # Show Grad-CAM and LIME visualizations
205
+ display_combined_visualization(processed_img, "Unknown (Uploaded)", pred_label, pred_idx, last_conv_layer_name)