AI-RESEARCHER-2024's picture
update files
6ab979f verified
#!/usr/bin/env python3
"""
app.py - Gradio demo for DRIL OCT Classification
-------------------------------------------------
Performs 5-fold ensemble inference using RETFound fine-tuned checkpoints.
The model probabilities from all 5 fold checkpoints are averaged, matching
the evaluation methodology used in training.
Models available (weights on Google Drive):
- RETFound Conservative (top-4 blocks unfrozen, 5-fold ensemble)
- RETFound Moderate (top-8 blocks unfrozen, 5-fold ensemble)
- RETFound Baseline (head-only fine-tuning, 5-fold ensemble)
To run locally:
pip install -r requirements.txt
python app.py
Environment variables:
CHECKPOINT_DIR - path to folder containing .pth files (default: ./checkpoints)
RETFOUND_DIR - path to cloned RETFound_MAE repo (default: ./RETFound_MAE)
"""
import os
import sys
import glob
import numpy as np
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "./checkpoints")
# ---------------------------------------------------------------------------
# Model registry
# Each entry: display name -> list of per-fold checkpoint filenames (ensembled)
# All filenames match exactly what is stored in the Google Drive folder.
# ---------------------------------------------------------------------------
MODEL_REGISTRY = {
"RETFound - Conservative (5-fold ensemble)": {
"folds": [
"retfound_v2_conservative_fold1.pth",
"retfound_v2_conservative_fold2.pth",
"retfound_v2_conservative_fold3.pth",
"retfound_v2_conservative_fold4.pth",
"retfound_v2_conservative_fold5.pth",
]
},
"RETFound - Moderate (5-fold ensemble)": {
"folds": [
"retfound_v2_moderate_fold1.pth",
"retfound_v2_moderate_fold2.pth",
"retfound_v2_moderate_fold3.pth",
"retfound_v2_moderate_fold4.pth",
"retfound_v2_moderate_fold5.pth",
]
},
"RETFound - Baseline CV (5-fold ensemble)": {
"folds": [
"retfound_v2_cv_fold1_best.pth",
"retfound_v2_cv_fold2_best.pth",
"retfound_v2_cv_fold3_best.pth",
"retfound_v2_cv_fold4_best.pth",
"retfound_v2_cv_fold5_best.pth",
]
},
}
DEFAULT_MODEL = "RETFound - Conservative (5-fold ensemble)"
# ---------------------------------------------------------------------------
# Auto-download weights from Drive if missing (gdown required)
# ---------------------------------------------------------------------------
DRIVE_FILE_IDS = {
"retfound_v2_conservative_fold1.pth": "110xefhcDD01YMGFcZ-6Hcv3zLm731vgu",
"retfound_v2_conservative_fold2.pth": "1gEzlq-LF7R7pNnd1Ud5sePtnRx0XI-mt",
"retfound_v2_conservative_fold3.pth": "1TRR0DuDHj99_qGC8KSbt50KLselMv1ti",
"retfound_v2_conservative_fold4.pth": "1huVy9EpLqa88MU3O5kfYrDLwGnm4TWOH",
"retfound_v2_conservative_fold5.pth": "1U0MwwuOji3P8psTjzwkNQqL81bXtm50d",
"retfound_v2_moderate_fold1.pth": "1y-xO33wRQAlgNrioYoSfOx019bNv30y-",
"retfound_v2_moderate_fold2.pth": "1r6f-EmdZnRgdGm4W9RaAO_dSRoqQkj_0",
"retfound_v2_moderate_fold3.pth": "1Mak5FuHl2jAZMS2NglR7T0gl09r0bXdR",
"retfound_v2_moderate_fold4.pth": "1qFE1CB3x96U1PakP0OAzKTUBUKfrqiZO",
"retfound_v2_moderate_fold5.pth": "1afUmpapz1dryl43rqVCE5QSEBLbaDGaw",
"retfound_v2_cv_fold1_best.pth": "1qUnrX9LJ6DF2ysG67rjiQzo4XtzJc5WW",
"retfound_v2_cv_fold2_best.pth": "1oEvZA2oQSXaxMbi7_-fnF5w9BDI7l_lZ",
"retfound_v2_cv_fold3_best.pth": "18w5XtZKA2HmSn0TqtDMa3yLC_rKYtHnb",
"retfound_v2_cv_fold4_best.pth": "1Jnoj-W6oQqp2l7GIYuqQ5e_WZRsII4Da",
"retfound_v2_cv_fold5_best.pth": "1d9stMnIjfcJoqeraFHFKKj8XquTV28KV",
"RETFound_oct.pth": "1v2v5XGMr7ipCyESE1jnASzLdANJDSLXv",
}
def maybe_download(filename: str) -> str:
"""
Returns the local path to a checkpoint, downloading from Drive if missing.
Requires the 'gdown' package.
"""
out_path = os.path.join(CHECKPOINT_DIR, filename)
if os.path.exists(out_path):
return out_path
file_id = DRIVE_FILE_IDS.get(filename)
if file_id is None:
raise FileNotFoundError(
f"{filename} not found in {CHECKPOINT_DIR} and has no Drive file ID registered."
)
try:
import gdown
except ImportError:
raise ImportError(
"gdown is required for auto-download. "
"Install it with: pip install gdown"
)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print(f"[auto-download] {filename} ...")
gdown.download(
f"https://drive.google.com/uc?id={file_id}",
out_path,
quiet=False,
)
return out_path
# ---------------------------------------------------------------------------
# Transforms
# ---------------------------------------------------------------------------
def _norm():
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def get_val_transform():
return transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
_norm(),
])
def get_tta_transforms():
n = _norm()
base = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), n])
hflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), n])
vflip = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomVerticalFlip(p=1.0), transforms.ToTensor(), n])
rot = transforms.Compose([transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomRotation(degrees=(10, 10)), transforms.ToTensor(), n])
return [base, hflip, vflip, rot]
# ---------------------------------------------------------------------------
# RETFound model builder
# ---------------------------------------------------------------------------
def _get_retfound_arch():
"""
Import RETFound_mae from the vendored models_vit.py that lives
in the same directory as app.py. No runtime git-clone needed.
"""
# Ensure the Space root is on sys.path so models_vit is importable
app_dir = os.path.dirname(os.path.abspath(__file__))
if app_dir not in sys.path:
sys.path.insert(0, app_dir)
try:
import models_vit
return models_vit.__dict__["RETFound_mae"]
except ImportError as e:
raise RuntimeError(
f"Could not import models_vit: {e}\n"
"Ensure models_vit.py is present in the same directory as app.py."
)
def build_retfound(ckpt_path: str) -> nn.Module:
"""Build a RETFound model and load a fine-tuned checkpoint."""
arch = _get_retfound_arch()
model = arch(num_classes=1, drop_path_rate=0.2, global_pool=True)
state = torch.load(ckpt_path, map_location="cpu")
# Fine-tuned checkpoints saved with model.state_dict() directly
model.load_state_dict(state, strict=True)
model.eval()
return model.to(DEVICE)
# ---------------------------------------------------------------------------
# Ensemble model cache { model_name -> list[nn.Module] }
# ---------------------------------------------------------------------------
_ensemble_cache: dict[str, list[nn.Module]] = {}
def get_ensemble(model_name: str) -> list[nn.Module]:
"""Load (and cache) all 5 fold models for the selected strategy."""
if model_name in _ensemble_cache:
return _ensemble_cache[model_name]
fold_files = MODEL_REGISTRY[model_name]["folds"]
models = []
for fname in fold_files:
ckpt_path = maybe_download(fname)
print(f" Loading {fname} ...")
models.append(build_retfound(ckpt_path))
_ensemble_cache[model_name] = models
return models
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def _infer_single(model: nn.Module, tensor: torch.Tensor) -> float:
"""Run a single forward pass; return sigmoid probability as float."""
out = model(tensor)
out = out.reshape(tensor.size(0), 1)
return torch.sigmoid(out).item()
def predict(image: Image.Image, model_name: str, use_tta: bool):
"""
Run 5-fold ensemble inference on a PIL image.
Returns:
label (str) - "DRIL" or "No-DRIL"
dril_prob (float) - raw DRIL probability (0-1)
fold_probs (list[float])- per-fold DRIL probabilities
img_np (np.ndarray) - RGB preview of the input
"""
image = image.convert("RGB")
ensemble = get_ensemble(model_name)
tfms = get_tta_transforms() if use_tta else [get_val_transform()]
fold_probs = []
with torch.no_grad():
for model in ensemble:
tta_probs = []
for t in tfms:
tensor = t(image).unsqueeze(0).to(DEVICE)
tta_probs.append(_infer_single(model, tensor))
fold_probs.append(float(np.mean(tta_probs)))
dril_prob = float(np.mean(fold_probs))
label = "DRIL" if dril_prob >= 0.5 else "No-DRIL"
img_np = np.array(image.resize((IMG_SIZE, IMG_SIZE)))
return label, dril_prob, fold_probs, img_np
# ---------------------------------------------------------------------------
# Gradio interface
# ---------------------------------------------------------------------------
def run_inference(pil_image, model_name: str, use_tta: bool):
if pil_image is None:
return "No image provided.", None
try:
label, dril_prob, fold_probs, img_np = predict(pil_image, model_name, use_tta)
except Exception as e:
return f"Error during inference:\n{e}", None
nodril_prob = 1.0 - dril_prob
confidence = dril_prob if label == "DRIL" else nodril_prob
fold_lines = "\n".join(
f" Fold {i+1}: {p*100:.1f}% DRIL" for i, p in enumerate(fold_probs)
)
result_text = (
f"Prediction : {label}\n"
f"Confidence : {confidence*100:.1f}%\n"
f"DRIL prob : {dril_prob:.4f} | No-DRIL prob: {nodril_prob:.4f}\n"
f"\nPer-fold probabilities (DRIL):\n{fold_lines}"
)
return result_text, img_np
with gr.Blocks(title="DRIL OCT Classification") as demo:
gr.Markdown(
"""
# DRIL OCT Classification
Classify macular OCT B-scan images as **DRIL** (Disruption of Retinal Inner Layers)
or **No-DRIL** using a 5-fold ensemble of RETFound foundation models fine-tuned on
a private OCT dataset (429 DRIL / 394 No-DRIL cases).
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload OCT Image")
model_selector = gr.Dropdown(
choices=list(MODEL_REGISTRY.keys()),
value=DEFAULT_MODEL,
label="Model (fine-tuning strategy)"
)
tta_checkbox = gr.Checkbox(
value=True,
label="Test-Time Augmentation (4-view TTA)"
)
classify_btn = gr.Button("Classify", variant="primary")
with gr.Column(scale=1):
result_text = gr.Textbox(label="Result", lines=10)
output_image = gr.Image(label="Input Preview", type="numpy")
classify_btn.click(
fn=run_inference,
inputs=[input_image, model_selector, tta_checkbox],
outputs=[result_text, output_image]
)
gr.Markdown(
"""
**Notes**
- The first inference call will load all 5 fold checkpoints into memory (~350 MB per strategy).
Subsequent calls on the same strategy are fast.
- Conservative strategy: top-4 ViT blocks unfrozen. Moderate: top-8 blocks. Baseline: head only.
- Optimal threshold for binary decision was determined by Youden-J on the validation set;
the demo uses the fixed 0.5 threshold for simplicity.
**Disclaimer:** This tool is for research purposes only and must not be used for clinical decisions.
"""
)
if __name__ == "__main__":
demo.launch()