Spaces:
Runtime error
Runtime error
jhj0517
commited on
Commit
·
62faa17
1
Parent(s):
a503e15
Update sam2
Browse files
segment-anything-2/sam2/sam2_video_predictor.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
|
|
|
| 7 |
from collections import OrderedDict
|
| 8 |
|
| 9 |
import torch
|
|
@@ -44,11 +45,13 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 44 |
async_loading_frames=False,
|
| 45 |
):
|
| 46 |
"""Initialize a inference state."""
|
|
|
|
| 47 |
images, video_height, video_width = load_video_frames(
|
| 48 |
video_path=video_path,
|
| 49 |
image_size=self.image_size,
|
| 50 |
offload_video_to_cpu=offload_video_to_cpu,
|
| 51 |
async_loading_frames=async_loading_frames,
|
|
|
|
| 52 |
)
|
| 53 |
inference_state = {}
|
| 54 |
inference_state["images"] = images
|
|
@@ -64,11 +67,11 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 64 |
# the original video height and width, used for resizing final output scores
|
| 65 |
inference_state["video_height"] = video_height
|
| 66 |
inference_state["video_width"] = video_width
|
| 67 |
-
inference_state["device"] =
|
| 68 |
if offload_state_to_cpu:
|
| 69 |
inference_state["storage_device"] = torch.device("cpu")
|
| 70 |
else:
|
| 71 |
-
inference_state["storage_device"] =
|
| 72 |
# inputs on each frame
|
| 73 |
inference_state["point_inputs_per_obj"] = {}
|
| 74 |
inference_state["mask_inputs_per_obj"] = {}
|
|
@@ -103,6 +106,23 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 103 |
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| 104 |
return inference_state
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def _obj_id_to_idx(self, inference_state, obj_id):
|
| 107 |
"""Map client-side object id to model-side object index."""
|
| 108 |
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
@@ -146,29 +166,66 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 146 |
return len(inference_state["obj_idx_to_id"])
|
| 147 |
|
| 148 |
@torch.inference_mode()
|
| 149 |
-
def
|
| 150 |
self,
|
| 151 |
inference_state,
|
| 152 |
frame_idx,
|
| 153 |
obj_id,
|
| 154 |
-
points,
|
| 155 |
-
labels,
|
| 156 |
clear_old_points=True,
|
| 157 |
normalize_coords=True,
|
|
|
|
| 158 |
):
|
| 159 |
"""Add new points to a frame."""
|
| 160 |
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 161 |
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| 162 |
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
| 163 |
|
| 164 |
-
if not
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
points = torch.tensor(points, dtype=torch.float32)
|
| 166 |
-
if
|
|
|
|
|
|
|
| 167 |
labels = torch.tensor(labels, dtype=torch.int32)
|
| 168 |
if points.dim() == 2:
|
| 169 |
points = points.unsqueeze(0) # add batch dimension
|
| 170 |
if labels.dim() == 1:
|
| 171 |
labels = labels.unsqueeze(0) # add batch dimension
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
if normalize_coords:
|
| 173 |
video_H = inference_state["video_height"]
|
| 174 |
video_W = inference_state["video_width"]
|
|
@@ -215,7 +272,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 215 |
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
| 216 |
|
| 217 |
if prev_out is not None and prev_out["pred_masks"] is not None:
|
| 218 |
-
|
|
|
|
| 219 |
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
| 220 |
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
| 221 |
current_out, _ = self._run_single_frame_inference(
|
|
@@ -251,6 +309,10 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 251 |
)
|
| 252 |
return frame_idx, obj_ids, video_res_masks
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
@torch.inference_mode()
|
| 255 |
def add_new_mask(
|
| 256 |
self,
|
|
@@ -531,7 +593,7 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 531 |
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 532 |
# Find all the frames that contain temporary outputs for any objects
|
| 533 |
# (these should be the frames that have just received clicks for mask inputs
|
| 534 |
-
# via `
|
| 535 |
temp_frame_inds = set()
|
| 536 |
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 537 |
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
@@ -734,7 +796,8 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 734 |
)
|
| 735 |
if backbone_out is None:
|
| 736 |
# Cache miss -- we will run inference on a single image
|
| 737 |
-
|
|
|
|
| 738 |
backbone_out = self.forward_image(image)
|
| 739 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 740 |
# a frame; we can use an LRU cache for more frames in the future).
|
|
@@ -895,4 +958,4 @@ class SAM2VideoPredictor(SAM2Base):
|
|
| 895 |
for t in range(frame_idx_begin, frame_idx_end + 1):
|
| 896 |
non_cond_frame_outputs.pop(t, None)
|
| 897 |
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 898 |
-
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
|
|
|
| 4 |
# This source code is licensed under the license found in the
|
| 5 |
# LICENSE file in the root directory of this source tree.
|
| 6 |
|
| 7 |
+
import warnings
|
| 8 |
from collections import OrderedDict
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 45 |
async_loading_frames=False,
|
| 46 |
):
|
| 47 |
"""Initialize a inference state."""
|
| 48 |
+
compute_device = self.device # device of the model
|
| 49 |
images, video_height, video_width = load_video_frames(
|
| 50 |
video_path=video_path,
|
| 51 |
image_size=self.image_size,
|
| 52 |
offload_video_to_cpu=offload_video_to_cpu,
|
| 53 |
async_loading_frames=async_loading_frames,
|
| 54 |
+
compute_device=compute_device,
|
| 55 |
)
|
| 56 |
inference_state = {}
|
| 57 |
inference_state["images"] = images
|
|
|
|
| 67 |
# the original video height and width, used for resizing final output scores
|
| 68 |
inference_state["video_height"] = video_height
|
| 69 |
inference_state["video_width"] = video_width
|
| 70 |
+
inference_state["device"] = compute_device
|
| 71 |
if offload_state_to_cpu:
|
| 72 |
inference_state["storage_device"] = torch.device("cpu")
|
| 73 |
else:
|
| 74 |
+
inference_state["storage_device"] = compute_device
|
| 75 |
# inputs on each frame
|
| 76 |
inference_state["point_inputs_per_obj"] = {}
|
| 77 |
inference_state["mask_inputs_per_obj"] = {}
|
|
|
|
| 106 |
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| 107 |
return inference_state
|
| 108 |
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
| 111 |
+
"""
|
| 112 |
+
Load a pretrained model from the Hugging Face hub.
|
| 113 |
+
|
| 114 |
+
Arguments:
|
| 115 |
+
model_id (str): The Hugging Face repository ID.
|
| 116 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
(SAM2VideoPredictor): The loaded model.
|
| 120 |
+
"""
|
| 121 |
+
from sam2.build_sam import build_sam2_video_predictor_hf
|
| 122 |
+
|
| 123 |
+
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
| 124 |
+
return sam_model
|
| 125 |
+
|
| 126 |
def _obj_id_to_idx(self, inference_state, obj_id):
|
| 127 |
"""Map client-side object id to model-side object index."""
|
| 128 |
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
|
|
|
| 166 |
return len(inference_state["obj_idx_to_id"])
|
| 167 |
|
| 168 |
@torch.inference_mode()
|
| 169 |
+
def add_new_points_or_box(
|
| 170 |
self,
|
| 171 |
inference_state,
|
| 172 |
frame_idx,
|
| 173 |
obj_id,
|
| 174 |
+
points=None,
|
| 175 |
+
labels=None,
|
| 176 |
clear_old_points=True,
|
| 177 |
normalize_coords=True,
|
| 178 |
+
box=None,
|
| 179 |
):
|
| 180 |
"""Add new points to a frame."""
|
| 181 |
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 182 |
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| 183 |
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
| 184 |
|
| 185 |
+
if (points is not None) != (labels is not None):
|
| 186 |
+
raise ValueError("points and labels must be provided together")
|
| 187 |
+
if points is None and box is None:
|
| 188 |
+
raise ValueError("at least one of points or box must be provided as input")
|
| 189 |
+
|
| 190 |
+
if points is None:
|
| 191 |
+
points = torch.zeros(0, 2, dtype=torch.float32)
|
| 192 |
+
elif not isinstance(points, torch.Tensor):
|
| 193 |
points = torch.tensor(points, dtype=torch.float32)
|
| 194 |
+
if labels is None:
|
| 195 |
+
labels = torch.zeros(0, dtype=torch.int32)
|
| 196 |
+
elif not isinstance(labels, torch.Tensor):
|
| 197 |
labels = torch.tensor(labels, dtype=torch.int32)
|
| 198 |
if points.dim() == 2:
|
| 199 |
points = points.unsqueeze(0) # add batch dimension
|
| 200 |
if labels.dim() == 1:
|
| 201 |
labels = labels.unsqueeze(0) # add batch dimension
|
| 202 |
+
|
| 203 |
+
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
| 204 |
+
# along with the user-provided points (consistent with how SAM 2 is trained).
|
| 205 |
+
if box is not None:
|
| 206 |
+
if not clear_old_points:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
"cannot add box without clearing old points, since "
|
| 209 |
+
"box prompt must be provided before any point prompt "
|
| 210 |
+
"(please use clear_old_points=True instead)"
|
| 211 |
+
)
|
| 212 |
+
if inference_state["tracking_has_started"]:
|
| 213 |
+
warnings.warn(
|
| 214 |
+
"You are adding a box after tracking starts. SAM 2 may not always be "
|
| 215 |
+
"able to incorporate a box prompt for *refinement*. If you intend to "
|
| 216 |
+
"use box prompt as an *initial* input before tracking, please call "
|
| 217 |
+
"'reset_state' on the inference state to restart from scratch.",
|
| 218 |
+
category=UserWarning,
|
| 219 |
+
stacklevel=2,
|
| 220 |
+
)
|
| 221 |
+
if not isinstance(box, torch.Tensor):
|
| 222 |
+
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
| 223 |
+
box_coords = box.reshape(1, 2, 2)
|
| 224 |
+
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
| 225 |
+
box_labels = box_labels.reshape(1, 2)
|
| 226 |
+
points = torch.cat([box_coords, points], dim=1)
|
| 227 |
+
labels = torch.cat([box_labels, labels], dim=1)
|
| 228 |
+
|
| 229 |
if normalize_coords:
|
| 230 |
video_H = inference_state["video_height"]
|
| 231 |
video_W = inference_state["video_width"]
|
|
|
|
| 272 |
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
| 273 |
|
| 274 |
if prev_out is not None and prev_out["pred_masks"] is not None:
|
| 275 |
+
device = inference_state["device"]
|
| 276 |
+
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
| 277 |
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
| 278 |
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
| 279 |
current_out, _ = self._run_single_frame_inference(
|
|
|
|
| 309 |
)
|
| 310 |
return frame_idx, obj_ids, video_res_masks
|
| 311 |
|
| 312 |
+
def add_new_points(self, *args, **kwargs):
|
| 313 |
+
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
| 314 |
+
return self.add_new_points_or_box(*args, **kwargs)
|
| 315 |
+
|
| 316 |
@torch.inference_mode()
|
| 317 |
def add_new_mask(
|
| 318 |
self,
|
|
|
|
| 593 |
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 594 |
# Find all the frames that contain temporary outputs for any objects
|
| 595 |
# (these should be the frames that have just received clicks for mask inputs
|
| 596 |
+
# via `add_new_points_or_box` or `add_new_mask`)
|
| 597 |
temp_frame_inds = set()
|
| 598 |
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 599 |
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
|
|
|
| 796 |
)
|
| 797 |
if backbone_out is None:
|
| 798 |
# Cache miss -- we will run inference on a single image
|
| 799 |
+
device = inference_state["device"]
|
| 800 |
+
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
| 801 |
backbone_out = self.forward_image(image)
|
| 802 |
# Cache the most recent frame's feature (for repeated interactions with
|
| 803 |
# a frame; we can use an LRU cache for more frames in the future).
|
|
|
|
| 958 |
for t in range(frame_idx_begin, frame_idx_end + 1):
|
| 959 |
non_cond_frame_outputs.pop(t, None)
|
| 960 |
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 961 |
+
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
segment-anything-2/sam2/utils/misc.py
CHANGED
|
@@ -106,7 +106,15 @@ class AsyncVideoFrameLoader:
|
|
| 106 |
A list of video frames to be load asynchronously without blocking session start.
|
| 107 |
"""
|
| 108 |
|
| 109 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
self.img_paths = img_paths
|
| 111 |
self.image_size = image_size
|
| 112 |
self.offload_video_to_cpu = offload_video_to_cpu
|
|
@@ -119,6 +127,7 @@ class AsyncVideoFrameLoader:
|
|
| 119 |
# video_height and video_width be filled when loading the first image
|
| 120 |
self.video_height = None
|
| 121 |
self.video_width = None
|
|
|
|
| 122 |
|
| 123 |
# load the first frame to fill video_height and video_width and also
|
| 124 |
# to cache it (since it's most likely where the user will click)
|
|
@@ -152,7 +161,7 @@ class AsyncVideoFrameLoader:
|
|
| 152 |
img -= self.img_mean
|
| 153 |
img /= self.img_std
|
| 154 |
if not self.offload_video_to_cpu:
|
| 155 |
-
img = img.
|
| 156 |
self.images[index] = img
|
| 157 |
return img
|
| 158 |
|
|
@@ -167,6 +176,7 @@ def load_video_frames(
|
|
| 167 |
img_mean=(0.485, 0.456, 0.406),
|
| 168 |
img_std=(0.229, 0.224, 0.225),
|
| 169 |
async_loading_frames=False,
|
|
|
|
| 170 |
):
|
| 171 |
"""
|
| 172 |
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
|
@@ -179,7 +189,15 @@ def load_video_frames(
|
|
| 179 |
if isinstance(video_path, str) and os.path.isdir(video_path):
|
| 180 |
jpg_folder = video_path
|
| 181 |
else:
|
| 182 |
-
raise NotImplementedError(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
frame_names = [
|
| 185 |
p
|
|
@@ -196,7 +214,12 @@ def load_video_frames(
|
|
| 196 |
|
| 197 |
if async_loading_frames:
|
| 198 |
lazy_images = AsyncVideoFrameLoader(
|
| 199 |
-
img_paths,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
)
|
| 201 |
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
| 202 |
|
|
@@ -204,9 +227,9 @@ def load_video_frames(
|
|
| 204 |
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
| 205 |
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
| 206 |
if not offload_video_to_cpu:
|
| 207 |
-
images = images.
|
| 208 |
-
img_mean = img_mean.
|
| 209 |
-
img_std = img_std.
|
| 210 |
# normalize by mean and std
|
| 211 |
images -= img_mean
|
| 212 |
images /= img_std
|
|
@@ -220,10 +243,25 @@ def fill_holes_in_mask_scores(mask, max_area):
|
|
| 220 |
# Holes are those connected components in background with area <= self.max_area
|
| 221 |
# (background regions are those with mask scores <= 0)
|
| 222 |
assert max_area > 0, "max_area must be positive"
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
return mask
|
| 228 |
|
| 229 |
|
|
@@ -235,4 +273,4 @@ def concat_points(old_point_inputs, new_points, new_labels):
|
|
| 235 |
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
|
| 236 |
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
|
| 237 |
|
| 238 |
-
return {"point_coords": points, "point_labels": labels}
|
|
|
|
| 106 |
A list of video frames to be load asynchronously without blocking session start.
|
| 107 |
"""
|
| 108 |
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
img_paths,
|
| 112 |
+
image_size,
|
| 113 |
+
offload_video_to_cpu,
|
| 114 |
+
img_mean,
|
| 115 |
+
img_std,
|
| 116 |
+
compute_device,
|
| 117 |
+
):
|
| 118 |
self.img_paths = img_paths
|
| 119 |
self.image_size = image_size
|
| 120 |
self.offload_video_to_cpu = offload_video_to_cpu
|
|
|
|
| 127 |
# video_height and video_width be filled when loading the first image
|
| 128 |
self.video_height = None
|
| 129 |
self.video_width = None
|
| 130 |
+
self.compute_device = compute_device
|
| 131 |
|
| 132 |
# load the first frame to fill video_height and video_width and also
|
| 133 |
# to cache it (since it's most likely where the user will click)
|
|
|
|
| 161 |
img -= self.img_mean
|
| 162 |
img /= self.img_std
|
| 163 |
if not self.offload_video_to_cpu:
|
| 164 |
+
img = img.to(self.compute_device, non_blocking=True)
|
| 165 |
self.images[index] = img
|
| 166 |
return img
|
| 167 |
|
|
|
|
| 176 |
img_mean=(0.485, 0.456, 0.406),
|
| 177 |
img_std=(0.229, 0.224, 0.225),
|
| 178 |
async_loading_frames=False,
|
| 179 |
+
compute_device=torch.device("cuda"),
|
| 180 |
):
|
| 181 |
"""
|
| 182 |
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
|
|
|
| 189 |
if isinstance(video_path, str) and os.path.isdir(video_path):
|
| 190 |
jpg_folder = video_path
|
| 191 |
else:
|
| 192 |
+
raise NotImplementedError(
|
| 193 |
+
"Only JPEG frames are supported at this moment. For video files, you may use "
|
| 194 |
+
"ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
|
| 195 |
+
"```\n"
|
| 196 |
+
"ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
|
| 197 |
+
"```\n"
|
| 198 |
+
"where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
|
| 199 |
+
"ffmpeg to start the JPEG file from 00000.jpg."
|
| 200 |
+
)
|
| 201 |
|
| 202 |
frame_names = [
|
| 203 |
p
|
|
|
|
| 214 |
|
| 215 |
if async_loading_frames:
|
| 216 |
lazy_images = AsyncVideoFrameLoader(
|
| 217 |
+
img_paths,
|
| 218 |
+
image_size,
|
| 219 |
+
offload_video_to_cpu,
|
| 220 |
+
img_mean,
|
| 221 |
+
img_std,
|
| 222 |
+
compute_device,
|
| 223 |
)
|
| 224 |
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
| 225 |
|
|
|
|
| 227 |
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
| 228 |
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
| 229 |
if not offload_video_to_cpu:
|
| 230 |
+
images = images.to(compute_device)
|
| 231 |
+
img_mean = img_mean.to(compute_device)
|
| 232 |
+
img_std = img_std.to(compute_device)
|
| 233 |
# normalize by mean and std
|
| 234 |
images -= img_mean
|
| 235 |
images /= img_std
|
|
|
|
| 243 |
# Holes are those connected components in background with area <= self.max_area
|
| 244 |
# (background regions are those with mask scores <= 0)
|
| 245 |
assert max_area > 0, "max_area must be positive"
|
| 246 |
+
|
| 247 |
+
input_mask = mask
|
| 248 |
+
try:
|
| 249 |
+
labels, areas = get_connected_components(mask <= 0)
|
| 250 |
+
is_hole = (labels > 0) & (areas <= max_area)
|
| 251 |
+
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
| 252 |
+
mask = torch.where(is_hole, 0.1, mask)
|
| 253 |
+
except Exception as e:
|
| 254 |
+
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
| 255 |
+
warnings.warn(
|
| 256 |
+
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
| 257 |
+
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
| 258 |
+
"functionality may be limited (which doesn't affect the results in most cases; see "
|
| 259 |
+
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
| 260 |
+
category=UserWarning,
|
| 261 |
+
stacklevel=2,
|
| 262 |
+
)
|
| 263 |
+
mask = input_mask
|
| 264 |
+
|
| 265 |
return mask
|
| 266 |
|
| 267 |
|
|
|
|
| 273 |
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
|
| 274 |
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
|
| 275 |
|
| 276 |
+
return {"point_coords": points, "point_labels": labels}
|