glaucoma / app.py
irenekar's picture
Update app.py
b0e15b6 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
# =======================================
# CONFIGURATION
# =======================================
DEVICE = torch.device("cpu") # Hugging Face Free Tier = CPU only
CLASS_NAMES = ["Non-Glaucoma", "Glaucoma"]
MODEL_PATH = "model_fold_0.pth"
# =======================================
# LOAD MODEL
# =======================================
try:
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load model from {MODEL_PATH}\nError: {str(e)}")
# =======================================
# IMAGE PREPROCESSING
# =======================================
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# =======================================
# PREDICTION FUNCTION
# =======================================
def predict_fundus(image):
if image is None:
return "Please upload a retinal fundus image to begin.", None
try:
img_pil = Image.fromarray(image).convert("RGB")
img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(img_tensor)
probs = torch.softmax(output, dim=1)[0].cpu().numpy()
pred_idx = int(np.argmax(probs))
confidence = float(probs[pred_idx])
label = CLASS_NAMES[pred_idx]
result_text = f"""
### Analysis Result
**Prediction:** {label}
**Confidence:** {confidence:.1%}
**Non-Glaucoma Probability:** {probs[0]:.1%}
**Glaucoma Probability:** {probs[1]:.1%}
---
⚠ This tool is for research and educational purposes only.
It must not be used for clinical diagnosis or medical decision-making.
""".strip()
img_display = np.array(img_pil.resize((400, 400)))
return result_text, img_display
except Exception as e:
return f"Error during analysis: {str(e)}", None
# =======================================
# PROFESSIONAL HIGH-CONTRAST CSS
# =======================================
custom_css = """
body {
font-family: 'Segoe UI', sans-serif;
background: #ffffff;
color: #111827;
}
.gradio-container {
max-width: 1100px !important;
margin: auto;
}
h1 {
color: #1e3a8a !important;
font-weight: 700 !important;
text-align: center;
}
h3 {
color: #1f2937 !important;
font-weight: 600 !important;
}
.markdown {
color: #111827 !important;
}
.upload-zone {
border: 2px dashed #64748b;
border-radius: 12px;
padding: 20px;
background: white;
}
.result-panel {
background: white;
border-radius: 12px;
box-shadow: 0 4px 15px rgba(0,0,0,0.08);
padding: 24px;
min-height: 380px;
}
.note {
font-size: 0.95em;
color: #374151;
margin-top: 16px;
}
"""
# =======================================
# GRADIO INTERFACE
# =======================================
with gr.Blocks(theme=gr.themes.Default(), css=custom_css) as demo:
gr.Markdown("""
# Glaucoma Screening – Fundus Image Analysis
Upload a retinal fundus photograph to receive an AI-based probability assessment.
""")
with gr.Row(equal_height=True):
with gr.Column(scale=5):
gr.Markdown("### Upload Fundus Image")
input_image = gr.Image(
type="numpy",
label="",
elem_classes=["upload-zone"],
height=480,
image_mode="RGB"
)
analyze_btn = gr.Button("Analyze Image", variant="primary")
with gr.Column(scale=5):
gr.Markdown("### Analysis Result")
output_text = gr.Markdown(
value="Upload an image and click Analyze to begin.",
elem_classes=["result-panel"]
)
output_image = gr.Image(
label="Uploaded Image (Resized)",
type="numpy",
height=400,
interactive=False
)
gr.Markdown("""
<div class="note">
<strong>Important:</strong> This is an experimental AI model trained on limited data.
Results should be interpreted cautiously and verified by a qualified ophthalmologist.
</div>
""", elem_classes=["note"])
analyze_btn.click(
fn=predict_fundus,
inputs=input_image,
outputs=[output_text, output_image]
)
# =======================================
# LAUNCH (HF Compatible)
# =======================================
demo.launch()