Sefat33 commited on
Commit
175fcdd
Β·
verified Β·
1 Parent(s): 7e3aa57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -98
app.py CHANGED
@@ -8,7 +8,6 @@ from keras.models import Model
8
  from keras.saving import register_keras_serializable
9
  from keras.layers import TFSMLayer
10
  import matplotlib.pyplot as plt
11
- import matplotlib.cm as cm
12
  from lime import lime_image
13
  from skimage.segmentation import mark_boundaries
14
 
@@ -41,28 +40,28 @@ class CustomTFOpLambda(tf.keras.layers.Layer):
41
  config.update({"function": self.function})
42
  return config
43
 
44
- # Constants
45
  IMG_SIZE = (224, 224)
46
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
47
 
48
- # Load model using TFSMLayer
49
  @st.cache_resource
50
  def load_model():
51
- model_path = "Model" # Adjust path to your SavedModel directory
52
  if not os.path.exists(model_path):
53
  st.error(f"❌ Model directory '{model_path}' not found!")
54
  st.stop()
55
  try:
56
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
57
- inputs = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
58
  outputs = tfsm_layer(inputs)
59
  model = Model(inputs=inputs, outputs=outputs)
60
  return model
61
  except Exception as e:
62
- st.error(f"❌ Error loading model: {str(e)}")
63
  st.stop()
64
 
65
- # Preprocessing functions
66
  def crop_circle(img):
67
  h, w = img.shape[:2]
68
  center = (w // 2, h // 2)
@@ -93,113 +92,93 @@ def preprocess_image(img):
93
  clahe = apply_clahe(circ)
94
  sharp = sharpen_image(clahe)
95
  resized = resize_normalize(sharp)
96
- return resized
97
 
98
- # --- Find Last Conv Layer for Grad-CAM ---
99
- def find_last_conv_layer(model):
 
100
  for layer in reversed(model.layers):
101
- # Check if Conv2D or custom attention output layer name pattern
102
- if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
103
- return layer.name
104
- raise ValueError("No suitable conv layer found.")
105
-
106
- # Grad-CAM generation
107
- def generate_gradcam(model, img_array, class_index, layer_name):
108
- grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
 
 
109
  with tf.GradientTape() as tape:
110
- conv_outputs, predictions = grad_model(img_array)
111
- loss = predictions[:, class_index]
112
- grads = tape.gradient(loss, conv_outputs)
113
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
114
- heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
115
- heatmap = tf.squeeze(heatmap)
116
- heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
117
- return heatmap.numpy()
118
-
119
- # LIME explainer
120
- explainer = lime_image.LimeImageExplainer()
121
- def predict_fn(images):
122
- images = np.array(images)
123
- preds = model.predict(images, verbose=0)
124
- # If output is dict (due to TFSMLayer), extract predictions properly
125
- if isinstance(preds, dict):
126
- preds = list(preds.values())[0]
127
- return preds
128
-
129
- # Explanation text per class
130
- explanation_text = {
131
- 'Normal': "Model predicted Normal based on healthy optic disc and macula.",
132
- 'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
133
- 'Glaucoma': "Detected increased cupping in the optic disc indicating Glaucoma.",
134
- 'Cataract': "Image blur indicated potential Cataract.",
135
- 'AMD': "Degeneration signs in macula indicate AMD.",
136
- 'Hypertension': "Blood vessel narrowing/hemorrhages indicate Hypertension.",
137
- 'Myopia': "Tilted disc and fundus shape suggest Myopia.",
138
- 'Others': "Non-specific features detected, marked as Others."
139
- }
140
-
141
- # Visualization in Streamlit
142
- def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
143
- input_array = np.expand_dims(img, axis=0)
144
-
145
- # Grad-CAM heatmap
146
- heatmap = generate_gradcam(model, input_array, pred_idx, layer_name)
147
- heatmap = cv2.resize(heatmap, IMG_SIZE)
148
- heatmap = np.uint8(255 * heatmap)
149
- heatmap = cv2.GaussianBlur(heatmap, (7, 7), 0)
150
- heatmap_rgb = cm.jet(heatmap / 255.0)[..., :3]
151
- heatmap_rgb = np.uint8(heatmap_rgb * 255)
152
- overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
153
-
154
- # LIME explanation
155
- explanation = explainer.explain_instance(
156
- image=img, classifier_fn=predict_fn,
157
- top_labels=1, hide_color=0, num_samples=1000
158
- )
159
- temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
160
-
161
- # Plot side by side
162
- fig, axs = plt.subplots(1, 3, figsize=(15, 5))
163
- axs[0].imshow(img)
164
- axs[0].set_title(f"Original\nTrue: {true_label}", fontsize=11)
165
- axs[1].imshow(overlayed)
166
- axs[1].set_title(f"Grad-CAM\nPred: {pred_label}", fontsize=11)
167
- axs[2].imshow(mark_boundaries(temp, mask))
168
- axs[2].set_title(f"LIME\nPred: {pred_label}", fontsize=11)
169
- for ax in axs:
170
- ax.axis('off')
171
- summary = explanation_text.get(pred_label, "Model detected features matching this class.")
172
- plt.figtext(0.5, 0.01, summary, wrap=True, ha='center', fontsize=10)
173
- plt.tight_layout(rect=[0, 0.03, 1, 1])
174
- st.pyplot(fig)
175
- plt.close()
176
-
177
- # Streamlit app UI
178
- st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
179
- st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
180
 
181
  model = load_model()
182
- last_conv_layer_name = find_last_conv_layer(model)
183
 
184
- uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "jpeg", "png"])
185
  if uploaded_file:
