Sefat33 commited on
Commit
99ce1ef
·
verified ·
1 Parent(s): 1e07c8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -88
app.py CHANGED
@@ -4,47 +4,26 @@ import cv2
4
  import tensorflow as tf
5
  import streamlit as st
6
  import matplotlib.pyplot as plt
7
- import matplotlib.cm as cm
8
  from lime import lime_image
9
  from skimage.segmentation import mark_boundaries
10
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
11
 
12
  # --- Fix deserialization issues ---
13
- original_bn_from_config = BatchNormalization.from_config
14
- def patched_bn_from_config(cls, config, *args, **kwargs):
15
- if "axis" in config and isinstance(config["axis"], (list, tuple)):
16
- config["axis"] = config["axis"][0]
17
- return original_bn_from_config(config, *args, **kwargs)
18
- BatchNormalization.from_config = classmethod(patched_bn_from_config)
19
-
20
- original_dwconv_from_config = DepthwiseConv2D.from_config
21
- def patched_dwconv_from_config(cls, config, *args, **kwargs):
22
- if "groups" in config:
23
- config.pop("groups")
24
- return original_dwconv_from_config(config, *args, **kwargs)
25
- DepthwiseConv2D.from_config = classmethod(patched_dwconv_from_config)
26
 
27
  # --- Constants ---
28
  IMG_SIZE = (224, 224)
29
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
30
 
31
- # --- Load model ---
32
  @st.cache_resource
33
  def load_model():
34
  model_path = "Model"
35
- if not os.path.exists(model_path):
36
- st.error(f"❌ Model folder '{model_path}' not found!")
37
- st.stop()
38
- try:
39
- model = tf.keras.Sequential([
40
- TFSMLayer(model_path, call_endpoint="serving_default")
41
- ])
42
- return model
43
- except Exception as e:
44
- st.error(f"❌ Error loading model: {str(e)}")
45
- st.stop()
46
-
47
- # --- Preprocessing function ---
48
  def preprocess_image(img):
49
  h, w = img.shape[:2]
50
  center = (w // 2, h // 2)
@@ -67,92 +46,100 @@ def preprocess_image(img):
67
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
68
  return resized, [img, circ, clahe_img, resized]
69
 
70
- # --- LIME explainer ---
71
- explainer = lime_image.LimeImageExplainer()
72
-
73
- def predict_fn(images):
74
- images = np.array(images)
75
- preds = model.predict(images, verbose=0)
76
- if isinstance(preds, dict):
77
- preds = list(preds.values())[0]
78
- return preds
79
-
80
- # --- Explanation text ---
81
  explanation_text = {
82
- 'Normal': "Healthy optic disc and macula.",
83
- 'Diabetes': "Retinal vessel changes suggest Diabetes.",
84
- 'Glaucoma': "Optic disc cupping detected.",
85
- 'Cataract': "Blurring suggests Cataract.",
86
- 'AMD': "Degeneration signs in macula.",
87
- 'Hypertension': "Hemorrhages suggest Hypertension.",
88
- 'Myopia': "Fundus tilt suggests Myopia.",
89
- 'Others': "Non-specific features detected."
90
  }
91
 
92
- # --- Display results ---
93
- def display_all_results(image_name, orig_img, processed_img, stages, pred_label, confidence, pred_idx):
94
- st.header(f"🖼️ Image: `{image_name}`")
95
 
96
- # Show preprocessing
97
- st.subheader("🔍 Preprocessing Steps")
98
- fig, axs = plt.subplots(1, 4, figsize=(20, 5))
 
 
99
  titles = ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]
100
- for i, (img, title) in enumerate(zip(stages, titles)):
 
101
  axs[i].imshow(img)
102
- axs[i].set_title(title)
103
  axs[i].axis('off')
104
  st.pyplot(fig)
105
  plt.close()
106
 
