CFPVesselSeg / app.py
farrell236's picture
add src
e99a83c
import argparse
from pathlib import Path
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from augmentations import IMAGENET_MEAN, IMAGENET_STD
from models import build_model
APP_STATE = {}
def load_model(args, device):
model = build_model(
model_name=args.model,
num_classes=1,
in_channels=3,
image_size=args.image_size,
backbone=args.backbone,
pretrained=False,
base_channels=args.base_channels,
dropout=args.dropout,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
return model
def preprocess_image(image, image_size):
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
else:
image = np.array(image)
if image.ndim == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[-1] == 4:
image = image[..., :3]
original_rgb = image.copy()
resized = cv2.resize(
image,
(image_size, image_size),
interpolation=cv2.INTER_LINEAR,
)
resized = resized.astype(np.float32) / 255.0
mean = np.array(IMAGENET_MEAN, dtype=np.float32).reshape(1, 1, 3)
std = np.array(IMAGENET_STD, dtype=np.float32).reshape(1, 1, 3)
resized = (resized - mean) / std
tensor = torch.from_numpy(resized).permute(2, 0, 1).unsqueeze(0).float()
return tensor, original_rgb
def overlay_mask(image_rgb, mask, alpha=0.45):
image_rgb = image_rgb.astype(np.uint8)
red = np.zeros_like(image_rgb)
red[..., 0] = 255
mask_3ch = mask[..., None]
overlay = image_rgb * (1 - alpha * mask_3ch) + red * (alpha * mask_3ch)
overlay = np.clip(overlay, 0, 255).astype(np.uint8)
return overlay
def run_inference(image, threshold):
tensor, original_rgb = preprocess_image(
image=image,
image_size=APP_STATE["image_size"],
)
tensor = tensor.to(APP_STATE["device"])
with torch.no_grad():
logits = APP_STATE["model"](tensor)
probs = torch.sigmoid(logits)
prob_map = probs[0, 0].detach().cpu().numpy()
original_h, original_w = original_rgb.shape[:2]
prob_map = cv2.resize(
prob_map,
(original_w, original_h),
interpolation=cv2.INTER_LINEAR,
)
pred_mask = (prob_map >= threshold).astype(np.float32)
return original_rgb, prob_map, pred_mask
def predict(image, threshold, alpha):
if image is None:
return None, None, None
original_rgb, prob_map, pred_mask = run_inference(image, threshold)
overlay = overlay_mask(original_rgb, pred_mask, alpha=alpha)
prob_vis = (prob_map * 255).clip(0, 255).astype(np.uint8)
mask_vis = (pred_mask * 255).astype(np.uint8)
return overlay, prob_vis, mask_vis
def build_app():
css = """
#input_image {
height: 430px !important;
}
#input_image img {
object-fit: contain !important;
max-height: 430px !important;
}
#overlay_output {
height: 200px !important;
}
#overlay_output img {
object-fit: contain !important;
max-height: 200px !important;
}
#prob_output {
height: 200px !important;
}
#prob_output img {
object-fit: contain !important;
max-height: 200px !important;
}
#mask_output {
height: 430px !important;
}
#mask_output img {
object-fit: contain !important;
max-height: 430px !important;
}
"""
with gr.Blocks(title="Retina Vessel Segmentation", css=css) as demo:
gr.Markdown("# Retina Vessel Segmentation")
gr.Markdown(
f"Model: `{APP_STATE['model_name']}` | "
f"Backbone: `{APP_STATE['backbone']}` | "
f"Image size: `{APP_STATE['image_size']}`"
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Input CFP Image",
elem_id="input_image",
height=430,
)
threshold = gr.Slider(
minimum=0.05,
maximum=0.95,
value=0.5,
step=0.05,
label="Prediction Threshold",
)
alpha = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.45,
step=0.05,
label="Overlay Alpha",
)
run_button = gr.Button("Segment")
with gr.Column(scale=1.2):
with gr.Row():
overlay_output = gr.Image(
type="numpy",
label="Overlay",
elem_id="overlay_output",
height=200,
)
prob_output = gr.Image(
type="numpy",
label="Probability Map",
elem_id="prob_output",
height=200,
)
mask_output = gr.Image(
type="numpy",
label="Binary Mask",
elem_id="mask_output",
height=430,
)
run_button.click(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
threshold.change(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
alpha.change(
fn=predict,
inputs=[input_image, threshold, alpha],
outputs=[overlay_output, prob_output, mask_output],
)
return demo
def parse_args():
parser = argparse.ArgumentParser(description="Gradio app for retina vessel segmentation.")
parser.add_argument("--checkpoint", type=str, default="checkpoints/fives_resunet/best.pt")
parser.add_argument("--image-size", type=int, default=1024)
parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"])
parser.add_argument("--backbone", type=str, default="resnet50")
parser.add_argument("--base-channels", type=int, default=32)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--server-name", type=str, default="127.0.0.1")
parser.add_argument("--server-port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
device = "cpu"
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
APP_STATE["device"] = torch.device(device)
APP_STATE["image_size"] = args.image_size
APP_STATE["model_name"] = args.model
APP_STATE["backbone"] = args.backbone
APP_STATE["model"] = load_model(
args=args,
device=APP_STATE["device"],
)
print(f"Loaded checkpoint: {checkpoint_path}")
print(f"Device: {APP_STATE['device']}")
print(f"Model: {APP_STATE['model_name']}")
print(f"Backbone: {APP_STATE['backbone']}")
print(f"Image size: {APP_STATE['image_size']}")
demo = build_app()
demo.launch(
# server_name=args.server_name,
# server_port=args.server_port,
# share=args.share,
)