import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np # ======================================= # CONFIGURATION # ======================================= DEVICE = torch.device("cpu") # Hugging Face Free Tier = CPU only CLASS_NAMES = ["Non-Glaucoma", "Glaucoma"] MODEL_PATH = "model_fold_0.pth" # ======================================= # LOAD MODEL # ======================================= try: model = models.resnet18(weights=None) model.fc = nn.Linear(model.fc.in_features, 2) state_dict = torch.load(MODEL_PATH, map_location=DEVICE) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() except Exception as e: raise RuntimeError(f"Failed to load model from {MODEL_PATH}\nError: {str(e)}") # ======================================= # IMAGE PREPROCESSING # ======================================= transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # ======================================= # PREDICTION FUNCTION # ======================================= def predict_fundus(image): if image is None: return "Please upload a retinal fundus image to begin.", None try: img_pil = Image.fromarray(image).convert("RGB") img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): output = model(img_tensor) probs = torch.softmax(output, dim=1)[0].cpu().numpy() pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) label = CLASS_NAMES[pred_idx] result_text = f""" ### Analysis Result **Prediction:** {label} **Confidence:** {confidence:.1%} **Non-Glaucoma Probability:** {probs[0]:.1%} **Glaucoma Probability:** {probs[1]:.1%} --- ⚠ This tool is for research and educational purposes only. It must not be used for clinical diagnosis or medical decision-making. """.strip() img_display = np.array(img_pil.resize((400, 400))) return result_text, img_display except Exception as e: return f"Error during analysis: {str(e)}", None # ======================================= # PROFESSIONAL HIGH-CONTRAST CSS # ======================================= custom_css = """ body { font-family: 'Segoe UI', sans-serif; background: #ffffff; color: #111827; } .gradio-container { max-width: 1100px !important; margin: auto; } h1 { color: #1e3a8a !important; font-weight: 700 !important; text-align: center; } h3 { color: #1f2937 !important; font-weight: 600 !important; } .markdown { color: #111827 !important; } .upload-zone { border: 2px dashed #64748b; border-radius: 12px; padding: 20px; background: white; } .result-panel { background: white; border-radius: 12px; box-shadow: 0 4px 15px rgba(0,0,0,0.08); padding: 24px; min-height: 380px; } .note { font-size: 0.95em; color: #374151; margin-top: 16px; } """ # ======================================= # GRADIO INTERFACE # ======================================= with gr.Blocks(theme=gr.themes.Default(), css=custom_css) as demo: gr.Markdown(""" # Glaucoma Screening – Fundus Image Analysis Upload a retinal fundus photograph to receive an AI-based probability assessment. """) with gr.Row(equal_height=True): with gr.Column(scale=5): gr.Markdown("### Upload Fundus Image") input_image = gr.Image( type="numpy", label="", elem_classes=["upload-zone"], height=480, image_mode="RGB" ) analyze_btn = gr.Button("Analyze Image", variant="primary") with gr.Column(scale=5): gr.Markdown("### Analysis Result") output_text = gr.Markdown( value="Upload an image and click Analyze to begin.", elem_classes=["result-panel"] ) output_image = gr.Image( label="Uploaded Image (Resized)", type="numpy", height=400, interactive=False ) gr.Markdown("""