107
- st.success(f"✅ Prediction: **{pred_label}** with confidence **{confidence:.2f}%**")
108
-
109
- # LIME
110
  with st.spinner("🟡 LIME Explanation is Loading..."):
111
  explanation = explainer.explain_instance(
112
- image=processed_img,
113
  classifier_fn=predict_fn,
114
  top_labels=1,
115
  hide_color=0,
116
  num_samples=1000
117
  )
118
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
119
- fig, axs = plt.subplots(1, 2, figsize=(12, 5))
120
- axs[0].imshow(processed_img)
121
- axs[0].set_title("Processed Image")
122
- axs[1].imshow(mark_boundaries(temp, mask))
123
- axs[1].set_title("LIME Explanation")
124
- for ax in axs:
125
- ax.axis("off")
126
- plt.figtext(0.5, 0.01, explanation_text.get(pred_label, ""), ha="center", fontsize=10)
127
  st.pyplot(fig)
128
  plt.close()
129
 
130
- # --- App UI ---
131
- st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
132
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
133
 
134
  model = load_model()
135
-
136
- uploaded_files = st.file_uploader(
137
- "Upload one or more retinal images",
138
- type=["jpg", "jpeg", "png"],
139
- accept_multiple_files=True
140
- )
141
 
142
  if uploaded_files:
