Sefat33 commited on
Commit
a2da10f
·
verified ·
1 Parent(s): d1b8d0a

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ coatnet_retina_app/model.keras.keras filter=lfs diff=lfs merge=lfs -text
coatnet_retina_app/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ # 🧠 Retinal Disease Classifier with CoAtNet
3
+
4
+ Upload a retinal image to classify it into 8 diseases using a fine-tuned CoAtNet model.
5
+ Includes Grad-CAM and LIME explanations.
6
+
7
+ Built with Streamlit • Deployable on Hugging Face Spaces.
coatnet_retina_app/app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+ from utils.gradcam import generate_gradcam
9
+ from utils.lime_explainer import explain_with_lime
10
+ from preprocessing.preprocess import preprocess_image
11
+
12
+ @st.cache_resource
13
+ def load_model():
14
+ return tf.keras.models.load_model("model.keras")
15
+
16
+ model = load_model()
17
+
18
+ class_names = ["Normal", "Diabetes", "Glaucoma", "Cataract", "AMD", "Hypertension", "Myopia", "Others"]
19
+
20
+ st.title("🧠 Retinal Disease Classifier (CoAtNet)")
21
+ uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "png", "jpeg"])
22
+
23
+ if uploaded_file:
24
+ image = Image.open(uploaded_file).convert("RGB")
25
+ st.image(image, caption="Input Image", use_column_width=True)
26
+
27
+ img_array = preprocess_image(image)
28
+ pred = model.predict(np.expand_dims(img_array, axis=0))
29
+ predicted_class = class_names[np.argmax(pred)]
30
+ st.success(f"✅ Predicted Disease: **{predicted_class}**")
31
+
32
+ st.subheader("🔍 Grad-CAM Explanation")
33
+ gradcam = generate_gradcam(model, img_array)
34
+ st.image(gradcam, caption="Grad-CAM", use_column_width=True)
35
+
36
+ st.subheader("🔍 LIME Explanation")
37
+ lime_expl = explain_with_lime(model, img_array)
38
+ st.image(lime_expl, caption="LIME Explanation", use_column_width=True)
coatnet_retina_app/logs/feedback_log.csv ADDED
File without changes
coatnet_retina_app/model.keras.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce970139e4fd81f52f923a9b302276a97f99dd9b3b6a85861ec3fcabbe09ef6d
3
+ size 133360977
coatnet_retina_app/preprocessing/preprocess.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import numpy as np
4
+
5
+ def crop_circle(img):
6
+ h, w = img.shape[:2]
7
+ center = (w // 2, h // 2)
8
+ radius = min(center[0], center[1])
9
+ Y, X = np.ogrid[:h, :w]
10
+ dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
11
+ mask = dist <= radius
12
+ if img.ndim == 3:
13
+ mask = np.stack([mask] * 3, axis=-1)
14
+ img[~mask] = 0
15
+ return img
16
+
17
+ def apply_clahe(img):
18
+ lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
19
+ l, a, b = cv2.split(lab)
20
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
21
+ cl = clahe.apply(l)
22
+ merged = cv2.merge((cl, a, b))
23
+ return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
24
+
25
+ def sharpen_image(img, sigma=10):
26
+ blur = cv2.GaussianBlur(img, (0, 0), sigma)
27
+ return cv2.addWeighted(img, 4, blur, -4, 128)
28
+
29
+ def resize_normalize(img, size=(224, 224)):
30
+ img = cv2.resize(img, size)
31
+ img = img / 255.0
32
+ return img
33
+
34
+ def preprocess_image(image):
35
+ img = np.array(image)
36
+ img = crop_circle(img)
37
+ img = apply_clahe(img)
38
+ img = sharpen_image(img)
39
+ img = resize_normalize(img)
40
+ return img.astype(np.float32)
coatnet_retina_app/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ streamlit
3
+ tensorflow==2.11.0
4
+ opencv-python-headless
5
+ lime
6
+ scikit-image
7
+ numpy
8
+ matplotlib
9
+ Pillow
coatnet_retina_app/utils/gradcam.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import cv2
5
+ import matplotlib.cm as cm
6
+
7
+ def find_last_conv_layer(model):
8
+ for layer in reversed(model.layers):
9
+ if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
10
+ return layer.name
11
+ raise ValueError("No suitable conv layer found.")
12
+
13
+ def generate_gradcam(model, img_array):
14
+ layer_name = find_last_conv_layer(model)
15
+ grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
16
+ with tf.GradientTape() as tape:
17
+ conv_outputs, predictions = grad_model(tf.expand_dims(img_array, axis=0))
18
+ loss = predictions[:, tf.argmax(predictions[0])]
19
+ grads = tape.gradient(loss, conv_outputs)
20
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
21
+ heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
22
+ heatmap = tf.squeeze(heatmap)
23
+ heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
24
+ heatmap = cv2.resize(heatmap.numpy(), (224, 224))
25
+ heatmap = np.uint8(255 * heatmap)
26
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
27
+ superimposed_img = cv2.addWeighted(np.uint8(img_array * 255), 0.5, heatmap, 0.5, 0)
28
+ return superimposed_img
coatnet_retina_app/utils/lime_explainer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from lime import lime_image
4
+ from skimage.segmentation import mark_boundaries
5
+
6
+ explainer = lime_image.LimeImageExplainer()
7
+
8
+ def explain_with_lime(model, img_array):
9
+ def predict_fn(images): return model.predict(np.array(images), verbose=0)
10
+ explanation = explainer.explain_instance(
11
+ image=img_array,
12
+ classifier_fn=predict_fn,
13
+ top_labels=1,
14
+ hide_color=0,
15
+ num_samples=1000
16
+ )
17
+ temp, mask = explanation.get_image_and_mask(
18
+ label=explanation.top_labels[0],
19
+ positive_only=True,
20
+ num_features=10,
21
+ hide_rest=False
22
+ )
23
+ return mark_boundaries(temp, mask)