186
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
187
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
188
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
189
 
190
- # Preprocess image
191
- processed_img = preprocess_image(rgb_img)
192
-
193
- # Predict
194
- input_tensor = np.expand_dims(processed_img, axis=0)
 
 
 
 
 
 
 
 
 
 
195
  preds = model.predict(input_tensor)
196
  if isinstance(preds, dict):
197
  preds = list(preds.values())[0]
 
198
  pred_idx = np.argmax(preds)
199
  pred_label = CLASS_NAMES[pred_idx]
200
  confidence = np.max(preds) * 100
201
 
202
- st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
 
203
 
204
- # Show Grad-CAM and LIME visualizations
205
- display_combined_visualization(processed_img, "Unknown (Uploaded)", pred_label, pred_idx, last_conv_layer_name)
 
8
  from keras.saving import register_keras_serializable
9
  from keras.layers import TFSMLayer
10
  import matplotlib.pyplot as plt
 
11
  from lime import lime_image
12
  from skimage.segmentation import mark_boundaries
13
 
 
40
  config.update({"function": self.function})
41
  return config
42
 
43
+ # --- Constants ---
44
  IMG_SIZE = (224, 224)
45
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
46
 
47
+ # --- Load SavedModel as TFSMLayer wrapped model ---
48
  @st.cache_resource
49
  def load_model():
50
+ model_path = "Model" # Your SavedModel directory path
51
  if not os.path.exists(model_path):
52
  st.error(f"❌ Model directory '{model_path}' not found!")
53
  st.stop()
54
  try:
55
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
56
+ inputs = Input(shape=(224, 224, 3))
57
  outputs = tfsm_layer(inputs)
58
  model = Model(inputs=inputs, outputs=outputs)
59
  return model
60
  except Exception as e:
61
+ st.error(f"❌ Error loading model with TFSMLayer: {str(e)}")
62
  st.stop()
63
 
64
+ # --- Preprocessing functions ---
65
  def crop_circle(img):
66
  h, w = img.shape[:2]
