Sefat33 commited on
Commit
ee4901c
·
verified ·
1 Parent(s): 99ce1ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -92
app.py CHANGED
@@ -9,137 +9,134 @@ from skimage.segmentation import mark_boundaries
9
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
10
 
11
  # --- Fix deserialization issues ---
12
- BatchNormalization.from_config = classmethod(lambda cls, config, *args, **kwargs: BatchNormalization.from_config.__func__(cls, {**config, "axis": config["axis"][0] if isinstance(config["axis"], (list, tuple)) else config["axis"]}))
13
- DepthwiseConv2D.from_config = classmethod(lambda cls, config, *args, **kwargs: DepthwiseConv2D.from_config.__func__(cls, {k: v for k, v in config.items() if k != "groups"}))
 
 
 
14
 
15
  # --- Constants ---
16
  IMG_SIZE = (224, 224)
17
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
 
18
 
 
19
  @st.cache_resource
20
  def load_model():
21
  model_path = "Model"
22
- model = tf.keras.Sequential([
23
- TFSMLayer(model_path, call_endpoint="serving_default")
24
- ])
25
- return model
26
-
27
- def preprocess_image(img):
 
 
 
 
 
 
28
  h, w = img.shape[:2]
29
- center = (w // 2, h // 2)
30
- radius = min(center[0], center[1])
31
  Y, X = np.ogrid[:h, :w]
32
  dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
33
  mask = dist <= radius
34
  circ = cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
35
 
36
  lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
37
- l, a, b = cv2.split(lab)
38
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
39
- cl = clahe.apply(l)
40
- merged = cv2.merge((cl, a, b))
41
  clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
42
 
43
- blur = cv2.GaussianBlur(clahe_img, (0, 0), 10)
44
- sharp = cv2.addWeighted(clahe_img, 4, blur, -4, 128)
45
-
46
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
47
- return resized, [img, circ, clahe_img, resized]
48
-
49
- explanation_text = {
50
- 'Normal': "Model predicted Normal based on healthy optic disc and macula.",
51
- 'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
52
- 'Glaucoma': "Detected increased cupping in the optic disc indicating Glaucoma.",
53
- 'Cataract': "Image blur indicated potential Cataract.",
54
- 'AMD': "Degeneration signs in macula indicate AMD.",
55
- 'Hypertension': "Blood vessel narrowing/hemorrhages indicate Hypertension.",
56
- 'Myopia': "Tilted disc and fundus shape suggest Myopia.",
57
- 'Others': "Non-specific features detected, marked as Others."
58
- }
59
-
60
- explainer = lime_image.LimeImageExplainer()
61
-
62
- def predict_fn(images):
63
- preds = model.predict(np.array(images))
64
- return list(preds.values())[0] if isinstance(preds, dict) else preds
65
 
66
- def show_preprocessing_steps(stages):
67
- titles = ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]
68
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
69
- for i, img in enumerate(stages):
70
- axs[i].imshow(img)
71
- axs[i].set_title(titles[i])
72
- axs[i].axis('off')
 
73
  st.pyplot(fig)
74
- plt.close()
75
 
76
- def show_lime_explanation(img, pred_idx, pred_label):
77
- with st.spinner("🟡 LIME Explanation is Loading..."):
78
- explanation = explainer.explain_instance(
 
 
 
 
 
 
 
79
  image=img,
80
- classifier_fn=predict_fn,
81
  top_labels=1,
82
  hide_color=0,
83
  num_samples=1000
84
  )
85
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
86
- fig, ax = plt.subplots(figsize=(6, 5))
 
87
  ax.imshow(mark_boundaries(temp, mask))
88
- ax.axis('off')
89
- ax.set_title(f"LIME: {pred_label}")
90
  st.pyplot(fig)
91
- plt.close()
92
 
93
  # --- Streamlit UI ---
94
  st.set_page_config(page_title="🧠 Retina Classifier - Multi Image LIME", layout="wide")
