John Ho
commited on
Commit
·
aaa1b00
1
Parent(s):
579e65b
fixed bug when returning masks for multiple objects
Browse files- app.py +13 -3
- samv2_handler.py +94 -50
- visualizer.py +100 -0
app.py
CHANGED
|
@@ -127,6 +127,7 @@ def process_video(
|
|
| 127 |
masks: Union[list, str],
|
| 128 |
drop_masks: bool = False,
|
| 129 |
ref_frame_idx: int = 0,
|
|
|
|
| 130 |
):
|
| 131 |
"""
|
| 132 |
SAM2 Video Segmentation
|
|
@@ -153,7 +154,7 @@ def process_video(
|
|
| 153 |
device="cuda",
|
| 154 |
do_tidy_up=True,
|
| 155 |
drop_mask=drop_masks,
|
| 156 |
-
async_frame_load=
|
| 157 |
ref_frame_idx=ref_frame_idx,
|
| 158 |
)
|
| 159 |
|
|
@@ -202,12 +203,21 @@ with gr.Blocks() as demo:
|
|
| 202 |
JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
|
| 203 |
""",
|
| 204 |
),
|
| 205 |
-
gr.Checkbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
gr.Number(
|
| 207 |
-
label="
|
|
|
|
| 208 |
value=0,
|
| 209 |
precision=0,
|
| 210 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
],
|
| 212 |
outputs=gr.JSON(label="Output JSON"),
|
| 213 |
title="SAM2 for Videos",
|
|
|
|
| 127 |
masks: Union[list, str],
|
| 128 |
drop_masks: bool = False,
|
| 129 |
ref_frame_idx: int = 0,
|
| 130 |
+
async_frame_load: bool = True,
|
| 131 |
):
|
| 132 |
"""
|
| 133 |
SAM2 Video Segmentation
|
|
|
|
| 154 |
device="cuda",
|
| 155 |
do_tidy_up=True,
|
| 156 |
drop_mask=drop_masks,
|
| 157 |
+
async_frame_load=async_frame_load,
|
| 158 |
ref_frame_idx=ref_frame_idx,
|
| 159 |
)
|
| 160 |
|
|
|
|
| 203 |
JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
|
| 204 |
""",
|
| 205 |
),
|
| 206 |
+
gr.Checkbox(
|
| 207 |
+
label="Drop Masks",
|
| 208 |
+
info="remove base64 encoded masks from result JSON",
|
| 209 |
+
value=True,
|
| 210 |
+
),
|
| 211 |
gr.Number(
|
| 212 |
+
label="Reference Frame Index",
|
| 213 |
+
info="frame index for the provided object masks",
|
| 214 |
value=0,
|
| 215 |
precision=0,
|
| 216 |
),
|
| 217 |
+
gr.Checkbox(
|
| 218 |
+
label="async frame load",
|
| 219 |
+
info="start inference in parallel to frame loading",
|
| 220 |
+
),
|
| 221 |
],
|
| 222 |
outputs=gr.JSON(label="Output JSON"),
|
| 223 |
title="SAM2 for Videos",
|
samv2_handler.py
CHANGED
|
@@ -9,8 +9,10 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
| 9 |
from sam2.utils.misc import variant_to_config_mapping
|
| 10 |
from sam2.utils.visualization import show_masks
|
| 11 |
from ffmpeg_extractor import extract_frames, logger
|
| 12 |
-
from
|
|
|
|
| 13 |
from toolbox.mask_encoding import b64_mask_encode
|
|
|
|
| 14 |
|
| 15 |
variant_checkpoints_mapping = {
|
| 16 |
"tiny": "checkpoints/sam2_hiera_tiny.pt",
|
|
@@ -32,23 +34,6 @@ class point_xy(BaseModel):
|
|
| 32 |
y: Union[int, float]
|
| 33 |
|
| 34 |
|
| 35 |
-
def mask_to_xyxy(mask: np.ndarray) -> tuple:
|
| 36 |
-
"""Convert a binary mask of shape (h, w) to
|
| 37 |
-
xyxy bounding box format (top-left and bottom-right coordinates).
|
| 38 |
-
"""
|
| 39 |
-
ys, xs = np.where(mask)
|
| 40 |
-
if len(xs) == 0 or len(ys) == 0:
|
| 41 |
-
logger.warning("mask_to_xyxy: No object found in the mask")
|
| 42 |
-
return None
|
| 43 |
-
x_min = np.min(xs)
|
| 44 |
-
y_min = np.min(ys)
|
| 45 |
-
x_max = np.max(xs)
|
| 46 |
-
y_max = np.max(ys)
|
| 47 |
-
xyxy = (x_min, y_min, x_max, y_max)
|
| 48 |
-
xyxy = tuple([int(i) for i in xyxy])
|
| 49 |
-
return xyxy
|
| 50 |
-
|
| 51 |
-
|
| 52 |
def load_sam_image_model(
|
| 53 |
# variant: Literal[*variant_checkpoints_mapping.keys()],
|
| 54 |
variant: Literal["tiny", "small", "base_plus", "large"],
|
|
@@ -96,7 +81,8 @@ def run_sam_im_inference(
|
|
| 96 |
point_labels
|
| 97 |
), f"{len(points)} points provided but {len(point_labels)} labels given."
|
| 98 |
|
| 99 |
-
#
|
|
|
|
| 100 |
has_multi = False
|
| 101 |
if points and bboxes:
|
| 102 |
has_multi = True
|
|
@@ -129,7 +115,7 @@ def run_sam_im_inference(
|
|
| 129 |
box=box_coords,
|
| 130 |
point_coords=point_coords,
|
| 131 |
point_labels=point_labels,
|
| 132 |
-
multimask_output=has_multi,
|
| 133 |
)
|
| 134 |
# mask here is of shape (X, h, w) of np array, X = number of masks
|
| 135 |
|
|
@@ -138,11 +124,16 @@ def run_sam_im_inference(
|
|
| 138 |
else:
|
| 139 |
output_masks = []
|
| 140 |
for i, mask in enumerate(masks):
|
| 141 |
-
if mask.ndim > 2: # shape (
|
| 142 |
-
|
| 143 |
-
mask
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
else:
|
|
|
|
| 146 |
output_masks.append(mask.squeeze().astype(np.uint8))
|
| 147 |
return (
|
| 148 |
[b64_mask_encode(m).decode("ascii") for m in output_masks]
|
|
@@ -151,6 +142,48 @@ def run_sam_im_inference(
|
|
| 151 |
)
|
| 152 |
|
| 153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
def run_sam_video_inference(
|
| 155 |
model: Any,
|
| 156 |
video_path: str,
|
|
@@ -166,7 +199,6 @@ def run_sam_video_inference(
|
|
| 166 |
# put video frames into directory
|
| 167 |
# TODO:
|
| 168 |
# change frame size
|
| 169 |
-
# async frame load
|
| 170 |
l_frames_fp = extract_frames(
|
| 171 |
video_path,
|
| 172 |
fps=sample_fps,
|
|
@@ -176,43 +208,55 @@ def run_sam_video_inference(
|
|
| 176 |
)
|
| 177 |
vframes_dir = os.path.dirname(l_frames_fp[0])
|
| 178 |
vinfo = VidInfo(video_path)
|
|
|
|
| 179 |
w = vinfo["frame_width"]
|
| 180 |
h = vinfo["frame_height"]
|
| 181 |
|
| 182 |
inference_state = model.init_state(
|
| 183 |
video_path=vframes_dir, device=device, async_loading_frames=async_frame_load
|
| 184 |
)
|
| 185 |
-
for
|
| 186 |
-
model.add_new_mask(
|
| 187 |
inference_state=inference_state,
|
| 188 |
frame_idx=ref_frame_idx,
|
| 189 |
-
obj_id=
|
| 190 |
mask=mask,
|
| 191 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
masks_generator = model.propagate_in_video(inference_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
det = { # miro's detections format for videos
|
| 205 |
-
"frame": i,
|
| 206 |
-
"track_id": id,
|
| 207 |
-
"x": x0 / w,
|
| 208 |
-
"y": y0 / h,
|
| 209 |
-
"w": (x1 - x0) / w,
|
| 210 |
-
"h": (y1 - y0) / h,
|
| 211 |
-
"conf": 1,
|
| 212 |
-
}
|
| 213 |
-
if not drop_mask:
|
| 214 |
-
det["mask_b64"] = b64_mask_encode(mask).decode("ascii")
|
| 215 |
-
detections.append(det)
|
| 216 |
|
| 217 |
if do_tidy_up:
|
| 218 |
# remove vframes_dir
|
|
|
|
| 9 |
from sam2.utils.misc import variant_to_config_mapping
|
| 10 |
from sam2.utils.visualization import show_masks
|
| 11 |
from ffmpeg_extractor import extract_frames, logger
|
| 12 |
+
from visualizer import annotate_masks, mask_to_xyxy
|
| 13 |
+
from toolbox.vid_utils import VidInfo, VidReader
|
| 14 |
from toolbox.mask_encoding import b64_mask_encode
|
| 15 |
+
from toolbox.img_utils import get_pil_im
|
| 16 |
|
| 17 |
variant_checkpoints_mapping = {
|
| 18 |
"tiny": "checkpoints/sam2_hiera_tiny.pt",
|
|
|
|
| 34 |
y: Union[int, float]
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def load_sam_image_model(
|
| 38 |
# variant: Literal[*variant_checkpoints_mapping.keys()],
|
| 39 |
variant: Literal["tiny", "small", "base_plus", "large"],
|
|
|
|
| 81 |
point_labels
|
| 82 |
), f"{len(points)} points provided but {len(point_labels)} labels given."
|
| 83 |
|
| 84 |
+
# multimask_output actually will provide 3 masks for each segmentation (see https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)
|
| 85 |
+
# so should also be set to False
|
| 86 |
has_multi = False
|
| 87 |
if points and bboxes:
|
| 88 |
has_multi = True
|
|
|
|
| 115 |
box=box_coords,
|
| 116 |
point_coords=point_coords,
|
| 117 |
point_labels=point_labels,
|
| 118 |
+
multimask_output=False, # has_multi,
|
| 119 |
)
|
| 120 |
# mask here is of shape (X, h, w) of np array, X = number of masks
|
| 121 |
|
|
|
|
| 124 |
else:
|
| 125 |
output_masks = []
|
| 126 |
for i, mask in enumerate(masks):
|
| 127 |
+
if mask.ndim > 2: # shape (1, h, w)
|
| 128 |
+
# logger.debug(f"found mask of shape {mask.shape}")
|
| 129 |
+
output_masks.append(mask.squeeze().astype(np.uint8))
|
| 130 |
+
|
| 131 |
+
# when multimask_output = True the mask is shape (3,h,w)
|
| 132 |
+
# mask = np.transpose(mask, (1, 2, 0)) # shape (h,w,3)
|
| 133 |
+
# mask = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
|
| 134 |
+
# output_masks.append(np.array(mask))
|
| 135 |
else:
|
| 136 |
+
# logger.debug(f"found mask of shape {mask.shape}")
|
| 137 |
output_masks.append(mask.squeeze().astype(np.uint8))
|
| 138 |
return (
|
| 139 |
[b64_mask_encode(m).decode("ascii") for m in output_masks]
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
|
| 145 |
+
def unpack_masks(
|
| 146 |
+
masks_generator,
|
| 147 |
+
frame_wh: tuple,
|
| 148 |
+
drop_mask: bool = False,
|
| 149 |
+
):
|
| 150 |
+
"""return a list of detections in Miro's format given a SAM2 mask generator"""
|
| 151 |
+
w, h = frame_wh
|
| 152 |
+
detections = []
|
| 153 |
+
for frame_idx, tracker_ids, mask_logits in masks_generator:
|
| 154 |
+
masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
|
| 155 |
+
|
| 156 |
+
# draw a couple frames for debug purpose
|
| 157 |
+
# if frame_idx % 15 == 0:
|
| 158 |
+
# ann_masks = [m.squeeze() for m in masks if mask_to_xyxy(m.squeeze())]
|
| 159 |
+
# if len(ann_masks) > 0:
|
| 160 |
+
# annotate_masks(
|
| 161 |
+
# get_pil_im(np.array(vr.get_data(frame_idx))),
|
| 162 |
+
# masks=ann_masks,
|
| 163 |
+
# ).save(os.path.join(vframes_dir, f"{frame_idx}.png"))
|
| 164 |
+
|
| 165 |
+
for id, mask in zip(tracker_ids, masks):
|
| 166 |
+
mask = mask.squeeze().astype(np.uint8)
|
| 167 |
+
xyxy = mask_to_xyxy(mask)
|
| 168 |
+
if not xyxy: # mask is empty
|
| 169 |
+
# logger.debug(f"track_id {id} is missing mask at frame {frame_idx}")
|
| 170 |
+
continue
|
| 171 |
+
x0, y0, x1, y1 = xyxy
|
| 172 |
+
det = { # miro's detections format for videos
|
| 173 |
+
"frame": frame_idx,
|
| 174 |
+
"track_id": id,
|
| 175 |
+
"x": x0 / w,
|
| 176 |
+
"y": y0 / h,
|
| 177 |
+
"w": (x1 - x0) / w,
|
| 178 |
+
"h": (y1 - y0) / h,
|
| 179 |
+
"conf": 1,
|
| 180 |
+
}
|
| 181 |
+
if not drop_mask:
|
| 182 |
+
det["mask_b64"] = b64_mask_encode(mask).decode("ascii")
|
| 183 |
+
detections.append(det)
|
| 184 |
+
return detections
|
| 185 |
+
|
| 186 |
+
|
| 187 |
def run_sam_video_inference(
|
| 188 |
model: Any,
|
| 189 |
video_path: str,
|
|
|
|
| 199 |
# put video frames into directory
|
| 200 |
# TODO:
|
| 201 |
# change frame size
|
|
|
|
| 202 |
l_frames_fp = extract_frames(
|
| 203 |
video_path,
|
| 204 |
fps=sample_fps,
|
|
|
|
| 208 |
)
|
| 209 |
vframes_dir = os.path.dirname(l_frames_fp[0])
|
| 210 |
vinfo = VidInfo(video_path)
|
| 211 |
+
vr = VidReader(video_path, use_imageio=True)
|
| 212 |
w = vinfo["frame_width"]
|
| 213 |
h = vinfo["frame_height"]
|
| 214 |
|
| 215 |
inference_state = model.init_state(
|
| 216 |
video_path=vframes_dir, device=device, async_loading_frames=async_frame_load
|
| 217 |
)
|
| 218 |
+
for mask_idx, mask in enumerate(masks):
|
| 219 |
+
_, object_ids, mask_logits = model.add_new_mask(
|
| 220 |
inference_state=inference_state,
|
| 221 |
frame_idx=ref_frame_idx,
|
| 222 |
+
obj_id=mask_idx,
|
| 223 |
mask=mask,
|
| 224 |
)
|
| 225 |
+
# debug
|
| 226 |
+
logger.debug(
|
| 227 |
+
f"adding mask {mask_idx} of shape {mask.shape} for frame {ref_frame_idx}, xyxy: {mask_to_xyxy(mask)}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# debug init state
|
| 231 |
+
logger.debug(f"model initiated with mask_logits of shape {mask_logits.shape}")
|
| 232 |
+
logger.debug(f"model initiated with object_ids of len {len(object_ids)}")
|
| 233 |
+
init_masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
|
| 234 |
+
init_masks = [m.squeeze() for m in init_masks]
|
| 235 |
+
ref_frame_im = get_pil_im(np.array(vr.get_data(ref_frame_idx)))
|
| 236 |
+
init_masks_im_fp = os.path.join(vframes_dir, f"model_init_masks.jpg")
|
| 237 |
+
input_masks_im_fp = os.path.join(vframes_dir, f"input_masks.jpg")
|
| 238 |
+
annotate_masks(ref_frame_im, init_masks).save(init_masks_im_fp)
|
| 239 |
+
annotate_masks(ref_frame_im, masks).save(input_masks_im_fp)
|
| 240 |
+
logger.debug(f"masks received by model visualized at {init_masks_im_fp}")
|
| 241 |
+
logger.debug(f"masks provided to model visualized at {input_masks_im_fp}")
|
| 242 |
+
|
| 243 |
masks_generator = model.propagate_in_video(inference_state)
|
| 244 |
+
detections = unpack_masks(
|
| 245 |
+
masks_generator,
|
| 246 |
+
drop_mask=drop_mask,
|
| 247 |
+
frame_wh=(w, h),
|
| 248 |
+
)
|
| 249 |
|
| 250 |
+
if ref_frame_idx != 0:
|
| 251 |
+
logger.debug(f"propagating in reverse now from {ref_frame_idx}")
|
| 252 |
+
# there's no need to reset state
|
| 253 |
+
# model.reset_state(inference_state)
|
| 254 |
+
masks_generator = model.propagate_in_video(inference_state, reverse=True)
|
| 255 |
+
detections += unpack_masks(
|
| 256 |
+
masks_generator,
|
| 257 |
+
drop_mask=drop_mask,
|
| 258 |
+
frame_wh=(w, h),
|
| 259 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
if do_tidy_up:
|
| 262 |
# remove vframes_dir
|
visualizer.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image, ImageColor
|
| 2 |
+
import matplotlib.colors as mcolors
|
| 3 |
+
import numpy as np
|
| 4 |
+
from toolbox.mask_encoding import b64_mask_decode
|
| 5 |
+
from toolbox.img_utils import im_draw_bbox, im_draw_point, im_color_mask
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def mask_to_xyxy(mask: np.ndarray, verbose: bool = False) -> tuple:
|
| 9 |
+
"""Convert a binary mask of shape (h, w) to
|
| 10 |
+
xyxy bounding box format (top-left and bottom-right coordinates).
|
| 11 |
+
"""
|
| 12 |
+
ys, xs = np.where(mask)
|
| 13 |
+
if len(xs) == 0 or len(ys) == 0:
|
| 14 |
+
if verbose:
|
| 15 |
+
logger.warning("mask_to_xyxy: No object found in the mask")
|
| 16 |
+
return None
|
| 17 |
+
x_min = np.min(xs)
|
| 18 |
+
y_min = np.min(ys)
|
| 19 |
+
x_max = np.max(xs)
|
| 20 |
+
y_max = np.max(ys)
|
| 21 |
+
xyxy = (x_min, y_min, x_max, y_max)
|
| 22 |
+
xyxy = tuple([int(i) for i in xyxy])
|
| 23 |
+
return xyxy
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def annotate_detections(
|
| 27 |
+
im: Image.Image,
|
| 28 |
+
l_obj: list,
|
| 29 |
+
color_key: str = "class",
|
| 30 |
+
bbox_width: int = 1,
|
| 31 |
+
label_key: str = "object_id",
|
| 32 |
+
color_dict: dict = {},
|
| 33 |
+
):
|
| 34 |
+
# color_list is a list of tuple(name, color_hex)
|
| 35 |
+
color_list = list(
|
| 36 |
+
mcolors.XKCD_COLORS.items()
|
| 37 |
+
) # list(mcolors.TABLEAU_COLORS.items())
|
| 38 |
+
unique_color_keys = list(
|
| 39 |
+
set([o[color_key] for o in l_obj if color_key in o.keys()])
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
for obj in l_obj:
|
| 43 |
+
color_index = unique_color_keys.index(obj[color_key])
|
| 44 |
+
bbox_color = (
|
| 45 |
+
color_dict[obj[color_key]] if color_dict else color_list[color_index][1]
|
| 46 |
+
)
|
| 47 |
+
im = (
|
| 48 |
+
im_draw_bbox(
|
| 49 |
+
im,
|
| 50 |
+
color=bbox_color,
|
| 51 |
+
width=bbox_width,
|
| 52 |
+
caption=(str(obj[label_key]) if label_key else None),
|
| 53 |
+
**obj["boundingBox"],
|
| 54 |
+
use_bbv=True,
|
| 55 |
+
)
|
| 56 |
+
if "boundingBox" in obj.keys()
|
| 57 |
+
else im_draw_point(
|
| 58 |
+
im,
|
| 59 |
+
**obj["point"],
|
| 60 |
+
width=bbox_width,
|
| 61 |
+
caption=(str(obj[label_key]) if label_key else None),
|
| 62 |
+
color=bbox_color,
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
return im
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def annotate_masks(
|
| 69 |
+
im: Image.Image, masks: list, mask_alpha: float = 0.9, bbox_width: int = 3
|
| 70 |
+
) -> Image.Image:
|
| 71 |
+
"""returns an annotated pillow image"""
|
| 72 |
+
masks = [
|
| 73 |
+
b64_mask_decode(m).astype(np.uint8) if isinstance(m, str) else m for m in masks
|
| 74 |
+
]
|
| 75 |
+
segs = []
|
| 76 |
+
for i, m in enumerate(masks):
|
| 77 |
+
x0, y0, x1, y1 = mask_to_xyxy(m)
|
| 78 |
+
segs.append(
|
| 79 |
+
{
|
| 80 |
+
"object_id": i,
|
| 81 |
+
"boundingBox": {"x0": x0, "y0": y0, "x1": x1, "y1": y1},
|
| 82 |
+
}
|
| 83 |
+
)
|
| 84 |
+
ann_im = np.array(im)
|
| 85 |
+
for i, m in enumerate(masks):
|
| 86 |
+
m_color = list(mcolors.XKCD_COLORS.items())[i]
|
| 87 |
+
ann_im = im_color_mask(
|
| 88 |
+
ann_im,
|
| 89 |
+
mask_array=m,
|
| 90 |
+
alpha=mask_alpha,
|
| 91 |
+
rbg_tup=ImageColor.getrgb(m_color[1]),
|
| 92 |
+
)
|
| 93 |
+
ann_im = annotate_detections(
|
| 94 |
+
ann_im,
|
| 95 |
+
l_obj=segs,
|
| 96 |
+
color_key="object_id",
|
| 97 |
+
label_key="object_id",
|
| 98 |
+
bbox_width=bbox_width,
|
| 99 |
+
)
|
| 100 |
+
return ann_im
|