import argparse from pathlib import Path import cv2 import gradio as gr import numpy as np import torch from PIL import Image from augmentations import IMAGENET_MEAN, IMAGENET_STD from models import build_model APP_STATE = {} def load_model(args, device): model = build_model( model_name=args.model, num_classes=1, in_channels=3, image_size=args.image_size, backbone=args.backbone, pretrained=False, base_channels=args.base_channels, dropout=args.dropout, ) checkpoint = torch.load(args.checkpoint, map_location="cpu") if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] else: state_dict = checkpoint model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() return model def preprocess_image(image, image_size): if isinstance(image, Image.Image): image = np.array(image.convert("RGB")) else: image = np.array(image) if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) if image.shape[-1] == 4: image = image[..., :3] original_rgb = image.copy() resized = cv2.resize( image, (image_size, image_size), interpolation=cv2.INTER_LINEAR, ) resized = resized.astype(np.float32) / 255.0 mean = np.array(IMAGENET_MEAN, dtype=np.float32).reshape(1, 1, 3) std = np.array(IMAGENET_STD, dtype=np.float32).reshape(1, 1, 3) resized = (resized - mean) / std tensor = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() return tensor, original_rgb def overlay_mask(image_rgb, mask, alpha=0.45): image_rgb = image_rgb.astype(np.uint8) red = np.zeros_like(image_rgb) red[..., 0] = 255 mask_3ch = mask[..., None] overlay = image_rgb * (1 - alpha * mask_3ch) + red * (alpha * mask_3ch) overlay = np.clip(overlay, 0, 255).astype(np.uint8) return overlay def run_inference(image, threshold): tensor, original_rgb = preprocess_image( image=image, image_size=APP_STATE["image_size"], ) tensor = tensor.to(APP_STATE["device"]) with torch.no_grad(): logits = APP_STATE["model"](tensor) probs = torch.sigmoid(logits) prob_map = probs[0, 0].detach().cpu().numpy() original_h, original_w = original_rgb.shape[:2] prob_map = cv2.resize( prob_map, (original_w, original_h), interpolation=cv2.INTER_LINEAR, ) pred_mask = (prob_map >= threshold).astype(np.float32) return original_rgb, prob_map, pred_mask def predict(image, threshold, alpha): if image is None: return None, None, None original_rgb, prob_map, pred_mask = run_inference(image, threshold) overlay = overlay_mask(original_rgb, pred_mask, alpha=alpha) prob_vis = (prob_map * 255).clip(0, 255).astype(np.uint8) mask_vis = (pred_mask * 255).astype(np.uint8) return overlay, prob_vis, mask_vis def build_app(): css = """ #input_image { height: 430px !important; } #input_image img { object-fit: contain !important; max-height: 430px !important; } #overlay_output { height: 200px !important; } #overlay_output img { object-fit: contain !important; max-height: 200px !important; } #prob_output { height: 200px !important; } #prob_output img { object-fit: contain !important; max-height: 200px !important; } #mask_output { height: 430px !important; } #mask_output img { object-fit: contain !important; max-height: 430px !important; } """ with gr.Blocks(title="Retina Vessel Segmentation", css=css) as demo: gr.Markdown("# Retina Vessel Segmentation") gr.Markdown( f"Model: `{APP_STATE['model_name']}` | " f"Backbone: `{APP_STATE['backbone']}` | " f"Image size: `{APP_STATE['image_size']}`" ) with gr.Row(equal_height=False): with gr.Column(scale=1): input_image = gr.Image( type="pil", label="Input CFP Image", elem_id="input_image", height=430, ) threshold = gr.Slider( minimum=0.05, maximum=0.95, value=0.5, step=0.05, label="Prediction Threshold", ) alpha = gr.Slider( minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="Overlay Alpha", ) run_button = gr.Button("Segment") with gr.Column(scale=1.2): with gr.Row(): overlay_output = gr.Image( type="numpy", label="Overlay", elem_id="overlay_output", height=200, ) prob_output = gr.Image( type="numpy", label="Probability Map", elem_id="prob_output", height=200, ) mask_output = gr.Image( type="numpy", label="Binary Mask", elem_id="mask_output", height=430, ) run_button.click( fn=predict, inputs=[input_image, threshold, alpha], outputs=[overlay_output, prob_output, mask_output], ) threshold.change( fn=predict, inputs=[input_image, threshold, alpha], outputs=[overlay_output, prob_output, mask_output], ) alpha.change( fn=predict, inputs=[input_image, threshold, alpha], outputs=[overlay_output, prob_output, mask_output], ) return demo def parse_args(): parser = argparse.ArgumentParser(description="Gradio app for retina vessel segmentation.") parser.add_argument("--checkpoint", type=str, default="checkpoints/fives_resunet/best.pt") parser.add_argument("--image-size", type=int, default=1024) parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"]) parser.add_argument("--backbone", type=str, default="resnet50") parser.add_argument("--base-channels", type=int, default=32) parser.add_argument("--dropout", type=float, default=0.0) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--server-name", type=str, default="127.0.0.1") parser.add_argument("--server-port", type=int, default=7860) parser.add_argument("--share", action="store_true") return parser.parse_args() if __name__ == "__main__": args = parse_args() device = args.device if device == "cuda" and not torch.cuda.is_available(): device = "cpu" checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") APP_STATE["device"] = torch.device(device) APP_STATE["image_size"] = args.image_size APP_STATE["model_name"] = args.model APP_STATE["backbone"] = args.backbone APP_STATE["model"] = load_model( args=args, device=APP_STATE["device"], ) print(f"Loaded checkpoint: {checkpoint_path}") print(f"Device: {APP_STATE['device']}") print(f"Model: {APP_STATE['model_name']}") print(f"Backbone: {APP_STATE['backbone']}") print(f"Image size: {APP_STATE['image_size']}") demo = build_app() demo.launch( # server_name=args.server_name, # server_port=args.server_port, # share=args.share, )