#!/usr/bin/env python3 """ Simple Gradio app for testing an EyeQ QC model. Example ------- python app_eyeq.py \ --checkpoint ./checkpoints/eyeq_vit_base/best.pt Then open the printed local URL in your browser. """ import argparse from pathlib import Path import gradio as gr import numpy as np import torch from PIL import Image from torchvision import transforms import timm ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} def build_transform(img_size: int): return transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def load_model(checkpoint_path: str, device: torch.device): ckpt = torch.load(checkpoint_path, map_location="cpu") args = ckpt.get("args", {}) model_name = args.get("model", "vit_base_patch16_224") img_size = int(args.get("img_size", 224)) id_to_label = ckpt.get("id_to_label", ID_TO_LABEL) id_to_label = {int(k): v for k, v in id_to_label.items()} model = timm.create_model( model_name, pretrained=False, num_classes=len(id_to_label), ) model.load_state_dict(ckpt["model"], strict=True) model.to(device) model.eval() tfm = build_transform(img_size) return model, tfm, id_to_label, model_name, img_size def get_eyeq_class_ids(id_to_label): """Return class IDs for Good, Usable, Reject. Falls back to the standard EyeQ ordering if the checkpoint does not store string labels in the expected form. """ label_to_id = {str(v).lower(): int(k) for k, v in id_to_label.items()} good_id = label_to_id.get("good", 0) usable_id = label_to_id.get("usable", 1) reject_id = label_to_id.get("reject", 2) return good_id, usable_id, reject_id def soft_eyeq_decision(probs, id_to_label, reject_threshold=0.60, reject_margin=0.15): """Apply a conservative Reject rule. Reject is only returned when: 1. P(Reject) >= reject_threshold, and 2. P(Reject) beats the best non-Reject class by reject_margin. Otherwise, the prediction is forced to Good vs Usable. """ good_id, usable_id, reject_id = get_eyeq_class_ids(id_to_label) prob_good = float(probs[good_id]) prob_usable = float(probs[usable_id]) prob_reject = float(probs[reject_id]) best_non_reject_id = good_id if prob_good >= prob_usable else usable_id best_non_reject_prob = max(prob_good, prob_usable) if ( prob_reject >= reject_threshold and (prob_reject - best_non_reject_prob) >= reject_margin ): pred_id = reject_id decision = "Soft rule: Reject threshold and margin were both satisfied." else: pred_id = best_non_reject_id decision = "Soft rule: Reject was not confident enough, so prediction was forced to Good/Usable." return pred_id, id_to_label[pred_id], decision def update_margin_slider(reject_threshold, reject_margin): """Keep reject_margin within a sensible range for the current threshold.""" max_margin = min(0.50, float(reject_threshold)) reject_margin = min(float(reject_margin), max_margin) return gr.update( maximum=max_margin, value=reject_margin, ) @torch.no_grad() def predict_quality( image: Image.Image, model, tfm, id_to_label, device, reject_threshold=0.60, reject_margin=0.15, ): if image is None: return None, {}, "Upload an image to run QC." image = image.convert("RGB") x = tfm(image).unsqueeze(0).to(device) logits = model(x) probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy() raw_pred_id = int(np.argmax(probs)) raw_pred_label = id_to_label[raw_pred_id] soft_pred_id, soft_pred_label, decision = soft_eyeq_decision( probs=probs, id_to_label=id_to_label, reject_threshold=reject_threshold, reject_margin=reject_margin, ) prob_dict = { id_to_label[i]: float(probs[i]) for i in range(len(probs)) } detail = ( f"Raw argmax: {raw_pred_label}\n" f"Soft decision: {soft_pred_label}\n" f"Reject threshold: {reject_threshold:.2f} | Reject margin: {reject_margin:.2f}\n" f"{decision}" ) return soft_pred_label, prob_dict, detail def make_app(checkpoint_path: str): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, tfm, id_to_label, model_name, img_size = load_model(checkpoint_path, device) def run(image, reject_threshold, reject_margin): pred_label, prob_dict, detail = predict_quality( image=image, model=model, tfm=tfm, id_to_label=id_to_label, device=device, reject_threshold=reject_threshold, reject_margin=reject_margin, ) return pred_label, prob_dict, detail with gr.Blocks(title="EyeQ CFP Quality Control") as demo: gr.Markdown("# EyeQ CFP Quality Control") gr.Markdown( f"Model: `{model_name}` \n" f"Input size: `{img_size} × {img_size}` \n" f"Device: `{device}` \n" f"Checkpoint: `{checkpoint_path}`" ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( label="Input CFP", type="pil", height=520, ) with gr.Accordion("Soft Reject rule", open=True): reject_threshold = gr.Slider( minimum=0.40, maximum=0.95, value=0.60, step=0.01, label="Reject threshold", info="Minimum Reject probability required before an image can be called Reject.", ) reject_margin = gr.Slider( minimum=0.00, maximum=0.50, value=0.15, step=0.01, label="Reject margin", info="Reject must beat both Good and Usable by at least this much.", ) run_button = gr.Button("Run QC", variant="primary") with gr.Column(scale=1): pred_output = gr.Label(label="Predicted quality") prob_output = gr.Label(label="Class probabilities", num_top_classes=3) decision_output = gr.Textbox( label="Decision details", lines=4, interactive=False, ) run_inputs = [image_input, reject_threshold, reject_margin] run_outputs = [pred_output, prob_output, decision_output] run_button.click( fn=run, inputs=run_inputs, outputs=run_outputs, ) image_input.change( fn=run, inputs=run_inputs, outputs=run_outputs, ) reject_threshold.change( fn=update_margin_slider, inputs=[reject_threshold, reject_margin], outputs=reject_margin, ).then( fn=run, inputs=run_inputs, outputs=run_outputs, ) reject_margin.change( fn=run, inputs=run_inputs, outputs=run_outputs, ) return demo def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/eyeq_deploy.pt") parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", action="store_true") return parser.parse_args() def main(): args = parse_args() checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") demo = make_app(str(checkpoint_path)) demo.launch( # server_name=args.host, # server_port=args.port, # share=args.share, ) if __name__ == "__main__": main()