Sefat33 commited on
Commit
7b260ca
Β·
verified Β·
1 Parent(s): 6d175e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -40
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- # Set Protocol Buffers implementation to Python (must be before TensorFlow imports)
3
  os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
4
 
5
  import numpy as np
@@ -10,8 +9,11 @@ from keras_cv_attention_models.coatnet import CoAtNet0
10
  from keras.layers import BatchNormalization, DepthwiseConv2D, Input, TFSMLayer
11
  from keras.models import Model
12
  from keras.saving import register_keras_serializable
 
 
 
13
 
14
- # --- Fix BatchNormalization axis bug during deserialization ---
15
  original_bn_from_config = BatchNormalization.from_config
16
  def patched_bn_from_config(cls, config, *args, **kwargs):
17
  if "axis" in config and isinstance(config["axis"], (list, tuple)):
@@ -19,7 +21,6 @@ def patched_bn_from_config(cls, config, *args, **kwargs):
19
  return original_bn_from_config(config, *args, **kwargs)
20
  BatchNormalization.from_config = classmethod(patched_bn_from_config)
21
 
22
- # --- Fix DepthwiseConv2D deserialization by removing unsupported 'groups' kwarg ---
23
  original_dwconv_from_config = DepthwiseConv2D.from_config
24
  def patched_dwconv_from_config(cls, config, *args, **kwargs):
25
  if "groups" in config:
@@ -27,12 +28,10 @@ def patched_dwconv_from_config(cls, config, *args, **kwargs):
27
  return original_dwconv_from_config(config, *args, **kwargs)
28
  DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
29
 
30
- # --- Register Functional Model for deserialization ---
31
  @register_keras_serializable(package='Custom', name='Functional')
32
  class Functional(tf.keras.models.Model):
33
  pass
34
 
35
- # --- Register a minimal stub for TFOpLambda layer ---
36
  @register_keras_serializable(package='Custom', name='TFOpLambda')
37
  class CustomTFOpLambda(tf.keras.layers.Layer):
38
  def __init__(self, name=None, trainable=False, dtype=None, function=None, **kwargs):
@@ -50,20 +49,15 @@ CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertensio
50
 
51
  @st.cache_resource
52
  def load_model():
53
- model_path = "Model" # SavedModel directory
54
-
55
  if not os.path.exists(model_path):
56
- st.error(f"❌ Model directory '{model_path}' not found! Please ensure the folder is in the app directory.")
57
  st.stop()
58
-
59
- st.info("πŸ“₯ Loading model from SavedModel directory using TFSMLayer...")
60
-
61
  try:
62
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
63
  inputs = Input(shape=(224, 224, 3))
64
  outputs = tfsm_layer(inputs)
65
  model = Model(inputs=inputs, outputs=outputs)
66
- st.success("βœ… Model loaded successfully!")
67
  return model
68
  except Exception as e:
69
  st.error(f"❌ Error loading model: {str(e)}")
@@ -76,11 +70,7 @@ def crop_circle(img):
76
  Y, X = np.ogrid[:h, :w]
77
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
78
  mask = dist <= radius
79
- if img.ndim == 3:
80
- mask = np.stack([mask]*3, axis=-1)
81
- img = img.copy()
82
- img[~mask] = 0
83
- return img
84
 
85
  def apply_clahe(img):
86
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
@@ -92,51 +82,79 @@ def apply_clahe(img):
92
 
93
  def sharpen_image(img, sigma=10):
94
  blur = cv2.GaussianBlur(img, (0,0), sigma)
95
- sharpened = cv2.addWeighted(img, 4, blur, -4, 128)
96
- return sharpened
97
 
98
  def resize_normalize(img):
99
  img = cv2.resize(img, IMG_SIZE)
100
- img = img / 255.0
101
- return img
102
 
103
  def preprocess_image(img):
