File size: 4,731 Bytes
fadee54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f9046d
 
 
 
 
fadee54
0f9046d
 
fadee54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f9046d
 
 
 
 
 
 
 
 
fadee54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f9046d
fadee54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3078d
441250a
 
fadee54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d468b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
from transformers import SamModel, SamProcessor
from PIL import Image, ImageDraw
import torch
import numpy as np

# Load SAM
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Global variables
global_state = {
    "image": None,
    "clicks": [],  # List of tuples: (x, y, label) where label=1 fg, 0 bg
    "bbox": None,  # (x0, y0, x1, y1)
}

# Helper to apply mask overlay
def apply_mask_overlay(image: Image.Image, mask: np.ndarray, color=(255, 0, 0)) -> Image.Image:
    if mask.ndim == 3:
        mask = mask.squeeze()
    if mask.max() <= 1:
        mask = (mask * 255).astype(np.uint8)
    mask_img = Image.fromarray(mask).convert("L")
    color_mask = Image.new("RGB", image.size, color)
    blended = Image.composite(color_mask, image, mask_img)
    return Image.blend(image, blended, alpha=0.5)

# Set image
def upload_image(img):
    global_state["image"] = img
    global_state["clicks"] = []
    global_state["bbox"] = None
    return "Image uploaded. Now click or draw a box.", img

# Handle point clicks
def on_click(evt: gr.SelectData):
    if global_state["image"] is None:
        return gr.update(), "Please upload an image first."

    # Default to foreground click
    global_state["clicks"].append((evt.index[0], evt.index[1], 1))
    return gr.update(), f"Point added: ({evt.index[0]}, {evt.index[1]})"

# Handle bounding box input
def set_bbox(x0: int, y0: int, x1: int, y1: int):
    global_state["bbox"] = (x0, y0, x1, y1)
    return f"Bounding box set: ({x0}, {y0}, {x1}, {y1})"

# Run segmentation
def run_segmentation():
    if global_state["image"] is None:
        return None, "Please upload an image first."

    image = global_state["image"]
    inputs = processor(image, return_tensors="pt").to(device)

    if global_state["clicks"]:
        coords = [[(x, y) for (x, y, l) in global_state["clicks"]]]
        labels = [[l for (_, _, l) in global_state["clicks"]]]
        input_points = torch.tensor([coords], device=device)  # shape [1, 1, N, 2]
        input_labels = torch.tensor([labels], device=device)  # shape [1, 1, N]

        inputs.update({
            "input_points": input_points,
            "input_labels": input_labels
            })

    if global_state["bbox"]:
        x0, y0, x1, y1 = global_state["bbox"]
        box = torch.tensor([[[x0, y0, x1, y1]]], device=device)
        inputs.update({"input_boxes": box})

    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)

    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0]

    final_mask = masks[0].numpy().astype(np.uint8)  # shape: (H, W)
    overlayed = apply_mask_overlay(image.convert("RGB"), final_mask)

    return overlayed, "Segmentation complete."

# Reset
def reset_all():
    global_state["image"] = None
    global_state["clicks"] = []
    global_state["bbox"] = None
    return None, None, "State reset."

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("Interactive Pathology Segmentation with SAM")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Pathology Image")
            upload_status = gr.Textbox(label="Status")
            upload_btn = gr.Button("Upload & Set Image")
            click_output = gr.Textbox(label="Click Info")
            run_btn = gr.Button("Run Segmentation")
            reset_btn = gr.Button("Reset")
            bbox_coords = gr.Textbox(label="Manual BBox (x0,y0,x1,y1)")
            set_bbox_btn = gr.Button("Set Bounding Box")
            examples = gr.Examples(
                examples=[
                    ["images/image.webp"],
                    ["images/image2.webp"],
                    ["images/image3.webp"]
                ],
                inputs=[image_input],
                label="Example Pathology Images"
            )
        with gr.Column():
            image_output = gr.Image(type="pil", label="Segmentation Output")

    # Handlers
    upload_btn.click(upload_image, inputs=image_input, outputs=[upload_status, image_output])
    image_input.select(on_click, outputs=[image_output, click_output])
    run_btn.click(run_segmentation, outputs=[image_output, upload_status])
    reset_btn.click(reset_all, outputs=[image_output, click_output, upload_status])
    set_bbox_btn.click(
        lambda coords: set_bbox(*map(int, coords.split(","))),
        inputs=bbox_coords,
        outputs=click_output
    )

demo.launch()