import os import gc import hashlib import numpy as np import gradio as gr from PIL import Image, ImageFilter, ImageDraw DISEASE_LABELS = [ "Normal", "Diabetes", "Glaucoma", "Cataract", "Age-related Macular Degeneration", "Hypertension", "Pathological Myopia", "Other", ] DISEASE_CODES = ["N", "D", "G", "C", "A", "H", "M", "O"] DISEASE_COLORS = [ "#22c55e", "#f59e0b", "#3b82f6", "#8b5cf6", "#ec4899", "#ef4444", "#06b6d4", "#6b7280", ] DISEASE_DESC = { "Normal": "No pathological findings detected. Retinal structures appear within normal limits.", "Diabetes": "Diabetic retinopathy detected. Microaneurysms or retinal hemorrhages may be present.", "Glaucoma": "Glaucomatous optic neuropathy suspected. Optic disc cupping may be elevated.", "Cataract": "Lens opacity detected. Light scattering may reduce visual acuity.", "Age-related Macular Degeneration": "AMD features detected. Drusen or macular changes may be present.", "Hypertension": "Hypertensive retinopathy detected. AV nicking or flame hemorrhages may be visible.", "Pathological Myopia": "Pathological myopia detected. Posterior staphyloma or myopic degeneration suspected.", "Other": "Other retinal abnormality detected. Further specialist evaluation recommended.", } _MODEL_CANDIDATES = [ "pytorch_model.bin", "best_model.pth", os.path.join(os.path.dirname(__file__), "pytorch_model.bin"), os.path.join(os.path.dirname(__file__), "best_model.pth"), "/home/user/app/pytorch_model.bin", "/home/user/app/best_model.pth", ] _CACHED_MODEL = None _MODEL_LOADED = False _LOAD_ERROR = None def _pil_resize(img_array, size): pil = Image.fromarray(img_array.astype(np.uint8)) pil = pil.resize((size, size), Image.BILINEAR) return np.array(pil) def _gaussian_blur_np(arr, radius=15): pil = Image.fromarray((arr * 255).astype(np.uint8)) pil = pil.filter(ImageFilter.GaussianBlur(radius=radius)) return np.array(pil).astype(np.float32) / 255.0 def _colormap_jet(gray): gray = np.clip(gray, 0.0, 1.0) r = np.clip(1.5 - np.abs(gray * 4 - 3), 0, 1) g = np.clip(1.5 - np.abs(gray * 4 - 2), 0, 1) b = np.clip(1.5 - np.abs(gray * 4 - 1), 0, 1) return np.stack([r, g, b], axis=-1).astype(np.float32) def _morphology_close(binary, kernel_size=9): pil = Image.fromarray(binary) pil = pil.filter(ImageFilter.MaxFilter(kernel_size)) pil = pil.filter(ImageFilter.MinFilter(kernel_size)) return np.array(pil) def _find_bboxes(binary): labeled = [] visited = np.zeros_like(binary, dtype=bool) rows, cols = np.where(binary > 0) if len(rows) == 0: return labeled from collections import deque for start_r, start_c in zip(rows, cols): if visited[start_r, start_c]: continue queue = deque() queue.append((start_r, start_c)) visited[start_r, start_c] = True min_r, max_r, min_c, max_c = start_r, start_r, start_c, start_c area = 0 while queue: r, c = queue.popleft() area += 1 min_r = min(min_r, r); max_r = max(max_r, r) min_c = min(min_c, c); max_c = max(max_c, c) for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]: nr, nc = r+dr, c+dc if 0 <= nr < binary.shape[0] and 0 <= nc < binary.shape[1]: if not visited[nr, nc] and binary[nr, nc] > 0: visited[nr, nc] = True queue.append((nr, nc)) if area > 180: labeled.append((min_c, min_r, max_c - min_c, max_r - min_r)) return labeled def _build_swin_model(): import torch import torch.nn as nn try: import timm except ImportError: return None class SwinOcular(nn.Module): def __init__(self, num_classes=8): super().__init__() self.swin = timm.create_model("swin_base_patch4_window7_224", pretrained=False, num_classes=0) swin_dim = self.swin.num_features self.tab_proj = nn.Sequential(nn.Linear(3, 64), nn.LayerNorm(64), nn.GELU(), nn.Linear(64, 128)) self.text_proj = nn.Sequential(nn.Linear(768, 256), nn.LayerNorm(256), nn.GELU()) self.cross_attn = nn.MultiheadAttention(embed_dim=swin_dim, num_heads=8, batch_first=True) self.graph_node_emb = nn.Embedding(num_classes, 128) self.gat_w = nn.Linear(128, 128) fusion_in = swin_dim + 128 + 256 self.fusion = nn.Sequential( nn.LayerNorm(fusion_in), nn.Linear(fusion_in, 512), nn.GELU(), nn.Dropout(0.3), nn.Linear(512, num_classes), ) def forward(self, img, meta=None, text=None): import torch img_feat = self.swin(img) img_seq = img_feat.unsqueeze(1) if img_feat.dim() == 2 else img_feat tab_feat = self.tab_proj(meta) if meta is not None else torch.zeros(img.shape[0], 128, device=img.device) text_feat = self.text_proj(text) if text is not None else torch.zeros(img.shape[0], 256, device=img.device) attn_out, _ = self.cross_attn(img_seq, img_seq, img_seq) img_fused = attn_out.squeeze(1) if attn_out.dim() == 3 else attn_out graph_feat = torch.tanh(self.gat_w(self.graph_node_emb.weight)).mean(0, keepdim=True).expand(img.shape[0], -1) return self.fusion(torch.cat([img_fused, graph_feat, text_feat], dim=1)) return SwinOcular(num_classes=8) def load_model(): global _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR if _CACHED_MODEL is not None: return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR model_path = next((p for p in _MODEL_CANDIDATES if os.path.isfile(p)), None) if model_path is None: _CACHED_MODEL = "SIMULATION" _MODEL_LOADED = False _LOAD_ERROR = None return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR try: import torch torch.set_num_threads(2) model = _build_swin_model() if model is None: raise ImportError("timm not available") try: state = torch.load(model_path, map_location="cpu", weights_only=False, mmap=True) except TypeError: state = torch.load(model_path, map_location="cpu", weights_only=False) if isinstance(state, dict): for key in ("model_state_dict", "state_dict", "model"): if key in state: state = state[key] break own = model.state_dict() filtered = {k: v for k, v in state.items() if k in own and v.shape == own[k].shape} own.update(filtered) model.load_state_dict(own) model.eval() del state, own, filtered gc.collect() _CACHED_MODEL = model _MODEL_LOADED = True _LOAD_ERROR = None except Exception as exc: _CACHED_MODEL = "SIMULATION" _MODEL_LOADED = False _LOAD_ERROR = str(exc) return _CACHED_MODEL, _MODEL_LOADED, _LOAD_ERROR def _simulate_probs(img_array, age_norm, sex_male): digest = hashlib.md5(img_array.tobytes()).digest() seed = int.from_bytes(digest[:4], "little") ^ int(age_norm * 1e6) ^ (sex_male * 999) rng = np.random.default_rng(seed) raw = rng.random(8).astype(np.float32) raw[0] = max(0.0, 0.85 - age_norm * 0.6) raw[4] += age_norm * 0.4 raw[2] += age_norm * 0.25 raw[5] += age_norm * 0.20 raw[1] += age_norm * 0.15 brightness = img_array.mean() / 255.0 raw[1] += (1.0 - brightness) * 0.3 raw[3] += (1.0 - brightness) * 0.2 raw = np.clip(raw, 0.0, 1.0) raw[0] = np.clip(raw[0], 0.05, 0.95) for i in range(1, 8): raw[i] = np.clip(raw[i] * 0.7, 0.02, 0.88) return raw def run_inference(model, img_input, age_norm, sex_male): if isinstance(img_input, np.ndarray): pil_img = Image.fromarray(img_input.astype(np.uint8)) else: pil_img = img_input img = np.array(pil_img.convert("RGB")) img_resized = _pil_resize(img, 224) img_float = img_resized.astype(np.float32) / 255.0 mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) img_norm = (img_float - mean) / std if model == "SIMULATION": probs = _simulate_probs(img_resized, age_norm, sex_male) saliency = None else: import torch tensor = torch.from_numpy(img_norm.transpose(2, 0, 1)).unsqueeze(0) meta = torch.tensor([[age_norm, float(sex_male), float(1 - sex_male)]]) text = torch.zeros(1, 768) with torch.inference_mode(): logits = model(tensor, meta, text) probs = torch.sigmoid(logits).squeeze(0).cpu().numpy() saliency = None try: inp = tensor.clone().requires_grad_(True) out = model(inp, meta, text) model.zero_grad() out[0].max().backward() sal = inp.grad.data.abs()[0].mean(0).numpy() sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-8) saliency = _pil_resize((sal * 255).astype(np.uint8)[..., np.newaxis].repeat(3, axis=-1), 224)[:, :, 0].astype(np.float32) / 255.0 except Exception: pass if saliency is None: cx, cy = 112, 112 yy, xx = np.mgrid[:224, :224] radial = np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * 55 ** 2)) digest = hashlib.md5(img_resized.tobytes()).digest() rng = np.random.default_rng(int.from_bytes(digest[:4], "little")) noise = rng.random((224, 224)).astype(np.float32) noise = _gaussian_blur_np(noise, radius=15) saliency = np.clip(radial * 0.6 + noise * 0.4, 0, 1).astype(np.float32) saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8) heatmap = _colormap_jet(saliency) cam_overlay = np.clip((0.55 * img_float + 0.45 * heatmap) * 255, 0, 255).astype(np.uint8) bbox_img = (img_float * 255).astype(np.uint8).copy() thresh_val = np.percentile(saliency, 78) binary = ((saliency > thresh_val) * 255).astype(np.uint8) binary = _morphology_close(binary, kernel_size=9) contours = _find_bboxes(binary) pil_bbox = Image.fromarray(bbox_img) draw = ImageDraw.Draw(pil_bbox) for (x, y, w, h) in contours: draw.rectangle([x, y, x + w, y + h], outline=(220, 38, 38), width=2) bbox_img = np.array(pil_bbox) return probs, cam_overlay, bbox_img def build_results_html(results_dict, threshold, is_simulation=False): banner = "" if is_simulation: banner = """
Advanced multimodal deep learning for ocular disease recognition from fundus photographs.