104
- img = crop_circle(img)
105
- img = apply_clahe(img)
106
- img = sharpen_image(img)
107
- img = resize_normalize(img)
108
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # --- Streamlit UI ---
111
  st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
112
  st.title("🧠 Retina Disease Classifier")
113
- st.markdown("Upload a retinal image and get the predicted disease class using a CoAtNet model.")
114
 
115
  model = load_model()
116
-
117
  uploaded_file = st.file_uploader("πŸ“€ Upload a retinal image", type=["jpg", "jpeg", "png"])
118
 
119
- if uploaded_file is not None:
120
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
121
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
122
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
123
- st.image(rgb_img, caption="Original Image", use_column_width=True)
124
 
125
- preprocessed = preprocess_image(rgb_img)
126
- input_tensor = np.expand_dims(preprocessed, axis=0)
 
 
 
 
127
 
 
128
  preds = model.predict(input_tensor)
129
-
130
- # Handle dict output from TFSMLayer
131
- if isinstance(preds, dict):
132
- preds = list(preds.values())[0]
133
-
134
  pred_idx = np.argmax(preds)
135
  pred_label = CLASS_NAMES[pred_idx]
136
  confidence = np.max(preds) * 100
137
 
138
- st.success(f"βœ… **Prediction:** {pred_label}")
139
  st.info(f"πŸ” Confidence: {confidence:.2f}%")
140
 
141
- st.subheader("πŸ§ͺ Preprocessed Input to Model")
142
- st.image((preprocessed * 255).astype(np.uint8), caption="Preprocessed Image", use_column_width=True)
 
1
  import os
 
2
  os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
3
 
4
  import numpy as np
 
9
  from keras.layers import BatchNormalization, DepthwiseConv2D, Input, TFSMLayer
10
  from keras.models import Model
11
  from keras.saving import register_keras_serializable
12
+ import matplotlib.pyplot as plt
13
+ from lime import lime_image
14
+ from skimage.segmentation import mark_boundaries
15
 
16
+ # --- Fix deserialization issues ---
17
  original_bn_from_config = BatchNormalization.from_config
18
  def patched_bn_from_config(cls, config, *args, **kwargs):
19
  if "axis" in config and isinstance(config["axis"], (list, tuple)):
 
21
  return original_bn_from_config(config, *args, **kwargs)
22
  BatchNormalization.from_config = classmethod(patched_bn_from_config)
23
 
 
24
  original_dwconv_from_config = DepthwiseConv2D.from_config
25
  def patched_dwconv_from_config(cls, config, *args, **kwargs):
26
  if "groups" in config:
 
28
  return original_dwconv_from_config(config, *args, **kwargs)
29
  DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
30
 
 
31
  @register_keras_serializable(package='Custom', name='Functional')
32
  class Functional(tf.keras.models.Model):
33
  pass
34
 
 
35
  @register_keras_serializable(package='Custom', name='TFOpLambda')
36
  class CustomTFOpLambda(tf.keras.layers.Layer):
37
  def __init__(self, name=None, trainable=False, dtype=None, function=None, **kwargs):
 
49
 
50
  @st.cache_resource
51
  def load_model():
52
+ model_path = "Model"
 
53
  if not os.path.exists(model_path):
54
+ st.error(f"❌ Model directory '{model_path}' not found!")
55
  st.stop()
 
 
 
56
  try:
57
  tfsm_layer = TFSMLayer(model_path, call_endpoint="serving_default")
58
  inputs = Input(shape=(224, 224, 3))
59
  outputs = tfsm_layer(inputs)
60
  model = Model(inputs=inputs, outputs=outputs)
 
61
  return model
62
  except Exception as e:
63
  st.error(f"❌ Error loading model: {str(e)}")
 
70
  Y, X = np.ogrid[:h, :w]
71
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
72
  mask = dist <= radius
73
+ return cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
 
 
 
 
74
 
75
  def apply_clahe(img):
76
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
 
82
 
