Sefat33 commited on
Commit
6dce6d5
·
verified ·
1 Parent(s): 0f46f4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -67
app.py CHANGED
@@ -3,14 +3,11 @@ import numpy as np
3
  import cv2
4
  import tensorflow as tf
5
  import streamlit as st
6
- from keras.layers import BatchNormalization, DepthwiseConv2D, Input
7
- from keras.models import Model
8
- from keras.saving import register_keras_serializable
9
- from keras.layers import TFSMLayer
10
  import matplotlib.pyplot as plt
11
  import matplotlib.cm as cm
12
  from lime import lime_image
13
  from skimage.segmentation import mark_boundaries
 
14
 
15
  # --- Fix deserialization issues ---
16
  original_bn_from_config = BatchNormalization.from_config
@@ -27,75 +24,49 @@ def patched_dwconv_from_config(cls, config, *args, **kwargs):
27
  return original_dwconv_from_config(config, *args, **kwargs)
28
  DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
29
 
30
- @register_keras_serializable(package='Custom', name='Functional')
31
- class Functional(tf.keras.models.Model): pass
32
-
33
- @register_keras_serializable(package='Custom', name='TFOpLambda')
34
- class CustomTFOpLambda(tf.keras.layers.Layer):
35
- def __init__(self, name=None, trainable=False, dtype=None, function=None, **kwargs):
36
- super().__init__(name=name, trainable=trainable, dtype=dtype, **kwargs)
37
- self.function = function
38
- def call(self, inputs): return inputs
39
- def get_config(self):
40
- config = super().get_config()
41
- config.update({"function": self.function})
42
- return config
43
-
44
  # --- Constants ---
45
  IMG_SIZE = (224, 224)
46
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
47
 
48
- # --- Load model with TFSMLayer ---
49
  @st.cache_resource
50
  def load_model():
51
- model_path = "Model" # Your SavedModel directory path
52
  if not os.path.exists(model_path):
53
- st.error(f"❌ Model directory '{model_path}' not found!")
54
  st.stop()
55
  try:
56
- tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
57
- inputs = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
58
- outputs = tfsm_layer(inputs)
59
- model = Model(inputs=inputs, outputs=outputs)
60
  return model
61
  except Exception as e:
62
  st.error(f"❌ Error loading model: {str(e)}")
63
  st.stop()
64
 
65
- # --- Preprocessing functions ---
66
- def crop_circle(img):
 
67
  h, w = img.shape[:2]
