Update app.py
Browse files
app.py
CHANGED
|
@@ -3,14 +3,11 @@ import numpy as np
|
|
| 3 |
import cv2
|
| 4 |
import tensorflow as tf
|
| 5 |
import streamlit as st
|
| 6 |
-
from keras.layers import BatchNormalization, DepthwiseConv2D, Input
|
| 7 |
-
from keras.models import Model
|
| 8 |
-
from keras.saving import register_keras_serializable
|
| 9 |
-
from keras.layers import TFSMLayer
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import matplotlib.cm as cm
|
| 12 |
from lime import lime_image
|
| 13 |
from skimage.segmentation import mark_boundaries
|
|
|
|
| 14 |
|
| 15 |
# --- Fix deserialization issues ---
|
| 16 |
original_bn_from_config = BatchNormalization.from_config
|
|
@@ -27,75 +24,49 @@ def patched_dwconv_from_config(cls, config, *args, **kwargs):
|
|
| 27 |
return original_dwconv_from_config(config, *args, **kwargs)
|
| 28 |
DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
|
| 29 |
|
| 30 |
-
@register_keras_serializable(package='Custom', name='Functional')
|
| 31 |
-
class Functional(tf.keras.models.Model): pass
|
| 32 |
-
|
| 33 |
-
@register_keras_serializable(package='Custom', name='TFOpLambda')
|
| 34 |
-
class CustomTFOpLambda(tf.keras.layers.Layer):
|
| 35 |
-
def __init__(self, name=None, trainable=False, dtype=None, function=None, **kwargs):
|
| 36 |
-
super().__init__(name=name, trainable=trainable, dtype=dtype, **kwargs)
|
| 37 |
-
self.function = function
|
| 38 |
-
def call(self, inputs): return inputs
|
| 39 |
-
def get_config(self):
|
| 40 |
-
config = super().get_config()
|
| 41 |
-
config.update({"function": self.function})
|
| 42 |
-
return config
|
| 43 |
-
|
| 44 |
# --- Constants ---
|
| 45 |
IMG_SIZE = (224, 224)
|
| 46 |
CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
|
| 47 |
|
| 48 |
-
# --- Load model
|
| 49 |
@st.cache_resource
|
| 50 |
def load_model():
|
| 51 |
-
model_path = "Model" #
|
| 52 |
if not os.path.exists(model_path):
|
| 53 |
-
st.error(f"❌ Model directory '{model_path}' not found!")
|
| 54 |
st.stop()
|
| 55 |
try:
|
| 56 |
-
|
| 57 |
-
inputs = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
|
| 58 |
-
outputs = tfsm_layer(inputs)
|
| 59 |
-
model = Model(inputs=inputs, outputs=outputs)
|
| 60 |
return model
|
| 61 |
except Exception as e:
|
| 62 |
st.error(f"❌ Error loading model: {str(e)}")
|
| 63 |
st.stop()
|
| 64 |
|
| 65 |
-
# --- Preprocessing
|
| 66 |
-
def
|
|
|
|
| 67 |
h, w = img.shape[:2]
|
| 68 |
-
center = (w
|
| 69 |
radius = min(center[0], center[1])
|
| 70 |
Y, X = np.ogrid[:h, :w]
|
| 71 |
dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
|
| 72 |
mask = dist <= radius
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
|
| 77 |
l, a, b = cv2.split(lab)
|
| 78 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 79 |
cl = clahe.apply(l)
|
| 80 |
merged = cv2.merge((cl,a,b))
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def resize_normalize(img):
|
| 88 |
-
img = cv2.resize(img, IMG_SIZE)
|
| 89 |
-
return img / 255.0
|
| 90 |
-
|
| 91 |
-
def preprocess_image(img):
|
| 92 |
-
circ = crop_circle(img)
|
| 93 |
-
clahe = apply_clahe(circ)
|
| 94 |
-
sharp = sharpen_image(clahe)
|
| 95 |
-
resized = resize_normalize(sharp)
|
| 96 |
return resized
|
| 97 |
|
| 98 |
-
# --- Find last
|
| 99 |
def find_last_conv_layer(model):
|
| 100 |
for layer in reversed(model.layers):
|
| 101 |
if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
|
|
@@ -109,7 +80,7 @@ def generate_gradcam(model, img_array, class_index, layer_name):
|
|
| 109 |
conv_outputs, predictions = grad_model(img_array)
|
| 110 |
loss = predictions[:, class_index]
|
| 111 |
grads = tape.gradient(loss, conv_outputs)
|
| 112 |
-
pooled_grads = tf.reduce_mean(grads, axis=(0,
|
| 113 |
heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
|
| 114 |
heatmap = tf.squeeze(heatmap)
|
| 115 |
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
|
|
@@ -124,7 +95,7 @@ def predict_fn(images):
|
|
| 124 |
preds = list(preds.values())[0]
|
| 125 |
return preds
|
| 126 |
|
| 127 |
-
# --- Explanation
|
| 128 |
explanation_text = {
|
| 129 |
'Normal': "Model predicted Normal based on healthy optic disc and macula.",
|
| 130 |
'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
|
|
@@ -136,22 +107,23 @@ explanation_text = {
|
|
| 136 |
'Others': "Non-specific features detected, marked as Others."
|
| 137 |
}
|
| 138 |
|
| 139 |
-
# --- Visualization
|
| 140 |
def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
|
| 141 |
input_array = np.expand_dims(img, axis=0)
|
| 142 |
|
| 143 |
# Grad-CAM heatmap
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
| 155 |
|
| 156 |
# LIME explanation
|
| 157 |
explanation = explainer.explain_instance(
|
|
@@ -160,7 +132,7 @@ def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_
|
|
| 160 |
)
|
| 161 |
temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
|
| 162 |
|
| 163 |
-
# Plot
|
| 164 |
cols = 3 if overlayed is not None else 2
|
| 165 |
fig, axs = plt.subplots(1, cols, figsize=(15, 5))
|
| 166 |
axs[0].imshow(img)
|
|
@@ -183,16 +155,15 @@ def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_
|
|
| 183 |
st.pyplot(fig)
|
| 184 |
plt.close()
|
| 185 |
|
| 186 |
-
# --- Streamlit
|
| 187 |
st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
|
| 188 |
st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
|
| 189 |
|
| 190 |
model = load_model()
|
| 191 |
|
| 192 |
-
# Try find last conv layer, disable Grad-CAM if not found
|
| 193 |
try:
|
| 194 |
last_conv_layer_name = find_last_conv_layer(model)
|
| 195 |
-
except
|
| 196 |
last_conv_layer_name = None
|
| 197 |
st.warning("⚠️ No Conv2D layer found; Grad-CAM will be disabled.")
|
| 198 |
|
|
@@ -214,4 +185,4 @@ if uploaded_file:
|
|
| 214 |
|
| 215 |
st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
|
| 216 |
|
| 217 |
-
display_combined_visualization(processed_img, "
|
|
|
|
| 3 |
import cv2
|
| 4 |
import tensorflow as tf
|
| 5 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import matplotlib.pyplot as plt
|
| 7 |
import matplotlib.cm as cm
|
| 8 |
from lime import lime_image
|
| 9 |
from skimage.segmentation import mark_boundaries
|
| 10 |
+
from keras.layers import BatchNormalization, DepthwiseConv2D
|
| 11 |
|
| 12 |
# --- Fix deserialization issues ---
|
| 13 |
original_bn_from_config = BatchNormalization.from_config
|
|
|
|
| 24 |
return original_dwconv_from_config(config, *args, **kwargs)
|
| 25 |
DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# --- Constants ---
|
| 28 |
IMG_SIZE = (224, 224)
|
| 29 |
CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
|
| 30 |
|
| 31 |
+
# --- Load model function ---
|
| 32 |
@st.cache_resource
|
| 33 |
def load_model():
|
| 34 |
+
model_path = "Model" # adjust path to your model folder or file
|
| 35 |
if not os.path.exists(model_path):
|
| 36 |
+
st.error(f"❌ Model directory or file '{model_path}' not found!")
|
| 37 |
st.stop()
|
| 38 |
try:
|
| 39 |
+
model = tf.keras.models.load_model(model_path)
|
|
|
|
|
|
|
|
|
|
| 40 |
return model
|
| 41 |
except Exception as e:
|
| 42 |
st.error(f"❌ Error loading model: {str(e)}")
|
| 43 |
st.stop()
|
| 44 |
|
| 45 |
+
# --- Preprocessing ---
|
| 46 |
+
def preprocess_image(img):
|
| 47 |
+
# Crop circular mask
|
| 48 |
h, w = img.shape[:2]
|
| 49 |
+
center = (w//2, h//2)
|
| 50 |
radius = min(center[0], center[1])
|
| 51 |
Y, X = np.ogrid[:h, :w]
|
| 52 |
dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
|
| 53 |
mask = dist <= radius
|
| 54 |
+
circ = cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
|
| 55 |
+
# CLAHE
|
| 56 |
+
lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
|
|
|
|
| 57 |
l, a, b = cv2.split(lab)
|
| 58 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 59 |
cl = clahe.apply(l)
|
| 60 |
merged = cv2.merge((cl,a,b))
|
| 61 |
+
clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
|
| 62 |
+
# Sharpen
|
| 63 |
+
blur = cv2.GaussianBlur(clahe_img, (0,0), 10)
|
| 64 |
+
sharp = cv2.addWeighted(clahe_img, 4, blur, -4, 128)
|
| 65 |
+
# Resize + normalize
|
| 66 |
+
resized = cv2.resize(sharp, IMG_SIZE) / 255.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return resized
|
| 68 |
|
| 69 |
+
# --- Find last Conv2D or MHSA output layer ---
|
| 70 |
def find_last_conv_layer(model):
|
| 71 |
for layer in reversed(model.layers):
|
| 72 |
if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
|
|
|
|
| 80 |
conv_outputs, predictions = grad_model(img_array)
|
| 81 |
loss = predictions[:, class_index]
|
| 82 |
grads = tape.gradient(loss, conv_outputs)
|
| 83 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
|
| 84 |
heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
|
| 85 |
heatmap = tf.squeeze(heatmap)
|
| 86 |
heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-10)
|
|
|
|
| 95 |
preds = list(preds.values())[0]
|
| 96 |
return preds
|
| 97 |
|
| 98 |
+
# --- Explanation texts ---
|
| 99 |
explanation_text = {
|
| 100 |
'Normal': "Model predicted Normal based on healthy optic disc and macula.",
|
| 101 |
'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
|
|
|
|
| 107 |
'Others': "Non-specific features detected, marked as Others."
|
| 108 |
}
|
| 109 |
|
| 110 |
+
# --- Visualization ---
|
| 111 |
def display_combined_visualization(img, true_label, pred_label, pred_idx, layer_name):
|
| 112 |
input_array = np.expand_dims(img, axis=0)
|
| 113 |
|
| 114 |
# Grad-CAM heatmap
|
| 115 |
+
overlayed = None
|
| 116 |
+
if layer_name is not None:
|
| 117 |
+
try:
|
| 118 |
+
heatmap = generate_gradcam(model, input_array, pred_idx, layer_name)
|
| 119 |
+
heatmap = cv2.resize(heatmap, IMG_SIZE)
|
| 120 |
+
heatmap = np.uint8(255 * heatmap)
|
| 121 |
+
heatmap = cv2.GaussianBlur(heatmap, (7, 7), 0)
|
| 122 |
+
heatmap_rgb = cm.jet(heatmap / 255.0)[..., :3]
|
| 123 |
+
heatmap_rgb = np.uint8(heatmap_rgb * 255)
|
| 124 |
+
overlayed = cv2.addWeighted(np.uint8(img * 255), 0.5, heatmap_rgb, 0.5, 0)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
st.warning(f"⚠️ Grad-CAM generation failed: {e}")
|
| 127 |
|
| 128 |
# LIME explanation
|
| 129 |
explanation = explainer.explain_instance(
|
|
|
|
| 132 |
)
|
| 133 |
temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
|
| 134 |
|
| 135 |
+
# Plot
|
| 136 |
cols = 3 if overlayed is not None else 2
|
| 137 |
fig, axs = plt.subplots(1, cols, figsize=(15, 5))
|
| 138 |
axs[0].imshow(img)
|
|
|
|
| 155 |
st.pyplot(fig)
|
| 156 |
plt.close()
|
| 157 |
|
| 158 |
+
# --- Streamlit App ---
|
| 159 |
st.set_page_config(page_title="🧠 Retina Disease Classifier with Grad-CAM & LIME", layout="centered")
|
| 160 |
st.title("🧠 Retina Disease Classifier with Grad-CAM & LIME")
|
| 161 |
|
| 162 |
model = load_model()
|
| 163 |
|
|
|
|
| 164 |
try:
|
| 165 |
last_conv_layer_name = find_last_conv_layer(model)
|
| 166 |
+
except Exception:
|
| 167 |
last_conv_layer_name = None
|
| 168 |
st.warning("⚠️ No Conv2D layer found; Grad-CAM will be disabled.")
|
| 169 |
|
|
|
|
| 185 |
|
| 186 |
st.success(f"Prediction: **{pred_label}** with confidence {confidence:.2f}%")
|
| 187 |
|
| 188 |
+
display_combined_visualization(processed_img, "Uploaded Image", pred_label, pred_idx, last_conv_layer_name)
|