Sefat33 commited on
Commit
1e07c8e
Β·
verified Β·
1 Parent(s): 5839dfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -64
app.py CHANGED
@@ -28,10 +28,10 @@ DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
28
  IMG_SIZE = (224, 224)
29
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
30
 
31
- # --- Load model using TFSMLayer ---
32
  @st.cache_resource
33
  def load_model():
34
- model_path = "Model" # Folder path to TF SavedModel
35
  if not os.path.exists(model_path):
36
  st.error(f"❌ Model folder '{model_path}' not found!")
37
  st.stop()
@@ -44,8 +44,8 @@ def load_model():
44
  st.error(f"❌ Error loading model: {str(e)}")
45
  st.stop()
46
 
47
- # --- Preprocessing with visualization ---
48
- def preprocess_and_show_steps(img):
49
  h, w = img.shape[:2]
50
  center = (w // 2, h // 2)
51
  radius = min(center[0], center[1])
@@ -65,26 +65,11 @@ def preprocess_and_show_steps(img):
65
  sharp = cv2.addWeighted(clahe_img, 4, blur, -4, 128)
66
 
67
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
 
68
 
69
- # Show preprocessing stages
70
- fig, axs = plt.subplots(1, 4, figsize=(20, 5))
71
- axs[0].imshow(img)
72
- axs[0].set_title("Original Image")
73
- axs[1].imshow(circ)
74
- axs[1].set_title("After Circular Crop")
75
- axs[2].imshow(clahe_img)
76
- axs[2].set_title("After CLAHE")
77
- axs[3].imshow(resized)
78
- axs[3].set_title("Sharpen + Resize")
79
- for ax in axs:
80
- ax.axis("off")
81
- st.pyplot(fig)
82
- plt.close()
83
-
84
- return resized
85
-
86
- # --- LIME Explainer ---
87
  explainer = lime_image.LimeImageExplainer()
 
88
  def predict_fn(images):
89
  images = np.array(images)
90
  preds = model.predict(images, verbose=0)
@@ -92,66 +77,82 @@ def predict_fn(images):
92
  preds = list(preds.values())[0]
93
  return preds
94
 
95
- # --- Explanation texts ---
96
  explanation_text = {
97
- 'Normal': "Model predicted Normal based on healthy optic disc and macula.",
98
- 'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
99
- 'Glaucoma': "Detected increased cupping in the optic disc indicating Glaucoma.",
100
- 'Cataract': "Image blur indicated potential Cataract.",
101
- 'AMD': "Degeneration signs in macula indicate AMD.",
102
- 'Hypertension': "Blood vessel narrowing/hemorrhages indicate Hypertension.",
103
- 'Myopia': "Tilted disc and fundus shape suggest Myopia.",
104
- 'Others': "Non-specific features detected, marked as Others."
105
  }
106
 
107
- # --- Display LIME only ---
108
- def display_lime_visualization(img, true_label, pred_label, pred_idx):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  with st.spinner("🟑 LIME Explanation is Loading..."):
110
  explanation = explainer.explain_instance(
111
- image=img,
112
  classifier_fn=predict_fn,
113
  top_labels=1,
114
  hide_color=0,
115
  num_samples=1000
116
  )
117
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
118
-
119
  fig, axs = plt.subplots(1, 2, figsize=(12, 5))
120
- axs[0].imshow(img)
121
- axs[0].set_title(f"Original\nTrue: {true_label}")
122
  axs[1].imshow(mark_boundaries(temp, mask))
123
- axs[1].set_title(f"LIME Explanation\nPred: {pred_label}")
124
  for ax in axs:
125
- ax.axis('off')
126
-
127
- summary = explanation_text.get(pred_label, "Model detected features matching this class.")
128
- plt.figtext(0.5, 0.01, summary, wrap=True, ha='center', fontsize=10)
129
- plt.tight_layout(rect=[0, 0.03, 1, 1])
130
  st.pyplot(fig)
131
  plt.close()
132
 
133
- # --- Streamlit UI ---
134
- st.set_page_config(page_title="🧠 Retina Disease Classifier with LIME", layout="centered")
135
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
136
 
137
  model = load_model()
138
 
139
- uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "jpeg", "png"])
140
- if uploaded_file:
141
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
142
- bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
143
- rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
144
-
145
- processed_img = preprocess_and_show_steps(rgb_img)
146
-
147
- input_tensor = np.expand_dims(processed_img, axis=0)
148
- preds = model.predict(input_tensor)
149
- if isinstance(preds, dict):
150
- preds = list(preds.values())[0]
151
- pred_idx = np.argmax(preds)
152
- pred_label = CLASS_NAMES[pred_idx]
153
- confidence = np.max(preds) * 100
154
-
155
- st.success(f"βœ… Prediction: **{pred_label}** with confidence **{confidence:.2f}%**")
156
-
157
- display_lime_visualization(processed_img, "Uploaded Image", pred_label, pred_idx)
 
 
 
 
 
