OcuScan / app.py
rocky250's picture
Update app.py
a360a8d verified
import os
import gc
import hashlib
import numpy as np
import gradio as gr
from PIL import Image, ImageFilter, ImageDraw
DISEASE_LABELS = [
"Normal", "Diabetes", "Glaucoma", "Cataract",
"Age-related Macular Degeneration", "Hypertension",
"Pathological Myopia", "Other",
]
DISEASE_CODES = ["N", "D", "G", "C", "A", "H", "M", "O"]
DISEASE_COLORS = [
"#22c55e", "#f59e0b", "#3b82f6", "#8b5cf6",
"#ec4899", "#ef4444", "#06b6d4", "#6b7280",
]
DISEASE_DESC = {
"Normal": "No pathological findings detected. Retinal structures appear within normal limits.",
"Diabetes": "Diabetic retinopathy detected. Microaneurysms or retinal hemorrhages may be present.",
"Glaucoma": "Glaucomatous optic neuropathy suspected. Optic disc cupping may be elevated.",
"Cataract": "Lens opacity detected. Light scattering may reduce visual acuity.",
"Age-related Macular Degeneration": "AMD features detected. Drusen or macular changes may be present.",
"Hypertension": "Hypertensive retinopathy detected. AV nicking or flame hemorrhages may be visible.",
"Pathological Myopia": "Pathological myopia detected. Posterior staphyloma or myopic degeneration suspected.",
"Other": "Other retinal abnormality detected. Further specialist evaluation recommended.",
}
_MODEL_CANDIDATES = [
"pytorch_model.bin", "best_model.pth",
os.path.join(os.path.dirname(__file__), "pytorch_model.bin"),
os.path.join(os.path.dirname(__file__), "best_model.pth"),
"/home/user/app/pytorch_model.bin", "/home/user/app/best_model.pth",
]
_CACHED_MODEL = None
_MODEL_LOADED = False
_LOAD_ERROR = None
def _pil_resize(img_array, size):
pil = Image.fromarray(img_array.astype(np.uint8))
pil = pil.resize((size, size), Image.BILINEAR)
return np.array(pil)
def _gaussian_blur_np(arr, radius=15):
pil = Image.fromarray((arr * 255).astype(np.uint8))
pil = pil.filter(ImageFilter.GaussianBlur(radius=radius))
return np.array(pil).astype(np.float32) / 255.0
def _colormap_jet(gray):
gray = np.clip(gray, 0.0, 1.0)
r = np.clip(1.5 - np.abs(gray * 4 - 3), 0, 1)
g = np.clip(1.5 - np.abs(gray * 4 - 2), 0, 1)
b = np.clip(1.5 - np.abs(gray * 4 - 1), 0, 1)
return np.stack([r, g, b], axis=-1).astype(np.float32)
def _morphology_close(binary, kernel_size=9):
pil = Image.fromarray(binary)
pil = pil.filter(ImageFilter.MaxFilter(kernel_size))
pil = pil.filter(ImageFilter.MinFilter(kernel_size))
return np.array(pil)
def _find_bboxes(binary):
labeled = []
visited = np.zeros_like(binary, dtype=bool)
rows, cols = np.where(binary > 0)
if len(rows) == 0:
return labeled
from collections import deque
for start_r, start_c in zip(rows, cols):
if visited[start_r, start_c]:
continue
queue = deque()
queue.append((start_r, start_c))
visited[start_r, start_c] = True
min_r, max_r, min_c, max_c = start_r, start_r, start_c, start_c
area = 0
while queue:
r, c = queue.popleft()
area += 1
min_r = min(min_r, r); max_r = max(max_r, r)
min_c = min(min_c, c); max_c = max(max_c, c)
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r+dr, c+dc
if 0 <= nr < binary.shape[0] and 0 <= nc < binary.shape[1]:
if not visited[nr, nc] and binary[nr, nc] > 0:
visited[nr, nc] = True
queue.append((nr, nc))
if area > 180:
labeled.append((min_c, min_r, max_c - min_c, max_r - min_r))
return labeled
def _build_swin_model():
import torch
import torch.nn as nn
try:
import timm
except ImportError:
return None
class SwinOcular(nn.Module):
def __init__(self, num_classes=8):
super().__init__()
self.swin = timm.create_model("swin_base_patch4_window7_224", pretrained=False, num_classes=0)
swin_dim = self.swin.num_features
self.tab_proj = nn.Sequential(nn.Linear(3, 64), nn.LayerNorm(64), nn.GELU(), nn.Linear(64, 128))
self.text_proj = nn.Sequential(nn.Linear(768, 256), nn.LayerNorm(256), nn.GELU())
self.cross_attn = nn.MultiheadAttention(embed_dim=swin_dim, num_heads=8, batch_first=True)
self.graph_node_emb = nn.Embedding(num_classes, 128)
self.gat_w = nn.Linear(128, 128)
fusion_in = swin_dim + 128 + 256
self.fusion = nn.Sequential(
nn.LayerNorm(fusion_in), nn.Linear(fusion_in, 512),
nn.GELU(), nn.Dropout(0.3), nn.Linear(512, num_classes),
)
def forward(self, img, meta=None, text=None):
import torch
img_feat = self.swin(img)
img_seq = img_feat.unsqueeze(1) if img_feat.dim() == 2 else img_feat
tab_feat = self.tab_proj(meta) if meta is not None else torch.zeros(img.shape[0], 128, device=img.device)
text_feat = self.text_proj(text) if text is not None else torch.zeros(img.shape[0], 256, device=img.device)
attn_out, _ = self.cross_attn(img_seq, img_seq, img_seq)
img_fused = attn_out.squeeze(1) if attn_out.dim() == 3 else attn_out
graph_feat = torch.tanh(self.gat_w(self.graph_node_emb.weight)).mean(0, keepdim=True).expand(img.shape[0], -1)
return self.fusion(torch.cat([img_fused, graph_feat, text_feat], dim=1))
return SwinOcular(num_classes=8)
def load_model():
global _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR
if _CACHED_MODEL is not None:
return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR
model_path = next((p for p in _MODEL_CANDIDATES if os.path.isfile(p)), None)
if model_path is None:
_CACHED_MODEL = "SIMULATION"
_MODEL_LOADED = False
_LOAD_ERROR = None
return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR
try:
import torch
torch.set_num_threads(2)
model = _build_swin_model()
if model is None:
raise ImportError("timm not available")
try:
state = torch.load(model_path, map_location="cpu", weights_only=False, mmap=True)
except TypeError:
state = torch.load(model_path, map_location="cpu", weights_only=False)
if isinstance(state, dict):
for key in ("model_state_dict", "state_dict", "model"):
if key in state:
state = state[key]
break
own = model.state_dict()
filtered = {k: v for k, v in state.items() if k in own and v.shape == own[k].shape}
own.update(filtered)
model.load_state_dict(own)
model.eval()
del state, own, filtered
gc.collect()
_CACHED_MODEL = model
_MODEL_LOADED = True
_LOAD_ERROR = None
except Exception as exc:
_CACHED_MODEL = "SIMULATION"
_MODEL_LOADED = False
_LOAD_ERROR = str(exc)
return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR
def _simulate_probs(img_array, age_norm, sex_male):
digest = hashlib.md5(img_array.tobytes()).digest()
seed = int.from_bytes(digest[:4], "little") ^ int(age_norm * 1e6) ^ (sex_male * 999)
rng = np.random.default_rng(seed)
raw = rng.random(8).astype(np.float32)
raw[0] = max(0.0, 0.85 - age_norm * 0.6)
raw[4] += age_norm * 0.4
raw[2] += age_norm * 0.25
raw[5] += age_norm * 0.20
raw[1] += age_norm * 0.15
brightness = img_array.mean() / 255.0
raw[1] += (1.0 - brightness) * 0.3
raw[3] += (1.0 - brightness) * 0.2
raw = np.clip(raw, 0.0, 1.0)
raw[0] = np.clip(raw[0], 0.05, 0.95)
for i in range(1, 8):
raw[i] = np.clip(raw[i] * 0.7, 0.02, 0.88)
return raw
def run_inference(model, img_input, age_norm, sex_male):
if isinstance(img_input, np.ndarray):
pil_img = Image.fromarray(img_input.astype(np.uint8))
else:
pil_img = img_input
img = np.array(pil_img.convert("RGB"))
img_resized = _pil_resize(img, 224)
img_float = img_resized.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img_norm = (img_float - mean) / std
if model == "SIMULATION":
probs = _simulate_probs(img_resized, age_norm, sex_male)
saliency = None
else:
import torch
tensor = torch.from_numpy(img_norm.transpose(2, 0, 1)).unsqueeze(0)
meta = torch.tensor([[age_norm, float(sex_male), float(1 - sex_male)]])
text = torch.zeros(1, 768)
with torch.inference_mode():
logits = model(tensor, meta, text)
probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
saliency = None
try:
inp = tensor.clone().requires_grad_(True)
out = model(inp, meta, text)
model.zero_grad()
out[0].max().backward()
sal = inp.grad.data.abs()[0].mean(0).numpy()
sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-8)
saliency = _pil_resize((sal * 255).astype(np.uint8)[..., np.newaxis].repeat(3, axis=-1), 224)[:, :, 0].astype(np.float32) / 255.0
except Exception:
pass
if saliency is None:
cx, cy = 112, 112
yy, xx = np.mgrid[:224, :224]
radial = np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * 55 ** 2))
digest = hashlib.md5(img_resized.tobytes()).digest()
rng = np.random.default_rng(int.from_bytes(digest[:4], "little"))
noise = rng.random((224, 224)).astype(np.float32)
noise = _gaussian_blur_np(noise, radius=15)
saliency = np.clip(radial * 0.6 + noise * 0.4, 0, 1).astype(np.float32)
saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
heatmap = _colormap_jet(saliency)
cam_overlay = np.clip((0.55 * img_float + 0.45 * heatmap) * 255, 0, 255).astype(np.uint8)
bbox_img = (img_float * 255).astype(np.uint8).copy()
thresh_val = np.percentile(saliency, 78)
binary = ((saliency > thresh_val) * 255).astype(np.uint8)
binary = _morphology_close(binary, kernel_size=9)
contours = _find_bboxes(binary)
pil_bbox = Image.fromarray(bbox_img)
draw = ImageDraw.Draw(pil_bbox)
for (x, y, w, h) in contours:
draw.rectangle([x, y, x + w, y + h], outline=(220, 38, 38), width=2)
bbox_img = np.array(pil_bbox)
return probs, cam_overlay, bbox_img
def build_results_html(results_dict, threshold, is_simulation=False):
banner = ""
if is_simulation:
banner = """
<div style="background:#fffbeb;border:1px solid #fde68a;border-radius:10px;
padding:10px 16px;font-size:12px;color:#78350f;margin-bottom:16px;
display:flex;align-items:center;gap:8px;">
<span style="font-size:16px;"></span>
<span><strong>Demo Mode:</strong> No trained model found β€” results are
<em>simulated for UI demonstration only</em> and have no clinical meaning.</span>
</div>
"""
html = f'<div style="font-family:Segoe UI,Arial,sans-serif;padding:4px;">{banner}'
for eye_name, probs in results_dict.items():
is_combined = "Combined" in eye_name
pill_bg = "rgba(5,150,105,0.12)" if is_combined else "rgba(139,92,246,0.12)"
pill_border = "rgba(5,150,105,0.3)" if is_combined else "rgba(139,92,246,0.3)"
pill_color = "#065f46" if is_combined else "#6d28d9"
html += f"""
<div style="margin-bottom:28px;">
<div style="display:inline-block;background:{pill_bg};border:1px solid {pill_border};
border-radius:100px;padding:4px 16px;font-size:11px;font-weight:700;
color:{pill_color};text-transform:uppercase;letter-spacing:0.08em;margin-bottom:14px;">
{eye_name}
</div>
"""
detected_any = False
for i, (label, code) in enumerate(zip(DISEASE_LABELS, DISEASE_CODES)):
conf = float(probs[i])
flagged = conf >= threshold
color = DISEASE_COLORS[i]
desc = DISEASE_DESC[label]
bar_pct = int(conf * 100)
if flagged and label != "Normal":
detected_any = True
card_bg = "#fff8f8"
card_border = "#fecaca"
left_color = "#ef4444"
pct_color = "#dc2626"
flag_badge = ('<span style="background:#ef4444;color:#fff;font-size:10px;'
'font-weight:700;padding:2px 8px;border-radius:4px;margin-left:8px;">'
"FLAGGED</span>")
elif label == "Normal" and conf >= threshold:
card_bg = "#f0fdf4"
card_border = "#bbf7d0"
left_color = "#22c55e"
pct_color = "#16a34a"
flag_badge = ""
else:
card_bg = "#f8fafc"
card_border = "#e2e8f0"
left_color = color
pct_color = "#334155"
flag_badge = ""
html += f"""
<div style="background:{card_bg};border:1px solid {card_border};
border-left:4px solid {left_color};border-radius:10px;
padding:11px 14px;margin-bottom:8px;">
<div style="font-size:14px;font-weight:700;color:#1e293b;margin-bottom:3px;">
{label}
<span style="font-size:11px;color:#94a3b8;font-weight:400;">({code})</span>
{flag_badge}
</div>
<div style="font-size:11px;color:#475569;margin-bottom:8px;line-height:1.55;">{desc}</div>
<div style="background:#e2e8f0;border-radius:100px;height:8px;overflow:hidden;margin-bottom:5px;">
<div style="width:{bar_pct}%;height:100%;background:{color};border-radius:100px;transition:width 0.4s;"></div>
</div>
<div style="font-size:14px;font-weight:700;color:{pct_color};">{bar_pct}%</div>
</div>
"""
if not detected_any:
html += """
<div style="background:#f0fdf4;border:1px solid #bbf7d0;border-radius:10px;
padding:14px;text-align:center;margin-bottom:8px;">
<div style="font-size:14px;font-weight:700;color:#166534;">
No disease detected above threshold
</div>
<div style="font-size:12px;color:#4ade80;margin-top:4px;">
All scores are below the detection threshold
</div>
</div>
"""
html += "</div>"
html += "</div>"
return html
PLACEHOLDER = np.full((224, 224, 3), 40, dtype=np.uint8)
def analyze(left_img, right_img, age, sex, threshold):
model, real_model, load_err = load_model()
is_sim = not real_model
if left_img is None and right_img is None:
warn_html = """
<div style="font-family:sans-serif;padding:20px;background:#fffbeb;
border:1px solid #fde68a;border-radius:12px;text-align:center;margin:8px 0;">
<div style="font-size:15px;font-weight:700;color:#92400e;">
Please upload at least one fundus image to begin analysis.
</div>
</div>
"""
return PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, warn_html
age_norm = float(age) / 100.0
sex_male = 1 if sex == "Male" else 0
results_probs = {}
viz = {}
for name, img in [("Left Eye", left_img), ("Right Eye", right_img)]:
if img is not None:
try:
probs, cam, bbox = run_inference(model, img, age_norm, sex_male)
results_probs[name] = probs
viz[name] = (cam, bbox)
except Exception as exc:
err_html = f"""
<div style="font-family:sans-serif;padding:20px;background:#fef2f2;
border:1px solid #fecaca;border-radius:12px;margin:8px 0;">
<div style="font-size:15px;font-weight:700;color:#dc2626;margin-bottom:6px;">
Inference Error β€” {name}
</div>
<div style="font-size:13px;color:#7f1d1d;">{exc}</div>
</div>
"""
return PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER, err_html
if len(results_probs) == 2:
results_probs["Combined (Worst-Case)"] = np.maximum(
results_probs["Left Eye"], results_probs["Right Eye"]
)
results_html = build_results_html(results_probs, threshold, is_simulation=is_sim)
left_cam = viz.get("Left Eye", (PLACEHOLDER, PLACEHOLDER))[0]
left_bbox = viz.get("Left Eye", (PLACEHOLDER, PLACEHOLDER))[1]
right_cam = viz.get("Right Eye", (PLACEHOLDER, PLACEHOLDER))[0]
right_bbox = viz.get("Right Eye", (PLACEHOLDER, PLACEHOLDER))[1]
return left_cam, left_bbox, right_cam, right_bbox, results_html
CSS = """
.gradio-container {
background: linear-gradient(135deg, #f0f4ff 0%, #faf5ff 50%, #f0fff4 100%) !important;
font-family: 'Segoe UI', Arial, sans-serif !important;
}
.gr-button { border-radius: 12px !important; }
footer { display: none !important; }
"""
with gr.Blocks(title="OcuScan", css=CSS) as demo:
gr.HTML("""
<div style="text-align:center;padding:32px 20px 20px;background:rgba(255,255,255,0.7);
border-radius:20px;margin-bottom:20px;border:1px solid rgba(139,92,246,0.15);
box-shadow:0 4px 24px rgba(99,102,241,0.07);">
<div style="display:inline-block;background:rgba(139,92,246,0.1);
border:1px solid rgba(139,92,246,0.25);border-radius:100px;
padding:5px 18px;font-size:11px;font-weight:700;color:#7c3aed;
letter-spacing:0.1em;text-transform:uppercase;margin-bottom:14px;">
Retinal Analysis
</div>
<h1 style="font-size:3rem;font-weight:800;margin:0 0 10px;
background:linear-gradient(135deg,#1e1b4b 0%,#7c3aed 40%,#ec4899 70%,#3b82f6 100%);
-webkit-background-clip:text;-webkit-text-fill-color:transparent;
background-clip:text;line-height:1.1;">
Ocular Disease Scan
</h1>
<p style="color:#475569;font-size:1rem;margin:0 auto;max-width:560px;line-height:1.7;">
Advanced multimodal deep learning for ocular disease recognition from fundus photographs.
</p>
<div style="display:flex;gap:10px;justify-content:center;flex-wrap:wrap;margin-top:16px;">
<span style="background:rgba(139,92,246,0.1);border:1px solid rgba(139,92,246,0.25);
border-radius:100px;padding:4px 14px;font-size:12px;font-weight:600;color:#6d28d9;">
8 Disease Classes</span>
<span style="background:rgba(59,130,246,0.1);border:1px solid rgba(59,130,246,0.25);
border-radius:100px;padding:4px 14px;font-size:12px;font-weight:600;color:#1d4ed8;">
Swin Transformer</span>
<span style="background:rgba(16,185,129,0.1);border:1px solid rgba(16,185,129,0.25);
border-radius:100px;padding:4px 14px;font-size:12px;font-weight:600;color:#065f46;">
Grad-CAM Saliency</span>
</div>
</div>
""")
with gr.Row(equal_height=False):
with gr.Column(scale=4, min_width=280):
gr.HTML('<div style="font-size:15px;font-weight:700;color:#1e1b4b;margin-bottom:10px;'
'padding-bottom:6px;border-bottom:2px solid rgba(139,92,246,0.2);">Patient Information</div>')
age = gr.Slider(minimum=1, maximum=100, value=45, step=1, label="Patient Age")
sex = gr.Radio(choices=["Male", "Female"], value="Male", label="Patient Sex")
threshold = gr.Slider(minimum=0.1, maximum=0.9, value=0.5, step=0.05, label="Detection Threshold")
gr.HTML('<div style="font-size:15px;font-weight:700;color:#1e1b4b;margin:16px 0 10px;'
'padding-bottom:6px;border-bottom:2px solid rgba(139,92,246,0.2);">Fundus Images</div>')
left_img = gr.Image(label="Left Eye Fundus", type="numpy", height=200)
right_img = gr.Image(label="Right Eye Fundus", type="numpy", height=200)
run_btn = gr.Button("Run Ocular Disease Analysis", variant="primary", size="lg")
gr.HTML("""
<div style="background:#fffbeb;border:1px solid #fde68a;border-radius:10px;
padding:10px 14px;font-size:11px;color:#78350f;line-height:1.6;margin-top:10px;">
For research and educational purposes only.
Not a substitute for professional ophthalmic diagnosis.
</div>
""")
with gr.Column(scale=7, min_width=400):
gr.HTML('<div style="font-size:15px;font-weight:700;color:#1e1b4b;margin-bottom:10px;'
'padding-bottom:6px;border-bottom:2px solid rgba(139,92,246,0.2);">Saliency Visualisation</div>')
with gr.Row():
out_cam1 = gr.Image(label="Left Eye β€” Saliency Heatmap", type="numpy", height=200)
out_bbox1 = gr.Image(label="Left Eye β€” Disease Localisation", type="numpy", height=200)
with gr.Row():
out_cam2 = gr.Image(label="Right Eye β€” Saliency Heatmap", type="numpy", height=200)
out_bbox2 = gr.Image(label="Right Eye β€” Disease Localisation", type="numpy", height=200)
gr.HTML('<div style="font-size:15px;font-weight:700;color:#1e1b4b;margin:16px 0 10px;'
'padding-bottom:6px;border-bottom:2px solid rgba(139,92,246,0.2);">Disease Confidence Scores</div>')
out_html = gr.HTML(
value=('<div style="font-family:sans-serif;padding:30px;text-align:center;'
'color:#94a3b8;background:#f8fafc;border-radius:12px;">'
"Upload fundus images and click <strong>Run Analysis</strong> to see results.</div>")
)
run_btn.click(
fn=analyze,
inputs=[left_img, right_img, age, sex, threshold],
outputs=[out_cam1, out_bbox1, out_cam2, out_bbox2, out_html],
)
gr.HTML("""
<div style="text-align:center;padding:20px;color:#94a3b8;font-size:12px;
border-top:1px solid rgba(139,92,246,0.1);margin-top:24px;">
OcuScan β€” For research and educational purposes only
</div>
""")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)