| |
| |
| import os |
|
|
| |
| |
|
|
| import numpy as np |
| import torch |
| from torchvision.ops.boxes import batched_nms, box_area |
| import torch.nn.functional as F |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from .modeling import Sam |
| from .predictor import SamPredictor |
| from .utils.amg import ( |
| MaskData, |
| area_from_rle, |
| batch_iterator, |
| batched_mask_to_box, |
| box_xyxy_to_xywh, |
| build_all_layer_point_grids, |
| calculate_stability_score, |
| coco_encode_rle, |
| generate_crop_boxes, |
| is_box_near_crop_edge, |
| mask_to_rle_pytorch, |
| remove_small_regions, |
| rle_to_mask, |
| uncrop_boxes_xyxy, |
| uncrop_masks, |
| uncrop_points, |
| ) |
| class SamAutomaticMaskGenerator: |
| def __init__( |
| self, |
| model: Sam, |
| points_per_side: Optional[int] = 32, |
| points_per_batch: int = 64, |
| pred_iou_thresh: float = 0.88, |
| stability_score_thresh: float = 0.95, |
| stability_score_offset: float = 1.0, |
| box_nms_thresh: float = 0.7, |
| crop_n_layers: int = 0, |
| crop_nms_thresh: float = 0.7, |
| crop_overlap_ratio: float = 512 / 1500, |
| crop_n_points_downscale_factor: int = 1, |
| point_grids: Optional[List[np.ndarray]] = None, |
| min_mask_region_area: int = 0, |
| output_mode: str = "binary_mask", |
| ) -> None: |
| """ |
| Using a SAM model, generates masks for the entire image. |
| Generates a grid of point prompts over the image, then filters |
| low quality and duplicate masks. The default settings are chosen |
| for SAM with a ViT-H backbone. |
| |
| Arguments: |
| model (Sam): The SAM model to use for mask prediction. |
| points_per_side (int or None): The number of points to be sampled |
| along one side of the image. The total number of points is |
| points_per_side**2. If None, 'point_grids' must provide explicit |
| point sampling. |
| points_per_batch (int): Sets the number of points run simultaneously |
| by the model. Higher numbers may be faster but use more GPU memory. |
| pred_iou_thresh (float): A filtering threshold in [0,1], using the |
| model's predicted mask quality. |
| stability_score_thresh (float): A filtering threshold in [0,1], using |
| the stability of the mask under changes to the cutoff used to binarize |
| the model's mask predictions. |
| stability_score_offset (float): The amount to shift the cutoff when |
| calculated the stability score. |
| box_nms_thresh (float): The box IoU cutoff used by non-maximal |
| suppression to filter duplicate masks. |
| crop_n_layers (int): If >0, mask prediction will be run again on |
| crops of the image. Sets the number of layers to run, where each |
| layer has 2**i_layer number of image crops. |
| crop_nms_thresh (float): The box IoU cutoff used by non-maximal |
| suppression to filter duplicate masks between different crops. |
| crop_overlap_ratio (float): Sets the degree to which crops overlap. |
| In the first crop layer, crops will overlap by this fraction of |
| the image length. Later layers with more crops scale down this overlap. |
| crop_n_points_downscale_factor (int): The number of points-per-side |
| sampled in layer n is scaled down by crop_n_points_downscale_factor**n. |
| point_grids (list(np.ndarray) or None): A list over explicit grids |
| of points used for sampling, normalized to [0,1]. The nth grid in the |
| list is used in the nth crop layer. Exclusive with points_per_side. |
| min_mask_region_area (int): If >0, postprocessing will be applied |
| to remove disconnected regions and holes in masks with area smaller |
| than min_mask_region_area. Requires opencv. |
| output_mode (str): The form masks are returned in. Can be 'binary_mask', |
| 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. |
| For large resolutions, 'binary_mask' may consume large amounts of |
| memory. |
| """ |
|
|
| assert (points_per_side is None) != ( |
| point_grids is None |
| ), "Exactly one of points_per_side or point_grid must be provided." |
| if points_per_side is not None: |
| self.point_grids = build_all_layer_point_grids( |
| points_per_side, |
| crop_n_layers, |
| crop_n_points_downscale_factor, |
| ) |
| elif point_grids is not None: |
| self.point_grids = point_grids |
| else: |
| raise ValueError("Can't have both points_per_side and point_grid be None.") |
|
|
| assert output_mode in [ |
| "binary_mask", |
| "uncompressed_rle", |
| "coco_rle", |
| ], f"Unknown output_mode {output_mode}." |
| if output_mode == "coco_rle": |
| from pycocotools import mask as mask_utils |
|
|
| if min_mask_region_area > 0: |
| import cv2 |
|
|
| self.predictor = SamPredictor(model) |
| self.points_per_batch = points_per_batch |
| self.pred_iou_thresh = pred_iou_thresh |
| self.stability_score_thresh = stability_score_thresh |
| self.stability_score_offset = stability_score_offset |
| self.box_nms_thresh = box_nms_thresh |
| self.crop_n_layers = crop_n_layers |
| self.crop_nms_thresh = crop_nms_thresh |
| self.crop_overlap_ratio = crop_overlap_ratio |
| self.crop_n_points_downscale_factor = crop_n_points_downscale_factor |
| self.min_mask_region_area = min_mask_region_area |
| self.output_mode = output_mode |
| self.debug_image = None |
|
|
| @torch.no_grad() |
| def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: |
| """ |
| Generates masks for the given image. |
| |
| Arguments: |
| image (np.ndarray): The image to generate masks for, in HWC uint8 format. |
| |
| Returns: |
| list(dict(str, any)): A list over records for masks. Each record is |
| a dict containing the following keys: |
| segmentation (dict(str, any) or np.ndarray): The mask. If |
| output_mode='binary_mask', is an array of shape HW. Otherwise, |
| is a dictionary containing the RLE. |
| bbox (list(float)): The box around the mask, in XYWH format. |
| area (int): The area in pixels of the mask. |
| predicted_iou (float): The model's own prediction of the mask's |
| quality. This is filtered by the pred_iou_thresh parameter. |
| point_coords (list(list(float))): The point coordinates input |
| to the model to generate this mask. |
| stability_score (float): A measure of the mask's quality. This |
| is filtered on using the stability_score_thresh parameter. |
| crop_box (list(float)): The crop of the image used to generate |
| the mask, given in XYWH format. |
| """ |
|
|
| |
| mask_data = self._generate_masks(image) |
|
|
| |
| if self.min_mask_region_area > 0: |
| mask_data = self.postprocess_small_regions( |
| mask_data, |
| self.min_mask_region_area, |
| max(self.box_nms_thresh, self.crop_nms_thresh), |
| ) |
|
|
| curr_anns = self.mask_data_to_dict(mask_data) |
|
|
| return curr_anns |
|
|
| def mask_data_to_dict(self, mask_data, ret_similarity=True): |
| |
| if self.output_mode == "coco_rle": |
| mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] |
| elif self.output_mode == "binary_mask": |
| mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] |
| else: |
| mask_data["segmentations"] = mask_data["rles"] |
| |
| curr_anns = [] |
| for idx in range(len(mask_data["segmentations"])): |
| ann = { |
| "segmentation": mask_data["segmentations"][idx], |
| "area": area_from_rle(mask_data["rles"][idx]), |
| "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), |
| "predicted_iou": mask_data["iou_preds"][idx].item(), |
| "cate_preds": mask_data["cate_preds"][idx].item(), |
| "fc_features": mask_data["fc_features"][idx].tolist(), |
| "point_coords": [mask_data["points"][idx].tolist()], |
| "stability_score": mask_data["stability_score"][idx].item(), |
| "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), |
| } |
| if "iou_reasoning" in mask_data._stats.keys(): |
| ann["iou_reasoning"]: mask_data["iou_reasoning"][idx].item() |
| if ret_similarity: |
| ann["similarity"] = mask_data["similarity"][idx].tolist(), |
| curr_anns.append(ann) |
| return curr_anns |
|
|
| def _generate_masks(self, image: np.ndarray) -> MaskData: |
| orig_size = image.shape[:2] |
| crop_boxes, layer_idxs = generate_crop_boxes( |
| orig_size, self.crop_n_layers, self.crop_overlap_ratio |
| ) |
|
|
| |
| data = MaskData() |
| for crop_box, layer_idx in zip(crop_boxes, layer_idxs): |
| crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) |
| data.cat(crop_data) |
|
|
| |
| if len(crop_boxes) > 1: |
| |
| scores = 1 / box_area(data["crop_boxes"]) |
| scores = scores.to(data["boxes"].device) |
| keep_by_nms = batched_nms( |
| data["boxes"].float(), |
| scores, |
| torch.zeros_like(data["boxes"][:, 0]), |
| iou_threshold=self.crop_nms_thresh, |
| ) |
| data.filter(keep_by_nms) |
| fc_features=F.normalize(data["fc_features"],dim=-1) |
| similarity=fc_features.mm(fc_features.t()) |
| similarity.fill_diagonal_(-np.inf) |
| data["similarity"]=similarity |
| data.to_numpy() |
| return data |
|
|
| def _process_crop( |
| self, |
| image: np.ndarray, |
| crop_box: List[int], |
| crop_layer_idx: int, |
| orig_size: Tuple[int, ...], |
| ) -> MaskData: |
| |
| x0, y0, x1, y1 = crop_box |
| cropped_im = image[y0:y1, x0:x1, :] |
| cropped_im_size = cropped_im.shape[:2] |
| self.predictor.set_image(cropped_im) |
| |
| points_scale = np.array(cropped_im_size)[None, ::-1] |
| points_for_image = self.point_grids[crop_layer_idx] * points_scale |
| |
| data = MaskData() |
| for (points,) in batch_iterator(self.points_per_batch, points_for_image): |
| batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) |
| data.cat(batch_data) |
| del batch_data |
| self.predictor.reset_image() |
| |
| keep_by_nms = batched_nms( |
| data["boxes"].float(), |
| data["iou_preds"], |
| torch.zeros_like(data["boxes"][:, 0]), |
| iou_threshold=self.box_nms_thresh, |
| ) |
| data.filter(keep_by_nms) |
| |
| data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) |
| data["points"] = uncrop_points(data["points"], crop_box) |
| data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) |
| return data |
|
|
| def process_crop_user_pts( |
| self, |
| image: np.ndarray, |
| crop_box: List[int], |
| user_pts: np.ndarray, |
| user_pts_label: np.ndarray, |
| orig_size: Tuple[int, ...], |
| ) -> MaskData: |
| |
| x0, y0, x1, y1 = crop_box |
| cropped_im = image[y0:y1, x0:x1, :] |
| cropped_im_size = cropped_im.shape[:2] |
| self.debug_image = cropped_im |
| self.predictor.set_image(cropped_im) |
| |
| user_pts = crop_points(user_pts, crop_box) |
| |
| data = MaskData() |
| for (points, points_labels) in batch_iterator(self.points_per_batch, user_pts, user_pts_label): |
| batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, points_labels) |
| data.cat(batch_data) |
| del batch_data |
| self.predictor.reset_image() |
| |
| if data._stats == {}: |
| return data |
| |
| |
| |
| |
| |
| |
| |
| |
| data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) |
| data["points"] = uncrop_points(data["points"], crop_box) |
| data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) |
| return data |
|
|
| def process_crop_queries( |
| self, |
| image: np.ndarray, |
| crop_box: List[int], |
| queries: dict, |
| orig_size: Tuple[int, ...], |
| ) -> MaskData: |
| |
| x0, y0, x1, y1 = crop_box |
| cropped_im = image[y0:y1, x0:x1, :] |
| cropped_im_size = cropped_im.shape[:2] |
| self.debug_image = cropped_im |
| self.predictor.set_image(cropped_im) |
| |
| data = MaskData() |
| for query in queries: |
| query_pts = crop_points(query["points"], crop_box) if query["points"] is not None else None |
| query_labels = query["labels"] |
| query_boxes = crop_boxes(query["boxes"], crop_box) |
| |
| mask_data = self._process_batch(query_pts, cropped_im_size, crop_box, orig_size, query_labels, boxes=query_boxes) |
| data.cat(mask_data) |
| del mask_data |
| self.predictor.reset_image() |
| |
| if data._stats == {}: |
| return data |
| |
| |
| |
| |
| |
| |
| |
| |
| data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) |
| data["points"] = uncrop_points(data["points"], crop_box) |
| data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) |
| return data |
|
|
|
|
| def _process_batch( |
| self, |
| points: np.ndarray, |
| im_size: Tuple[int, ...], |
| crop_box: List[int], |
| orig_size: Tuple[int, ...], |
| points_labels: np.ndarray = None, |
| mask_input: np.ndarray = None, |
| boxes: np.ndarray = None, |
| ) -> MaskData: |
| orig_h, orig_w = orig_size |
| |
| if points is not None: |
| transformed_points = self.predictor.transform.apply_coords(points, im_size) |
| in_points = torch.as_tensor(transformed_points, device=self.predictor.device) |
| if len(in_points.shape) == 2: |
| in_points = in_points[:, None, :] |
| else: |
| in_points = None |
| if points_labels is None and points is not None: |
| in_labels = torch.ones((in_points.shape[0], 1), dtype=torch.int, device=in_points.device) |
| elif points_labels is not None: |
| in_labels = torch.as_tensor(points_labels, dtype=torch.int, device=in_points.device) |
| if len(in_labels.shape) == 1: |
| in_labels = in_labels[:, None] |
| else: |
| in_labels = None |
| if mask_input is not None: |
| mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.predictor.device) |
| else: |
| mask_input_torch = None |
| if boxes is not None: |
| in_boxes = torch.as_tensor(boxes, dtype=torch.float, device=self.predictor.device) |
| in_boxes = self.predictor.transform.apply_boxes_torch(in_boxes, im_size) |
| else: |
| in_boxes = None |
| if self.debug_image is not None: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| masks, iou_preds, cate_preds, fc_features, logits = self.predictor.predict_torch( |
| in_points, |
| in_labels, |
| boxes=in_boxes, |
| multimask_output=False, |
| return_logits=True, |
| ) |
| else: |
| masks, iou_preds, cate_preds,fc_features, logits = self.predictor.predict_torch( |
| in_points, |
| in_labels, |
| multimask_output=False, |
| return_logits=True, |
| ) |
| cate_preds = torch.argmax(cate_preds, dim=2) |
| |
| if points_labels is None: |
| points_labels = np.ones(points.shape[0], dtype=int) |
| data = MaskData( |
| masks=masks.flatten(0, 1), |
| iou_preds=iou_preds.flatten(0, 1), |
| cate_preds = cate_preds.flatten(0,1), |
| fc_features = fc_features.flatten(0,1), |
| points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), |
| points_labels=in_labels, |
| ) |
|
|
| del masks |
| del fc_features |
| |
| if self.debug_image is None: |
| keep_mask = data["cate_preds"] > 0 |
| data.filter(keep_mask) |
|
|
| |
| if self.pred_iou_thresh > 0.0 and self.debug_image is None: |
| keep_mask = data["iou_preds"] > self.pred_iou_thresh |
| data.filter(keep_mask) |
|
|
| |
| data["stability_score"] = calculate_stability_score( |
| data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset |
| ) |
| if self.stability_score_thresh > 0.0 and self.debug_image is None: |
| keep_mask = data["stability_score"] >= self.stability_score_thresh |
| data.filter(keep_mask) |
|
|
| |
| data["masks"] = data["masks"] > self.predictor.model.mask_threshold |
| data["boxes"] = batched_mask_to_box(data["masks"]) |
|
|
| |
| if self.debug_image is None: |
| keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) |
| if not torch.all(keep_mask): |
| data.filter(keep_mask) |
|
|
| |
| if self.debug_image is not None: |
| masks_ = data["masks"].detach().cpu().numpy() |
| points_ = data["points"].detach().cpu().numpy() |
| points_labels_ = data["points_labels"].detach().cpu().numpy() |
| iou_preds_ = data["iou_preds"].detach().cpu().numpy() |
| cate_preds_ = data["cate_preds"].detach().cpu().numpy() |
| stab_ = data["stability_score"].detach().cpu().numpy() |
| for mask, pts, labs, iou, cate, stab in zip(masks_, points_, points_labels_, iou_preds_, cate_preds_, stab_): |
| save_pred_fig(self.debug_image, mask, pts, labs, iou, cate, stab, boxe=boxes) |
|
|
|
|
| |
| data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) |
| data["rles"] = mask_to_rle_pytorch(data["masks"]) |
| del data["masks"] |
|
|
| return data |
|
|
| @staticmethod |
| def postprocess_small_regions( |
| mask_data: MaskData, min_area: int, nms_thresh: float |
| ) -> MaskData: |
| """ |
| Removes small disconnected regions and holes in masks, then reruns |
| box NMS to remove any new duplicates. |
| |
| Edits mask_data in place. |
| |
| Requires open-cv as a dependency. |
| """ |
| if len(mask_data["rles"]) == 0: |
| return mask_data |
|
|
| |
| new_masks = [] |
| scores = [] |
| for rle in mask_data["rles"]: |
| mask = rle_to_mask(rle) |
|
|
| mask, changed = remove_small_regions(mask, min_area, mode="holes") |
| unchanged = not changed |
| mask, changed = remove_small_regions(mask, min_area, mode="islands") |
| unchanged = unchanged and not changed |
|
|
| new_masks.append(torch.as_tensor(mask).unsqueeze(0)) |
| |
| |
| scores.append(float(unchanged)) |
|
|
| |
| masks = torch.cat(new_masks, dim=0) |
| boxes = batched_mask_to_box(masks) |
| keep_by_nms = batched_nms( |
| boxes.float(), |
| torch.as_tensor(scores), |
| torch.zeros_like(boxes[:, 0]), |
| iou_threshold=nms_thresh, |
| ) |
|
|
| |
| for i_mask in keep_by_nms: |
| if scores[i_mask] == 0.0: |
| mask_torch = masks[i_mask].unsqueeze(0) |
| mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] |
| mask_data["boxes"][i_mask] = boxes[i_mask] |
| mask_data.filter(keep_by_nms) |
| return mask_data |
|
|
|
|
| def crop_points(points: np.ndarray, crop_box: List[int]) -> np.ndarray: |
| x0, y0, _, _ = crop_box |
| offset = np.array([[x0, y0]]) |
| |
| if len(points.shape) == 3: |
| offset = np.expand_dims(offset, axis=1) |
| return points - offset |
|
|
| def crop_boxes(boxes: np.ndarray, crop_box: List[int]) -> np.ndarray: |
| |
| x0, y0, w, h = crop_box |
| return boxes - np.array([x0, y0, x0, y0]) |
|
|
|
|
| import matplotlib.pyplot as plt |
|
|
| def show_mask(mask, ax, color=None): |
| if color is None: |
| color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
| h, w = mask.shape[-2:] |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| ax.imshow(mask_image) |
|
|
|
|
| def show_points(coords, labels, ax, marker_size=150): |
| pos_points = coords[labels == 1] |
| neg_points = coords[labels == 0] |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', |
| linewidth=1.25) |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', |
| linewidth=1.25) |
|
|
| def show_box(box, ax): |
| x0, y0 = box[0], box[1] |
| w, h = box[2] - box[0], box[3] - box[1] |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) |
|
|
| def save_pred_fig(image, mask, input_points, input_labels, iou, cate, stab, boxe=None): |
| |
| plt.figure(figsize=(10, 10)) |
| plt.imshow(image) |
| if cate == 1: |
| color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
| else: |
| color = np.array([255 / 255, 144 / 255, 30 / 255, 0.6]) |
| show_mask(mask, plt.gca(), color) |
| show_points(input_points, input_labels, plt.gca()) |
| if boxe is not None: |
| show_box(boxe, plt.gca()) |
| |
| |
| x, y = input_points[0] |
| plt.text( |
| x, |
| y, |
| f"IoU: {iou:.2f}, Stab: {stab:.2f}", |
| color="white", |
| fontsize=8, |
| ha="right", |
| va="bottom", |
| bbox=dict(facecolor="black", alpha=0.5, edgecolor="none"), |
| ) |
|
|
| plt.axis('off') |
| save_dir = "data/prompt_engineering_debug" |
| os.makedirs(save_dir, exist_ok=True) |
| num = len(os.listdir(save_dir)) |
| plt.savefig(os.path.join(save_dir, f"{num}.png")) |
| plt.close() |