|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
import numpy as np |
|
|
import torch |
|
|
from torchvision.ops.boxes import batched_nms, box_area |
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Callable |
|
|
|
|
|
from .utils import ( |
|
|
MaskCaptionData, |
|
|
area_from_rle, |
|
|
batch_iterator, |
|
|
batched_mask_to_box, |
|
|
box_xyxy_to_xywh, |
|
|
build_all_layer_point_grids, |
|
|
calculate_stability_score, |
|
|
coco_encode_rle, |
|
|
generate_crop_boxes, |
|
|
is_box_near_crop_edge, |
|
|
mask_to_rle_pytorch, |
|
|
remove_small_regions, |
|
|
rle_to_mask, |
|
|
uncrop_boxes_xyxy, |
|
|
uncrop_masks, |
|
|
uncrop_points, |
|
|
) |
|
|
|
|
|
|
|
|
class ScaAutomaticMaskCaptionGenerator: |
|
|
def __init__( |
|
|
self, |
|
|
model: torch.nn.Module, |
|
|
processor: Callable, |
|
|
points_per_side: Optional[int] = 32, |
|
|
points_per_batch: int = 64, |
|
|
pred_iou_thresh: float = 0.88, |
|
|
stability_score_thresh: float = 0.95, |
|
|
stability_score_offset: float = 1.0, |
|
|
box_nms_thresh: float = 0.7, |
|
|
crop_n_layers: int = 0, |
|
|
crop_nms_thresh: float = 0.7, |
|
|
crop_overlap_ratio: float = 512 / 1500, |
|
|
crop_n_points_downscale_factor: int = 1, |
|
|
point_grids: Optional[List[np.ndarray]] = None, |
|
|
min_mask_region_area: int = 0, |
|
|
output_mode: str = "binary_mask", |
|
|
) -> None: |
|
|
""" |
|
|
Copy from segment-anything |
|
|
|
|
|
Using a SAM model, generates masks for the entire image. |
|
|
Generates a grid of point prompts over the image, then filters |
|
|
low quality and duplicate masks. The default settings are chosen |
|
|
for SAM with a ViT-H backbone. |
|
|
|
|
|
Arguments: |
|
|
model (Sam): The SAM model to use for mask prediction. |
|
|
points_per_side (int or None): The number of points to be sampled |
|
|
along one side of the image. The total number of points is |
|
|
points_per_side**2. If None, 'point_grids' must provide explicit |
|
|
point sampling. |
|
|
points_per_batch (int): Sets the number of points run simultaneously |
|
|
by the model. Higher numbers may be faster but use more GPU memory. |
|
|
pred_iou_thresh (float): A filtering threshold in [0,1], using the |
|
|
model's predicted mask quality. |
|
|
stability_score_thresh (float): A filtering threshold in [0,1], using |
|
|
the stability of the mask under changes to the cutoff used to binarize |
|
|
the model's mask predictions. |
|
|
stability_score_offset (float): The amount to shift the cutoff when |
|
|
calculated the stability score. |
|
|
box_nms_thresh (float): The box IoU cutoff used by non-maximal |
|
|
suppression to filter duplicate masks. |
|
|
crop_n_layers (int): If >0, mask prediction will be run again on |
|
|
crops of the image. Sets the number of layers to run, where each |
|
|
layer has 2**i_layer number of image crops. |
|
|
crop_nms_thresh (float): The box IoU cutoff used by non-maximal |
|
|
suppression to filter duplicate masks between different crops. |
|
|
crop_overlap_ratio (float): Sets the degree to which crops overlap. |
|
|
In the first crop layer, crops will overlap by this fraction of |
|
|
the image length. Later layers with more crops scale down this overlap. |
|
|
crop_n_points_downscale_factor (int): The number of points-per-side |
|
|
sampled in layer n is scaled down by crop_n_points_downscale_factor**n. |
|
|
point_grids (list(np.ndarray) or None): A list over explicit grids |
|
|
of points used for sampling, normalized to [0,1]. The nth grid in the |
|
|
list is used in the nth crop layer. Exclusive with points_per_side. |
|
|
min_mask_region_area (int): If >0, postprocessing will be applied |
|
|
to remove disconnected regions and holes in masks with area smaller |
|
|
than min_mask_region_area. Requires opencv. |
|
|
output_mode (str): The form masks are returned in. Can be 'binary_mask', |
|
|
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. |
|
|
For large resolutions, 'binary_mask' may consume large amounts of |
|
|
memory. |
|
|
""" |
|
|
|
|
|
assert (points_per_side is None) != ( |
|
|
point_grids is None |
|
|
), "Exactly one of points_per_side or point_grid must be provided." |
|
|
if points_per_side is not None: |
|
|
self.point_grids = build_all_layer_point_grids( |
|
|
points_per_side, |
|
|
crop_n_layers, |
|
|
crop_n_points_downscale_factor, |
|
|
) |
|
|
elif point_grids is not None: |
|
|
self.point_grids = point_grids |
|
|
else: |
|
|
raise ValueError("Can't have both points_per_side and point_grid be None.") |
|
|
|
|
|
assert output_mode in [ |
|
|
"binary_mask", |
|
|
"uncompressed_rle", |
|
|
"coco_rle", |
|
|
], f"Unknown output_mode {output_mode}." |
|
|
if output_mode == "coco_rle": |
|
|
from pycocotools import mask as mask_utils |
|
|
|
|
|
if min_mask_region_area > 0: |
|
|
import cv2 |
|
|
|
|
|
|
|
|
self.model = model |
|
|
self.processor = processor |
|
|
self.device = self.model.device |
|
|
self.dtype = self.model.dtype |
|
|
|
|
|
self.points_per_batch = points_per_batch |
|
|
self.pred_iou_thresh = pred_iou_thresh |
|
|
self.stability_score_thresh = stability_score_thresh |
|
|
self.stability_score_offset = stability_score_offset |
|
|
self.box_nms_thresh = box_nms_thresh |
|
|
self.crop_n_layers = crop_n_layers |
|
|
self.crop_nms_thresh = crop_nms_thresh |
|
|
self.crop_overlap_ratio = crop_overlap_ratio |
|
|
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor |
|
|
self.min_mask_region_area = min_mask_region_area |
|
|
self.output_mode = output_mode |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Generates masks for the given image. |
|
|
|
|
|
Arguments: |
|
|
image (np.ndarray): The image to generate masks for, in HWC uint8 format. |
|
|
|
|
|
Returns: |
|
|
list(dict(str, any)): A list over records for masks. Each record is |
|
|
a dict containing the following keys: |
|
|
segmentation (dict(str, any) or np.ndarray): The mask. If |
|
|
output_mode='binary_mask', is an array of shape HW. Otherwise, |
|
|
is a dictionary containing the RLE. |
|
|
bbox (list(float)): The box around the mask, in XYWH format. |
|
|
area (int): The area in pixels of the mask. |
|
|
predicted_iou (float): The model's own prediction of the mask's |
|
|
quality. This is filtered by the pred_iou_thresh parameter. |
|
|
point_coords (list(list(float))): The point coordinates input |
|
|
to the model to generate this mask. |
|
|
stability_score (float): A measure of the mask's quality. This |
|
|
is filtered on using the stability_score_thresh parameter. |
|
|
crop_box (list(float)): The crop of the image used to generate |
|
|
the mask, given in XYWH format. |
|
|
""" |
|
|
|
|
|
|
|
|
mask_data = self._generate_masks(image) |
|
|
|
|
|
|
|
|
if self.min_mask_region_area > 0: |
|
|
mask_data = self.postprocess_small_regions( |
|
|
mask_data, |
|
|
self.min_mask_region_area, |
|
|
max(self.box_nms_thresh, self.crop_nms_thresh), |
|
|
) |
|
|
|
|
|
|
|
|
if self.output_mode == "coco_rle": |
|
|
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] |
|
|
elif self.output_mode == "binary_mask": |
|
|
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] |
|
|
else: |
|
|
mask_data["segmentations"] = mask_data["rles"] |
|
|
|
|
|
|
|
|
curr_anns = [] |
|
|
for idx in range(len(mask_data["segmentations"])): |
|
|
ann = { |
|
|
"segmentation": mask_data["segmentations"][idx], |
|
|
"area": area_from_rle(mask_data["rles"][idx]), |
|
|
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), |
|
|
"predicted_iou": mask_data["iou_preds"][idx].item(), |
|
|
"point_coords": [mask_data["points"][idx].tolist()], |
|
|
"stability_score": mask_data["stability_score"][idx].item(), |
|
|
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), |
|
|
|
|
|
"caption": mask_data["captions"][idx], |
|
|
} |
|
|
curr_anns.append(ann) |
|
|
|
|
|
return curr_anns |
|
|
|
|
|
def _generate_masks(self, image: np.ndarray) -> MaskCaptionData: |
|
|
orig_size = image.shape[:2] |
|
|
crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio) |
|
|
|
|
|
|
|
|
data = MaskCaptionData() |
|
|
for crop_box, layer_idx in zip(crop_boxes, layer_idxs): |
|
|
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) |
|
|
data.cat(crop_data) |
|
|
|
|
|
|
|
|
if len(crop_boxes) > 1: |
|
|
|
|
|
scores = 1 / box_area(data["crop_boxes"]) |
|
|
scores = scores.to(data["boxes"].device) |
|
|
keep_by_nms = batched_nms( |
|
|
data["boxes"].float(), |
|
|
scores, |
|
|
torch.zeros_like(data["boxes"][:, 0]), |
|
|
iou_threshold=self.crop_nms_thresh, |
|
|
) |
|
|
data.filter(keep_by_nms) |
|
|
|
|
|
data.to_numpy() |
|
|
return data |
|
|
|
|
|
def _process_crop( |
|
|
self, |
|
|
image: np.ndarray, |
|
|
crop_box: List[int], |
|
|
crop_layer_idx: int, |
|
|
orig_size: Tuple[int, ...], |
|
|
) -> MaskCaptionData: |
|
|
|
|
|
x0, y0, x1, y1 = crop_box |
|
|
cropped_im = image[y0:y1, x0:x1, :] |
|
|
cropped_im_size = cropped_im.shape[:2] |
|
|
|
|
|
input_image_encoding = self.processor(cropped_im, return_tensors="pt") |
|
|
|
|
|
|
|
|
points_scale = np.array(cropped_im_size)[None, ::-1] |
|
|
points_for_image = self.point_grids[crop_layer_idx] * points_scale |
|
|
|
|
|
|
|
|
data = MaskCaptionData() |
|
|
for (points,) in batch_iterator(self.points_per_batch, points_for_image): |
|
|
batch_data = self._process_batch(input_image_encoding, points, cropped_im_size, crop_box, orig_size) |
|
|
data.cat(batch_data) |
|
|
del batch_data |
|
|
|
|
|
|
|
|
|
|
|
keep_by_nms = batched_nms( |
|
|
data["boxes"].float(), |
|
|
data["iou_preds"], |
|
|
torch.zeros_like(data["boxes"][:, 0]), |
|
|
iou_threshold=self.box_nms_thresh, |
|
|
) |
|
|
data.filter(keep_by_nms) |
|
|
|
|
|
|
|
|
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) |
|
|
data["points"] = uncrop_points(data["points"], crop_box) |
|
|
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) |
|
|
|
|
|
return data |
|
|
|
|
|
def _process_batch( |
|
|
self, |
|
|
input_image_encoding: dict, |
|
|
points: np.ndarray, |
|
|
im_size: Tuple[int, ...], |
|
|
crop_box: List[int], |
|
|
orig_size: Tuple[int, ...], |
|
|
) -> MaskCaptionData: |
|
|
orig_h, orig_w = orig_size |
|
|
|
|
|
|
|
|
sca_input_format_points = points[None, :, None, :] |
|
|
prompt_encoding = self.processor( |
|
|
input_points=torch.tensor(sca_input_format_points), |
|
|
input_labels=None, |
|
|
input_boxes=None, |
|
|
original_sizes=input_image_encoding["original_sizes"], |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = { |
|
|
**input_image_encoding, |
|
|
**prompt_encoding, |
|
|
} |
|
|
for k, v in inputs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
|
|
|
inputs[k] = v.to(self.device, self.dtype if v.dtype == torch.float32 else v.dtype) |
|
|
multimask_output = True |
|
|
tic = time.perf_counter() |
|
|
|
|
|
with torch.inference_mode(): |
|
|
model_outputs = self.model.generate( |
|
|
**inputs, |
|
|
multimask_output=multimask_output, |
|
|
pad_token_id=self.processor.tokenizer.eos_token_id, |
|
|
num_beams=3, |
|
|
|
|
|
|
|
|
) |
|
|
toc = time.perf_counter() |
|
|
print(f"Time taken: {(toc - tic)*1000:0.4f} ms") |
|
|
|
|
|
batch_size, num_masks, num_text_heads, num_tokens = model_outputs.sequences.shape |
|
|
batch_size_, num_masks, num_mask_heads, *_ = model_outputs.pred_masks.shape |
|
|
|
|
|
|
|
|
masks = self.processor.post_process_masks( |
|
|
model_outputs.pred_masks, |
|
|
inputs["original_sizes"], |
|
|
inputs["reshaped_input_sizes"], |
|
|
binarize=False, |
|
|
) |
|
|
iou_scores = model_outputs.iou_scores |
|
|
captions = self.processor.tokenizer.batch_decode( |
|
|
model_outputs.sequences.reshape(-1, num_tokens), skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = masks[0] |
|
|
iou_scores = iou_scores[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data = MaskCaptionData( |
|
|
masks=masks.flatten(0, 1), |
|
|
iou_preds=iou_scores.flatten(0, 1), |
|
|
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), |
|
|
captions=np.array(captions)[..., None].repeat(masks.shape[1], axis=0).flatten().tolist(), |
|
|
) |
|
|
del masks |
|
|
|
|
|
|
|
|
if self.pred_iou_thresh > 0.0: |
|
|
keep_mask = data["iou_preds"] > self.pred_iou_thresh |
|
|
data.filter(keep_mask) |
|
|
|
|
|
|
|
|
|
|
|
self_predictor_model_mask_threshold = 0.0 |
|
|
data["stability_score"] = calculate_stability_score( |
|
|
data["masks"], self_predictor_model_mask_threshold, self.stability_score_offset |
|
|
) |
|
|
if self.stability_score_thresh > 0.0: |
|
|
keep_mask = data["stability_score"] >= self.stability_score_thresh |
|
|
data.filter(keep_mask) |
|
|
|
|
|
|
|
|
|
|
|
data["masks"] = data["masks"] > self_predictor_model_mask_threshold |
|
|
data["boxes"] = batched_mask_to_box(data["masks"]) |
|
|
|
|
|
|
|
|
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) |
|
|
if not torch.all(keep_mask): |
|
|
data.filter(keep_mask) |
|
|
|
|
|
|
|
|
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) |
|
|
data["rles"] = mask_to_rle_pytorch(data["masks"]) |
|
|
del data["masks"] |
|
|
|
|
|
return data |
|
|
|
|
|
@staticmethod |
|
|
def postprocess_small_regions(mask_data: MaskCaptionData, min_area: int, nms_thresh: float) -> MaskCaptionData: |
|
|
""" |
|
|
Removes small disconnected regions and holes in masks, then reruns |
|
|
box NMS to remove any new duplicates. |
|
|
|
|
|
Edits mask_data in place. |
|
|
|
|
|
Requires open-cv as a dependency. |
|
|
""" |
|
|
if len(mask_data["rles"]) == 0: |
|
|
return mask_data |
|
|
|
|
|
|
|
|
new_masks = [] |
|
|
scores = [] |
|
|
for rle in mask_data["rles"]: |
|
|
mask = rle_to_mask(rle) |
|
|
|
|
|
mask, changed = remove_small_regions(mask, min_area, mode="holes") |
|
|
unchanged = not changed |
|
|
mask, changed = remove_small_regions(mask, min_area, mode="islands") |
|
|
unchanged = unchanged and not changed |
|
|
|
|
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0)) |
|
|
|
|
|
|
|
|
scores.append(float(unchanged)) |
|
|
|
|
|
|
|
|
masks = torch.cat(new_masks, dim=0) |
|
|
boxes = batched_mask_to_box(masks) |
|
|
keep_by_nms = batched_nms( |
|
|
boxes.float(), |
|
|
torch.as_tensor(scores), |
|
|
torch.zeros_like(boxes[:, 0]), |
|
|
iou_threshold=nms_thresh, |
|
|
) |
|
|
|
|
|
|
|
|
for i_mask in keep_by_nms: |
|
|
if scores[i_mask] == 0.0: |
|
|
mask_torch = masks[i_mask].unsqueeze(0) |
|
|
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] |
|
|
mask_data["boxes"][i_mask] = boxes[i_mask] |
|
|
mask_data.filter(keep_by_nms) |
|
|
|
|
|
return mask_data |
|
|
|