67
  center = (w // 2, h // 2)
 
92
  clahe = apply_clahe(circ)
93
  sharp = sharpen_image(clahe)
94
  resized = resize_normalize(sharp)
95
+ return circ, clahe, sharp, resized
96
 
97
+ # --- Grad-CAM ---
98
+ def show_gradcam(model, img, class_idx):
99
+ last_conv_layer = None
100
  for layer in reversed(model.layers):
101
+ if isinstance(layer, tf.keras.layers.Conv2D):
102
+ last_conv_layer = layer.name
103
+ break
104
+ if last_conv_layer is None:
105
+ st.warning("⚠️ No Conv2D layer found for Grad-CAM.")
106
+ return
107
+
108
+ grad_model = Model(model.inputs, [model.get_layer(last_conv_layer).output, model.output])
109
+ img_tensor = tf.convert_to_tensor(img[np.newaxis, ...])
110
+
111
  with tf.GradientTape() as tape:
112
+ conv_outputs, predictions = grad_model(img_tensor)
113
+ loss = predictions[:, class_idx]
114
+ grads = tape.gradient(loss, conv_outputs)[0]
115
+ cam = tf.reduce_mean(grads, axis=-1).numpy()
116
+
117
+ cam = np.maximum(cam, 0)
118
+ cam = cv2.resize(cam, IMG_SIZE)
119
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-10)
120
+
121
+ heatmap = np.uint8(255 * cam)
122
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
123
+ overlay = cv2.addWeighted(np.uint8(img * 255), 0.6, heatmap, 0.4, 0)
124
+
125
+ st.subheader("πŸ”₯ Grad-CAM")
126
+ st.image(overlay, use_container_width=True)
127
+
128
+ # --- LIME ---
129
+ def show_lime(model, img, class_idx):
130
+ explainer = lime_image.LimeImageExplainer()
131
+
132
+ def predict_fn(images):
133
+ images = np.array(images)
134
+ preds = model.predict(images)
135
+ if isinstance(preds, dict):
136
+ preds = list(preds.values())[0]
137
+ return preds
138
+
139
+ explanation = explainer.explain_instance(np.uint8(img * 255), predict_fn, top_labels=1, hide_color=0, num_samples=1000)
140
+ lime_img, mask = explanation.get_image_and_mask(class_idx, positive_only=True, hide_rest=False)
141
+
142
+ st.subheader("🟒 LIME Explanation")
143
+ st.image(mark_boundaries(lime_img, mask), use_container_width=True)
144
+
145
+ # --- Streamlit UI ---
146
+ st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
147
+ st.title("🧠 Retina Disease Classifier")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  model = load_model()
 
150
 
151
+ uploaded_file = st.file_uploader("πŸ“€ Upload a retinal image", type=["jpg", "jpeg", "png"])
152
  if uploaded_file:
153
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
154
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
155
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
156
 
157
+ circ, clahe, sharp, final = preprocess_image(rgb_img)
158
+
159
+ st.subheader("πŸ§ͺ Preprocessing Pipeline (Left ➝ Right)")
160
+ steps = [
161
+ ("πŸ“· Original", rgb_img),
162
+ ("πŸ”΅ Circular Crop", circ),
163
+ ("βšͺ CLAHE", clahe),
164
+ ("🟣 Sharpened", sharp),
165
+ ("πŸ“ Resized", (final * 255).astype(np.uint8))
166
+ ]
167
+ cols = st.columns(len(steps))
168
+ for col, (label, img) in zip(cols, steps):
169
+ col.image(img, caption=label, use_container_width=True)
170
+
171
+ input_tensor = np.expand_dims(final, axis=0)
172
  preds = model.predict(input_tensor)
173
  if isinstance(preds, dict):
174
  preds = list(preds.values())[0]
175
+
176
  pred_idx = np.argmax(preds)
177
  pred_label = CLASS_NAMES[pred_idx]
178
  confidence = np.max(preds) * 100
179
 
180
+ st.success(f"βœ… Prediction: **{pred_label}**")
181
+ st.info(f"πŸ” Confidence: {confidence:.2f}%")
182
 
183
+ show_gradcam(model, final, pred_idx)
184
+ show_lime(model, final, pred_idx)