| import gradio as gr |
| from transformers import SamModel, SamProcessor |
| from PIL import Image, ImageDraw |
| import torch |
| import numpy as np |
|
|
| |
| processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") |
| model = SamModel.from_pretrained("facebook/sam-vit-huge") |
| model.eval() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model.to(device) |
|
|
| |
| global_state = { |
| "image": None, |
| "clicks": [], |
| "bbox": None, |
| } |
|
|
| |
| def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image: |
| mask_img = Image.fromarray(mask.astype(np.uint8) * 255).convert("L") |
| color_mask = Image.new("RGB", image.size, color) |
| mask_rgb = Image.composite(color_mask, image, mask_img) |
| blended = Image.blend(image, mask_rgb, alpha=0.5) |
| return blended |
|
|
| |
| def upload_image(img): |
| global_state["image"] = img |
| global_state["clicks"] = [] |
| global_state["bbox"] = None |
| return "Image uploaded. Now click or draw a box.", img |
|
|
| |
| def on_click(evt: gr.SelectData): |
| if global_state["image"] is None: |
| return gr.update(), "Please upload an image first." |
|
|
| |
| global_state["clicks"].append((evt.index[0], evt.index[1], 1)) |
| return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})" |
|
|
| |
| def set_bbox(x0: int, y0: int, x1: int, y1: int): |
| global_state["bbox"] = (x0, y0, x1, y1) |
| return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})" |
|
|
| |
| def run_segmentation(): |
| if global_state["image"] is None: |
| return None, "Please upload an image first." |
|
|
| image = global_state["image"] |
| inputs = processor(image, return_tensors="pt").to(device) |
|
|
| if global_state["clicks"]: |
| points = torch.tensor([[[x, y] for (x, y, l) in global_state["clicks"]]], device=device) |
| labels = torch.tensor([[l for (_, _, l) in global_state["clicks"]]], device=device) |
| inputs.update({"input_points": points, "input_labels": labels}) |
|
|
| if global_state["bbox"]: |
| x0, y0, x1, y1 = global_state["bbox"] |
| box = torch.tensor([[[x0, y0, x1, y1]]], device=device) |
| inputs.update({"input_boxes": box}) |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs, multimask_output=False) |
|
|
| masks = processor.image_processor.post_process_masks( |
| outputs.pred_masks.cpu(), |
| inputs["original_sizes"].cpu(), |
| inputs["reshaped_input_sizes"].cpu() |
| )[0] |
|
|
| final_mask = masks[0].numpy() |
| overlayed = apply_mask_overlay(image.convert("RGB"), final_mask) |
|
|
| return overlayed, "Segmentation complete." |
|
|
| |
| def reset_all(): |
| global_state["image"] = None |
| global_state["clicks"] = [] |
| global_state["bbox"] = None |
| return None, None, "State reset." |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("Interactive Pathology Segmentation with SAM") |
| |
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="Upload Pathology Image") |
| upload_status = gr.Textbox(label="Status") |
| upload_btn = gr.Button("Upload & Set Image") |
| click_output = gr.Textbox(label="Click Info") |
| run_btn = gr.Button("Run Segmentation") |
| reset_btn = gr.Button("Reset") |
| bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)") |
| set_bbox_btn = gr.Button("Set Bounding Box") |
| examples = gr.Examples( |
| examples=[ |
| ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"], |
| ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"], |
| ["https://www.webpathology.com/_next/image?url=https%3A%2F%2Fd3cyex60hhnlth.cloudfront.net%2Ffit-in%2F650x650%2Ffilters%3Aformat(webp)%2Fcase%2Fdetail_images%2Fc354_detail.jpg&w=750&q=75"] |
| ], |
| inputs=[image_input], |
| label="Example Pathology Images" |
| ) |
| with gr.Column(): |
| image_output = gr.Image(type="pil", label="Segmentation Output") |
|
|
| |
| upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output]) |
| image_input.select(on_click, outputs=[image_output, click_output]) |
| run_btn.click(run_segmentation, outputs=[image_output, upload_status]) |
| reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status]) |
| set_bbox_btn.click( |
| lambda coords: set_bbox(*map(int, coords.split(","))), |
| inputs=bbox_coords, |
| outputs=click_output |
| ) |
|
|
| demo.launch() |