EyeQ / app.py
farrell236's picture
add src
d0344ce
#!/usr/bin/env python3
"""
Simple Gradio app for testing an EyeQ QC model.
Example
-------
python app_eyeq.py \
--checkpoint ./checkpoints/eyeq_vit_base/best.pt
Then open the printed local URL in your browser.
"""
import argparse
from pathlib import Path
import gradio as gr
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
import timm
ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"}
def build_transform(img_size: int):
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def load_model(checkpoint_path: str, device: torch.device):
ckpt = torch.load(checkpoint_path, map_location="cpu")
args = ckpt.get("args", {})
model_name = args.get("model", "vit_base_patch16_224")
img_size = int(args.get("img_size", 224))
id_to_label = ckpt.get("id_to_label", ID_TO_LABEL)
id_to_label = {int(k): v for k, v in id_to_label.items()}
model = timm.create_model(
model_name,
pretrained=False,
num_classes=len(id_to_label),
)
model.load_state_dict(ckpt["model"], strict=True)
model.to(device)
model.eval()
tfm = build_transform(img_size)
return model, tfm, id_to_label, model_name, img_size
def get_eyeq_class_ids(id_to_label):
"""Return class IDs for Good, Usable, Reject.
Falls back to the standard EyeQ ordering if the checkpoint does not store
string labels in the expected form.
"""
label_to_id = {str(v).lower(): int(k) for k, v in id_to_label.items()}
good_id = label_to_id.get("good", 0)
usable_id = label_to_id.get("usable", 1)
reject_id = label_to_id.get("reject", 2)
return good_id, usable_id, reject_id
def soft_eyeq_decision(probs, id_to_label, reject_threshold=0.60, reject_margin=0.15):
"""Apply a conservative Reject rule.
Reject is only returned when:
1. P(Reject) >= reject_threshold, and
2. P(Reject) beats the best non-Reject class by reject_margin.
Otherwise, the prediction is forced to Good vs Usable.
"""
good_id, usable_id, reject_id = get_eyeq_class_ids(id_to_label)
prob_good = float(probs[good_id])
prob_usable = float(probs[usable_id])
prob_reject = float(probs[reject_id])
best_non_reject_id = good_id if prob_good >= prob_usable else usable_id
best_non_reject_prob = max(prob_good, prob_usable)
if (
prob_reject >= reject_threshold
and (prob_reject - best_non_reject_prob) >= reject_margin
):
pred_id = reject_id
decision = "Soft rule: Reject threshold and margin were both satisfied."
else:
pred_id = best_non_reject_id
decision = "Soft rule: Reject was not confident enough, so prediction was forced to Good/Usable."
return pred_id, id_to_label[pred_id], decision
def update_margin_slider(reject_threshold, reject_margin):
"""Keep reject_margin within a sensible range for the current threshold."""
max_margin = min(0.50, float(reject_threshold))
reject_margin = min(float(reject_margin), max_margin)
return gr.update(
maximum=max_margin,
value=reject_margin,
)
@torch.no_grad()
def predict_quality(
image: Image.Image,
model,
tfm,
id_to_label,
device,
reject_threshold=0.60,
reject_margin=0.15,
):
if image is None:
return None, {}, "Upload an image to run QC."
image = image.convert("RGB")
x = tfm(image).unsqueeze(0).to(device)
logits = model(x)
probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy()
raw_pred_id = int(np.argmax(probs))
raw_pred_label = id_to_label[raw_pred_id]
soft_pred_id, soft_pred_label, decision = soft_eyeq_decision(
probs=probs,
id_to_label=id_to_label,
reject_threshold=reject_threshold,
reject_margin=reject_margin,
)
prob_dict = {
id_to_label[i]: float(probs[i])
for i in range(len(probs))
}
detail = (
f"Raw argmax: {raw_pred_label}\n"
f"Soft decision: {soft_pred_label}\n"
f"Reject threshold: {reject_threshold:.2f} | Reject margin: {reject_margin:.2f}\n"
f"{decision}"
)
return soft_pred_label, prob_dict, detail
def make_app(checkpoint_path: str):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tfm, id_to_label, model_name, img_size = load_model(checkpoint_path, device)
def run(image, reject_threshold, reject_margin):
pred_label, prob_dict, detail = predict_quality(
image=image,
model=model,
tfm=tfm,
id_to_label=id_to_label,
device=device,
reject_threshold=reject_threshold,
reject_margin=reject_margin,
)
return pred_label, prob_dict, detail
with gr.Blocks(title="EyeQ CFP Quality Control") as demo:
gr.Markdown("# EyeQ CFP Quality Control")
gr.Markdown(
f"Model: `{model_name}` \n"
f"Input size: `{img_size} × {img_size}` \n"
f"Device: `{device}` \n"
f"Checkpoint: `{checkpoint_path}`"
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="Input CFP",
type="pil",
height=520,
)
with gr.Accordion("Soft Reject rule", open=True):
reject_threshold = gr.Slider(
minimum=0.40,
maximum=0.95,
value=0.60,
step=0.01,
label="Reject threshold",
info="Minimum Reject probability required before an image can be called Reject.",
)
reject_margin = gr.Slider(
minimum=0.00,
maximum=0.50,
value=0.15,
step=0.01,
label="Reject margin",
info="Reject must beat both Good and Usable by at least this much.",
)
run_button = gr.Button("Run QC", variant="primary")
with gr.Column(scale=1):
pred_output = gr.Label(label="Predicted quality")
prob_output = gr.Label(label="Class probabilities", num_top_classes=3)
decision_output = gr.Textbox(
label="Decision details",
lines=4,
interactive=False,
)
run_inputs = [image_input, reject_threshold, reject_margin]
run_outputs = [pred_output, prob_output, decision_output]
run_button.click(
fn=run,
inputs=run_inputs,
outputs=run_outputs,
)
image_input.change(
fn=run,
inputs=run_inputs,
outputs=run_outputs,
)
reject_threshold.change(
fn=update_margin_slider,
inputs=[reject_threshold, reject_margin],
outputs=reject_margin,
).then(
fn=run,
inputs=run_inputs,
outputs=run_outputs,
)
reject_margin.change(
fn=run,
inputs=run_inputs,
outputs=run_outputs,
)
return demo
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/eyeq_deploy.pt")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
demo = make_app(str(checkpoint_path))
demo.launch(
# server_name=args.host,
# server_port=args.port,
# share=args.share,
)
if __name__ == "__main__":
main()