EyePACS / app.py
Hou
add src
ebbe758
Raw
History Blame Contribute Delete
12.8 kB
# app.py
import argparse
from pathlib import Path
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn.functional as F
from augmentations import get_val_transforms
from model import DeepSeeNet
ID_TO_LABEL = {
0: "No DR",
1: "Mild NPDR",
2: "Moderate NPDR",
3: "Severe NPDR",
4: "PDR",
}
BUCKET_MODE_DESCRIPTIONS = {
"severity": "5-class DR severity",
"referral": "Non-referable vs referable DR",
"any_dr": "No DR vs any DR",
"stage": "No DR vs NPDR vs PDR",
"high_risk": "Non-severe vs severe/PDR",
}
class AlbumentationsTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, image):
return self.transform(image=np.asarray(image))["image"]
def unwrap_logits(output):
if isinstance(output, (tuple, list)):
return output[0]
return output
def find_fold_checkpoints(checkpoint_dir, checkpoint_name="best_model_only.pt"):
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
paths = sorted(checkpoint_dir.glob(f"*_fold*/{checkpoint_name}"))
if len(paths) == 0:
# fallback: allow directly passing a folder containing checkpoints
paths = sorted(checkpoint_dir.glob(f"**/{checkpoint_name}"))
if len(paths) == 0:
raise FileNotFoundError(
f"No fold checkpoints named '{checkpoint_name}' found under {checkpoint_dir}"
)
return paths
def load_single_model(checkpoint_path, backbone, image_size, device):
checkpoint_path = Path(checkpoint_path)
ckpt = torch.load(checkpoint_path, map_location=device)
saved_args = ckpt.get("args", {})
backbone = saved_args.get("backbone", backbone)
image_size = int(saved_args.get("image_size", image_size))
model = DeepSeeNet(
n_classes=5,
backbone=backbone,
pretrained=False,
freeze_backbone=False,
)
model.load_state_dict(ckpt["model"], strict=True)
model.to(device)
model.eval()
return model, backbone, image_size, ckpt
def load_ensemble(checkpoint_dir, checkpoint_name, backbone, image_size, device):
checkpoint_paths = find_fold_checkpoints(
checkpoint_dir=checkpoint_dir,
checkpoint_name=checkpoint_name,
)
models = []
ckpts = []
resolved_backbone = backbone
resolved_image_size = image_size
for path in checkpoint_paths:
model, resolved_backbone, resolved_image_size, ckpt = load_single_model(
checkpoint_path=path,
backbone=resolved_backbone,
image_size=resolved_image_size,
device=device,
)
models.append(model)
ckpts.append(ckpt)
transform = AlbumentationsTransform(get_val_transforms(resolved_image_size))
return models, transform, resolved_backbone, resolved_image_size, checkpoint_paths, ckpts
def get_grouped_probs(probs, mode):
p0, p1, p2, p3, p4 = probs
if mode == "severity":
return {
"No DR": float(p0),
"Mild NPDR": float(p1),
"Moderate NPDR": float(p2),
"Severe NPDR": float(p3),
"PDR": float(p4),
}
if mode == "referral":
return {
"Non-referable DR": float(p0 + p1),
"Referable DR": float(p2 + p3 + p4),
}
if mode == "any_dr":
return {
"No DR": float(p0),
"Any DR": float(p1 + p2 + p3 + p4),
}
if mode == "stage":
return {
"No DR": float(p0),
"NPDR": float(p1 + p2 + p3),
"PDR": float(p4),
}
if mode == "high_risk":
return {
"Non-severe / non-PDR": float(p0 + p1 + p2),
"Severe NPDR or PDR": float(p3 + p4),
}
raise ValueError(f"Unknown bucket mode: {mode}")
def grouped_probs_to_fixed_dataframe(prob_dict, max_rows=5):
rows = [
{
"label": label,
"probability": f"{prob:.4f}",
}
for label, prob in prob_dict.items()
]
rows = sorted(rows, key=lambda x: float(x["probability"]), reverse=True)
while len(rows) < max_rows:
rows.append({"label": "", "probability": ""})
return pd.DataFrame(rows[:max_rows], columns=["label", "probability"])
def fold_probs_to_dataframe(fold_probs, checkpoint_paths):
rows = []
for i, probs in enumerate(fold_probs):
pred_grade = int(np.argmax(probs))
pred_label = ID_TO_LABEL[pred_grade]
rows.append(
{
"fold": Path(checkpoint_paths[i]).parent.name,
"prediction": pred_label,
"probability": f"{float(probs[pred_grade]):.4f}",
"p0_no_dr": f"{float(probs[0]):.4f}",
"p1_mild": f"{float(probs[1]):.4f}",
"p2_moderate": f"{float(probs[2]):.4f}",
"p3_severe": f"{float(probs[3]):.4f}",
"p4_pdr": f"{float(probs[4]):.4f}",
}
)
return pd.DataFrame(rows)
def ensemble_predict(models, x):
fold_probs = []
with torch.no_grad():
for model in models:
logits = unwrap_logits(model(x))
probs = F.softmax(logits, dim=1)[0].detach().cpu().numpy()
fold_probs.append(probs)
fold_probs = np.stack(fold_probs, axis=0)
ensemble_probs = fold_probs.mean(axis=0)
return ensemble_probs, fold_probs
def make_prediction_fn(models, transform, checkpoint_paths, device):
def predict(image, bucket_mode):
empty_df = pd.DataFrame(
[{"label": "", "probability": ""} for _ in range(5)],
columns=["label", "probability"],
)
empty_fold_df = pd.DataFrame(
columns=[
"fold",
"prediction",
"probability",
"p0_no_dr",
"p1_mild",
"p2_moderate",
"p3_severe",
"p4_pdr",
]
)
if image is None:
return (
None,
empty_df,
empty_fold_df,
"Upload a fundus image to run ensemble inference.",
)
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
x = transform(image).unsqueeze(0).to(device)
probs, fold_probs = ensemble_predict(models, x)
pred_grade = int(np.argmax(probs))
pred_label = ID_TO_LABEL[pred_grade]
pred_prob = float(probs[pred_grade])
raw_probs = {
ID_TO_LABEL[i]: float(probs[i])
for i in range(5)
}
grouped_probs = get_grouped_probs(probs, bucket_mode)
bucket_label = max(grouped_probs, key=grouped_probs.get)
bucket_prob = grouped_probs[bucket_label]
bucket_df = grouped_probs_to_fixed_dataframe(grouped_probs, max_rows=5)
fold_df = fold_probs_to_dataframe(fold_probs, checkpoint_paths)
fold_preds = [ID_TO_LABEL[int(np.argmax(p))] for p in fold_probs]
agreement = fold_preds.count(pred_label) / len(fold_preds)
summary = (
f"Ensemble prediction: {pred_label} "
f"(grade {pred_grade}, probability {pred_prob:.3f})\n\n"
f"Probability grouping: {BUCKET_MODE_DESCRIPTIONS[bucket_mode]}\n"
f"Grouped result: {bucket_label} "
f"(probability {bucket_prob:.3f})\n\n"
f"Fold agreement with ensemble class: {agreement:.1%} "
f"({fold_preds.count(pred_label)}/{len(fold_preds)} folds)"
)
return raw_probs, bucket_df, fold_df, summary
return predict
def make_app(checkpoint_dir, checkpoint_name, backbone, image_size, device):
(
models,
transform,
backbone,
image_size,
checkpoint_paths,
ckpts,
) = load_ensemble(
checkpoint_dir=checkpoint_dir,
checkpoint_name=checkpoint_name,
backbone=backbone,
image_size=image_size,
device=device,
)
predict_fn = make_prediction_fn(
models=models,
transform=transform,
checkpoint_paths=checkpoint_paths,
device=device,
)
fold_names = [p.parent.name for p in checkpoint_paths]
best_metrics = [
ckpt.get("best_metric", None)
for ckpt in ckpts
]
model_info = (
f"Backbone: `{backbone}` | "
f"Input: `{image_size}×{image_size}` | "
f"Folds loaded: `{len(models)}` | "
f"Device: `{device}`"
)
# fold_info = "Loaded checkpoints: " + ", ".join(fold_names)
#
# if any(m is not None for m in best_metrics):
# metric_info = "Best metrics: " + ", ".join(
# "NA" if m is None else f"{float(m):.4f}"
# for m in best_metrics
# )
# else:
# metric_info = ""
with gr.Blocks(title="EyePACS DR Ensemble Classifier") as demo:
gr.Markdown("# EyePACS DR Ensemble Classifier")
gr.Markdown(model_info)
# gr.Markdown(fold_info)
# if metric_info:
# gr.Markdown(metric_info)
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(
type="pil",
label="Fundus image",
)
bucket_mode = gr.Dropdown(
choices=[
"severity",
"referral",
"any_dr",
"stage",
"high_risk",
],
value="referral",
label="Probability grouping",
)
run_btn = gr.Button("Run inference", variant="primary")
with gr.Column(scale=1):
raw_probs = gr.Label(
num_top_classes=5,
label="Ensemble raw 5-class probabilities",
)
bucket_table = gr.Dataframe(
label="Grouped probabilities",
headers=["label", "probability"],
row_count=(5, "fixed"),
col_count=(2, "fixed"),
interactive=False,
wrap=True,
)
with gr.Accordion("Per-fold predictions", open=False):
fold_table = gr.Dataframe(
label="Per-fold probabilities",
interactive=False,
wrap=True,
)
summary = gr.Textbox(
label="Summary",
lines=7,
interactive=False,
)
inputs = [image, bucket_mode]
outputs = [raw_probs, bucket_table, fold_table, summary]
run_btn.click(
fn=predict_fn,
inputs=inputs,
outputs=outputs,
)
image.change(
fn=predict_fn,
inputs=inputs,
outputs=outputs,
)
bucket_mode.change(
fn=predict_fn,
inputs=inputs,
outputs=outputs,
)
return demo
def parse_args():
parser = argparse.ArgumentParser(
description="EyePACS DR Gradio ensemble inference app."
)
parser.add_argument(
"--checkpoint-dir",
default="runs/eyepacs_dr",
help="Folder containing fold checkpoint folders.",
)
parser.add_argument(
"--checkpoint-name",
default="best_model_only.pt",
help="Checkpoint filename inside each fold folder.",
)
parser.add_argument(
"--backbone",
default="inception_v3",
help="Fallback backbone if checkpoint args are unavailable.",
)
parser.add_argument(
"--image-size",
type=int,
default=1024,
help="Fallback image size if checkpoint args are unavailable.",
)
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
demo = make_app(
checkpoint_dir=args.checkpoint_dir,
checkpoint_name=args.checkpoint_name,
backbone=args.backbone,
image_size=args.image_size,
device=device,
)
demo.launch(
# server_name=args.host,
# server_port=args.port,
# share=args.share,
)
if __name__ == "__main__":
main()