John Ho commited on
Commit
aaa1b00
·
1 Parent(s): 579e65b

fixed bug when returning masks for multiple objects

Browse files
Files changed (3) hide show
  1. app.py +13 -3
  2. samv2_handler.py +94 -50
  3. visualizer.py +100 -0
app.py CHANGED
@@ -127,6 +127,7 @@ def process_video(
127
  masks: Union[list, str],
128
  drop_masks: bool = False,
129
  ref_frame_idx: int = 0,
 
130
  ):
131
  """
132
  SAM2 Video Segmentation
@@ -153,7 +154,7 @@ def process_video(
153
  device="cuda",
154
  do_tidy_up=True,
155
  drop_mask=drop_masks,
156
- async_frame_load=True,
157
  ref_frame_idx=ref_frame_idx,
158
  )
159
 
@@ -202,12 +203,21 @@ with gr.Blocks() as demo:
202
  JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
203
  """,
204
  ),
205
- gr.Checkbox(label="remove base64 encoded masks from result JSON"),
 
 
 
 
206
  gr.Number(
207
- label="frame index for the provided object masks",
 
208
  value=0,
209
  precision=0,
210
  ),
 
 
 
 
211
  ],
212
  outputs=gr.JSON(label="Output JSON"),
213
  title="SAM2 for Videos",
 
127
  masks: Union[list, str],
128
  drop_masks: bool = False,
129
  ref_frame_idx: int = 0,
130
+ async_frame_load: bool = True,
131
  ):
132
  """
133
  SAM2 Video Segmentation
 
154
  device="cuda",
155
  do_tidy_up=True,
156
  drop_mask=drop_masks,
157
+ async_frame_load=async_frame_load,
158
  ref_frame_idx=ref_frame_idx,
159
  )
160
 
 
203
  JSON list of base64 encoded masks, e.g.: ["b'iVBORw0KGgoAAAANSUhEUgAABDgAAAeAAQAAAAADGtqnAAAXz...'",...]
204
  """,
205
  ),
206
+ gr.Checkbox(
207
+ label="Drop Masks",
208
+ info="remove base64 encoded masks from result JSON",
209
+ value=True,
210
+ ),
211
  gr.Number(
212
+ label="Reference Frame Index",
213
+ info="frame index for the provided object masks",
214
  value=0,
215
  precision=0,
216
  ),
217
+ gr.Checkbox(
218
+ label="async frame load",
219
+ info="start inference in parallel to frame loading",
220
+ ),
221
  ],
222
  outputs=gr.JSON(label="Output JSON"),
223
  title="SAM2 for Videos",
samv2_handler.py CHANGED
@@ -9,8 +9,10 @@ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
  from sam2.utils.misc import variant_to_config_mapping
10
  from sam2.utils.visualization import show_masks
11
  from ffmpeg_extractor import extract_frames, logger
12
- from toolbox.vid_utils import VidInfo
 
13
  from toolbox.mask_encoding import b64_mask_encode
 
14
 
