#!/usr/bin/env python3 """ app.py - Gradio demo for DRIL OCT Classification ------------------------------------------------- Performs 5-fold ensemble inference using RETFound fine-tuned checkpoints. The model probabilities from all 5 fold checkpoints are averaged, matching the evaluation methodology used in training. Models available (weights on Google Drive): - RETFound Conservative (top-4 blocks unfrozen, 5-fold ensemble) - RETFound Moderate (top-8 blocks unfrozen, 5-fold ensemble) - RETFound Baseline (head-only fine-tuning, 5-fold ensemble) To run locally: pip install -r requirements.txt python app.py Environment variables: CHECKPOINT_DIR - path to folder containing .pth files (default: ./checkpoints) RETFOUND_DIR - path to cloned RETFound_MAE repo (default: ./RETFound_MAE) """ import os import sys import glob import numpy as np import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMG_SIZE = 224 CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "./checkpoints") # --------------------------------------------------------------------------- # Model registry # Each entry: display name -> list of per-fold checkpoint filenames (ensembled) # All filenames match exactly what is stored in the Google Drive folder. # --------------------------------------------------------------------------- MODEL_REGISTRY = { "RETFound - Conservative (5-fold ensemble)": { "folds": [ "retfound_v2_conservative_fold1.pth", "retfound_v2_conservative_fold2.pth", "retfound_v2_conservative_fold3.pth", "retfound_v2_conservative_fold4.pth", "retfound_v2_conservative_fold5.pth", ] }, "RETFound - Moderate (5-fold ensemble)": { "folds": [ "retfound_v2_moderate_fold1.pth", "retfound_v2_moderate_fold2.pth", "retfound_v2_moderate_fold3.pth", "retfound_v2_moderate_fold4.pth", "retfound_v2_moderate_fold5.pth", ] }, "RETFound - Baseline CV (5-fold ensemble)": { "folds": [ "retfound_v2_cv_fold1_best.pth", "retfound_v2_cv_fold2_best.pth", "retfound_v2_cv_fold3_best.pth", "retfound_v2_cv_fold4_best.pth", "retfound_v2_cv_fold5_best.pth", ] }, } DEFAULT_MODEL = "RETFound - Conservative (5-fold ensemble)" # --------------------------------------------------------------------------- # Auto-download weights from Drive if missing (gdown required) # --------------------------------------------------------------------------- DRIVE_FILE_IDS = { "retfound_v2_conservative_fold1.pth": "110xefhcDD01YMGFcZ-6Hcv3zLm731vgu", "retfound_v2_conservative_fold2.pth": "1gEzlq-LF7R7pNnd1Ud5sePtnRx0XI-mt", "retfound_v2_conservative_fold3.pth": "1TRR0DuDHj99_qGC8KSbt50KLselMv1ti", "retfound_v2_conservative_fold4.pth": "1huVy9EpLqa88MU3O5kfYrDLwGnm4TWOH", "retfound_v2_conservative_fold5.pth": "1U0MwwuOji3P8psTjzwkNQqL81bXtm50d", "retfound_v2_moderate_fold1.pth": "1y-xO33wRQAlgNrioYoSfOx019bNv30y-", "retfound_v2_moderate_fold2.pth": "1r6f-EmdZnRgdGm4W9RaAO_dSRoqQkj_0", "retfound_v2_moderate_fold3.pth": "1Mak5FuHl2jAZMS2NglR7T0gl09r0bXdR", "retfound_v2_moderate_fold4.pth": "1qFE1CB3x96U1PakP0OAzKTUBUKfrqiZO", "retfound_v2_moderate_fold5.pth": "1afUmpapz1dryl43rqVCE5QSEBLbaDGaw", "retfound_v2_cv_fold1_best.pth": "1qUnrX9LJ6DF2ysG67rjiQzo4XtzJc5WW", "retfound_v2_cv_fold2_best.pth": "1oEvZA2oQSXaxMbi7_-fnF5w9BDI7l_lZ", "retfound_v2_cv_fold3_best.pth": "18w5XtZKA2HmSn0TqtDMa3yLC_rKYtHnb", "retfound_v2_cv_fold4_best.pth": "1Jnoj-W6oQqp2l7GIYuqQ5e_WZRsII4Da", "retfound_v2_cv_fold5_best.pth": "1d9stMnIjfcJoqeraFHFKKj8XquTV28KV", "RETFound_oct.pth": "1v2v5XGMr7ipCyESE1jnASzLdANJDSLXv", } def maybe_download(filename: str) -> str: """ Returns the local path to a checkpoint, downloading from Drive if missing. Requires the 'gdown' package. """ out_path = os.path.join(CHECKPOINT_DIR, filename) if os.path.exists(out_path): return out_path file_id = DRIVE_FILE_IDS.get(filename) if file_id is None: raise FileNotFoundError( f"{filename} not found in {CHECKPOINT_DIR} and has no Drive file ID registered." ) try: import gdown except ImportError: raise ImportError( "gdown is required for auto-download. " "Install it with: pip install gdown" ) os.makedirs(CHECKPOINT_DIR, exist_ok=True) print(f"[auto-download] {filename} ...") gdown.download( f"https://drive.google.com/uc?id={file_id}", out_path, quiet=False, ) return out_path # --------------------------------------------------------------------------- # Transforms # --------------------------------------------------------------------------- def _norm(): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def get_val_transform(): return transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), _norm(), ]) def get_tta_transforms(): n = _norm() base = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), n]) hflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), n]) vflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor(), n]) rot = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomRotation(degrees=(10, 10)), transforms.ToTensor(), n]) return [base, hflip, vflip, rot] # --------------------------------------------------------------------------- # RETFound model builder # --------------------------------------------------------------------------- def _get_retfound_arch(): """ Import RETFound_mae from the vendored models_vit.py that lives in the same directory as app.py. No runtime git-clone needed. """ # Ensure the Space root is on sys.path so models_vit is importable app_dir = os.path.dirname(os.path.abspath(__file__)) if app_dir not in sys.path: sys.path.insert(0, app_dir) try: import models_vit return models_vit.__dict__["RETFound_mae"] except ImportError as e: raise RuntimeError( f"Could not import models_vit: {e}\n" "Ensure models_vit.py is present in the same directory as app.py." ) def build_retfound(ckpt_path: str) -> nn.Module: """Build a RETFound model and load a fine-tuned checkpoint.""" arch = _get_retfound_arch() model = arch(num_classes=1, drop_path_rate=0.2, global_pool=True) state = torch.load(ckpt_path, map_location="cpu") # Fine-tuned checkpoints saved with model.state_dict() directly model.load_state_dict(state, strict=True) model.eval() return model.to(DEVICE) # --------------------------------------------------------------------------- # Ensemble model cache { model_name -> list[nn.Module] } # --------------------------------------------------------------------------- _ensemble_cache: dict[str, list[nn.Module]] = {} def get_ensemble(model_name: str) -> list[nn.Module]: """Load (and cache) all 5 fold models for the selected strategy.""" if model_name in _ensemble_cache: return _ensemble_cache[model_name] fold_files = MODEL_REGISTRY[model_name]["folds"] models = [] for fname in fold_files: ckpt_path = maybe_download(fname) print(f" Loading {fname} ...") models.append(build_retfound(ckpt_path)) _ensemble_cache[model_name] = models return models # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def _infer_single(model: nn.Module, tensor: torch.Tensor) -> float: """Run a single forward pass; return sigmoid probability as float.""" out = model(tensor) out = out.reshape(tensor.size(0), 1) return torch.sigmoid(out).item() def predict(image: Image.Image, model_name: str, use_tta: bool): """ Run 5-fold ensemble inference on a PIL image. Returns: label (str) - "DRIL" or "No-DRIL" dril_prob (float) - raw DRIL probability (0-1) fold_probs (list[float])- per-fold DRIL probabilities img_np (np.ndarray) - RGB preview of the input """ image = image.convert("RGB") ensemble = get_ensemble(model_name) tfms = get_tta_transforms() if use_tta else [get_val_transform()] fold_probs = [] with torch.no_grad(): for model in ensemble: tta_probs = [] for t in tfms: tensor = t(image).unsqueeze(0).to(DEVICE) tta_probs.append(_infer_single(model, tensor)) fold_probs.append(float(np.mean(tta_probs))) dril_prob = float(np.mean(fold_probs)) label = "DRIL" if dril_prob >= 0.5 else "No-DRIL" img_np = np.array(image.resize((IMG_SIZE, IMG_SIZE))) return label, dril_prob, fold_probs, img_np # --------------------------------------------------------------------------- # Gradio interface # --------------------------------------------------------------------------- def run_inference(pil_image, model_name: str, use_tta: bool): if pil_image is None: return "No image provided.", None try: label, dril_prob, fold_probs, img_np = predict(pil_image, model_name, use_tta) except Exception as e: return f"Error during inference:\n{e}", None nodril_prob = 1.0 - dril_prob confidence = dril_prob if label == "DRIL" else nodril_prob fold_lines = "\n".join( f" Fold {i+1}: {p*100:.1f}% DRIL" for i, p in enumerate(fold_probs) ) result_text = ( f"Prediction : {label}\n" f"Confidence : {confidence*100:.1f}%\n" f"DRIL prob : {dril_prob:.4f} | No-DRIL prob: {nodril_prob:.4f}\n" f"\nPer-fold probabilities (DRIL):\n{fold_lines}" ) return result_text, img_np with gr.Blocks(title="DRIL OCT Classification") as demo: gr.Markdown( """ # DRIL OCT Classification Classify macular OCT B-scan images as **DRIL** (Disruption of Retinal Inner Layers) or **No-DRIL** using a 5-fold ensemble of RETFound foundation models fine-tuned on a private OCT dataset (429 DRIL / 394 No-DRIL cases). """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload OCT Image") model_selector = gr.Dropdown( choices=list(MODEL_REGISTRY.keys()), value=DEFAULT_MODEL, label="Model (fine-tuning strategy)" ) tta_checkbox = gr.Checkbox( value=True, label="Test-Time Augmentation (4-view TTA)" ) classify_btn = gr.Button("Classify", variant="primary") with gr.Column(scale=1): result_text = gr.Textbox(label="Result", lines=10) output_image = gr.Image(label="Input Preview", type="numpy") classify_btn.click( fn=run_inference, inputs=[input_image, model_selector, tta_checkbox], outputs=[result_text, output_image] ) gr.Markdown( """ **Notes** - The first inference call will load all 5 fold checkpoints into memory (~350 MB per strategy). Subsequent calls on the same strategy are fast. - Conservative strategy: top-4 ViT blocks unfrozen. Moderate: top-8 blocks. Baseline: head only. - Optimal threshold for binary decision was determined by Youden-J on the validation set; the demo uses the fixed 0.5 threshold for simplicity. **Disclaimer:** This tool is for research purposes only and must not be used for clinical decisions. """ ) if __name__ == "__main__": demo.launch()