95
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
96
 
97
  model = load_model()
98
- uploaded_files = st.file_uploader("Upload one or more retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if uploaded_files:
101
- filenames = [file.name for file in uploaded_files]
102
- selected_file = st.selectbox("Select image to analyze:", filenames)
103
-
104
- # Show individual image analysis
105
- for file in uploaded_files:
106
- if file.name == selected_file:
107
- file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
108
- img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
109
- rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
- processed, stages = preprocess_image(rgb)
111
- show_preprocessing_steps(stages)
112
-
113
- input_tensor = np.expand_dims(processed, axis=0)
114
- preds = predict_fn(input_tensor)
115
- pred_idx = np.argmax(preds)
116
- pred_label = CLASS_NAMES[pred_idx]
117
- confidence = np.max(preds) * 100
118
- st.success(f"✅ Prediction: **{pred_label}** ({confidence:.2f}%)")
119
- show_lime_explanation(processed, pred_idx, pred_label)
120
- break
121
-
122
- st.markdown("---")
123
- st.subheader("📊 LIME Explanations for All Uploaded Images")
124
- cols = st.columns(len(uploaded_files))
125
  for i, file in enumerate(uploaded_files):
126
- file.seek(0)
127
- file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
128
- img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
129
- rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
130
- processed, _ = preprocess_image(rgb)
131
- input_tensor = np.expand_dims(processed, axis=0)
132
- preds = predict_fn(input_tensor)
133
  pred_idx = np.argmax(preds)
134
  pred_label = CLASS_NAMES[pred_idx]
135
 
136
- explanation = explainer.explain_instance(
137
- image=processed,
138
- classifier_fn=predict_fn,
139
- top_labels=1,
140
- hide_color=0,
141
- num_samples=1000
142
- )
143
- temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
144
- with cols[i]:
145
- st.image(mark_boundaries(temp, mask), caption=f"{file.name}\n({pred_label})", use_column_width=True)
 
 
9
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
10
 
11
  # --- Fix deserialization issues ---
12
+ original_bn = BatchNormalization.from_config
13
+ BatchNormalization.from_config = classmethod(lambda cls, config, *a, **k: original_bn(config if not isinstance(config.get("axis"), list) else {**config, "axis": config["axis"][0]}, *a, **k))
14
+
15
+ original_dw = DepthwiseConv2D.from_config
16
+ DepthwiseConv2D.from_config = classmethod(lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k))
17
 
18
  # --- Constants ---
19
  IMG_SIZE = (224, 224)
20
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
21
+ LIME_EXPLAINER = lime_image.LimeImageExplainer()
22
 
23
+ # --- Load model from TFSMLayer ---
24
  @st.cache_resource
25
  def load_model():
26
  model_path = "Model"
27
+ if not os.path.exists(model_path):
28
+ st.error(f"Model folder '{model_path}' not found.")
29
+ st.stop()
30
+ try:
31
+ model = tf.keras.Sequential([TFSMLayer(model_path, call_endpoint="serving_default")])
32
+ return model
33
+ except Exception as e:
34
+ st.error(f"Error loading model: {e}")
35
+ st.stop()
36
+
37
+ # --- Preprocessing with Visualization ---
38
+ def preprocess_with_steps(img):
39
  h, w = img.shape[:2]
40
+ center, radius = (w // 2, h // 2), min(w, h) // 2
 
41
  Y, X = np.ogrid[:h, :w]
42
  dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
43
  mask = dist <= radius
44
  circ = cv2.bitwise_and(img, img, mask=mask.astype(np.uint8))
45
 
46
  lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
47
+ cl = cv2.createCLAHE(clipLimit=2.0).apply(lab[:, :, 0])
48
+ merged = cv2.merge((cl, lab[:, :, 1], lab[:, :, 2]))
 
 
49
  clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
50
 
51
+ sharp = cv2.addWeighted(clahe_img, 4, cv2.GaussianBlur(clahe_img, (0, 0), 10), -4, 128)
 
 
52
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
54
  fig, axs = plt.subplots(1, 4, figsize=(20, 5))
55
+ for ax, image, title in zip(axs, [img, circ, clahe_img, resized],
56
+ ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]):
57
+ ax.imshow(image)
58
+ ax.set_title(title)
59
+ ax.axis("off")
60
  st.pyplot(fig)
