# app.py import argparse from pathlib import Path import gradio as gr import numpy as np import pandas as pd from PIL import Image import torch import torch.nn.functional as F from augmentations import get_val_transforms from model import DeepSeeNet ID_TO_LABEL = { 0: "No DR", 1: "Mild NPDR", 2: "Moderate NPDR", 3: "Severe NPDR", 4: "PDR", } BUCKET_MODE_DESCRIPTIONS = { "severity": "5-class DR severity", "referral": "Non-referable vs referable DR", "any_dr": "No DR vs any DR", "stage": "No DR vs NPDR vs PDR", "high_risk": "Non-severe vs severe/PDR", } class AlbumentationsTransform: def __init__(self, transform): self.transform = transform def __call__(self, image): return self.transform(image=np.asarray(image))["image"] def unwrap_logits(output): if isinstance(output, (tuple, list)): return output[0] return output def find_fold_checkpoints(checkpoint_dir, checkpoint_name="best_model_only.pt"): checkpoint_dir = Path(checkpoint_dir) if not checkpoint_dir.exists(): raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") paths = sorted(checkpoint_dir.glob(f"*_fold*/{checkpoint_name}")) if len(paths) == 0: # fallback: allow directly passing a folder containing checkpoints paths = sorted(checkpoint_dir.glob(f"**/{checkpoint_name}")) if len(paths) == 0: raise FileNotFoundError( f"No fold checkpoints named '{checkpoint_name}' found under {checkpoint_dir}" ) return paths def load_single_model(checkpoint_path, backbone, image_size, device): checkpoint_path = Path(checkpoint_path) ckpt = torch.load(checkpoint_path, map_location=device) saved_args = ckpt.get("args", {}) backbone = saved_args.get("backbone", backbone) image_size = int(saved_args.get("image_size", image_size)) model = DeepSeeNet( n_classes=5, backbone=backbone, pretrained=False, freeze_backbone=False, ) model.load_state_dict(ckpt["model"], strict=True) model.to(device) model.eval() return model, backbone, image_size, ckpt def load_ensemble(checkpoint_dir, checkpoint_name, backbone, image_size, device): checkpoint_paths = find_fold_checkpoints( checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name, ) models = [] ckpts = [] resolved_backbone = backbone resolved_image_size = image_size for path in checkpoint_paths: model, resolved_backbone, resolved_image_size, ckpt = load_single_model( checkpoint_path=path, backbone=resolved_backbone, image_size=resolved_image_size, device=device, ) models.append(model) ckpts.append(ckpt) transform = AlbumentationsTransform(get_val_transforms(resolved_image_size)) return models, transform, resolved_backbone, resolved_image_size, checkpoint_paths, ckpts def get_grouped_probs(probs, mode): p0, p1, p2, p3, p4 = probs if mode == "severity": return { "No DR": float(p0), "Mild NPDR": float(p1), "Moderate NPDR": float(p2), "Severe NPDR": float(p3), "PDR": float(p4), } if mode == "referral": return { "Non-referable DR": float(p0 + p1), "Referable DR": float(p2 + p3 + p4), } if mode == "any_dr": return { "No DR": float(p0), "Any DR": float(p1 + p2 + p3 + p4), } if mode == "stage": return { "No DR": float(p0), "NPDR": float(p1 + p2 + p3), "PDR": float(p4), } if mode == "high_risk": return { "Non-severe / non-PDR": float(p0 + p1 + p2), "Severe NPDR or PDR": float(p3 + p4), } raise ValueError(f"Unknown bucket mode: {mode}") def grouped_probs_to_fixed_dataframe(prob_dict, max_rows=5): rows = [ { "label": label, "probability": f"{prob:.4f}", } for label, prob in prob_dict.items() ] rows = sorted(rows, key=lambda x: float(x["probability"]), reverse=True) while len(rows) < max_rows: rows.append({"label": "", "probability": ""}) return pd.DataFrame(rows[:max_rows], columns=["label", "probability"]) def fold_probs_to_dataframe(fold_probs, checkpoint_paths): rows = [] for i, probs in enumerate(fold_probs): pred_grade = int(np.argmax(probs)) pred_label = ID_TO_LABEL[pred_grade] rows.append( { "fold": Path(checkpoint_paths[i]).parent.name, "prediction": pred_label, "probability": f"{float(probs[pred_grade]):.4f}", "p0_no_dr": f"{float(probs[0]):.4f}", "p1_mild": f"{float(probs[1]):.4f}", "p2_moderate": f"{float(probs[2]):.4f}", "p3_severe": f"{float(probs[3]):.4f}", "p4_pdr": f"{float(probs[4]):.4f}", } ) return pd.DataFrame(rows) def ensemble_predict(models, x): fold_probs = [] with torch.no_grad(): for model in models: logits = unwrap_logits(model(x)) probs = F.softmax(logits, dim=1)[0].detach().cpu().numpy() fold_probs.append(probs) fold_probs = np.stack(fold_probs, axis=0) ensemble_probs = fold_probs.mean(axis=0) return ensemble_probs, fold_probs def make_prediction_fn(models, transform, checkpoint_paths, device): def predict(image, bucket_mode): empty_df = pd.DataFrame( [{"label": "", "probability": ""} for _ in range(5)], columns=["label", "probability"], ) empty_fold_df = pd.DataFrame( columns=[ "fold", "prediction", "probability", "p0_no_dr", "p1_mild", "p2_moderate", "p3_severe", "p4_pdr", ] ) if image is None: return ( None, empty_df, empty_fold_df, "Upload a fundus image to run ensemble inference.", ) if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert("RGB") x = transform(image).unsqueeze(0).to(device) probs, fold_probs = ensemble_predict(models, x) pred_grade = int(np.argmax(probs)) pred_label = ID_TO_LABEL[pred_grade] pred_prob = float(probs[pred_grade]) raw_probs = { ID_TO_LABEL[i]: float(probs[i]) for i in range(5) } grouped_probs = get_grouped_probs(probs, bucket_mode) bucket_label = max(grouped_probs, key=grouped_probs.get) bucket_prob = grouped_probs[bucket_label] bucket_df = grouped_probs_to_fixed_dataframe(grouped_probs, max_rows=5) fold_df = fold_probs_to_dataframe(fold_probs, checkpoint_paths) fold_preds = [ID_TO_LABEL[int(np.argmax(p))] for p in fold_probs] agreement = fold_preds.count(pred_label) / len(fold_preds) summary = ( f"Ensemble prediction: {pred_label} " f"(grade {pred_grade}, probability {pred_prob:.3f})\n\n" f"Probability grouping: {BUCKET_MODE_DESCRIPTIONS[bucket_mode]}\n" f"Grouped result: {bucket_label} " f"(probability {bucket_prob:.3f})\n\n" f"Fold agreement with ensemble class: {agreement:.1%} " f"({fold_preds.count(pred_label)}/{len(fold_preds)} folds)" ) return raw_probs, bucket_df, fold_df, summary return predict def make_app(checkpoint_dir, checkpoint_name, backbone, image_size, device): ( models, transform, backbone, image_size, checkpoint_paths, ckpts, ) = load_ensemble( checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name, backbone=backbone, image_size=image_size, device=device, ) predict_fn = make_prediction_fn( models=models, transform=transform, checkpoint_paths=checkpoint_paths, device=device, ) fold_names = [p.parent.name for p in checkpoint_paths] best_metrics = [ ckpt.get("best_metric", None) for ckpt in ckpts ] model_info = ( f"Backbone: `{backbone}` | " f"Input: `{image_size}×{image_size}` | " f"Folds loaded: `{len(models)}` | " f"Device: `{device}`" ) # fold_info = "Loaded checkpoints: " + ", ".join(fold_names) # # if any(m is not None for m in best_metrics): # metric_info = "Best metrics: " + ", ".join( # "NA" if m is None else f"{float(m):.4f}" # for m in best_metrics # ) # else: # metric_info = "" with gr.Blocks(title="EyePACS DR Ensemble Classifier") as demo: gr.Markdown("# EyePACS DR Ensemble Classifier") gr.Markdown(model_info) # gr.Markdown(fold_info) # if metric_info: # gr.Markdown(metric_info) with gr.Row(): with gr.Column(scale=1): image = gr.Image( type="pil", label="Fundus image", ) bucket_mode = gr.Dropdown( choices=[ "severity", "referral", "any_dr", "stage", "high_risk", ], value="referral", label="Probability grouping", ) run_btn = gr.Button("Run inference", variant="primary") with gr.Column(scale=1): raw_probs = gr.Label( num_top_classes=5, label="Ensemble raw 5-class probabilities", ) bucket_table = gr.Dataframe( label="Grouped probabilities", headers=["label", "probability"], row_count=(5, "fixed"), col_count=(2, "fixed"), interactive=False, wrap=True, ) with gr.Accordion("Per-fold predictions", open=False): fold_table = gr.Dataframe( label="Per-fold probabilities", interactive=False, wrap=True, ) summary = gr.Textbox( label="Summary", lines=7, interactive=False, ) inputs = [image, bucket_mode] outputs = [raw_probs, bucket_table, fold_table, summary] run_btn.click( fn=predict_fn, inputs=inputs, outputs=outputs, ) image.change( fn=predict_fn, inputs=inputs, outputs=outputs, ) bucket_mode.change( fn=predict_fn, inputs=inputs, outputs=outputs, ) return demo def parse_args(): parser = argparse.ArgumentParser( description="EyePACS DR Gradio ensemble inference app." ) parser.add_argument( "--checkpoint-dir", default="runs/eyepacs_dr", help="Folder containing fold checkpoint folders.", ) parser.add_argument( "--checkpoint-name", default="best_model_only.pt", help="Checkpoint filename inside each fold folder.", ) parser.add_argument( "--backbone", default="inception_v3", help="Fallback backbone if checkpoint args are unavailable.", ) parser.add_argument( "--image-size", type=int, default=1024, help="Fallback image size if checkpoint args are unavailable.", ) parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", action="store_true") return parser.parse_args() def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") demo = make_app( checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, backbone=args.backbone, image_size=args.image_size, device=device, ) demo.launch( # server_name=args.host, # server_port=args.port, # share=args.share, ) if __name__ == "__main__": main()