import numpy as np import cv2 import tensorflow as tf import streamlit as st import matplotlib.pyplot as plt from lime import lime_image from skimage.segmentation import mark_boundaries from keras.layers import BatchNormalization, DepthwiseConv2D, TFSMLayer import os from io import BytesIO import base64 # FIXED CSS - Removed animations and stabilized background st.markdown( """ """, unsafe_allow_html=True, ) # --- Fix deserialization issues --- original_bn = BatchNormalization.from_config BatchNormalization.from_config = classmethod( lambda cls, config, *a, **k: original_bn( config if not isinstance(config.get("axis"), list) else {**config, "axis": config["axis"][0]}, *a, **k ) ) original_dw = DepthwiseConv2D.from_config DepthwiseConv2D.from_config = classmethod( lambda cls, config, *a, **k: original_dw({k: v for k, v in config.items() if k != "groups"}, *a, **k) ) # --- FIXED: Simplified background function (no dynamic changes) --- def set_background(): """Set a stable, consistent background""" st.markdown(""" """, unsafe_allow_html=True) # Apply stable background set_background() # --- Constants --- IMG_SIZE = (224, 224) CLASS_NAMES = [ 'Normal', 'Diabetic Retinopathy', 'Glaucoma', 'Cataract', 'Age-related Macular Degeneration (AMD)', 'Hypertension', 'Myopia', 'Others' ] LIME_EXPLAINER = lime_image.LimeImageExplainer() # --- Load Model --- @st.cache_resource def load_model(): model_path = "Model" if not os.path.exists(model_path): st.error(f"🚨 Model folder '{model_path}' not found.") st.stop() try: model = tf.keras.Sequential([TFSMLayer(model_path, call_endpoint="serving_default")]) return model except Exception as e: st.error(f"🚨 Error loading model: {e}") st.stop() # --- Prediction --- def predict(images, model): images = np.array(images) preds = model.predict(images, verbose=0) if isinstance(preds, dict): for v in preds.values(): if isinstance(v, (np.ndarray, list)): return np.array(v) return np.array(list(preds.values())[0]) else: return preds # --- FIXED: Stable preprocessing with consistent styling --- def preprocess_with_steps(img): h, w = img.shape[:2] center, radius = (w // 2, h // 2), min(w, h) // 2 Y, X = np.ogrid[:h, :w] dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) mask = (dist <= radius).astype(np.uint8) circ = img.copy() white_bg = np.ones_like(circ, dtype=np.uint8) * 255 circ = np.where(mask[:, :, np.newaxis] == 1, circ, white_bg) lab = cv2.cvtColor(circ, cv2.COLOR_RGB2LAB) cl = cv2.createCLAHE(clipLimit=2.0).apply(lab[:, :, 0]) merged = cv2.merge((cl, lab[:, :, 1], lab[:, :, 2])) clahe_img = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB) sharp = cv2.addWeighted(clahe_img, 4, cv2.GaussianBlur(clahe_img, (0, 0), 10), -4, 128) resized = cv2.resize(sharp, IMG_SIZE) / 255.0 # FIXED: Stable visualization with consistent styling fig, axs = plt.subplots(1, 4, figsize=(16, 4)) fig.patch.set_facecolor('white') # Fixed to white background for ax, image, title in zip( axs, [img, circ, clahe_img, resized], ["Original", "Circular Crop", "CLAHE", "Sharpen + Resize"] ): ax.imshow(image) ax.set_title(title, fontsize=14, fontweight='bold', color='#1e40af') ax.axis("off") plt.tight_layout() st.pyplot(fig) plt.close(fig) return resized # FIXED: Stable explanation text (no dynamic styling) explanation_text = { 'Normal': """

✅ Normal Retina

""", 'Diabetic Retinopathy': """

⚠️ Diabetic Retinopathy

""", 'Glaucoma': """

👁 Glaucoma

""", 'Cataract': """

🌫️ Cataract

""", 'Age-related Macular Degeneration (AMD)': """

🧓 Age-related Macular Degeneration

""", 'Hypertension': """

🩸 Hypertensive Retinopathy

""", 'Myopia': """

👓 Myopic Changes

""", 'Others': """

🔍 Unclassified Findings

