Spaces:
Running
Running
| import argparse | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from augmentations import IMAGENET_MEAN, IMAGENET_STD | |
| from models import build_model | |
| APP_STATE = {} | |
| def load_model(args, device): | |
| model = build_model( | |
| model_name=args.model, | |
| num_classes=1, | |
| in_channels=3, | |
| image_size=args.image_size, | |
| backbone=args.backbone, | |
| pretrained=False, | |
| base_channels=args.base_channels, | |
| dropout=args.dropout, | |
| ) | |
| checkpoint = torch.load(args.checkpoint, map_location="cpu") | |
| if "model_state_dict" in checkpoint: | |
| state_dict = checkpoint["model_state_dict"] | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def preprocess_image(image, image_size): | |
| if isinstance(image, Image.Image): | |
| image = np.array(image.convert("RGB")) | |
| else: | |
| image = np.array(image) | |
| if image.ndim == 2: | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| if image.shape[-1] == 4: | |
| image = image[..., :3] | |
| original_rgb = image.copy() | |
| resized = cv2.resize( | |
| image, | |
| (image_size, image_size), | |
| interpolation=cv2.INTER_LINEAR, | |
| ) | |
| resized = resized.astype(np.float32) / 255.0 | |
| mean = np.array(IMAGENET_MEAN, dtype=np.float32).reshape(1, 1, 3) | |
| std = np.array(IMAGENET_STD, dtype=np.float32).reshape(1, 1, 3) | |
| resized = (resized - mean) / std | |
| tensor = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float() | |
| return tensor, original_rgb | |
| def overlay_mask(image_rgb, mask, alpha=0.45): | |
| image_rgb = image_rgb.astype(np.uint8) | |
| red = np.zeros_like(image_rgb) | |
| red[..., 0] = 255 | |
| mask_3ch = mask[..., None] | |
| overlay = image_rgb * (1 - alpha * mask_3ch) + red * (alpha * mask_3ch) | |
| overlay = np.clip(overlay, 0, 255).astype(np.uint8) | |
| return overlay | |
| def run_inference(image, threshold): | |
| tensor, original_rgb = preprocess_image( | |
| image=image, | |
| image_size=APP_STATE["image_size"], | |
| ) | |
| tensor = tensor.to(APP_STATE["device"]) | |
| with torch.no_grad(): | |
| logits = APP_STATE["model"](tensor) | |
| probs = torch.sigmoid(logits) | |
| prob_map = probs[0, 0].detach().cpu().numpy() | |
| original_h, original_w = original_rgb.shape[:2] | |
| prob_map = cv2.resize( | |
| prob_map, | |
| (original_w, original_h), | |
| interpolation=cv2.INTER_LINEAR, | |
| ) | |
| pred_mask = (prob_map >= threshold).astype(np.float32) | |
| return original_rgb, prob_map, pred_mask | |
| def predict(image, threshold, alpha): | |
| if image is None: | |
| return None, None, None | |
| original_rgb, prob_map, pred_mask = run_inference(image, threshold) | |
| overlay = overlay_mask(original_rgb, pred_mask, alpha=alpha) | |
| prob_vis = (prob_map * 255).clip(0, 255).astype(np.uint8) | |
| mask_vis = (pred_mask * 255).astype(np.uint8) | |
| return overlay, prob_vis, mask_vis | |
| def build_app(): | |
| css = """ | |
| #input_image { | |
| height: 430px !important; | |
| } | |
| #input_image img { | |
| object-fit: contain !important; | |
| max-height: 430px !important; | |
| } | |
| #overlay_output { | |
| height: 200px !important; | |
| } | |
| #overlay_output img { | |
| object-fit: contain !important; | |
| max-height: 200px !important; | |
| } | |
| #prob_output { | |
| height: 200px !important; | |
| } | |
| #prob_output img { | |
| object-fit: contain !important; | |
| max-height: 200px !important; | |
| } | |
| #mask_output { | |
| height: 430px !important; | |
| } | |
| #mask_output img { | |
| object-fit: contain !important; | |
| max-height: 430px !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Retina Vessel Segmentation", css=css) as demo: | |
| gr.Markdown("# Retina Vessel Segmentation") | |
| gr.Markdown( | |
| f"Model: `{APP_STATE['model_name']}` | " | |
| f"Backbone: `{APP_STATE['backbone']}` | " | |
| f"Image size: `{APP_STATE['image_size']}`" | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Input CFP Image", | |
| elem_id="input_image", | |
| height=430, | |
| ) | |
| threshold = gr.Slider( | |
| minimum=0.05, | |
| maximum=0.95, | |
| value=0.5, | |
| step=0.05, | |
| label="Prediction Threshold", | |
| ) | |
| alpha = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.45, | |
| step=0.05, | |
| label="Overlay Alpha", | |
| ) | |
| run_button = gr.Button("Segment") | |
| with gr.Column(scale=1.2): | |
| with gr.Row(): | |
| overlay_output = gr.Image( | |
| type="numpy", | |
| label="Overlay", | |
| elem_id="overlay_output", | |
| height=200, | |
| ) | |
| prob_output = gr.Image( | |
| type="numpy", | |
| label="Probability Map", | |
| elem_id="prob_output", | |
| height=200, | |
| ) | |
| mask_output = gr.Image( | |
| type="numpy", | |
| label="Binary Mask", | |
| elem_id="mask_output", | |
| height=430, | |
| ) | |
| run_button.click( | |
| fn=predict, | |
| inputs=[input_image, threshold, alpha], | |
| outputs=[overlay_output, prob_output, mask_output], | |
| ) | |
| threshold.change( | |
| fn=predict, | |
| inputs=[input_image, threshold, alpha], | |
| outputs=[overlay_output, prob_output, mask_output], | |
| ) | |
| alpha.change( | |
| fn=predict, | |
| inputs=[input_image, threshold, alpha], | |
| outputs=[overlay_output, prob_output, mask_output], | |
| ) | |
| return demo | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Gradio app for retina vessel segmentation.") | |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/fives_resunet/best.pt") | |
| parser.add_argument("--image-size", type=int, default=1024) | |
| parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"]) | |
| parser.add_argument("--backbone", type=str, default="resnet50") | |
| parser.add_argument("--base-channels", type=int, default=32) | |
| parser.add_argument("--dropout", type=float, default=0.0) | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--server-name", type=str, default="127.0.0.1") | |
| parser.add_argument("--server-port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| device = args.device | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| device = "cpu" | |
| checkpoint_path = Path(args.checkpoint) | |
| if not checkpoint_path.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| APP_STATE["device"] = torch.device(device) | |
| APP_STATE["image_size"] = args.image_size | |
| APP_STATE["model_name"] = args.model | |
| APP_STATE["backbone"] = args.backbone | |
| APP_STATE["model"] = load_model( | |
| args=args, | |
| device=APP_STATE["device"], | |
| ) | |
| print(f"Loaded checkpoint: {checkpoint_path}") | |
| print(f"Device: {APP_STATE['device']}") | |
| print(f"Model: {APP_STATE['model_name']}") | |
| print(f"Backbone: {APP_STATE['backbone']}") | |
| print(f"Image size: {APP_STATE['image_size']}") | |
| demo = build_app() | |
| demo.launch( | |
| # server_name=args.server_name, | |
| # server_port=args.server_port, | |
| # share=args.share, | |
| ) |