import numpy as np
import cv2
import tensorflow as tf
import streamlit as st
import matplotlib.pyplot as plt
from lime import lime_image
from skimage.segmentation import mark_boundaries
from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer
import os
from io import BytesIO
import base64
# FIXED CSS - Removed animations and stabilized background
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# --- Fix deserialization issues ---
original_bn = BatchNormalization.from_config
BatchNormalization.from_config = classmethod(
lambda cls, config, *a, **k: original_bn(
config if not isinstance(config.get("axis"), list) else {**config, "axis": config["axis"][0]}, *a, **k
)
)
original_dw = DepthwiseConv2D.from_config
DepthwiseConv2D.from_config = classmethod(
lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k)
)
# --- FIXED: Simplified background function (no dynamic changes) ---
def set_background():
"""Set a stable, consistent background"""
st.markdown("""
""", unsafe_allow_html=True)
# Apply stable background
set_background()
# --- Constants ---
IMG_SIZE = (224, 224)
CLASS_NAMES = [
'Normal', 'Diabetic Retinopathy', 'Glaucoma', 'Cataract',
'Age-related Macular Degeneration (AMD)', 'Hypertension', 'Myopia', 'Others'
]
LIME_EXPLAINER = lime_image.LimeImageExplainer()
# --- Load Model ---
@st.cache_resource
def load_model():
model_path = "Model"
if not os.path.exists(model_path):
st.error(f"🚨 Model folder '{model_path}' not found.")
st.stop()
try:
model = tf.keras.Sequential([TFSMLayer(model_path, call_endpoint="serving_default")])
return model
except Exception as e:
st.error(f"🚨 Error loading model: {e}")
st.stop()
# --- Prediction ---
def predict(images, model):
images = np.array(images)
preds = model.predict(images, verbose=0)
if isinstance(preds, dict):
for v in preds.values():
if isinstance(v, (np.ndarray, list)):
return np.array(v)
return np.array(list(preds.values())[0])
else:
return preds
# --- FIXED: Stable preprocessing with consistent styling ---
def preprocess_with_steps(img):
h, w = img.shape[:2]
center, radius = (w // 2, h // 2), min(w, h) // 2
Y, X = np.ogrid[:h, :w]
dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
mask = (dist <= radius).astype(np.uint8)
circ = img.copy()
white_bg = np.ones_like(circ, dtype=np.uint8) * 255
circ = np.where(mask[:, :, np.newaxis] == 1, circ, white_bg)
lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB)
cl = cv2.createCLAHE(clipLimit=2.0).apply(lab[:, :, 0])
merged = cv2.merge((cl, lab[:, :, 1], lab[:, :, 2]))
clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
sharp = cv2.addWeighted(clahe_img, 4, cv2.GaussianBlur(clahe_img, (0, 0), 10), -4, 128)
resized = cv2.resize(sharp, IMG_SIZE) / 255.0
# FIXED: Stable visualization with consistent styling
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
fig.patch.set_facecolor('white') # Fixed to white background
for ax, image, title in zip(
axs, [img, circ, clahe_img, resized],
["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"]
):
ax.imshow(image)
ax.set_title(title, fontsize=14, fontweight='bold', color='#1e40af')
ax.axis("off")
plt.tight_layout()
st.pyplot(fig)
plt.close(fig)
return resized
# FIXED: Stable explanation text (no dynamic styling)
explanation_text = {
'Normal': """
✅ Normal Retina
- 🟢 Clear retinal structure - No pathological lesions detected
- 🩺 Healthy blood vessels - Normal caliber and branching pattern
- 👁 Intact optic disc & macula - Proper anatomical structure
- ✅ No disease indicators - Excellent retinal health
""",
'Diabetic Retinopathy': """
⚠️ Diabetic Retinopathy
- 🔴 Microhemorrhages - Red spots indicating vessel damage
- 🩸 Vascular leakage - Fluid accumulation in retinal tissue
- 👁 Macular involvement - Possible diabetic macular edema
- 🔬 Requires monitoring - Regular ophthalmologic follow-up needed
""",
'Glaucoma': """
👁 Glaucoma
- 🔴 Optic nerve damage - Thinning of nerve fiber layer
- ⚪ Increased cup-to-disc ratio - Optic disc cupping
- 📉 Visual field risk - Potential peripheral vision loss
- 💊 Pressure management - IOP control essential
""",
'Cataract': """
🌫️ Cataract
- ☁️ Lens opacity - Clouding affecting image clarity
- 🔍 Reduced contrast - Decreased retinal detail visibility
- 👁 Fundus visualization - Limited view of posterior structures
- 🏥 Surgical consideration - May benefit from cataract extraction
""",
'Age-related Macular Degeneration (AMD)': """
🧓 Age-related Macular Degeneration
- 🟡 Drusen deposits - Yellow spots near macular region
- 👁 Central vision impact - Macula-specific changes
- 📈 Progressive condition - Age-related degenerative process
- 🔬 Monitoring required - Regular assessment for progression
""",
'Hypertension': """
🩸 Hypertensive Retinopathy
- ⭐ Cotton wool spots - Nerve fiber layer infarcts
- 🔴 Flame hemorrhages - Superficial retinal bleeding
- 🩸 Arteriovenous nicking - Vessel caliber changes
- 💊 BP management - Systemic hypertension control needed
""",
'Myopia': """
👓 Myopic Changes
- 🔵 Axial elongation signs - Elongated eyeball morphology
- ⚪ Peripapillary atrophy - Tissue thinning around optic disc
- 📐 Disc tilting - Oblique optic disc orientation
- 👁 Refractive changes - Associated with high myopia
""",
'Others': """
🔍 Unclassified Findings
- ❓ Atypical presentation - Unusual retinal patterns
- 🔬 Further evaluation - Additional testing recommended
- 🩺 Specialist referral - Ophthalmologist consultation advised
- 📋 Comprehensive exam - Complete ocular assessment needed
"""
}
# --- FIXED: Stable LIME Display ---
def show_lime(img, model, pred_idx, pred_label, all_probs):
with st.spinner("🔬 Generating LIME explanation..."):
explanation = LIME_EXPLAINER.explain_instance(
image=img,
classifier_fn=lambda imgs: predict(imgs, model),
top_labels=1,
hide_color=0,
num_samples=200,
)
temp, mask = explanation.get_image_and_mask(
label=pred_idx, positive_only=True, num_features=10, hide_rest=False
)
lime_img = mark_boundaries(temp, mask)
buf = BytesIO()
plt.imsave(buf, lime_img, format="png")
buf.seek(0)
lime_data = buf.getvalue()
# FIXED: Stable layout
col1, col2 = st.columns(2)
with col1:
st.markdown("""
🔬 LIME Explanation
""", unsafe_allow_html=True)
st.image(lime_data, width=280, output_format="PNG")
st.download_button(
"📥 Download LIME Analysis",
lime_data,
file_name=f"{pred_label}_LIME_Analysis.png",
mime="image/png"
)
with col2:
st.markdown(explanation_text.get(pred_label, "No explanation available.
"), unsafe_allow_html=True)
# --- FIXED: Stable confidence display ---
def show_confidence(confidence, pred_label):
# FIXED: Determine confidence level without dynamic styling
if confidence >= 80:
icon = "🎯"
level = "high"
elif confidence >= 60:
icon = "⚠️"
level = "medium"
else:
icon = "🔍"
level = "low"
st.markdown(f"""
{icon} Diagnosis: {pred_label}
Confidence: {confidence:.1f}%
""", unsafe_allow_html=True)
# --- FIXED: Stable Streamlit App UI ---
st.set_page_config(
page_title="👁️ Retina AI Classifier",
layout="wide",
initial_sidebar_state="expanded"
)
# FIXED: Stable main header
st.markdown("""
👁️ Retina Disease Classifier
AI-Powered Retinal Analysis with LIME Explainability
""", unsafe_allow_html=True)
model = load_model()
# FIXED: Stable sidebar
with st.sidebar:
st.markdown("""
""", unsafe_allow_html=True)
uploaded_files = st.file_uploader(
"Choose retinal images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True,
help="Upload high-quality fundus photographs"
)
selected_filename = None
if uploaded_files:
st.markdown("""
""", unsafe_allow_html=True)
filenames = [f.name for f in uploaded_files]
selected_filename = st.selectbox(
"Choose image for analysis",
filenames,
help="Select which image to analyze with LIME"
)
# FIXED: Stable main content area
if uploaded_files and selected_filename:
file = next(f for f in uploaded_files if f.name == selected_filename)
file.seek(0)
bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
# FIXED: Stable processing steps section
st.markdown("""
🔬 Image Preprocessing Pipeline
Standardized preprocessing steps for optimal AI analysis
""", unsafe_allow_html=True)
preprocessed = preprocess_with_steps(rgb)
input_tensor = np.expand_dims(preprocessed, axis=0)
# Prediction
preds = predict(input_tensor, model)
pred_idx = np.argmax(preds)
pred_label = CLASS_NAMES[pred_idx]
confidence = np.max(preds) * 100
# FIXED: Stable prediction display
show_confidence(confidence, pred_label)
# FIXED: Stable LIME explanation section
st.markdown("""
🧠 AI Explanation & Clinical Insights
Understanding how AI identified the diagnosis with medical context
""", unsafe_allow_html=True)
# LIME explanation
show_lime(preprocessed, model, pred_idx, pred_label, preds)
else:
# FIXED: Stable welcome screen
st.markdown("""
Welcome to the Retina AI Classifier
Upload retinal fundus images to begin AI-powered analysis
Drag and drop your images or use the sidebar to get started
""", unsafe_allow_html=True)
# FIXED: Stable feature grid
st.markdown("""
🔬
AI-Powered Analysis
Advanced deep learning models trained on thousands of retinal images
👁️
8 Conditions Detected
Normal, Diabetic Retinopathy, Glaucoma, Cataract, AMD, Hypertension, Myopia, Others
🔍
LIME Explanations
Visual explanations showing which areas influenced the AI's decision
🏥
Clinical Grade
Designed for healthcare professionals with detailed medical insights
""", unsafe_allow_html=True)