68
- center = (w // 2, h // 2)
69
  radius = min(center[0], center[1])
70
  Y, X = np.ogrid[:h, :w]
71
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
72
  mask = dist <= radius
73
- return cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
74
-
75
- def apply_clahe(img):
76
- lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
77
  l, a, b = cv2.split(lab)
78
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
79
  cl = clahe.apply(l)
80
  merged = cv2.merge((cl,a,b))
81
- return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
82
-
83
- def sharpen_image(img, sigma=10):
84
- blur = cv2.GaussianBlur(img, (0,0), sigma)
85
- return cv2.addWeighted(img, 4, blur, -4, 128)
86
-
87
- def resize_normalize(img):
88
- img = cv2.resize(img, IMG_SIZE)
89
- return img / 255.0
90
-
91
- def preprocess_image(img):
92
- circ = crop_circle(img)
93
- clahe = apply_clahe(circ)
94
- sharp = sharpen_image(clahe)
95
- resized = resize_normalize(sharp)
96
  return resized
97
 
98
- # --- Find last Conv layer for Grad-CAM ---
99
  def find_last_conv_layer(model):
100
  for layer in reversed(model.layers):
101
  if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
@@ -109,7 +80,7 @@ def generate_gradcam(model, img_array, class_index, layer_name):
109
  conv_outputs, predictions = grad_model(img_array)
110
  loss = predictions[:, class_index]
111
  grads = tape.gradient(loss, conv_outputs)
112
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
113
  heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
114
  heatmap = tf.squeeze(heatmap)
115
  heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
@@ -124,7 +95,7 @@ def predict_fn(images):
124
  preds = list(preds.values())[0]
125
  return preds
126
 
127
- # --- Explanation text ---
128
  explanation_text = {
129
  'Normal': "Model predicted Normal based on healthy optic disc and macula.",
130
  'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
@@ -136,22 +107,23 @@ explanation_text = {
136
  'Others': "Non-specific features detected, marked as Others."
137
  }
138
 
139
- # --- Visualization function ---
140
  def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
141
  input_array = np.expand_dims(img, axis=0)
142
 
143
  # Grad-CAM heatmap
144
- try:
145
- heatmap = generate_gradcam(model, input_array, pred_idx, layer_name)
146
- heatmap = cv2.resize(heatmap, IMG_SIZE)
147
- heatmap = np.uint8(255 * heatmap)
148
- heatmap = cv2.GaussianBlur(heatmap, (7, 7), 0)
149
- heatmap_rgb = cm.jet(heatmap / 255.0)[..., :3]
150
- heatmap_rgb = np.uint8(heatmap_rgb * 255)
151
- overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
152
- except Exception as e:
153
- overlayed = None
154
- st.warning(f"⚠️ Grad-CAM generation failed: {e}")
 
155
 
156
  # LIME explanation
157
  explanation = explainer.explain_instance(
@@ -160,7 +132,7 @@ def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_
160
  )
161
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
162
 
163
- # Plot side-by-side
164
  cols = 3 if overlayed is not None else 2
165
  fig, axs = plt.subplots(1, cols, figsize=(15, 5))
166
  axs[0].imshow(img)
@@ -183,16 +155,15 @@ def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_
183
  st.pyplot(fig)
184
  plt.close()
185
 
186
- # --- Streamlit UI ---
187
  st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
188
  st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
189
 
190
  model = load_model()
191
 
192
- # Try find last conv layer, disable Grad-CAM if not found
193
  try:
194
  last_conv_layer_name = find_last_conv_layer(model)
195
- except ValueError:
196
  last_conv_layer_name = None
197
  st.warning("⚠️ No Conv2D layer found; Grad-CAM will be disabled.")
198
 
@@ -214,4 +185,4 @@ if uploaded_file:
214
 
215
  st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
216
 
217
- display_combined_visualization(processed_img, "Unknown (Uploaded)", pred_label, pred_idx, last_conv_layer_name)
 
3
  import cv2
4
  import tensorflow as tf
5
  import streamlit as st
 
 
 
 
6
  import matplotlib.pyplot as plt
7
  import matplotlib.cm as cm
8
  from lime import lime_image
9
  from skimage.segmentation import mark_boundaries
10
+ from keras.layers import BatchNormalization, DepthwiseConv2D
11
 
12
  # --- Fix deserialization issues ---
13
  original_bn_from_config = BatchNormalization.from_config
 
24
  return original_dwconv_from_config(config, *args, **kwargs)
25
  DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # --- Constants ---
28
  IMG_SIZE = (224, 224)
29
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
30
 
31
+ # --- Load model function ---
32
  @st.cache_resource
33
  def load_model():
34
+ model_path = "Model" # adjust path to your model folder or file
35
  if not os.path.exists(model_path):
36
+ st.error(f"❌ Model directory or file '{model_path}' not found!")
37
  st.stop()
38
  try:
39
+ model = tf.keras.models.load_model(model_path)
 
 
 
40
  return model
41
  except Exception as e:
42
  st.error(f"❌ Error loading model: {str(e)}")
43
  st.stop()
44
 
45
+ # --- Preprocessing ---
46
+ def preprocess_image(img):
47
+ # Crop circular mask
48
  h, w = img.shape[:2]
49
+ center = (w//2, h//2)
50
  radius = min(center[0], center[1])
51
  Y, X = np.ogrid[:h, :w]
52
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
53
  mask = dist <= radius
54
+ circ = cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
55
+ # CLAHE
56
+ lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
 
57
  l, a, b = cv2.split(lab)
58
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
59
  cl = clahe.apply(l)
60
  merged = cv2.merge((cl,a,b))
61
+ clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
62
+ # Sharpen
63
+ blur = cv2.GaussianBlur(clahe_img, (0,0), 10)
64
+ sharp = cv2.addWeighted(clahe_img, 4, blur, -4, 128)
65
+ # Resize + normalize
66
+ resized = cv2.resize(sharp, IMG_SIZE) / 255.0
 
 
 
 
 
 
 
 
 
67
  return resized
68
 
69
+ # --- Find last Conv2D or MHSA output layer ---
70
  def find_last_conv_layer(model):
71
  for layer in reversed(model.layers):
72
  if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
 
80
  conv_outputs, predictions = grad_model(img_array)
81
  loss = predictions[:, class_index]
82
  grads = tape.gradient(loss, conv_outputs)
83
+ pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
84
  heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
85
  heatmap = tf.squeeze(heatmap)
86
  heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
 
95
  preds = list(preds.values())[0]
96
  return preds
97
 
98
+ # --- Explanation texts ---
99
  explanation_text = {
100
  'Normal': "Model predicted Normal based on healthy optic disc and macula.",
101
  'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
 
107
  'Others': "Non-specific features detected, marked as Others."
108
  }
109
 
110
+ # --- Visualization ---
111
  def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
112
  input_array = np.expand_dims(img, axis=0)
113
 
114
  # Grad-CAM heatmap
115
+ overlayed = None
116
+ if layer_name is not None:
117
+ try:
118
+ heatmap = generate_gradcam(model, input_array, pred_idx, layer_name)
119
+ heatmap = cv2.resize(heatmap, IMG_SIZE)
120
+ heatmap = np.uint8(255 * heatmap)
121
+ heatmap = cv2.GaussianBlur(heatmap, (7, 7), 0)
122
+ heatmap_rgb = cm.jet(heatmap / 255.0)[..., :3]
123
+ heatmap_rgb = np.uint8(heatmap_rgb * 255)
124
+ overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
125
+ except Exception as e:
126
+ st.warning(f"⚠️ Grad-CAM generation failed: {e}")
127
 
128
  # LIME explanation
129
  explanation = explainer.explain_instance(
 
132
  )
133
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
134
 
135
+ # Plot
136
  cols = 3 if overlayed is not None else 2
137
  fig, axs = plt.subplots(1, cols, figsize=(15, 5))
138
  axs[0].imshow(img)
 
155
  st.pyplot(fig)
156
  plt.close()
157
 
158
+ # --- Streamlit App ---
159
  st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
160
  st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
161
 
162
  model = load_model()
163
 
 
164
  try:
165
  last_conv_layer_name = find_last_conv_layer(model)
166
+ except Exception:
167
  last_conv_layer_name = None
168
  st.warning("⚠️ No Conv2D layer found; Grad-CAM will be disabled.")
169
 
 
185
 
186
  st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
187
 
188
+ display_combined_visualization(processed_img, "Uploaded Image", pred_label, pred_idx, last_conv_layer_name)