Upload 8 files
Browse files- .gitattributes +1 -0
- coatnet_retina_app/README.md +7 -0
- coatnet_retina_app/app.py +38 -0
- coatnet_retina_app/logs/feedback_log.csv +0 -0
- coatnet_retina_app/model.keras.keras +3 -0
- coatnet_retina_app/preprocessing/preprocess.py +40 -0
- coatnet_retina_app/requirements.txt +9 -0
- coatnet_retina_app/utils/gradcam.py +28 -0
- coatnet_retina_app/utils/lime_explainer.py +23 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
coatnet_retina_app/model.keras.keras filter=lfs diff=lfs merge=lfs -text
|
coatnet_retina_app/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# 🧠 Retinal Disease Classifier with CoAtNet
|
| 3 |
+
|
| 4 |
+
Upload a retinal image to classify it into 8 diseases using a fine-tuned CoAtNet model.
|
| 5 |
+
Includes Grad-CAM and LIME explanations.
|
| 6 |
+
|
| 7 |
+
Built with Streamlit • Deployable on Hugging Face Spaces.
|
coatnet_retina_app/app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from utils.gradcam import generate_gradcam
|
| 9 |
+
from utils.lime_explainer import explain_with_lime
|
| 10 |
+
from preprocessing.preprocess import preprocess_image
|
| 11 |
+
|
| 12 |
+
@st.cache_resource
|
| 13 |
+
def load_model():
|
| 14 |
+
return tf.keras.models.load_model("model.keras")
|
| 15 |
+
|
| 16 |
+
model = load_model()
|
| 17 |
+
|
| 18 |
+
class_names = ["Normal", "Diabetes", "Glaucoma", "Cataract", "AMD", "Hypertension", "Myopia", "Others"]
|
| 19 |
+
|
| 20 |
+
st.title("🧠 Retinal Disease Classifier (CoAtNet)")
|
| 21 |
+
uploaded_file = st.file_uploader("Upload a retinal image", type=["jpg", "png", "jpeg"])
|
| 22 |
+
|
| 23 |
+
if uploaded_file:
|
| 24 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 25 |
+
st.image(image, caption="Input Image", use_column_width=True)
|
| 26 |
+
|
| 27 |
+
img_array = preprocess_image(image)
|
| 28 |
+
pred = model.predict(np.expand_dims(img_array, axis=0))
|
| 29 |
+
predicted_class = class_names[np.argmax(pred)]
|
| 30 |
+
st.success(f"✅ Predicted Disease: **{predicted_class}**")
|
| 31 |
+
|
| 32 |
+
st.subheader("🔍 Grad-CAM Explanation")
|
| 33 |
+
gradcam = generate_gradcam(model, img_array)
|
| 34 |
+
st.image(gradcam, caption="Grad-CAM", use_column_width=True)
|
| 35 |
+
|
| 36 |
+
st.subheader("🔍 LIME Explanation")
|
| 37 |
+
lime_expl = explain_with_lime(model, img_array)
|
| 38 |
+
st.image(lime_expl, caption="LIME Explanation", use_column_width=True)
|
coatnet_retina_app/logs/feedback_log.csv
ADDED
|
File without changes
|
coatnet_retina_app/model.keras.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ce970139e4fd81f52f923a9b302276a97f99dd9b3b6a85861ec3fcabbe09ef6d
|
| 3 |
+
size 133360977
|
coatnet_retina_app/preprocessing/preprocess.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def crop_circle(img):
|
| 6 |
+
h, w = img.shape[:2]
|
| 7 |
+
center = (w // 2, h // 2)
|
| 8 |
+
radius = min(center[0], center[1])
|
| 9 |
+
Y, X = np.ogrid[:h, :w]
|
| 10 |
+
dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
|
| 11 |
+
mask = dist <= radius
|
| 12 |
+
if img.ndim == 3:
|
| 13 |
+
mask = np.stack([mask] * 3, axis=-1)
|
| 14 |
+
img[~mask] = 0
|
| 15 |
+
return img
|
| 16 |
+
|
| 17 |
+
def apply_clahe(img):
|
| 18 |
+
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
|
| 19 |
+
l, a, b = cv2.split(lab)
|
| 20 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 21 |
+
cl = clahe.apply(l)
|
| 22 |
+
merged = cv2.merge((cl, a, b))
|
| 23 |
+
return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
|
| 24 |
+
|
| 25 |
+
def sharpen_image(img, sigma=10):
|
| 26 |
+
blur = cv2.GaussianBlur(img, (0, 0), sigma)
|
| 27 |
+
return cv2.addWeighted(img, 4, blur, -4, 128)
|
| 28 |
+
|
| 29 |
+
def resize_normalize(img, size=(224, 224)):
|
| 30 |
+
img = cv2.resize(img, size)
|
| 31 |
+
img = img / 255.0
|
| 32 |
+
return img
|
| 33 |
+
|
| 34 |
+
def preprocess_image(image):
|
| 35 |
+
img = np.array(image)
|
| 36 |
+
img = crop_circle(img)
|
| 37 |
+
img = apply_clahe(img)
|
| 38 |
+
img = sharpen_image(img)
|
| 39 |
+
img = resize_normalize(img)
|
| 40 |
+
return img.astype(np.float32)
|
coatnet_retina_app/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
streamlit
|
| 3 |
+
tensorflow==2.11.0
|
| 4 |
+
opencv-python-headless
|
| 5 |
+
lime
|
| 6 |
+
scikit-image
|
| 7 |
+
numpy
|
| 8 |
+
matplotlib
|
| 9 |
+
Pillow
|
coatnet_retina_app/utils/gradcam.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import cv2
|
| 5 |
+
import matplotlib.cm as cm
|
| 6 |
+
|
| 7 |
+
def find_last_conv_layer(model):
|
| 8 |
+
for layer in reversed(model.layers):
|
| 9 |
+
if isinstance(layer, tf.keras.layers.Conv2D) or 'mhsa_output' in layer.name:
|
| 10 |
+
return layer.name
|
| 11 |
+
raise ValueError("No suitable conv layer found.")
|
| 12 |
+
|
| 13 |
+
def generate_gradcam(model, img_array):
|
| 14 |
+
layer_name = find_last_conv_layer(model)
|
| 15 |
+
grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(layer_name).output, model.output])
|
| 16 |
+
with tf.GradientTape() as tape:
|
| 17 |
+
conv_outputs, predictions = grad_model(tf.expand_dims(img_array, axis=0))
|
| 18 |
+
loss = predictions[:, tf.argmax(predictions[0])]
|
| 19 |
+
grads = tape.gradient(loss, conv_outputs)
|
| 20 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
| 21 |
+
heatmap = conv_outputs[0] @ pooled_grads[..., tf.newaxis]
|
| 22 |
+
heatmap = tf.squeeze(heatmap)
|
| 23 |
+
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
|
| 24 |
+
heatmap = cv2.resize(heatmap.numpy(), (224, 224))
|
| 25 |
+
heatmap = np.uint8(255 * heatmap)
|
| 26 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 27 |
+
superimposed_img = cv2.addWeighted(np.uint8(img_array * 255), 0.5, heatmap, 0.5, 0)
|
| 28 |
+
return superimposed_img
|
coatnet_retina_app/utils/lime_explainer.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
from lime import lime_image
|
| 4 |
+
from skimage.segmentation import mark_boundaries
|
| 5 |
+
|
| 6 |
+
explainer = lime_image.LimeImageExplainer()
|
| 7 |
+
|
| 8 |
+
def explain_with_lime(model, img_array):
|
| 9 |
+
def predict_fn(images): return model.predict(np.array(images), verbose=0)
|
| 10 |
+
explanation = explainer.explain_instance(
|
| 11 |
+
image=img_array,
|
| 12 |
+
classifier_fn=predict_fn,
|
| 13 |
+
top_labels=1,
|
| 14 |
+
hide_color=0,
|
| 15 |
+
num_samples=1000
|
| 16 |
+
)
|
| 17 |
+
temp, mask = explanation.get_image_and_mask(
|
| 18 |
+
label=explanation.top_labels[0],
|
| 19 |
+
positive_only=True,
|
| 20 |
+
num_features=10,
|
| 21 |
+
hide_rest=False
|
| 22 |
+
)
|
| 23 |
+
return mark_boundaries(temp, mask)
|