""" } # --- FIXED: Stable LIME Display --- def show_lime(img, model, pred_idx, pred_label, all_probs): with st.spinner("🔬 Generating LIME explanation..."): explanation = LIME_EXPLAINER.explain_instance( image=img, classifier_fn=lambda imgs: predict(imgs, model), top_labels=1, hide_color=0, num_samples=200, ) temp, mask = explanation.get_image_and_mask( label=pred_idx, positive_only=True, num_features=10, hide_rest=False ) lime_img = mark_boundaries(temp, mask) buf = BytesIO() plt.imsave(buf, lime_img, format="png") buf.seek(0) lime_data = buf.getvalue() # FIXED: Stable layout col1, col2 = st.columns(2) with col1: st.markdown("""

🔬 LIME Explanation

""", unsafe_allow_html=True) st.image(lime_data, width=280, output_format="PNG") st.download_button( "📥 Download LIME Analysis", lime_data, file_name=f"{pred_label}_LIME_Analysis.png", mime="image/png" ) with col2: st.markdown(explanation_text.get(pred_label, "

No explanation available.

"), unsafe_allow_html=True) # --- FIXED: Stable confidence display --- def show_confidence(confidence, pred_label): # FIXED: Determine confidence level without dynamic styling if confidence >= 80: icon = "🎯" level = "high" elif confidence >= 60: icon = "⚠️" level = "medium" else: icon = "🔍" level = "low" st.markdown(f"""

{icon} Diagnosis: {pred_label}

Confidence: {confidence:.1f}%

""", unsafe_allow_html=True) # --- FIXED: Stable Streamlit App UI --- st.set_page_config( page_title="👁️ Retina AI Classifier", layout="wide", initial_sidebar_state="expanded" ) # FIXED: Stable main header st.markdown("""

👁️ Retina Disease Classifier

AI-Powered Retinal Analysis with LIME Explainability

""", unsafe_allow_html=True) model = load_model() # FIXED: Stable sidebar with st.sidebar: st.markdown(""" """, unsafe_allow_html=True) uploaded_files = st.file_uploader( "Choose retinal images", type=["jpg", "jpeg", "png"], accept_multiple_files=True, help="Upload high-quality fundus photographs" ) selected_filename = None if uploaded_files: st.markdown(""" """, unsafe_allow_html=True) filenames = [f.name for f in uploaded_files] selected_filename = st.selectbox( "Choose image for analysis", filenames, help="Select which image to analyze with LIME" ) # FIXED: Stable main content area if uploaded_files and selected_filename: file = next(f for f in uploaded_files if f.name == selected_filename) file.seek(0) bgr = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR) rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) # FIXED: Stable processing steps section st.markdown("""

🔬 Image Preprocessing Pipeline

Standardized preprocessing steps for optimal AI analysis

""", unsafe_allow_html=True) preprocessed = preprocess_with_steps(rgb) input_tensor = np.expand_dims(preprocessed, axis=0) # Prediction preds = predict(input_tensor, model) pred_idx = np.argmax(preds) pred_label = CLASS_NAMES[pred_idx] confidence = np.max(preds) * 100 # FIXED: Stable prediction display show_confidence(confidence, pred_label) # FIXED: Stable LIME explanation section st.markdown("""

🧠 AI Explanation & Clinical Insights

Understanding how AI identified the diagnosis with medical context

""", unsafe_allow_html=True) # LIME explanation show_lime(preprocessed, model, pred_idx, pred_label, preds) else: # FIXED: Stable welcome screen st.markdown("""

Welcome to the Retina AI Classifier

Upload retinal fundus images to begin AI-powered analysis

Drag and drop your images or use the sidebar to get started

""", unsafe_allow_html=True) # FIXED: Stable feature grid st.markdown("""
🔬
AI-Powered Analysis
Advanced deep learning models trained on thousands of retinal images
👁️
8 Conditions Detected
Normal, Diabetic Retinopathy, Glaucoma, Cataract, AMD, Hypertension, Myopia, Others
🔍
LIME Explanations
Visual explanations showing which areas influenced the AI's decision
🏥
Clinical Grade
Designed for healthcare professionals with detailed medical insights
""", unsafe_allow_html=True)