143
- for uploaded_file in uploaded_files:
144
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
145
- bgr_img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
146
- rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
147
-
148
- processed_img, stages = preprocess_image(rgb_img)
149
- input_tensor = np.expand_dims(processed_img, axis=0)
150
-
151
- preds = model.predict(input_tensor)
152
- if isinstance(preds, dict):
153
- preds = list(preds.values())[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  pred_idx = np.argmax(preds)
155
  pred_label = CLASS_NAMES[pred_idx]
156
- confidence = np.max(preds) * 100
157
 
158
- display_all_results(uploaded_file.name, rgb_img, processed_img, stages, pred_label, confidence, pred_idx)
 
 
 
 
 
 
 
 
 
 
4
  import tensorflow as tf
5
  import streamlit as st
6
  import matplotlib.pyplot as plt
 
7
  from lime import lime_image
8
  from skimage.segmentation import mark_boundaries
9
  from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
10
 
11
  # --- Fix deserialization issues ---
12
+ BatchNormalization.from_config = classmethod(lambda cls, config, *args, **kwargs: BatchNormalization.from_config.__func__(cls, {**config, "axis": config["axis"][0] if isinstance(config["axis"], (list, tuple)) else config["axis"]}))
13
+ DepthwiseConv2D.from_config = classmethod(lambda cls, config, *args, **kwargs: DepthwiseConv2D.from_config.__func__(cls, {k: v for k, v in config.items() if k != "groups"}))
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # --- Constants ---
16
  IMG_SIZE = (224, 224)
17
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
18
 
 
19
  @st.cache_resource
20
  def load_model():
21
  model_path = "Model"
22
+ model = tf.keras.Sequential([
23
+ TFSMLayer(model_path, call_endpoint="serving_default")
24
+ ])
25
+ return model
26
+
 
 
 
 
 
 
 
 
27
  def preprocess_image(img):
28
  h, w = img.shape[:2]
29
  center = (w // 2, h // 2)
 
46
  resized = cv2.resize(sharp, IMG_SIZE) / 255.0
47
  return resized, [img, circ, clahe_img, resized]
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  explanation_text = {
50
+ 'Normal': "Model predicted Normal based on healthy optic disc and macula.",
51
+ 'Diabetes': "Detected retinal blood vessel changes suggestive of Diabetes.",
52
+ 'Glaucoma': "Detected increased cupping in the optic disc indicating Glaucoma.",
53
+ 'Cataract': "Image blur indicated potential Cataract.",
54
+ 'AMD': "Degeneration signs in macula indicate AMD.",
55
+ 'Hypertension': "Blood vessel narrowing/hemorrhages indicate Hypertension.",
56
+ 'Myopia': "Tilted disc and fundus shape suggest Myopia.",
57
+ 'Others': "Non-specific features detected, marked as Others."
58
  }
59
 
60
+ explainer = lime_image.LimeImageExplainer()
 
 
61
 
62
+ def predict_fn(images):
63
+ preds = model.predict(np.array(images))
64
+ return list(preds.values())[0] if isinstance(preds, dict) else preds
65
+
66
+ def show_preprocessing_steps(stages):
67
  titles = ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]
68
+ fig, axs = plt.subplots(1, 4, figsize=(20, 5))
69
+ for i, img in enumerate(stages):
70
  axs[i].imshow(img)
71
+ axs[i].set_title(titles[i])
72
  axs[i].axis('off')
73
  st.pyplot(fig)
74
  plt.close()
75
 
76
+ def show_lime_explanation(img, pred_idx, pred_label):
 
 
77
  with st.spinner("🟡 LIME Explanation is Loading..."):
78
  explanation = explainer.explain_instance(
79
+ image=img,
80
  classifier_fn=predict_fn,
81
  top_labels=1,
82
  hide_color=0,
83
  num_samples=1000
84
  )
85
  temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
86
+ fig, ax = plt.subplots(figsize=(6, 5))
87
+ ax.imshow(mark_boundaries(temp, mask))
88
+ ax.axis('off')
89
+ ax.set_title(f"LIME: {pred_label}")
 
 
 
 
90
  st.pyplot(fig)
91
  plt.close()
92
 
93
+ # --- Streamlit UI ---
94
+ st.set_page_config(page_title="🧠 Retina Classifier - Multi Image LIME", layout="wide")
95
  st.title("🧠 Retina Disease Classifier with LIME Explanation")
96
 
97
  model = load_model()
98
+ uploaded_files = st.file_uploader("Upload one or more retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
 
 
 
 
 
99
 
100
  if uploaded_files:
101
+ filenames = [file.name for file in uploaded_files]
102
+ selected_file = st.selectbox("Select image to analyze:", filenames)
103
+
104
+ # Show individual image analysis
105
+ for file in uploaded_files:
106
+ if file.name == selected_file:
107
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
108
+ img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
109
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
+ processed, stages = preprocess_image(rgb)
111
+ show_preprocessing_steps(stages)
112
+
113
+ input_tensor = np.expand_dims(processed, axis=0)
114
+ preds = predict_fn(input_tensor)
115
+ pred_idx = np.argmax(preds)
116
+ pred_label = CLASS_NAMES[pred_idx]
117
+ confidence = np.max(preds) * 100
118
+ st.success(f"✅ Prediction: **{pred_label}** ({confidence:.2f}%)")
119
+ show_lime_explanation(processed, pred_idx, pred_label)
120
+ break
121
+
122
+ st.markdown("---")
123
+ st.subheader("📊 LIME Explanations for All Uploaded Images")
124
+ cols = st.columns(len(uploaded_files))
125
+ for i, file in enumerate(uploaded_files):
126
+ file.seek(0)
127
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
128
+ img = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
129
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
130
+ processed, _ = preprocess_image(rgb)
131
+ input_tensor = np.expand_dims(processed, axis=0)
132
+ preds = predict_fn(input_tensor)
133
  pred_idx = np.argmax(preds)
134
  pred_label = CLASS_NAMES[pred_idx]
 
135
 
136
+ explanation = explainer.explain_instance(
137
+ image=processed,
138
+ classifier_fn=predict_fn,
139
+ top_labels=1,
140
+ hide_color=0,
141
+ num_samples=1000
142
+ )
143
+ temp, mask = explanation.get_image_and_mask(label=pred_idx, positive_only=True, num_features=10, hide_rest=False)
144
+ with cols[i]:
145
+ st.image(mark_boundaries(temp, mask), caption=f"{file.name}\n({pred_label})", use_column_width=True)