import gradio as gr import numpy as np import json import os from PIL import Image import onnxruntime as ort # โ”€โ”€ Load ONNX model โ€” only ~200MB RAM (vs 2GB for TensorFlow) โ”€ MODEL_PATH = "eye_disease_model.onnx" CLASS_PATH = "class_indices.json" if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Model not found: {MODEL_PATH}") session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider']) input_name = session.get_inputs()[0].name with open(CLASS_PATH) as f: idx_map = json.load(f) IMG_SIZE = 224 LABEL_MAP = { "cataract": "๐Ÿ”ต Cataract", "diabetic_retinopathy": "๐Ÿ”ด Diabetic Retinopathy", "glaucoma": "๐ŸŸก Glaucoma", "normal": "๐ŸŸข Normal" } INFO_MAP = { "cataract": "Clouding of the eye lens. Treatable with surgery.", "diabetic_retinopathy": "Retinal damage caused by diabetes. Early detection prevents blindness.", "glaucoma": "Optic nerve damage from high eye pressure. Can cause permanent vision loss.", "normal": "No disease detected. Keep up with regular eye checkups!" } def predict(image): if image is None: return "โš ๏ธ Please upload a retinal image." img = Image.fromarray(image).convert("RGB").resize((IMG_SIZE, IMG_SIZE)) arr = np.array(img, dtype=np.float32) / 255.0 arr = np.expand_dims(arr, axis=0) # (1, 224, 224, 3) preds = session.run(None, {input_name: arr})[0][0] top_idx = int(np.argmax(preds)) top_cls = idx_map[str(top_idx)] conf = float(preds[top_idx]) * 100 result = f"## {LABEL_MAP[top_cls]}\n" result += f"**Confidence: {conf:.1f}%**\n\n" result += f"๐Ÿ“‹ {INFO_MAP[top_cls]}\n\n" result += "---\n### All Probabilities\n\n" for i in range(len(preds)): cls = idx_map[str(i)] prob = float(preds[i]) * 100 bar = "โ–ˆ" * int(prob / 5) + "โ–‘" * (20 - int(prob / 5)) result += f"`{LABEL_MAP[cls]:<32}` {bar} {prob:.1f}%\n\n" result += "\n---\nโš ๏ธ *For educational purposes only. Always consult an ophthalmologist.*" return result with gr.Blocks(title="๐Ÿ‘๏ธ Eye Disease Detector", theme=gr.themes.Soft()) as demo: gr.Markdown("# ๐Ÿ‘๏ธ Eye Disease Detection") gr.Markdown( "Upload a **retinal fundus image** โ€” detects: " "Cataract ยท Diabetic Retinopathy ยท Glaucoma ยท Normal\n\n" "*MobileNetV2 Transfer Learning โ€” No TensorFlow needed at runtime*" ) with gr.Row(): with gr.Column(): img_input = gr.Image(label="Upload Retinal Image", type="numpy") btn = gr.Button("๐Ÿ” Analyse", variant="primary") with gr.Column(): output = gr.Markdown() btn.click(fn=predict, inputs=img_input, outputs=output) img_input.change(fn=predict, inputs=img_input, outputs=output) demo.launch()