PathoSeg / app.py
tbuyuktanir's picture
Update app.py
0f9046d verified
import gradio as gr
from transformers import SamModel, SamProcessor
from PIL import Image, ImageDraw
import torch
import numpy as np
# Load SAM
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Global variables
global_state = {
"image": None,
"clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg
"bbox": None, # (x0, y0, x1, y1)
}
# Helper to apply mask overlay
def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
if mask.ndim == 3:
mask = mask.squeeze()
if mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
mask_img = Image.fromarray(mask).convert("L")
color_mask = Image.new("RGB", image.size, color)
blended = Image.composite(color_mask, image, mask_img)
return Image.blend(image, blended, alpha=0.5)
# Set image
def upload_image(img):
global_state["image"] = img
global_state["clicks"] = []
global_state["bbox"] = None
return "Image uploaded. Now click or draw a box.", img
# Handle point clicks
def on_click(evt: gr.SelectData):
if global_state["image"] is None:
return gr.update(), "Please upload an image first."
# Default to foreground click
global_state["clicks"].append((evt.index[0], evt.index[1], 1))
return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"
# Handle bounding box input
def set_bbox(x0: int, y0: int, x1: int, y1: int):
global_state["bbox"] = (x0, y0, x1, y1)
return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"
# Run segmentation
def run_segmentation():
if global_state["image"] is None:
return None, "Please upload an image first."
image = global_state["image"]
inputs = processor(image, return_tensors="pt").to(device)
if global_state["clicks"]:
coords = [[(x, y) for (x, y, l) in global_state["clicks"]]]
labels = [[l for (_, _, l) in global_state["clicks"]]]
input_points = torch.tensor([coords], device=device) # shape [1, 1, N, 2]
input_labels = torch.tensor([labels], device=device) # shape [1, 1, N]
inputs.update({
"input_points": input_points,
"input_labels": input_labels
})
if global_state["bbox"]:
x0, y0, x1, y1 = global_state["bbox"]
box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
inputs.update({"input_boxes": box})
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0]
final_mask = masks[0].numpy().astype(np.uint8) # shape: (H, W)
overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)
return overlayed, "Segmentation complete."
# Reset
def reset_all():
global_state["image"] = None
global_state["clicks"] = []
global_state["bbox"] = None
return None, None, "State reset."
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("Interactive Pathology Segmentation with SAM")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Pathology Image")
upload_status = gr.Textbox(label="Status")
upload_btn = gr.Button("Upload & Set Image")
click_output = gr.Textbox(label="Click Info")
run_btn = gr.Button("Run Segmentation")
reset_btn = gr.Button("Reset")
bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
set_bbox_btn = gr.Button("Set Bounding Box")
examples = gr.Examples(
examples=[
["images/image.webp"],
["images/image2.webp"],
["images/image3.webp"]
],
inputs=[image_input],
label="Example Pathology Images"
)
with gr.Column():
image_output = gr.Image(type="pil", label="Segmentation Output")
# Handlers
upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
image_input.select(on_click, outputs=[image_output, click_output])
run_btn.click(run_segmentation, outputs=[image_output, upload_status])
reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
set_bbox_btn.click(
lambda coords: set_bbox(*map(int, coords.split(","))),
inputs=bbox_coords,
outputs=click_output
)
demo.launch()