|
|
import gradio as gr |
|
|
from transformers import SamModel, SamProcessor |
|
|
from PIL import Image, ImageDraw |
|
|
import torch |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
|
|
model = SamModel.from_pretrained("facebook/sam-vit-huge") |
|
|
model.eval() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
global_state = { |
|
|
"image": None, |
|
|
"clicks": [], |
|
|
"bbox": None, |
|
|
} |
|
|
|
|
|
|
|
|
def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image: |
|
|
if mask.ndim == 3: |
|
|
mask = mask.squeeze() |
|
|
if mask.max() <= 1: |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
mask_img = Image.fromarray(mask).convert("L") |
|
|
color_mask = Image.new("RGB", image.size, color) |
|
|
blended = Image.composite(color_mask, image, mask_img) |
|
|
return Image.blend(image, blended, alpha=0.5) |
|
|
|
|
|
|
|
|
def upload_image(img): |
|
|
global_state["image"] = img |
|
|
global_state["clicks"] = [] |
|
|
global_state["bbox"] = None |
|
|
return "Image uploaded. Now click or draw a box.", img |
|
|
|
|
|
|
|
|
def on_click(evt: gr.SelectData): |
|
|
if global_state["image"] is None: |
|
|
return gr.update(), "Please upload an image first." |
|
|
|
|
|
|
|
|
global_state["clicks"].append((evt.index[0], evt.index[1], 1)) |
|
|
return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})" |
|
|
|
|
|
|
|
|
def set_bbox(x0: int, y0: int, x1: int, y1: int): |
|
|
global_state["bbox"] = (x0, y0, x1, y1) |
|
|
return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})" |
|
|
|
|
|
|
|
|
def run_segmentation(): |
|
|
if global_state["image"] is None: |
|
|
return None, "Please upload an image first." |
|
|
|
|
|
image = global_state["image"] |
|
|
inputs = processor(image, return_tensors="pt").to(device) |
|
|
|
|
|
if global_state["clicks"]: |
|
|
coords = [[(x, y) for (x, y, l) in global_state["clicks"]]] |
|
|
labels = [[l for (_, _, l) in global_state["clicks"]]] |
|
|
input_points = torch.tensor([coords], device=device) |
|
|
input_labels = torch.tensor([labels], device=device) |
|
|
|
|
|
inputs.update({ |
|
|
"input_points": input_points, |
|
|
"input_labels": input_labels |
|
|
}) |
|
|
|
|
|
if global_state["bbox"]: |
|
|
x0, y0, x1, y1 = global_state["bbox"] |
|
|
box = torch.tensor([[[x0, y0, x1, y1]]], device=device) |
|
|
inputs.update({"input_boxes": box}) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs, multimask_output=False) |
|
|
|
|
|
masks = processor.image_processor.post_process_masks( |
|
|
outputs.pred_masks.cpu(), |
|
|
inputs["original_sizes"].cpu(), |
|
|
inputs["reshaped_input_sizes"].cpu() |
|
|
)[0] |
|
|
|
|
|
final_mask = masks[0].numpy().astype(np.uint8) |
|
|
overlayed = apply_mask_overlay(image.convert("RGB"), final_mask) |
|
|
|
|
|
return overlayed, "Segmentation complete." |
|
|
|
|
|
|
|
|
def reset_all(): |
|
|
global_state["image"] = None |
|
|
global_state["clicks"] = [] |
|
|
global_state["bbox"] = None |
|
|
return None, None, "State reset." |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("Interactive Pathology Segmentation with SAM") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image(type="pil", label="Upload Pathology Image") |
|
|
upload_status = gr.Textbox(label="Status") |
|
|
upload_btn = gr.Button("Upload & Set Image") |
|
|
click_output = gr.Textbox(label="Click Info") |
|
|
run_btn = gr.Button("Run Segmentation") |
|
|
reset_btn = gr.Button("Reset") |
|
|
bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)") |
|
|
set_bbox_btn = gr.Button("Set Bounding Box") |
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["images/image.webp"], |
|
|
["images/image2.webp"], |
|
|
["images/image3.webp"] |
|
|
], |
|
|
inputs=[image_input], |
|
|
label="Example Pathology Images" |
|
|
) |
|
|
with gr.Column(): |
|
|
image_output = gr.Image(type="pil", label="Segmentation Output") |
|
|
|
|
|
|
|
|
upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output]) |
|
|
image_input.select(on_click, outputs=[image_output, click_output]) |
|
|
run_btn.click(run_segmentation, outputs=[image_output, upload_status]) |
|
|
reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status]) |
|
|
set_bbox_btn.click( |
|
|
lambda coords: set_bbox(*map(int, coords.split(","))), |
|
|
inputs=bbox_coords, |
|
|
outputs=click_output |
|
|
) |
|
|
|
|
|
demo.launch() |