28
  IMG_SIZE = (224, 224)
29
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
30
 
31
+ # --- Load model ---
32
  @st.cache_resource
33
  def load_model():
34
+ model_path = "Model"
35
  if not os.path.exists(model_path):
36
  st.error(f"❌ Model folder '{model_path}' not found!")
37
  st.stop()
 
44
  st.error(f"❌ Error loading model: {str(e)}")
45
  st.stop()
46
 
47
+ # --- Preprocessing function ---
48
+ def preprocess_image(img):
49
  h, w = img.shape[:2]
50
  center = (w // 2, h // 2)
51
  radius = min(center[0], center[1])
 
65
  sharp = cv2.addWeighted(clahe_img, 4, blur, -4, 128)
66
 
67
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
68
+ return resized, [img, circ, clahe_img, resized]
69
 
70
+ # --- LIME explainer ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  explainer = lime_image.LimeImageExplainer()
72
+
73
  def predict_fn(images):
74
  images = np.array(images)
75
  preds = model.predict(images, verbose=0)
 
77
  preds = list(preds.values())[0]
78
  return preds
79
 
80
+ # --- Explanation text ---
81
  explanation_text = {
82
+ 'Normal': "Healthy optic disc and macula.",
83
+ 'Diabetes': "Retinal vessel changes suggest Diabetes.",
84
+ 'Glaucoma': "Optic disc cupping detected.",
85
+ 'Cataract': "Blurring suggests Cataract.",
86
+ 'AMD': "Degeneration signs in macula.",
87
+ 'Hypertension': "Hemorrhages suggest Hypertension.",
88
+ 'Myopia': "Fundus tilt suggests Myopia.",
89
+ 'Others': "Non-specific features detected."
90
  }
91
 
92
+ # --- Display results ---
93
+ def display_all_results(image_name, orig_img, processed_img, stages, pred_label, confidence, pred_idx):
94
+ st.header(f"πŸ–ΌοΈ Image: `{image_name}`")
95
+
96
+ # Show preprocessing
97
+ st.subheader("πŸ” Preprocessing Steps")
98
+ fig, axs = plt.subplots(1, 4, figsize=(20, 5))
99
+ titles = ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]
100
+ for i, (img, title) in enumerate(zip(stages, titles)):
101
+ axs[i].imshow(img)
102
+ axs[i].set_title(title)
103
+ axs[i].axis('off')
104
+ st.pyplot(fig)
105
+ plt.close()
106
+
107
+ st.success(f"βœ… Prediction: **{pred_label}** with confidence **{confidence:.2f}%**")
108
+
109
+ # LIME
110
  with st.spinner("🟑 LIME Explanation is Loading..."):
111
  explanation = explainer.explain_instance(
112
+ image=processed_img,
113
  classifier_fn=predict_fn,
114
  top_labels=1,
115
  hide_color=0,
116
  num_samples=1000
117
  )
118
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
 
119
  fig, axs = plt.subplots(1, 2, figsize=(12, 5))
120
+ axs[0].imshow(processed_img)
121
+ axs[0].set_title("Processed Image")
122
  axs[1].imshow(mark_boundaries(temp, mask))
123
+ axs[1].set_title("LIME Explanation")
124
  for ax in axs:
125
+ ax.axis("off")
126
+ plt.figtext(0.5, 0.01, explanation_text.get(pred_label, ""), ha="center", fontsize=10)
 
 
 
127
  st.pyplot(fig)
128
  plt.close()
129
 
130
+ # --- App UI ---
131
+ st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
132
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
133
 
134
  model = load_model()
135
 
136
+ uploaded_files = st.file_uploader(
137
+ "Upload one or more retinal images",
138
+ type=["jpg", "jpeg", "png"],
139
+ accept_multiple_files=True
140
+ )
141
+
142
+ if uploaded_files:
143
+ for uploaded_file in uploaded_files:
144
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
145
+ bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
146
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
147
+
148
+ processed_img, stages = preprocess_image(rgb_img)
149
+ input_tensor = np.expand_dims(processed_img, axis=0)
150
+
151
+ preds = model.predict(input_tensor)
152
+ if isinstance(preds, dict):
153
+ preds = list(preds.values())[0]
154
+ pred_idx = np.argmax(preds)
155
+ pred_label = CLASS_NAMES[pred_idx]
156
+ confidence = np.max(preds) * 100
157
+
158
+ display_all_results(uploaded_file.name, rgb_img, processed_img, stages, pred_label, confidence, pred_idx)