| |
| """ |
| app.py - Gradio demo for DRIL OCT Classification |
| ------------------------------------------------- |
| Performs 5-fold ensemble inference using RETFound fine-tuned checkpoints. |
| The model probabilities from all 5 fold checkpoints are averaged, matching |
| the evaluation methodology used in training. |
| |
| Models available (weights on Google Drive): |
| - RETFound Conservative (top-4 blocks unfrozen, 5-fold ensemble) |
| - RETFound Moderate (top-8 blocks unfrozen, 5-fold ensemble) |
| - RETFound Baseline (head-only fine-tuning, 5-fold ensemble) |
| |
| To run locally: |
| pip install -r requirements.txt |
| python app.py |
| |
| Environment variables: |
| CHECKPOINT_DIR - path to folder containing .pth files (default: ./checkpoints) |
| RETFOUND_DIR - path to cloned RETFound_MAE repo (default: ./RETFound_MAE) |
| """ |
|
|
| import os |
| import sys |
| import glob |
| import numpy as np |
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from PIL import Image |
|
|
| |
| |
| |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| IMG_SIZE = 224 |
| CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "./checkpoints") |
|
|
| |
| |
| |
| |
| |
|
|
| MODEL_REGISTRY = { |
| "RETFound - Conservative (5-fold ensemble)": { |
| "folds": [ |
| "retfound_v2_conservative_fold1.pth", |
| "retfound_v2_conservative_fold2.pth", |
| "retfound_v2_conservative_fold3.pth", |
| "retfound_v2_conservative_fold4.pth", |
| "retfound_v2_conservative_fold5.pth", |
| ] |
| }, |
| "RETFound - Moderate (5-fold ensemble)": { |
| "folds": [ |
| "retfound_v2_moderate_fold1.pth", |
| "retfound_v2_moderate_fold2.pth", |
| "retfound_v2_moderate_fold3.pth", |
| "retfound_v2_moderate_fold4.pth", |
| "retfound_v2_moderate_fold5.pth", |
| ] |
| }, |
| "RETFound - Baseline CV (5-fold ensemble)": { |
| "folds": [ |
| "retfound_v2_cv_fold1_best.pth", |
| "retfound_v2_cv_fold2_best.pth", |
| "retfound_v2_cv_fold3_best.pth", |
| "retfound_v2_cv_fold4_best.pth", |
| "retfound_v2_cv_fold5_best.pth", |
| ] |
| }, |
| } |
|
|
| DEFAULT_MODEL = "RETFound - Conservative (5-fold ensemble)" |
|
|
| |
| |
| |
|
|
| DRIVE_FILE_IDS = { |
| "retfound_v2_conservative_fold1.pth": "110xefhcDD01YMGFcZ-6Hcv3zLm731vgu", |
| "retfound_v2_conservative_fold2.pth": "1gEzlq-LF7R7pNnd1Ud5sePtnRx0XI-mt", |
| "retfound_v2_conservative_fold3.pth": "1TRR0DuDHj99_qGC8KSbt50KLselMv1ti", |
| "retfound_v2_conservative_fold4.pth": "1huVy9EpLqa88MU3O5kfYrDLwGnm4TWOH", |
| "retfound_v2_conservative_fold5.pth": "1U0MwwuOji3P8psTjzwkNQqL81bXtm50d", |
| "retfound_v2_moderate_fold1.pth": "1y-xO33wRQAlgNrioYoSfOx019bNv30y-", |
| "retfound_v2_moderate_fold2.pth": "1r6f-EmdZnRgdGm4W9RaAO_dSRoqQkj_0", |
| "retfound_v2_moderate_fold3.pth": "1Mak5FuHl2jAZMS2NglR7T0gl09r0bXdR", |
| "retfound_v2_moderate_fold4.pth": "1qFE1CB3x96U1PakP0OAzKTUBUKfrqiZO", |
| "retfound_v2_moderate_fold5.pth": "1afUmpapz1dryl43rqVCE5QSEBLbaDGaw", |
| "retfound_v2_cv_fold1_best.pth": "1qUnrX9LJ6DF2ysG67rjiQzo4XtzJc5WW", |
| "retfound_v2_cv_fold2_best.pth": "1oEvZA2oQSXaxMbi7_-fnF5w9BDI7l_lZ", |
| "retfound_v2_cv_fold3_best.pth": "18w5XtZKA2HmSn0TqtDMa3yLC_rKYtHnb", |
| "retfound_v2_cv_fold4_best.pth": "1Jnoj-W6oQqp2l7GIYuqQ5e_WZRsII4Da", |
| "retfound_v2_cv_fold5_best.pth": "1d9stMnIjfcJoqeraFHFKKj8XquTV28KV", |
| "RETFound_oct.pth": "1v2v5XGMr7ipCyESE1jnASzLdANJDSLXv", |
| } |
|
|
|
|
| def maybe_download(filename: str) -> str: |
| """ |
| Returns the local path to a checkpoint, downloading from Drive if missing. |
| Requires the 'gdown' package. |
| """ |
| out_path = os.path.join(CHECKPOINT_DIR, filename) |
| if os.path.exists(out_path): |
| return out_path |
| file_id = DRIVE_FILE_IDS.get(filename) |
| if file_id is None: |
| raise FileNotFoundError( |
| f"{filename} not found in {CHECKPOINT_DIR} and has no Drive file ID registered." |
| ) |
| try: |
| import gdown |
| except ImportError: |
| raise ImportError( |
| "gdown is required for auto-download. " |
| "Install it with: pip install gdown" |
| ) |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| print(f"[auto-download] {filename} ...") |
| gdown.download( |
| f"https://drive.google.com/uc?id={file_id}", |
| out_path, |
| quiet=False, |
| ) |
| return out_path |
|
|
|
|
| |
| |
| |
|
|
| def _norm(): |
| return transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]) |
|
|
|
|
| def get_val_transform(): |
| return transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| _norm(), |
| ]) |
|
|
|
|
| def get_tta_transforms(): |
| n = _norm() |
| base = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), n]) |
| hflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), n]) |
| vflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor(), n]) |
| rot = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.RandomRotation(degrees=(10, 10)), transforms.ToTensor(), n]) |
| return [base, hflip, vflip, rot] |
|
|
|
|
| |
| |
| |
|
|
| def _get_retfound_arch(): |
| """ |
| Import RETFound_mae from the vendored models_vit.py that lives |
| in the same directory as app.py. No runtime git-clone needed. |
| """ |
| |
| app_dir = os.path.dirname(os.path.abspath(__file__)) |
| if app_dir not in sys.path: |
| sys.path.insert(0, app_dir) |
| try: |
| import models_vit |
| return models_vit.__dict__["RETFound_mae"] |
| except ImportError as e: |
| raise RuntimeError( |
| f"Could not import models_vit: {e}\n" |
| "Ensure models_vit.py is present in the same directory as app.py." |
| ) |
|
|
|
|
| def build_retfound(ckpt_path: str) -> nn.Module: |
| """Build a RETFound model and load a fine-tuned checkpoint.""" |
| arch = _get_retfound_arch() |
| model = arch(num_classes=1, drop_path_rate=0.2, global_pool=True) |
| state = torch.load(ckpt_path, map_location="cpu") |
| |
| model.load_state_dict(state, strict=True) |
| model.eval() |
| return model.to(DEVICE) |
|
|
|
|
| |
| |
| |
|
|
| _ensemble_cache: dict[str, list[nn.Module]] = {} |
|
|
|
|
| def get_ensemble(model_name: str) -> list[nn.Module]: |
| """Load (and cache) all 5 fold models for the selected strategy.""" |
| if model_name in _ensemble_cache: |
| return _ensemble_cache[model_name] |
|
|
| fold_files = MODEL_REGISTRY[model_name]["folds"] |
| models = [] |
| for fname in fold_files: |
| ckpt_path = maybe_download(fname) |
| print(f" Loading {fname} ...") |
| models.append(build_retfound(ckpt_path)) |
|
|
| _ensemble_cache[model_name] = models |
| return models |
|
|
|
|
| |
| |
| |
|
|
| def _infer_single(model: nn.Module, tensor: torch.Tensor) -> float: |
| """Run a single forward pass; return sigmoid probability as float.""" |
| out = model(tensor) |
| out = out.reshape(tensor.size(0), 1) |
| return torch.sigmoid(out).item() |
|
|
|
|
| def predict(image: Image.Image, model_name: str, use_tta: bool): |
| """ |
| Run 5-fold ensemble inference on a PIL image. |
| |
| Returns: |
| label (str) - "DRIL" or "No-DRIL" |
| dril_prob (float) - raw DRIL probability (0-1) |
| fold_probs (list[float])- per-fold DRIL probabilities |
| img_np (np.ndarray) - RGB preview of the input |
| """ |
| image = image.convert("RGB") |
| ensemble = get_ensemble(model_name) |
|
|
| tfms = get_tta_transforms() if use_tta else [get_val_transform()] |
|
|
| fold_probs = [] |
| with torch.no_grad(): |
| for model in ensemble: |
| tta_probs = [] |
| for t in tfms: |
| tensor = t(image).unsqueeze(0).to(DEVICE) |
| tta_probs.append(_infer_single(model, tensor)) |
| fold_probs.append(float(np.mean(tta_probs))) |
|
|
| dril_prob = float(np.mean(fold_probs)) |
| label = "DRIL" if dril_prob >= 0.5 else "No-DRIL" |
| img_np = np.array(image.resize((IMG_SIZE, IMG_SIZE))) |
| return label, dril_prob, fold_probs, img_np |
|
|
|
|
| |
| |
| |
|
|
| def run_inference(pil_image, model_name: str, use_tta: bool): |
| if pil_image is None: |
| return "No image provided.", None |
|
|
| try: |
| label, dril_prob, fold_probs, img_np = predict(pil_image, model_name, use_tta) |
| except Exception as e: |
| return f"Error during inference:\n{e}", None |
|
|
| nodril_prob = 1.0 - dril_prob |
| confidence = dril_prob if label == "DRIL" else nodril_prob |
|
|
| fold_lines = "\n".join( |
| f" Fold {i+1}: {p*100:.1f}% DRIL" for i, p in enumerate(fold_probs) |
| ) |
|
|
| result_text = ( |
| f"Prediction : {label}\n" |
| f"Confidence : {confidence*100:.1f}%\n" |
| f"DRIL prob : {dril_prob:.4f} | No-DRIL prob: {nodril_prob:.4f}\n" |
| f"\nPer-fold probabilities (DRIL):\n{fold_lines}" |
| ) |
| return result_text, img_np |
|
|
|
|
| with gr.Blocks(title="DRIL OCT Classification") as demo: |
| gr.Markdown( |
| """ |
| # DRIL OCT Classification |
| Classify macular OCT B-scan images as **DRIL** (Disruption of Retinal Inner Layers) |
| or **No-DRIL** using a 5-fold ensemble of RETFound foundation models fine-tuned on |
| a private OCT dataset (429 DRIL / 394 No-DRIL cases). |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image(type="pil", label="Upload OCT Image") |
| model_selector = gr.Dropdown( |
| choices=list(MODEL_REGISTRY.keys()), |
| value=DEFAULT_MODEL, |
| label="Model (fine-tuning strategy)" |
| ) |
| tta_checkbox = gr.Checkbox( |
| value=True, |
| label="Test-Time Augmentation (4-view TTA)" |
| ) |
| classify_btn = gr.Button("Classify", variant="primary") |
|
|
| with gr.Column(scale=1): |
| result_text = gr.Textbox(label="Result", lines=10) |
| output_image = gr.Image(label="Input Preview", type="numpy") |
|
|
| classify_btn.click( |
| fn=run_inference, |
| inputs=[input_image, model_selector, tta_checkbox], |
| outputs=[result_text, output_image] |
| ) |
|
|
| gr.Markdown( |
| """ |
| **Notes** |
| - The first inference call will load all 5 fold checkpoints into memory (~350 MB per strategy). |
| Subsequent calls on the same strategy are fast. |
| - Conservative strategy: top-4 ViT blocks unfrozen. Moderate: top-8 blocks. Baseline: head only. |
| - Optimal threshold for binary decision was determined by Youden-J on the validation set; |
| the demo uses the fixed 0.5 threshold for simplicity. |
| |
| **Disclaimer:** This tool is for research purposes only and must not be used for clinical decisions. |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|