datakid's picture
Update app.py
68218ae verified
import gradio as gr
import numpy as np
import json
import os
from PIL import Image
import onnxruntime as ort
# ── Load ONNX model β€” only ~200MB RAM (vs 2GB for TensorFlow) ─
MODEL_PATH = "eye_disease_model.onnx"
CLASS_PATH = "class_indices.json"
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
with open(CLASS_PATH) as f:
idx_map = json.load(f)
IMG_SIZE = 224
LABEL_MAP = {
"cataract": "πŸ”΅ Cataract",
"diabetic_retinopathy": "πŸ”΄ Diabetic Retinopathy",
"glaucoma": "🟑 Glaucoma",
"normal": "🟒 Normal"
}
INFO_MAP = {
"cataract": "Clouding of the eye lens. Treatable with surgery.",
"diabetic_retinopathy": "Retinal damage caused by diabetes. Early detection prevents blindness.",
"glaucoma": "Optic nerve damage from high eye pressure. Can cause permanent vision loss.",
"normal": "No disease detected. Keep up with regular eye checkups!"
}
def predict(image):
if image is None:
return "⚠️ Please upload a retinal image."
img = Image.fromarray(image).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
arr = np.array(img, dtype=np.float32) / 255.0
arr = np.expand_dims(arr, axis=0) # (1, 224, 224, 3)
preds = session.run(None, {input_name: arr})[0][0]
top_idx = int(np.argmax(preds))
top_cls = idx_map[str(top_idx)]
conf = float(preds[top_idx]) * 100
result = f"## {LABEL_MAP[top_cls]}\n"
result += f"**Confidence: {conf:.1f}%**\n\n"
result += f"πŸ“‹ {INFO_MAP[top_cls]}\n\n"
result += "---\n### All Probabilities\n\n"
for i in range(len(preds)):
cls = idx_map[str(i)]
prob = float(preds[i]) * 100
bar = "β–ˆ" * int(prob / 5) + "β–‘" * (20 - int(prob / 5))
result += f"`{LABEL_MAP[cls]:<32}` {bar} {prob:.1f}%\n\n"
result += "\n---\n⚠️ *For educational purposes only. Always consult an ophthalmologist.*"
return result
with gr.Blocks(title="πŸ‘οΈ Eye Disease Detector", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ‘οΈ Eye Disease Detection")
gr.Markdown(
"Upload a **retinal fundus image** β€” detects: "
"Cataract Β· Diabetic Retinopathy Β· Glaucoma Β· Normal\n\n"
"*MobileNetV2 Transfer Learning β€” No TensorFlow needed at runtime*"
)
with gr.Row():
with gr.Column():
img_input = gr.Image(label="Upload Retinal Image", type="numpy")
btn = gr.Button("πŸ” Analyse", variant="primary")
with gr.Column():
output = gr.Markdown()
btn.click(fn=predict, inputs=img_input, outputs=output)
img_input.change(fn=predict, inputs=img_input, outputs=output)
demo.launch()