| | |
| | |
| |
|
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from segment_anything.modeling import Sam |
| |
|
| | from typing import Optional, Tuple |
| |
|
| | from .utils.transforms import ResizeLongestSide |
| | from .utils.amg import calculate_stability_score |
| |
|
| |
|
| | class SamPredictor: |
| | def __init__( |
| | self, |
| | sam_model: Sam, |
| | ) -> None: |
| | """ |
| | Uses SAM to calculate the image embedding for an image, and then |
| | allow repeated, efficient mask prediction given prompts. |
| | |
| | Arguments: |
| | sam_model (Sam): The model to use for mask prediction. |
| | """ |
| | super().__init__() |
| | self.model = sam_model |
| | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) |
| | self.stability_score_offset = 1.0 |
| | self.reset_image() |
| |
|
| | def set_image( |
| | self, |
| | image: np.ndarray, |
| | image_format: str = "RGB", |
| | ) -> torch.Tensor: |
| | """ |
| | Calculates the image embeddings for the provided image, allowing |
| | masks to be predicted with the 'predict' method. |
| | |
| | Arguments: |
| | image (np.ndarray): The image for calculating masks. Expects an |
| | image in HWC uint8 format, with pixel values in [0, 255]. |
| | image_format (str): The color format of the image, in ['RGB', 'BGR']. |
| | """ |
| | assert image_format in [ |
| | "RGB", |
| | "BGR", |
| | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." |
| | if image_format != self.model.image_format: |
| | image = image[..., ::-1] |
| |
|
| | |
| | input_image = self.transform.apply_image(image) |
| | input_image_torch = torch.as_tensor(input_image, device=self.device) |
| | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] |
| |
|
| | input_size = tuple(input_image_torch.shape[-2:]) |
| | original_size = image.shape[:2] |
| |
|
| | return self.set_torch_image(input_image_torch), input_size, original_size |
| |
|
| | @torch.no_grad() |
| | def set_torch_image( |
| | self, |
| | transformed_image: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Calculates the image embeddings for the provided image, allowing |
| | masks to be predicted with the 'predict' method. Expects the input |
| | image to be already transformed to the format expected by the model. |
| | |
| | Arguments: |
| | transformed_image (torch.Tensor): The input image, with shape |
| | 1x3xHxW, which has been transformed with ResizeLongestSide. |
| | """ |
| | assert ( |
| | len(transformed_image.shape) == 4 |
| | and transformed_image.shape[1] == 3 |
| | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size |
| | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." |
| | self.reset_image() |
| |
|
| | input_image = self.model.preprocess(transformed_image) |
| | features = self.model.image_encoder(input_image) |
| |
|
| | return features |
| |
|
| | def predict( |
| | self, |
| | features: torch.Tensor, |
| | input_size: Tuple[int, int], |
| | original_size: Tuple[int, int], |
| | point_coords: Optional[np.ndarray] = None, |
| | point_labels: Optional[np.ndarray] = None, |
| | box: Optional[np.ndarray] = None, |
| | mask_input: Optional[np.ndarray] = None, |
| | num_multimask_outputs: int = 3, |
| | return_logits: bool = False, |
| | use_stability_score: bool = False, |
| | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| | """ |
| | Predict masks for the given input prompts, using the currently set image. |
| | |
| | Arguments: |
| | point_coords (np.ndarray or None): A Nx2 array of point prompts to the |
| | model. Each point is in (X,Y) in pixels. |
| | point_labels (np.ndarray or None): A length N array of labels for the |
| | point prompts. 1 indicates a foreground point and 0 indicates a |
| | background point. |
| | box (np.ndarray or None): A length 4 array given a box prompt to the |
| | model, in XYXY format. |
| | mask_input (np.ndarray): A low resolution mask input to the model, typically |
| | coming from a previous prediction iteration. Has form 1xHxW, where |
| | for SAM, H=W=256. |
| | num_multimask_outputs (int): the number of masks to predict |
| | when disambiguating masks. Choices: 1, 3, 4. |
| | return_logits (bool): If true, returns un-thresholded masks logits |
| | instead of a binary mask. |
| | use_stability_score (bool): If true, use stability scores to substitute |
| | IoU predictions. |
| | |
| | Returns: |
| | (np.ndarray): The output masks in CxHxW format, where C is the |
| | number of masks, and (H, W) is the original image size. |
| | (np.ndarray): An array of length C containing the model's |
| | predictions for the quality of each mask. |
| | (np.ndarray): An array of shape CxHxW, where C is the number |
| | of masks and H=W=256. These low resolution logits can be passed to |
| | a subsequent iteration as mask input. |
| | """ |
| | |
| | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None |
| | if point_coords is not None: |
| | assert ( |
| | point_labels is not None |
| | ), "point_labels must be supplied if point_coords is supplied." |
| | point_coords = self.transform.apply_coords(point_coords, original_size) |
| | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) |
| | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) |
| | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] |
| | if box is not None: |
| | box = self.transform.apply_boxes(box, original_size) |
| | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) |
| | box_torch = box_torch[None, :] |
| | if mask_input is not None: |
| | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) |
| | mask_input_torch = mask_input_torch[None, :, :, :] |
| |
|
| | masks, iou_predictions, low_res_masks = self.predict_torch( |
| | features, |
| | input_size, |
| | original_size, |
| | coords_torch, |
| | labels_torch, |
| | box_torch, |
| | mask_input_torch, |
| | num_multimask_outputs, |
| | return_logits=return_logits, |
| | use_stability_score=use_stability_score |
| | ) |
| |
|
| | masks_np = masks[0].detach().cpu().numpy() |
| | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() |
| | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() |
| | return masks_np, iou_predictions_np, low_res_masks_np |
| |
|
| | @torch.no_grad() |
| | def predict_torch( |
| | self, |
| | features: torch.Tensor, |
| | input_size: Tuple[int, int], |
| | original_size: Tuple[int, int], |
| | point_coords: Optional[torch.Tensor], |
| | point_labels: Optional[torch.Tensor], |
| | boxes: Optional[torch.Tensor] = None, |
| | mask_input: Optional[torch.Tensor] = None, |
| | num_multimask_outputs: int = 3, |
| | return_logits: bool = False, |
| | use_stability_score: bool = True |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Predict masks for the given input prompts, using the currently set image. |
| | Input prompts are batched torch tensors and are expected to already be |
| | transformed to the input frame using ResizeLongestSide. |
| | |
| | Arguments: |
| | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the |
| | model. Each point is in (X,Y) in pixels. |
| | point_labels (torch.Tensor or None): A BxN array of labels for the |
| | point prompts. 1 indicates a foreground point and 0 indicates a |
| | background point. |
| | boxes (np.ndarray or None): A Bx4 array given a box prompt to the |
| | model, in XYXY format. |
| | mask_input (np.ndarray): A low resolution mask input to the model, typically |
| | coming from a previous prediction iteration. Has form Bx1xHxW, where |
| | for SAM, H=W=256. Masks returned by a previous iteration of the |
| | predict method do not need further transformation. |
| | num_multimask_outputs (int): the number of masks to predict |
| | when disambiguating masks. Choices: 1, 3, 4. |
| | return_logits (bool): If true, returns un-thresholded masks logits |
| | instead of a binary mask. |
| | use_stability_score (bool): If true, use stability scores to substitute |
| | IoU predictions. |
| | |
| | Returns: |
| | (torch.Tensor): The output masks in BxCxHxW format, where C is the |
| | number of masks, and (H, W) is the original image size. |
| | (torch.Tensor): An array of shape BxC containing the model's |
| | predictions for the quality of each mask. |
| | (torch.Tensor): An array of shape BxCxHxW, where C is the number |
| | of masks and H=W=256. These low res logits can be passed to |
| | a subsequent iteration as mask input. |
| | """ |
| | if features is None and not self.is_image_set: |
| | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") |
| |
|
| | if point_coords is not None: |
| | points = (point_coords, point_labels) |
| | else: |
| | points = None |
| |
|
| | |
| | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( |
| | points=points, |
| | boxes=boxes, |
| | masks=mask_input, |
| | ) |
| |
|
| | |
| | low_res_masks, iou_predictions = self.model.mask_decoder( |
| | image_embeddings=features, |
| | image_pe=self.model.prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | num_multimask_outputs=num_multimask_outputs, |
| | ) |
| |
|
| | if use_stability_score: |
| | iou_predictions = calculate_stability_score( |
| | low_res_masks, self.model.mask_threshold, self.stability_score_offset |
| | ) |
| |
|
| | |
| | masks = self.model.postprocess_masks(low_res_masks, input_size, original_size) |
| |
|
| | if not return_logits: |
| | masks = masks > self.model.mask_threshold |
| |
|
| | return masks, iou_predictions, low_res_masks |
| |
|
| | def get_image_embedding(self) -> torch.Tensor: |
| | """ |
| | Returns the image embeddings for the currently set image, with |
| | shape 1xCxHxW, where C is the embedding dimension and (H,W) are |
| | the embedding spatial dimension of SAM (typically C=256, H=W=64). |
| | """ |
| | if not self.is_image_set: |
| | raise RuntimeError( |
| | "An image must be set with .set_image(...) to generate an embedding." |
| | ) |
| | assert self.features is not None, "Features must exist if an image has been set." |
| | return self.features |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | return self.model.device |
| |
|
| | def reset_image(self) -> None: |
| | """Resets the currently set image.""" |
| | self.is_image_set = False |
| | self.features = None |
| | self.orig_h = None |
| | self.orig_w = None |
| | self.input_h = None |
| | self.input_w = None |
| |
|