File size: 4,706 Bytes
79b7634
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdce371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b7634
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from transformers import SamModel, SamProcessor
import streamlit as st
import numpy as np
from PIL import Image
from typing import Tuple, List


# Use @st.cache_resource to avoid reloading the model on every rerun
@st.cache_resource(show_spinner="Loading Segment Anything Model (SAM)...")
def load_sam_model() -> Tuple[SamModel, SamProcessor, str]:
    """Loads the SAM model and processor from Hugging Face."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Using facebook/sam-vit-base as the standard baseline
    model_id = "facebook/sam-vit-base"
    model = SamModel.from_pretrained(model_id).to(device)
    processor = SamProcessor.from_pretrained(model_id)
    return model, processor, device


@st.cache_resource(show_spinner="Computing Image Embeddings...")
def compute_image_embedding(image: Image.Image) -> torch.Tensor:
    """
    Computes and caches the SAM image embedding for a given image.
    This is the heavy part of the computation.
    """
    model, processor, device = load_sam_model()

    # Preprocess the image to get pixel values
    inputs = processor(images=image, return_tensors="pt").to(device)

    # Compute image embeddings
    with torch.no_grad():
        image_embeddings = model.get_image_embeddings(inputs.pixel_values)

    return image_embeddings


def predict_mask(
    image: Image.Image,
    image_embeddings: torch.Tensor,
    input_points: List[List[int]],
    input_labels: List[int],
) -> np.ndarray:
    """
    Predicts a binary mask given the image embeddings and prompt points.
    input_points: list of [x, y] coordinates
    input_labels: list of 1 (positive) or 0 (negative) for each point
    """
    model, processor, device = load_sam_model()

    # Format inputs for the processor
    # The processor expects points in the format [[[x1, y1], [x2, y2], ...]]
    # and labels in [[1, 0, ...]] for a single batch
    points = [input_points]
    labels = [input_labels]

    # Preprocess prompts
    inputs = processor(
        images=image, input_points=points, input_labels=labels, return_tensors="pt"
    ).to(device)

    # Run prediction using the cached embeddings
    with torch.no_grad():
        outputs = model(
            image_embeddings=image_embeddings,
            input_points=inputs.input_points,
            input_labels=inputs.input_labels,
            multimask_output=False,  # We only want the best mask
        )

    # Process the predicted mask back to the original image size
    # inputs contains original_sizes and reshaped_input_sizes from the processor call
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu(),
    )

    # masks is a list of tensors, get the first one and squeeze it to a 2D array
    mask = masks[0]
    # Squeeze out the batch and channel dimensions if present, but keep spatial dims.
    # Usually shape is (1, 1, H, W) or (1, H, W)
    if mask.ndim > 2:
        mask = mask.squeeze()
        # If the image was 1x1, squeeze might have removed all dimensions.
        if mask.ndim < 2:
            mask = mask.view(masks[0].shape[-2], masks[0].shape[-1])

    mask = mask.numpy()

    # The mask is boolean, convert to uint8 for OpenCV (0 and 255)
    binary_mask = (mask * 255).astype(np.uint8)

    # To prevent "overlap onto a different part of the image", keep only the connected
    # components of the mask that actually contain at least one positive input point.
    import cv2
    num_labels, labels_img = cv2.connectedComponents(binary_mask)
    if num_labels > 1:
        # labels_img has values from 0 (background) to num_labels - 1
        filtered_mask = np.zeros_like(binary_mask)
        positive_points_labels = set()

        for pt, label in zip(input_points, input_labels):
            if label == 1:  # Positive point
                x, y = pt
                # Make sure point is within image bounds
                if 0 <= y < labels_img.shape[0] and 0 <= x < labels_img.shape[1]:
                    component_label = labels_img[y, x]
                    if component_label > 0:  # Ignore background
                        positive_points_labels.add(component_label)

        # Keep only the connected components that were clicked
        for comp_label in positive_points_labels:
            filtered_mask[labels_img == comp_label] = 255

        # If for some reason no positive points fell exactly on a mask component,
        # fallback to returning the original mask to prevent returning an empty mask.
        if len(positive_points_labels) > 0:
            binary_mask = filtered_mask

    return binary_mask