Spaces:
Sleeping
Sleeping
| # app.py | |
| import argparse | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from augmentations import get_val_transforms | |
| from model import DeepSeeNet | |
| ID_TO_LABEL = { | |
| 0: "No DR", | |
| 1: "Mild NPDR", | |
| 2: "Moderate NPDR", | |
| 3: "Severe NPDR", | |
| 4: "PDR", | |
| } | |
| BUCKET_MODE_DESCRIPTIONS = { | |
| "severity": "5-class DR severity", | |
| "referral": "Non-referable vs referable DR", | |
| "any_dr": "No DR vs any DR", | |
| "stage": "No DR vs NPDR vs PDR", | |
| "high_risk": "Non-severe vs severe/PDR", | |
| } | |
| class AlbumentationsTransform: | |
| def __init__(self, transform): | |
| self.transform = transform | |
| def __call__(self, image): | |
| return self.transform(image=np.asarray(image))["image"] | |
| def unwrap_logits(output): | |
| if isinstance(output, (tuple, list)): | |
| return output[0] | |
| return output | |
| def find_fold_checkpoints(checkpoint_dir, checkpoint_name="best_model_only.pt"): | |
| checkpoint_dir = Path(checkpoint_dir) | |
| if not checkpoint_dir.exists(): | |
| raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") | |
| paths = sorted(checkpoint_dir.glob(f"*_fold*/{checkpoint_name}")) | |
| if len(paths) == 0: | |
| # fallback: allow directly passing a folder containing checkpoints | |
| paths = sorted(checkpoint_dir.glob(f"**/{checkpoint_name}")) | |
| if len(paths) == 0: | |
| raise FileNotFoundError( | |
| f"No fold checkpoints named '{checkpoint_name}' found under {checkpoint_dir}" | |
| ) | |
| return paths | |
| def load_single_model(checkpoint_path, backbone, image_size, device): | |
| checkpoint_path = Path(checkpoint_path) | |
| ckpt = torch.load(checkpoint_path, map_location=device) | |
| saved_args = ckpt.get("args", {}) | |
| backbone = saved_args.get("backbone", backbone) | |
| image_size = int(saved_args.get("image_size", image_size)) | |
| model = DeepSeeNet( | |
| n_classes=5, | |
| backbone=backbone, | |
| pretrained=False, | |
| freeze_backbone=False, | |
| ) | |
| model.load_state_dict(ckpt["model"], strict=True) | |
| model.to(device) | |
| model.eval() | |
| return model, backbone, image_size, ckpt | |
| def load_ensemble(checkpoint_dir, checkpoint_name, backbone, image_size, device): | |
| checkpoint_paths = find_fold_checkpoints( | |
| checkpoint_dir=checkpoint_dir, | |
| checkpoint_name=checkpoint_name, | |
| ) | |
| models = [] | |
| ckpts = [] | |
| resolved_backbone = backbone | |
| resolved_image_size = image_size | |
| for path in checkpoint_paths: | |
| model, resolved_backbone, resolved_image_size, ckpt = load_single_model( | |
| checkpoint_path=path, | |
| backbone=resolved_backbone, | |
| image_size=resolved_image_size, | |
| device=device, | |
| ) | |
| models.append(model) | |
| ckpts.append(ckpt) | |
| transform = AlbumentationsTransform(get_val_transforms(resolved_image_size)) | |
| return models, transform, resolved_backbone, resolved_image_size, checkpoint_paths, ckpts | |
| def get_grouped_probs(probs, mode): | |
| p0, p1, p2, p3, p4 = probs | |
| if mode == "severity": | |
| return { | |
| "No DR": float(p0), | |
| "Mild NPDR": float(p1), | |
| "Moderate NPDR": float(p2), | |
| "Severe NPDR": float(p3), | |
| "PDR": float(p4), | |
| } | |
| if mode == "referral": | |
| return { | |
| "Non-referable DR": float(p0 + p1), | |
| "Referable DR": float(p2 + p3 + p4), | |
| } | |
| if mode == "any_dr": | |
| return { | |
| "No DR": float(p0), | |
| "Any DR": float(p1 + p2 + p3 + p4), | |
| } | |
| if mode == "stage": | |
| return { | |
| "No DR": float(p0), | |
| "NPDR": float(p1 + p2 + p3), | |
| "PDR": float(p4), | |
| } | |
| if mode == "high_risk": | |
| return { | |
| "Non-severe / non-PDR": float(p0 + p1 + p2), | |
| "Severe NPDR or PDR": float(p3 + p4), | |
| } | |
| raise ValueError(f"Unknown bucket mode: {mode}") | |
| def grouped_probs_to_fixed_dataframe(prob_dict, max_rows=5): | |
| rows = [ | |
| { | |
| "label": label, | |
| "probability": f"{prob:.4f}", | |
| } | |
| for label, prob in prob_dict.items() | |
| ] | |
| rows = sorted(rows, key=lambda x: float(x["probability"]), reverse=True) | |
| while len(rows) < max_rows: | |
| rows.append({"label": "", "probability": ""}) | |
| return pd.DataFrame(rows[:max_rows], columns=["label", "probability"]) | |
| def fold_probs_to_dataframe(fold_probs, checkpoint_paths): | |
| rows = [] | |
| for i, probs in enumerate(fold_probs): | |
| pred_grade = int(np.argmax(probs)) | |
| pred_label = ID_TO_LABEL[pred_grade] | |
| rows.append( | |
| { | |
| "fold": Path(checkpoint_paths[i]).parent.name, | |
| "prediction": pred_label, | |
| "probability": f"{float(probs[pred_grade]):.4f}", | |
| "p0_no_dr": f"{float(probs[0]):.4f}", | |
| "p1_mild": f"{float(probs[1]):.4f}", | |
| "p2_moderate": f"{float(probs[2]):.4f}", | |
| "p3_severe": f"{float(probs[3]):.4f}", | |
| "p4_pdr": f"{float(probs[4]):.4f}", | |
| } | |
| ) | |
| return pd.DataFrame(rows) | |
| def ensemble_predict(models, x): | |
| fold_probs = [] | |
| with torch.no_grad(): | |
| for model in models: | |
| logits = unwrap_logits(model(x)) | |
| probs = F.softmax(logits, dim=1)[0].detach().cpu().numpy() | |
| fold_probs.append(probs) | |
| fold_probs = np.stack(fold_probs, axis=0) | |
| ensemble_probs = fold_probs.mean(axis=0) | |
| return ensemble_probs, fold_probs | |
| def make_prediction_fn(models, transform, checkpoint_paths, device): | |
| def predict(image, bucket_mode): | |
| empty_df = pd.DataFrame( | |
| [{"label": "", "probability": ""} for _ in range(5)], | |
| columns=["label", "probability"], | |
| ) | |
| empty_fold_df = pd.DataFrame( | |
| columns=[ | |
| "fold", | |
| "prediction", | |
| "probability", | |
| "p0_no_dr", | |
| "p1_mild", | |
| "p2_moderate", | |
| "p3_severe", | |
| "p4_pdr", | |
| ] | |
| ) | |
| if image is None: | |
| return ( | |
| None, | |
| empty_df, | |
| empty_fold_df, | |
| "Upload a fundus image to run ensemble inference.", | |
| ) | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| x = transform(image).unsqueeze(0).to(device) | |
| probs, fold_probs = ensemble_predict(models, x) | |
| pred_grade = int(np.argmax(probs)) | |
| pred_label = ID_TO_LABEL[pred_grade] | |
| pred_prob = float(probs[pred_grade]) | |
| raw_probs = { | |
| ID_TO_LABEL[i]: float(probs[i]) | |
| for i in range(5) | |
| } | |
| grouped_probs = get_grouped_probs(probs, bucket_mode) | |
| bucket_label = max(grouped_probs, key=grouped_probs.get) | |
| bucket_prob = grouped_probs[bucket_label] | |
| bucket_df = grouped_probs_to_fixed_dataframe(grouped_probs, max_rows=5) | |
| fold_df = fold_probs_to_dataframe(fold_probs, checkpoint_paths) | |
| fold_preds = [ID_TO_LABEL[int(np.argmax(p))] for p in fold_probs] | |
| agreement = fold_preds.count(pred_label) / len(fold_preds) | |
| summary = ( | |
| f"Ensemble prediction: {pred_label} " | |
| f"(grade {pred_grade}, probability {pred_prob:.3f})\n\n" | |
| f"Probability grouping: {BUCKET_MODE_DESCRIPTIONS[bucket_mode]}\n" | |
| f"Grouped result: {bucket_label} " | |
| f"(probability {bucket_prob:.3f})\n\n" | |
| f"Fold agreement with ensemble class: {agreement:.1%} " | |
| f"({fold_preds.count(pred_label)}/{len(fold_preds)} folds)" | |
| ) | |
| return raw_probs, bucket_df, fold_df, summary | |
| return predict | |
| def make_app(checkpoint_dir, checkpoint_name, backbone, image_size, device): | |
| ( | |
| models, | |
| transform, | |
| backbone, | |
| image_size, | |
| checkpoint_paths, | |
| ckpts, | |
| ) = load_ensemble( | |
| checkpoint_dir=checkpoint_dir, | |
| checkpoint_name=checkpoint_name, | |
| backbone=backbone, | |
| image_size=image_size, | |
| device=device, | |
| ) | |
| predict_fn = make_prediction_fn( | |
| models=models, | |
| transform=transform, | |
| checkpoint_paths=checkpoint_paths, | |
| device=device, | |
| ) | |
| fold_names = [p.parent.name for p in checkpoint_paths] | |
| best_metrics = [ | |
| ckpt.get("best_metric", None) | |
| for ckpt in ckpts | |
| ] | |
| model_info = ( | |
| f"Backbone: `{backbone}` | " | |
| f"Input: `{image_size}×{image_size}` | " | |
| f"Folds loaded: `{len(models)}` | " | |
| f"Device: `{device}`" | |
| ) | |
| # fold_info = "Loaded checkpoints: " + ", ".join(fold_names) | |
| # | |
| # if any(m is not None for m in best_metrics): | |
| # metric_info = "Best metrics: " + ", ".join( | |
| # "NA" if m is None else f"{float(m):.4f}" | |
| # for m in best_metrics | |
| # ) | |
| # else: | |
| # metric_info = "" | |
| with gr.Blocks(title="EyePACS DR Ensemble Classifier") as demo: | |
| gr.Markdown("# EyePACS DR Ensemble Classifier") | |
| gr.Markdown(model_info) | |
| # gr.Markdown(fold_info) | |
| # if metric_info: | |
| # gr.Markdown(metric_info) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image( | |
| type="pil", | |
| label="Fundus image", | |
| ) | |
| bucket_mode = gr.Dropdown( | |
| choices=[ | |
| "severity", | |
| "referral", | |
| "any_dr", | |
| "stage", | |
| "high_risk", | |
| ], | |
| value="referral", | |
| label="Probability grouping", | |
| ) | |
| run_btn = gr.Button("Run inference", variant="primary") | |
| with gr.Column(scale=1): | |
| raw_probs = gr.Label( | |
| num_top_classes=5, | |
| label="Ensemble raw 5-class probabilities", | |
| ) | |
| bucket_table = gr.Dataframe( | |
| label="Grouped probabilities", | |
| headers=["label", "probability"], | |
| row_count=(5, "fixed"), | |
| col_count=(2, "fixed"), | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| with gr.Accordion("Per-fold predictions", open=False): | |
| fold_table = gr.Dataframe( | |
| label="Per-fold probabilities", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| summary = gr.Textbox( | |
| label="Summary", | |
| lines=7, | |
| interactive=False, | |
| ) | |
| inputs = [image, bucket_mode] | |
| outputs = [raw_probs, bucket_table, fold_table, summary] | |
| run_btn.click( | |
| fn=predict_fn, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| image.change( | |
| fn=predict_fn, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| bucket_mode.change( | |
| fn=predict_fn, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| return demo | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="EyePACS DR Gradio ensemble inference app." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-dir", | |
| default="runs/eyepacs_dr", | |
| help="Folder containing fold checkpoint folders.", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-name", | |
| default="best_model_only.pt", | |
| help="Checkpoint filename inside each fold folder.", | |
| ) | |
| parser.add_argument( | |
| "--backbone", | |
| default="inception_v3", | |
| help="Fallback backbone if checkpoint args are unavailable.", | |
| ) | |
| parser.add_argument( | |
| "--image-size", | |
| type=int, | |
| default=1024, | |
| help="Fallback image size if checkpoint args are unavailable.", | |
| ) | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| demo = make_app( | |
| checkpoint_dir=args.checkpoint_dir, | |
| checkpoint_name=args.checkpoint_name, | |
| backbone=args.backbone, | |
| image_size=args.image_size, | |
| device=device, | |
| ) | |
| demo.launch( | |
| # server_name=args.host, | |
| # server_port=args.port, | |
| # share=args.share, | |
| ) | |
| if __name__ == "__main__": | |
| main() |