83
  def sharpen_image(img, sigma=10):
84
  blur = cv2.GaussianBlur(img, (0,0), sigma)
85
+ return cv2.addWeighted(img, 4, blur, -4, 128)
 
86
 
87
  def resize_normalize(img):
88
  img = cv2.resize(img, IMG_SIZE)
89
+ return img / 255.0
 
90
 
91
  def preprocess_image(img):
92
+ circ = crop_circle(img)
93
+ clahe = apply_clahe(circ)
94
+ sharp = sharpen_image(clahe)
95
+ resized = resize_normalize(sharp)
96
+ return circ, clahe, sharp, resized
97
+
98
+ def show_step(title, img):
99
+ st.subheader(title)
100
+ st.image(img, use_column_width=True)
101
+
102
+ def show_gradcam(model, img, class_idx):
103
+ grad_model = Model(model.inputs, [model.layers[-1].output, model.layers[-2].output])
104
+ img_tensor = tf.convert_to_tensor(img[np.newaxis, ...])
105
+ with tf.GradientTape() as tape:
106
+ conv_outputs, predictions = grad_model(img_tensor)
107
+ loss = predictions[:, class_idx]
108
+ grads = tape.gradient(loss, conv_outputs)[0]
109
+ cam = tf.reduce_mean(grads, axis=-1)
110
+ cam = tf.nn.relu(cam)
111
+ cam = cam.numpy()
112
+ cam = cv2.resize(cam, IMG_SIZE)
113
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
114
+ heatmap = np.uint8(255 * cam)
115
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
116
+ overlay = cv2.addWeighted(np.uint8(img * 255), 0.6, heatmap, 0.4, 0)
117
+ st.subheader("πŸ”₯ Grad-CAM")
118
+ st.image(overlay, use_column_width=True)
119
+
120
+ def show_lime(model, img, class_idx):
121
+ explainer = lime_image.LimeImageExplainer()
122
+ def predict_fn(images):
123
+ images = np.array(images)
124
+ return model.predict(images)
125
+ explanation = explainer.explain_instance(np.uint8(img*255), predict_fn, top_labels=1, hide_color=0, num_samples=1000)
126
+ lime_img, mask = explanation.get_image_and_mask(class_idx, positive_only=True, hide_rest=False)
127
+ st.subheader("🟒 LIME Explanation")
128
+ st.image(mark_boundaries(lime_img, mask), use_column_width=True)
129
 
130
  # --- Streamlit UI ---
131
  st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
132
  st.title("🧠 Retina Disease Classifier")
 
133
 
134
  model = load_model()
 
135
  uploaded_file = st.file_uploader("πŸ“€ Upload a retinal image", type=["jpg", "jpeg", "png"])
136
 
137
+ if uploaded_file:
138
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
139
  bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
140
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
 
141
 
142
+ show_step("πŸ“· Original Image", rgb_img)
143
+ circ, clahe, sharp, final = preprocess_image(rgb_img)
144
+ show_step("πŸ”΅ Circular Cropped", circ)
145
+ show_step("βšͺ CLAHE Applied", clahe)
146
+ show_step("🟣 Sharpened", sharp)
147
+ show_step("πŸ“ Final Resized", (final * 255).astype(np.uint8))
148
 
149
+ input_tensor = np.expand_dims(final, axis=0)
150
  preds = model.predict(input_tensor)
151
+ if isinstance(preds, dict): preds = list(preds.values())[0]
 
 
 
 
152
  pred_idx = np.argmax(preds)
153
  pred_label = CLASS_NAMES[pred_idx]
154
  confidence = np.max(preds) * 100
155
 
156
+ st.success(f"βœ… Prediction: **{pred_label}**")
157
  st.info(f"πŸ” Confidence: {confidence:.2f}%")
158
 
159
+ show_gradcam(model, final, pred_idx)
160
+ show_lime(model, final, pred_idx)