Spaces:
Running
Running
| import torch | |
| from transformers import SamModel, SamProcessor | |
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Tuple, List | |
| # Use @st.cache_resource to avoid reloading the model on every rerun | |
| def load_sam_model() -> Tuple[SamModel, SamProcessor, str]: | |
| """Loads the SAM model and processor from Hugging Face.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Using facebook/sam-vit-base as the standard baseline | |
| model_id = "facebook/sam-vit-base" | |
| model = SamModel.from_pretrained(model_id).to(device) | |
| processor = SamProcessor.from_pretrained(model_id) | |
| return model, processor, device | |
| def compute_image_embedding(image: Image.Image) -> torch.Tensor: | |
| """ | |
| Computes and caches the SAM image embedding for a given image. | |
| This is the heavy part of the computation. | |
| """ | |
| model, processor, device = load_sam_model() | |
| # Preprocess the image to get pixel values | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| # Compute image embeddings | |
| with torch.no_grad(): | |
| image_embeddings = model.get_image_embeddings(inputs.pixel_values) | |
| return image_embeddings | |
| def predict_mask( | |
| image: Image.Image, | |
| image_embeddings: torch.Tensor, | |
| input_points: List[List[int]], | |
| input_labels: List[int], | |
| ) -> np.ndarray: | |
| """ | |
| Predicts a binary mask given the image embeddings and prompt points. | |
| input_points: list of [x, y] coordinates | |
| input_labels: list of 1 (positive) or 0 (negative) for each point | |
| """ | |
| model, processor, device = load_sam_model() | |
| # Format inputs for the processor | |
| # The processor expects points in the format [[[x1, y1], [x2, y2], ...]] | |
| # and labels in [[1, 0, ...]] for a single batch | |
| points = [input_points] | |
| labels = [input_labels] | |
| # Preprocess prompts | |
| inputs = processor( | |
| images=image, input_points=points, input_labels=labels, return_tensors="pt" | |
| ).to(device) | |
| # Run prediction using the cached embeddings | |
| with torch.no_grad(): | |
| outputs = model( | |
| image_embeddings=image_embeddings, | |
| input_points=inputs.input_points, | |
| input_labels=inputs.input_labels, | |
| multimask_output=False, # We only want the best mask | |
| ) | |
| # Process the predicted mask back to the original image size | |
| # inputs contains original_sizes and reshaped_input_sizes from the processor call | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), | |
| inputs["original_sizes"].cpu(), | |
| inputs["reshaped_input_sizes"].cpu(), | |
| ) | |
| # masks is a list of tensors, get the first one and squeeze it to a 2D array | |
| mask = masks[0] | |
| # Squeeze out the batch and channel dimensions if present, but keep spatial dims. | |
| # Usually shape is (1, 1, H, W) or (1, H, W) | |
| if mask.ndim > 2: | |
| mask = mask.squeeze() | |
| # If the image was 1x1, squeeze might have removed all dimensions. | |
| if mask.ndim < 2: | |
| mask = mask.view(masks[0].shape[-2], masks[0].shape[-1]) | |
| mask = mask.numpy() | |
| # The mask is boolean, convert to uint8 for OpenCV (0 and 255) | |
| binary_mask = (mask * 255).astype(np.uint8) | |
| # To prevent "overlap onto a different part of the image", keep only the connected | |
| # components of the mask that actually contain at least one positive input point. | |
| import cv2 | |
| num_labels, labels_img = cv2.connectedComponents(binary_mask) | |
| if num_labels > 1: | |
| # labels_img has values from 0 (background) to num_labels - 1 | |
| filtered_mask = np.zeros_like(binary_mask) | |
| positive_points_labels = set() | |
| for pt, label in zip(input_points, input_labels): | |
| if label == 1: # Positive point | |
| x, y = pt | |
| # Make sure point is within image bounds | |
| if 0 <= y < labels_img.shape[0] and 0 <= x < labels_img.shape[1]: | |
| component_label = labels_img[y, x] | |
| if component_label > 0: # Ignore background | |
| positive_points_labels.add(component_label) | |
| # Keep only the connected components that were clicked | |
| for comp_label in positive_points_labels: | |
| filtered_mask[labels_img == comp_label] = 255 | |
| # If for some reason no positive points fell exactly on a mask component, | |
| # fallback to returning the original mask to prevent returning an empty mask. | |
| if len(positive_points_labels) > 0: | |
| binary_mask = filtered_mask | |
| return binary_mask | |