61
+ return resized
62
 
63
+ # --- Prediction Function ---
64
+ def predict(images, model):
65
+ images = np.array(images)
66
+ preds = model.predict(images, verbose=0)
67
+ return list(preds.values())[0] if isinstance(preds, dict) else preds
68
+
69
+ # --- LIME Visualization ---
70
+ def show_lime(img, model, pred_idx, pred_label):
71
+ with st.spinner("🟡 LIME explanation is loading..."):
72
+ explanation = LIME_EXPLAINER.explain_instance(
73
  image=img,
74
+ classifier_fn=lambda imgs: predict(imgs, model),
75
  top_labels=1,
76
  hide_color=0,
77
  num_samples=1000
78
  )
79
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
80
+
81
+ fig, ax = plt.subplots(1, 1, figsize=(6, 5))
82
  ax.imshow(mark_boundaries(temp, mask))
83
+ ax.set_title(f"LIME Explanation: {pred_label}")
84
+ ax.axis("off")
85
  st.pyplot(fig)
 
86
 
87
  # --- Streamlit UI ---
88
  st.set_page_config(page_title="🧠 Retina Classifier - Multi Image LIME", layout="wide")
89
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
90
 
91
  model = load_model()
 
92
 
93
+ with st.sidebar:
94
+ uploaded_files = st.file_uploader("📂 Upload retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
95
+ selected_filename = None
96
+ if uploaded_files:
97
+ filenames = [f.name for f in uploaded_files]
98
+ selected_filename = st.selectbox("🎯 Select an image to explain", filenames)
99
+
100
+ # -- Predict & Display for Selected Image --
101
+ if uploaded_files and selected_filename:
102
+ file = next(f for f in uploaded_files if f.name == selected_filename)
103
+ bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
104
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
105
+
106
+ st.subheader("🔍 Preprocessing Steps")
107
+ preprocessed = preprocess_with_steps(rgb)
108
+ input_tensor = np.expand_dims(preprocessed, axis=0)
109
+
110
+ preds = predict(input_tensor, model)
111
+ pred_idx = np.argmax(preds)
112
+ pred_label = CLASS_NAMES[pred_idx]
113
+ confidence = np.max(preds) * 100
114
+
115
+ st.success(f"✅ Prediction: **{pred_label}** ({confidence:.2f}%)")
116
+ show_lime(preprocessed, model, pred_idx, pred_label)
117
+
118
+ # -- Show LIME for all images --
119
  if uploaded_files:
120
+ st.markdown("## 🧪 LIME Explanations for All Images")
121
+ cols = st.columns(min(4, len(uploaded_files)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  for i, file in enumerate(uploaded_files):
123
+ bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
124
+ rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
125
+ img = cv2.resize(rgb, IMG_SIZE) / 255.0
126
+ input_tensor = np.expand_dims(img, axis=0)
127
+
128
+ preds = predict(input_tensor, model)
 
129
  pred_idx = np.argmax(preds)
130
  pred_label = CLASS_NAMES[pred_idx]
131
 
132
+ with cols[i % len(cols)]:
133
+ st.markdown(f"**{file.name}**<br>🧠 *{pred_label}*", unsafe_allow_html=True)
134
+ explanation = LIME_EXPLAINER.explain_instance(
135
+ image=img,
136
+ classifier_fn=lambda imgs: predict(imgs, model),
137
+ top_labels=1,
138
+ hide_color=0,
139
+ num_samples=1000
140
+ )
141
+ temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
142
+ st.image(mark_boundaries(temp, mask), use_column_width=True)