File size: 4,731 Bytes
fadee54 0f9046d fadee54 0f9046d fadee54 0f9046d fadee54 0f9046d fadee54 bf3078d 441250a fadee54 95d468b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
from transformers import SamModel, SamProcessor
from PIL import Image, ImageDraw
import torch
import numpy as np
# Load SAM
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Global variables
global_state = {
"image": None,
"clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg
"bbox": None, # (x0, y0, x1, y1)
}
# Helper to apply mask overlay
def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
if mask.ndim == 3:
mask = mask.squeeze()
if mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
mask_img = Image.fromarray(mask).convert("L")
color_mask = Image.new("RGB", image.size, color)
blended = Image.composite(color_mask, image, mask_img)
return Image.blend(image, blended, alpha=0.5)
# Set image
def upload_image(img):
global_state["image"] = img
global_state["clicks"] = []
global_state["bbox"] = None
return "Image uploaded. Now click or draw a box.", img
# Handle point clicks
def on_click(evt: gr.SelectData):
if global_state["image"] is None:
return gr.update(), "Please upload an image first."
# Default to foreground click
global_state["clicks"].append((evt.index[0], evt.index[1], 1))
return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"
# Handle bounding box input
def set_bbox(x0: int, y0: int, x1: int, y1: int):
global_state["bbox"] = (x0, y0, x1, y1)
return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"
# Run segmentation
def run_segmentation():
if global_state["image"] is None:
return None, "Please upload an image first."
image = global_state["image"]
inputs = processor(image, return_tensors="pt").to(device)
if global_state["clicks"]:
coords = [[(x, y) for (x, y, l) in global_state["clicks"]]]
labels = [[l for (_, _, l) in global_state["clicks"]]]
input_points = torch.tensor([coords], device=device) # shape [1, 1, N, 2]
input_labels = torch.tensor([labels], device=device) # shape [1, 1, N]
inputs.update({
"input_points": input_points,
"input_labels": input_labels
})
if global_state["bbox"]:
x0, y0, x1, y1 = global_state["bbox"]
box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
inputs.update({"input_boxes": box})
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0]
final_mask = masks[0].numpy().astype(np.uint8) # shape: (H, W)
overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
return overlayed, "Segmentation complete."
# Reset
def reset_all():
global_state["image"] = None
global_state["clicks"] = []
global_state["bbox"] = None
return None, None, "State reset."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("Interactive Pathology Segmentation with SAM")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Pathology Image")
upload_status = gr.Textbox(label="Status")
upload_btn = gr.Button("Upload & Set Image")
click_output = gr.Textbox(label="Click Info")
run_btn = gr.Button("Run Segmentation")
reset_btn = gr.Button("Reset")
bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
set_bbox_btn = gr.Button("Set Bounding Box")
examples = gr.Examples(
examples=[
["images/image.webp"],
["images/image2.webp"],
["images/image3.webp"]
],
inputs=[image_input],
label="Example Pathology Images"
)
with gr.Column():
image_output = gr.Image(type="pil", label="Segmentation Output")
# Handlers
upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
image_input.select(on_click, outputs=[image_output, click_output])
run_btn.click(run_segmentation, outputs=[image_output, upload_status])
reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
set_bbox_btn.click(
lambda coords: set_bbox(*map(int, coords.split(","))),
inputs=bbox_coords,
outputs=click_output
)
demo.launch() |