Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Simple Gradio app for testing an EyeQ QC model. | |
| Example | |
| ------- | |
| python app_eyeq.py \ | |
| --checkpoint ./checkpoints/eyeq_vit_base/best.pt | |
| Then open the printed local URL in your browser. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| import timm | |
| ID_TO_LABEL = {0: "Good", 1: "Usable", 2: "Reject"} | |
| def build_transform(img_size: int): | |
| return transforms.Compose([ | |
| transforms.Resize((img_size, img_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| def load_model(checkpoint_path: str, device: torch.device): | |
| ckpt = torch.load(checkpoint_path, map_location="cpu") | |
| args = ckpt.get("args", {}) | |
| model_name = args.get("model", "vit_base_patch16_224") | |
| img_size = int(args.get("img_size", 224)) | |
| id_to_label = ckpt.get("id_to_label", ID_TO_LABEL) | |
| id_to_label = {int(k): v for k, v in id_to_label.items()} | |
| model = timm.create_model( | |
| model_name, | |
| pretrained=False, | |
| num_classes=len(id_to_label), | |
| ) | |
| model.load_state_dict(ckpt["model"], strict=True) | |
| model.to(device) | |
| model.eval() | |
| tfm = build_transform(img_size) | |
| return model, tfm, id_to_label, model_name, img_size | |
| def get_eyeq_class_ids(id_to_label): | |
| """Return class IDs for Good, Usable, Reject. | |
| Falls back to the standard EyeQ ordering if the checkpoint does not store | |
| string labels in the expected form. | |
| """ | |
| label_to_id = {str(v).lower(): int(k) for k, v in id_to_label.items()} | |
| good_id = label_to_id.get("good", 0) | |
| usable_id = label_to_id.get("usable", 1) | |
| reject_id = label_to_id.get("reject", 2) | |
| return good_id, usable_id, reject_id | |
| def soft_eyeq_decision(probs, id_to_label, reject_threshold=0.60, reject_margin=0.15): | |
| """Apply a conservative Reject rule. | |
| Reject is only returned when: | |
| 1. P(Reject) >= reject_threshold, and | |
| 2. P(Reject) beats the best non-Reject class by reject_margin. | |
| Otherwise, the prediction is forced to Good vs Usable. | |
| """ | |
| good_id, usable_id, reject_id = get_eyeq_class_ids(id_to_label) | |
| prob_good = float(probs[good_id]) | |
| prob_usable = float(probs[usable_id]) | |
| prob_reject = float(probs[reject_id]) | |
| best_non_reject_id = good_id if prob_good >= prob_usable else usable_id | |
| best_non_reject_prob = max(prob_good, prob_usable) | |
| if ( | |
| prob_reject >= reject_threshold | |
| and (prob_reject - best_non_reject_prob) >= reject_margin | |
| ): | |
| pred_id = reject_id | |
| decision = "Soft rule: Reject threshold and margin were both satisfied." | |
| else: | |
| pred_id = best_non_reject_id | |
| decision = "Soft rule: Reject was not confident enough, so prediction was forced to Good/Usable." | |
| return pred_id, id_to_label[pred_id], decision | |
| def update_margin_slider(reject_threshold, reject_margin): | |
| """Keep reject_margin within a sensible range for the current threshold.""" | |
| max_margin = min(0.50, float(reject_threshold)) | |
| reject_margin = min(float(reject_margin), max_margin) | |
| return gr.update( | |
| maximum=max_margin, | |
| value=reject_margin, | |
| ) | |
| def predict_quality( | |
| image: Image.Image, | |
| model, | |
| tfm, | |
| id_to_label, | |
| device, | |
| reject_threshold=0.60, | |
| reject_margin=0.15, | |
| ): | |
| if image is None: | |
| return None, {}, "Upload an image to run QC." | |
| image = image.convert("RGB") | |
| x = tfm(image).unsqueeze(0).to(device) | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy() | |
| raw_pred_id = int(np.argmax(probs)) | |
| raw_pred_label = id_to_label[raw_pred_id] | |
| soft_pred_id, soft_pred_label, decision = soft_eyeq_decision( | |
| probs=probs, | |
| id_to_label=id_to_label, | |
| reject_threshold=reject_threshold, | |
| reject_margin=reject_margin, | |
| ) | |
| prob_dict = { | |
| id_to_label[i]: float(probs[i]) | |
| for i in range(len(probs)) | |
| } | |
| detail = ( | |
| f"Raw argmax: {raw_pred_label}\n" | |
| f"Soft decision: {soft_pred_label}\n" | |
| f"Reject threshold: {reject_threshold:.2f} | Reject margin: {reject_margin:.2f}\n" | |
| f"{decision}" | |
| ) | |
| return soft_pred_label, prob_dict, detail | |
| def make_app(checkpoint_path: str): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model, tfm, id_to_label, model_name, img_size = load_model(checkpoint_path, device) | |
| def run(image, reject_threshold, reject_margin): | |
| pred_label, prob_dict, detail = predict_quality( | |
| image=image, | |
| model=model, | |
| tfm=tfm, | |
| id_to_label=id_to_label, | |
| device=device, | |
| reject_threshold=reject_threshold, | |
| reject_margin=reject_margin, | |
| ) | |
| return pred_label, prob_dict, detail | |
| with gr.Blocks(title="EyeQ CFP Quality Control") as demo: | |
| gr.Markdown("# EyeQ CFP Quality Control") | |
| gr.Markdown( | |
| f"Model: `{model_name}` \n" | |
| f"Input size: `{img_size} × {img_size}` \n" | |
| f"Device: `{device}` \n" | |
| f"Checkpoint: `{checkpoint_path}`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="Input CFP", | |
| type="pil", | |
| height=520, | |
| ) | |
| with gr.Accordion("Soft Reject rule", open=True): | |
| reject_threshold = gr.Slider( | |
| minimum=0.40, | |
| maximum=0.95, | |
| value=0.60, | |
| step=0.01, | |
| label="Reject threshold", | |
| info="Minimum Reject probability required before an image can be called Reject.", | |
| ) | |
| reject_margin = gr.Slider( | |
| minimum=0.00, | |
| maximum=0.50, | |
| value=0.15, | |
| step=0.01, | |
| label="Reject margin", | |
| info="Reject must beat both Good and Usable by at least this much.", | |
| ) | |
| run_button = gr.Button("Run QC", variant="primary") | |
| with gr.Column(scale=1): | |
| pred_output = gr.Label(label="Predicted quality") | |
| prob_output = gr.Label(label="Class probabilities", num_top_classes=3) | |
| decision_output = gr.Textbox( | |
| label="Decision details", | |
| lines=4, | |
| interactive=False, | |
| ) | |
| run_inputs = [image_input, reject_threshold, reject_margin] | |
| run_outputs = [pred_output, prob_output, decision_output] | |
| run_button.click( | |
| fn=run, | |
| inputs=run_inputs, | |
| outputs=run_outputs, | |
| ) | |
| image_input.change( | |
| fn=run, | |
| inputs=run_inputs, | |
| outputs=run_outputs, | |
| ) | |
| reject_threshold.change( | |
| fn=update_margin_slider, | |
| inputs=[reject_threshold, reject_margin], | |
| outputs=reject_margin, | |
| ).then( | |
| fn=run, | |
| inputs=run_inputs, | |
| outputs=run_outputs, | |
| ) | |
| reject_margin.change( | |
| fn=run, | |
| inputs=run_inputs, | |
| outputs=run_outputs, | |
| ) | |
| return demo | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", type=str, default="./checkpoints/eyeq_vit_base/eyeq_deploy.pt") | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| checkpoint_path = Path(args.checkpoint) | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| demo = make_app(str(checkpoint_path)) | |
| demo.launch( | |
| # server_name=args.host, | |
| # server_port=args.port, | |
| # share=args.share, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |