import gradio as gr from transformers import SamModel, SamProcessor from PIL import Image, ImageDraw import torch import numpy as np # Load SAM processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") model = SamModel.from_pretrained("facebook/sam-vit-huge") model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Global variables global_state = { "image": None, "clicks": [], # List of tuples: (x, y, label) where label=1 fg, 0 bg "bbox": None, # (x0, y0, x1, y1) } # Helper to apply mask overlay def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image: if mask.ndim == 3: mask = mask.squeeze() if mask.max() <= 1: mask = (mask * 255).astype(np.uint8) mask_img = Image.fromarray(mask).convert("L") color_mask = Image.new("RGB", image.size, color) blended = Image.composite(color_mask, image, mask_img) return Image.blend(image, blended, alpha=0.5) # Set image def upload_image(img): global_state["image"] = img global_state["clicks"] = [] global_state["bbox"] = None return "Image uploaded. Now click or draw a box.", img # Handle point clicks def on_click(evt: gr.SelectData): if global_state["image"] is None: return gr.update(), "Please upload an image first." # Default to foreground click global_state["clicks"].append((evt.index[0], evt.index[1], 1)) return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})" # Handle bounding box input def set_bbox(x0: int, y0: int, x1: int, y1: int): global_state["bbox"] = (x0, y0, x1, y1) return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})" # Run segmentation def run_segmentation(): if global_state["image"] is None: return None, "Please upload an image first." image = global_state["image"] inputs = processor(image, return_tensors="pt").to(device) if global_state["clicks"]: coords = [[(x, y) for (x, y, l) in global_state["clicks"]]] labels = [[l for (_, _, l) in global_state["clicks"]]] input_points = torch.tensor([coords], device=device) # shape [1, 1, N, 2] input_labels = torch.tensor([labels], device=device) # shape [1, 1, N] inputs.update({ "input_points": input_points, "input_labels": input_labels }) if global_state["bbox"]: x0, y0, x1, y1 = global_state["bbox"] box = torch.tensor([[[x0, y0, x1, y1]]], device=device) inputs.update({"input_boxes": box}) with torch.no_grad(): outputs = model(**inputs, multimask_output=False) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() )[0] final_mask = masks[0].numpy().astype(np.uint8) # shape: (H, W) overlayed = apply_mask_overlay(image.convert("RGB"), final_mask) return overlayed, "Segmentation complete." # Reset def reset_all(): global_state["image"] = None global_state["clicks"] = [] global_state["bbox"] = None return None, None, "State reset." # Gradio UI with gr.Blocks() as demo: gr.Markdown("Interactive Pathology Segmentation with SAM") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Pathology Image") upload_status = gr.Textbox(label="Status") upload_btn = gr.Button("Upload & Set Image") click_output = gr.Textbox(label="Click Info") run_btn = gr.Button("Run Segmentation") reset_btn = gr.Button("Reset") bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)") set_bbox_btn = gr.Button("Set Bounding Box") examples = gr.Examples( examples=[ ["images/image.webp"], ["images/image2.webp"], ["images/image3.webp"] ], inputs=[image_input], label="Example Pathology Images" ) with gr.Column(): image_output = gr.Image(type="pil", label="Segmentation Output") # Handlers upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output]) image_input.select(on_click, outputs=[image_output, click_output]) run_btn.click(run_segmentation, outputs=[image_output, upload_status]) reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status]) set_bbox_btn.click( lambda coords: set_bbox(*map(int, coords.split(","))), inputs=bbox_coords, outputs=click_output ) demo.launch()