15
  variant_checkpoints_mapping = {
16
  "tiny": "checkpoints/sam2_hiera_tiny.pt",
@@ -32,23 +34,6 @@ class point_xy(BaseModel):
32
  y: Union[int, float]
33
 
34
 
35
- def mask_to_xyxy(mask: np.ndarray) -> tuple:
36
- """Convert a binary mask of shape (h, w) to
37
- xyxy bounding box format (top-left and bottom-right coordinates).
38
- """
39
- ys, xs = np.where(mask)
40
- if len(xs) == 0 or len(ys) == 0:
41
- logger.warning("mask_to_xyxy: No object found in the mask")
42
- return None
43
- x_min = np.min(xs)
44
- y_min = np.min(ys)
45
- x_max = np.max(xs)
46
- y_max = np.max(ys)
47
- xyxy = (x_min, y_min, x_max, y_max)
48
- xyxy = tuple([int(i) for i in xyxy])
49
- return xyxy
50
-
51
-
52
  def load_sam_image_model(
53
  # variant: Literal[*variant_checkpoints_mapping.keys()],
54
  variant: Literal["tiny", "small", "base_plus", "large"],
@@ -96,7 +81,8 @@ def run_sam_im_inference(
96
  point_labels
97
  ), f"{len(points)} points provided but {len(point_labels)} labels given."
98
 
99
- # determine multimask_output
 
100
  has_multi = False
101
  if points and bboxes:
102
  has_multi = True
@@ -129,7 +115,7 @@ def run_sam_im_inference(
129
  box=box_coords,
130
  point_coords=point_coords,
131
  point_labels=point_labels,
132
- multimask_output=has_multi,
133
  )
134
  # mask here is of shape (X, h, w) of np array, X = number of masks
135
 
@@ -138,11 +124,16 @@ def run_sam_im_inference(
138
  else:
139
  output_masks = []
140
  for i, mask in enumerate(masks):
141
- if mask.ndim > 2: # shape (3, h, w)
142
- mask = np.transpose(mask, (1, 2, 0)) # shape (h,w,3)
143
- mask = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
144
- output_masks.append(np.array(mask))
 
 
 
 
145
  else:
 
146
  output_masks.append(mask.squeeze().astype(np.uint8))
147
  return (
148
  [b64_mask_encode(m).decode("ascii") for m in output_masks]
@@ -151,6 +142,48 @@ def run_sam_im_inference(
151
  )
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def run_sam_video_inference(
155
  model: Any,
156
  video_path: str,
@@ -166,7 +199,6 @@ def run_sam_video_inference(
166
  # put video frames into directory
167
  # TODO:
168
  # change frame size
169
- # async frame load
170
  l_frames_fp = extract_frames(
171
  video_path,
172
  fps=sample_fps,
@@ -176,43 +208,55 @@ def run_sam_video_inference(
176
  )
177
  vframes_dir = os.path.dirname(l_frames_fp[0])
178
  vinfo = VidInfo(video_path)
 
179
  w = vinfo["frame_width"]
180
  h = vinfo["frame_height"]
181
 
182
  inference_state = model.init_state(
183
  video_path=vframes_dir, device=device, async_loading_frames=async_frame_load
184
  )
185
- for i, mask in enumerate(masks):
186
- model.add_new_mask(
187
  inference_state=inference_state,
188
  frame_idx=ref_frame_idx,
189
- obj_id=i,
190
  mask=mask,
191
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  masks_generator = model.propagate_in_video(inference_state)
 
 
 
 
 
193
 
194
- detections = []
195
- for i, tracker_ids, mask_logits in masks_generator:
196
- masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
197
- for id, mask in zip(tracker_ids, masks):
198
- mask = mask.squeeze().astype(np.uint8)
199
- xyxy = mask_to_xyxy(mask)
200
- if not xyxy: # mask is empty
201
- logger.debug(f"track_id {id} is missing mask at frame {i}")
202
- continue
203
- x0, y0, x1, y1 = xyxy
204
- det = { # miro's detections format for videos
205
- "frame": i,
206
- "track_id": id,
207
- "x": x0 / w,
208
- "y": y0 / h,
209
- "w": (x1 - x0) / w,
210
- "h": (y1 - y0) / h,
211
- "conf": 1,
212
- }
213
- if not drop_mask:
214
- det["mask_b64"] = b64_mask_encode(mask).decode("ascii")
215
- detections.append(det)
216
 
217
  if do_tidy_up:
218
  # remove vframes_dir
 
9
  from sam2.utils.misc import variant_to_config_mapping
10
  from sam2.utils.visualization import show_masks
11
  from ffmpeg_extractor import extract_frames, logger
12
+ from visualizer import annotate_masks, mask_to_xyxy
13
+ from toolbox.vid_utils import VidInfo, VidReader
14
  from toolbox.mask_encoding import b64_mask_encode
15
+ from toolbox.img_utils import get_pil_im
16
 
17
  variant_checkpoints_mapping = {
18
  "tiny": "checkpoints/sam2_hiera_tiny.pt",
 
34
  y: Union[int, float]
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def load_sam_image_model(
38
  # variant: Literal[*variant_checkpoints_mapping.keys()],
39
  variant: Literal["tiny", "small", "base_plus", "large"],
 
81
  point_labels
82
  ), f"{len(points)} points provided but {len(point_labels)} labels given."
83
 
84
+ # multimask_output actually will provide 3 masks for each segmentation (see https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb)
85
+ # so should also be set to False
86
  has_multi = False
87
  if points and bboxes:
88
  has_multi = True
 
115
  box=box_coords,
116
  point_coords=point_coords,
117
  point_labels=point_labels,
118
+ multimask_output=False, # has_multi,
119
  )
120
  # mask here is of shape (X, h, w) of np array, X = number of masks
121
 
 
124
  else:
125
  output_masks = []
126
  for i, mask in enumerate(masks):
127
+ if mask.ndim > 2: # shape (1, h, w)
128
+ # logger.debug(f"found mask of shape {mask.shape}")
129
+ output_masks.append(mask.squeeze().astype(np.uint8))
130
+
131
+ # when multimask_output = True the mask is shape (3,h,w)
132
+ # mask = np.transpose(mask, (1, 2, 0)) # shape (h,w,3)
133
+ # mask = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
134
+ # output_masks.append(np.array(mask))
135
  else:
136
+ # logger.debug(f"found mask of shape {mask.shape}")
137
  output_masks.append(mask.squeeze().astype(np.uint8))
138
  return (
139
  [b64_mask_encode(m).decode("ascii") for m in output_masks]
 
142
  )
143
 
144
 
145
+ def unpack_masks(
146
+ masks_generator,
147
+ frame_wh: tuple,
148
+ drop_mask: bool = False,
149
+ ):
150
+ """return a list of detections in Miro's format given a SAM2 mask generator"""
151
+ w, h = frame_wh
152
+ detections = []
153
+ for frame_idx, tracker_ids, mask_logits in masks_generator:
154
+ masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
155
+
156
+ # draw a couple frames for debug purpose
157
+ # if frame_idx % 15 == 0:
158
+ # ann_masks = [m.squeeze() for m in masks if mask_to_xyxy(m.squeeze())]
159
+ # if len(ann_masks) > 0:
160
+ # annotate_masks(
161
+ # get_pil_im(np.array(vr.get_data(frame_idx))),
162
+ # masks=ann_masks,
163
+ # ).save(os.path.join(vframes_dir, f"{frame_idx}.png"))
164
+
165
+ for id, mask in zip(tracker_ids, masks):
166
+ mask = mask.squeeze().astype(np.uint8)
167
+ xyxy = mask_to_xyxy(mask)
168
+ if not xyxy: # mask is empty
169
+ # logger.debug(f"track_id {id} is missing mask at frame {frame_idx}")
170
+ continue
171
+ x0, y0, x1, y1 = xyxy
172
+ det = { # miro's detections format for videos
173
+ "frame": frame_idx,
174
+ "track_id": id,
175
+ "x": x0 / w,
176
+ "y": y0 / h,
177
+ "w": (x1 - x0) / w,
178
+ "h": (y1 - y0) / h,
179
+ "conf": 1,
180
+ }
181
+ if not drop_mask:
182
+ det["mask_b64"] = b64_mask_encode(mask).decode("ascii")
183
+ detections.append(det)
184
+ return detections
185
+
186
+
187
  def run_sam_video_inference(
188
  model: Any,
189
  video_path: str,
 
199
  # put video frames into directory
200
  # TODO:
201
  # change frame size
 
202
  l_frames_fp = extract_frames(
203
  video_path,
204
  fps=sample_fps,
 
208
  )
209
  vframes_dir = os.path.dirname(l_frames_fp[0])
210
  vinfo = VidInfo(video_path)
211
+ vr = VidReader(video_path, use_imageio=True)
212
  w = vinfo["frame_width"]
213
  h = vinfo["frame_height"]
214
 
215
  inference_state = model.init_state(
216
  video_path=vframes_dir, device=device, async_loading_frames=async_frame_load
217
  )
218
+ for mask_idx, mask in enumerate(masks):
219
+ _, object_ids, mask_logits = model.add_new_mask(
220
  inference_state=inference_state,
221
  frame_idx=ref_frame_idx,
222
+ obj_id=mask_idx,
223
  mask=mask,
224
  )
225
+ # debug
226
+ logger.debug(
227
+ f"adding mask {mask_idx} of shape {mask.shape} for frame {ref_frame_idx}, xyxy: {mask_to_xyxy(mask)}"
228
+ )
229
+
230
+ # debug init state
231
+ logger.debug(f"model initiated with mask_logits of shape {mask_logits.shape}")
232
+ logger.debug(f"model initiated with object_ids of len {len(object_ids)}")
233
+ init_masks = (mask_logits > 0.0).cpu().numpy().astype(np.uint8)
234
+ init_masks = [m.squeeze() for m in init_masks]
235
+ ref_frame_im = get_pil_im(np.array(vr.get_data(ref_frame_idx)))
236
+ init_masks_im_fp = os.path.join(vframes_dir, f"model_init_masks.jpg")
237
+ input_masks_im_fp = os.path.join(vframes_dir, f"input_masks.jpg")
238
+ annotate_masks(ref_frame_im, init_masks).save(init_masks_im_fp)
239
+ annotate_masks(ref_frame_im, masks).save(input_masks_im_fp)
240
+ logger.debug(f"masks received by model visualized at {init_masks_im_fp}")
241
+ logger.debug(f"masks provided to model visualized at {input_masks_im_fp}")
242
+
243
  masks_generator = model.propagate_in_video(inference_state)
244
+ detections = unpack_masks(
245
+ masks_generator,
246
+ drop_mask=drop_mask,
247
+ frame_wh=(w, h),
248
+ )
249
 
250
+ if ref_frame_idx != 0:
251
+ logger.debug(f"propagating in reverse now from {ref_frame_idx}")
252
+ # there's no need to reset state
253
+ # model.reset_state(inference_state)
254
+ masks_generator = model.propagate_in_video(inference_state, reverse=True)
255
+ detections += unpack_masks(
256
+ masks_generator,
257
+ drop_mask=drop_mask,
258
+ frame_wh=(w, h),
259
+ )
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  if do_tidy_up:
262
  # remove vframes_dir
visualizer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageColor
2
+ import matplotlib.colors as mcolors
3
+ import numpy as np
4
+ from toolbox.mask_encoding import b64_mask_decode
5
+ from toolbox.img_utils import im_draw_bbox, im_draw_point, im_color_mask
6
+
7
+
8
+ def mask_to_xyxy(mask: np.ndarray, verbose: bool = False) -> tuple:
9
+ """Convert a binary mask of shape (h, w) to
10
+ xyxy bounding box format (top-left and bottom-right coordinates).
11
+ """
12
+ ys, xs = np.where(mask)
13
+ if len(xs) == 0 or len(ys) == 0:
14
+ if verbose:
15
+ logger.warning("mask_to_xyxy: No object found in the mask")
16
+ return None
17
+ x_min = np.min(xs)
18
+ y_min = np.min(ys)
19
+ x_max = np.max(xs)
20
+ y_max = np.max(ys)
21
+ xyxy = (x_min, y_min, x_max, y_max)
22
+ xyxy = tuple([int(i) for i in xyxy])
23
+ return xyxy
24
+
25
+
26
+ def annotate_detections(
27
+ im: Image.Image,
28
+ l_obj: list,
29
+ color_key: str = "class",
30
+ bbox_width: int = 1,
31
+ label_key: str = "object_id",
32
+ color_dict: dict = {},
33
+ ):
34
+ # color_list is a list of tuple(name, color_hex)
35
+ color_list = list(
36
+ mcolors.XKCD_COLORS.items()
37
+ ) # list(mcolors.TABLEAU_COLORS.items())
38
+ unique_color_keys = list(
39
+ set([o[color_key] for o in l_obj if color_key in o.keys()])
40
+ )
41
+
42
+ for obj in l_obj:
43
+ color_index = unique_color_keys.index(obj[color_key])
44
+ bbox_color = (
45
+ color_dict[obj[color_key]] if color_dict else color_list[color_index][1]
46
+ )
47
+ im = (
48
+ im_draw_bbox(
49
+ im,
50
+ color=bbox_color,
51
+ width=bbox_width,
52
+ caption=(str(obj[label_key]) if label_key else None),
53
+ **obj["boundingBox"],
54
+ use_bbv=True,
55
+ )
56
+ if "boundingBox" in obj.keys()
57
+ else im_draw_point(
58
+ im,
59
+ **obj["point"],
60
+ width=bbox_width,
61
+ caption=(str(obj[label_key]) if label_key else None),
62
+ color=bbox_color,
63
+ )
64
+ )
65
+ return im
66
+
67
+
68
+ def annotate_masks(
69
+ im: Image.Image, masks: list, mask_alpha: float = 0.9, bbox_width: int = 3
70
+ ) -> Image.Image:
71
+ """returns an annotated pillow image"""
72
+ masks = [
73
+ b64_mask_decode(m).astype(np.uint8) if isinstance(m, str) else m for m in masks
74
+ ]
75
+ segs = []
76
+ for i, m in enumerate(masks):
77
+ x0, y0, x1, y1 = mask_to_xyxy(m)
78
+ segs.append(
79
+ {
80
+ "object_id": i,
81
+ "boundingBox": {"x0": x0, "y0": y0, "x1": x1, "y1": y1},
82
+ }
83
+ )
84
+ ann_im = np.array(im)
85
+ for i, m in enumerate(masks):
86
+ m_color = list(mcolors.XKCD_COLORS.items())[i]
87
+ ann_im = im_color_mask(
88
+ ann_im,
89
+ mask_array=m,
90
+ alpha=mask_alpha,
91
+ rbg_tup=ImageColor.getrgb(m_color[1]),
92
+ )
93
+ ann_im = annotate_detections(
94
+ ann_im,
95
+ l_obj=segs,
96
+ color_key="object_id",
97
+ label_key="object_id",
98
+ bbox_width=bbox_width,
99
+ )
100
+ return ann_im