Zhen Ye commited on
Commit
032b60f
·
1 Parent(s): 3fde4e4

Refine Grounded-SAM2 tracking behavior and video frame handling

Browse files
inference.py CHANGED
@@ -1641,8 +1641,8 @@ def run_grounded_sam2_tracking(
1641
  frame_path = _os.path.join(frame_dir, frame_names[frame_idx])
1642
  frame = cv2.imread(frame_path)
1643
  if frame is None:
1644
- logging.warning("Failed to read frame %d, skipping", frame_idx)
1645
- continue
1646
 
1647
  frame_objects = tracking_results.get(frame_idx, {})
1648
 
@@ -1671,7 +1671,8 @@ def run_grounded_sam2_tracking(
1671
  label = f"{obj_info.instance_id} {obj_info.class_name}"
1672
  label_list.append(label)
1673
 
1674
- if obj_info.x1 or obj_info.y1 or obj_info.x2 or obj_info.y2:
 
1675
  boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
1676
 
1677
  # Draw masks
 
1641
  frame_path = _os.path.join(frame_dir, frame_names[frame_idx])
1642
  frame = cv2.imread(frame_path)
1643
  if frame is None:
1644
+ logging.warning("Failed to read frame %d, writing blank", frame_idx)
1645
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
1646
 
1647
  frame_objects = tracking_results.get(frame_idx, {})
1648
 
 
1671
  label = f"{obj_info.instance_id} {obj_info.class_name}"
1672
  label_list.append(label)
1673
 
1674
+ has_box = not (obj_info.x1 == 0 and obj_info.y1 == 0 and obj_info.x2 == 0 and obj_info.y2 == 0)
1675
+ if has_box:
1676
  boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
1677
 
1678
  # Draw masks
models/segmenters/grounded_sam2.py CHANGED
@@ -10,6 +10,7 @@ Reference implementation:
10
 
11
  import copy
12
  import logging
 
13
  from dataclasses import dataclass, field
14
  from typing import Any, Dict, List, Optional, Sequence, Tuple
15
 
@@ -84,7 +85,7 @@ class MaskDictionary:
84
  def update_masks(
85
  self,
86
  tracking_dict: "MaskDictionary",
87
- iou_threshold: float = 0.8,
88
  objects_count: int = 0,
89
  ) -> int:
90
  """Match current detections against tracked objects via IoU."""
@@ -156,7 +157,7 @@ class GroundedSAM2Segmenter(Segmenter):
156
  model_size: str = "large",
157
  device: Optional[str] = None,
158
  step: int = 20,
159
- iou_threshold: float = 0.8,
160
  ):
161
  self.model_size = model_size
162
  self.step = step
@@ -240,7 +241,9 @@ class GroundedSAM2Segmenter(Segmenter):
240
  import cv2 as _cv2
241
  frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB)
242
 
243
- with torch.autocast(device_type=self.device.split(":")[0], dtype=torch.bfloat16):
 
 
244
  self._image_predictor.set_image(frame_rgb)
245
  input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
246
  masks, scores, _ = self._image_predictor.predict(
@@ -311,68 +314,70 @@ class GroundedSAM2Segmenter(Segmenter):
311
  total_frames, step, text_prompts,
312
  )
313
 
314
- # Init SAM2 video predictor state
315
- with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
316
- inference_state = self._video_predictor.init_state(
317
- video_path=frame_dir,
318
- offload_video_to_cpu=True,
319
- async_loading_frames=True,
320
- )
321
 
322
  sam2_masks = MaskDictionary()
323
  objects_count = 0
324
  all_results: Dict[int, Dict[int, ObjectInfo]] = {}
325
 
326
- for start_idx in range(0, total_frames, step):
327
- logging.info("Processing keyframe %d / %d", start_idx, total_frames)
 
 
 
 
 
328
 
329
- img_path = os.path.join(frame_dir, frame_names[start_idx])
330
- image = Image.open(img_path).convert("RGB")
331
 
332
- mask_dict = MaskDictionary()
 
333
 
334
- # -- Grounding DINO detection on keyframe --
335
- inputs = gdino_processor(
336
- images=image, text=prompt, return_tensors="pt"
337
- )
338
- inputs = {k: v.to(device) for k, v in inputs.items()}
339
 
340
- with torch.no_grad():
341
- outputs = gdino_model(**inputs)
 
 
 
342
 
343
- results = gdino_processor.post_process_grounded_object_detection(
344
- outputs,
345
- inputs["input_ids"],
346
- threshold=0.25,
347
- text_threshold=0.25,
348
- target_sizes=[image.size[::-1]],
349
- )
350
 
351
- input_boxes = results[0]["boxes"]
352
- det_labels = results[0].get("text_labels") or results[0].get("labels", [])
353
- if torch.is_tensor(det_labels):
354
- det_labels = det_labels.detach().cpu().tolist()
355
- det_labels = [str(l) for l in det_labels]
356
-
357
- if input_boxes.shape[0] == 0:
358
- logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
359
- # Fill empty results for this segment
360
- for fi in range(start_idx, min(start_idx + step, total_frames)):
361
- if fi not in all_results:
362
- # Carry forward last known masks
363
- all_results[fi] = {
364
- k: ObjectInfo(
365
- instance_id=v.instance_id,
366
- mask=v.mask,
367
- class_name=v.class_name,
368
- x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2,
369
- )
370
- for k, v in sam2_masks.labels.items()
371
- } if sam2_masks.labels else {}
372
- continue
373
 
374
- # -- SAM2 image predictor on keyframe --
375
- with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  self._image_predictor.set_image(np.array(image))
377
  masks, scores, logits = self._image_predictor.predict(
378
  point_coords=None,
@@ -381,34 +386,33 @@ class GroundedSAM2Segmenter(Segmenter):
381
  multimask_output=False,
382
  )
383
 
384
- # Normalize mask dims
385
- if masks.ndim == 2:
386
- masks = masks[None]
387
- scores = scores[None]
388
- logits = logits[None]
389
- elif masks.ndim == 4:
390
- masks = masks.squeeze(1)
391
-
392
- mask_dict.add_new_frame_annotation(
393
- mask_list=torch.tensor(masks).to(device),
394
- box_list=torch.tensor(input_boxes.cpu().numpy() if torch.is_tensor(input_boxes) else input_boxes),
395
- label_list=det_labels,
396
- )
397
 
398
- # -- IoU matching to maintain persistent IDs --
399
- objects_count = mask_dict.update_masks(
400
- tracking_dict=sam2_masks,
401
- iou_threshold=self.iou_threshold,
402
- objects_count=objects_count,
403
- )
404
 
405
- if len(mask_dict.labels) == 0:
406
- for fi in range(start_idx, min(start_idx + step, total_frames)):
407
- all_results[fi] = {}
408
- continue
409
 
410
- # -- SAM2 video predictor: propagate masks --
411
- with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
412
  self._video_predictor.reset_state(inference_state)
413
 
414
  for obj_id, obj_info in mask_dict.labels.items():
 
10
 
11
  import copy
12
  import logging
13
+ from contextlib import nullcontext
14
  from dataclasses import dataclass, field
15
  from typing import Any, Dict, List, Optional, Sequence, Tuple
16
 
 
85
  def update_masks(
86
  self,
87
  tracking_dict: "MaskDictionary",
88
+ iou_threshold: float = 0.5,
89
  objects_count: int = 0,
90
  ) -> int:
91
  """Match current detections against tracked objects via IoU."""
 
157
  model_size: str = "large",
158
  device: Optional[str] = None,
159
  step: int = 20,
160
+ iou_threshold: float = 0.5,
161
  ):
162
  self.model_size = model_size
163
  self.step = step
 
241
  import cv2 as _cv2
242
  frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB)
243
 
244
+ device_type = self.device.split(":")[0]
245
+ autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
246
+ with autocast_ctx:
247
  self._image_predictor.set_image(frame_rgb)
248
  input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
249
  masks, scores, _ = self._image_predictor.predict(
 
314
  total_frames, step, text_prompts,
315
  )
316
 
317
+ # Single global autocast context (matches reference implementation)
318
+ device_type = device.split(":")[0]
319
+ autocast_ctx = torch.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
 
 
 
 
320
 
321
  sam2_masks = MaskDictionary()
322
  objects_count = 0
323
  all_results: Dict[int, Dict[int, ObjectInfo]] = {}
324
 
325
+ with autocast_ctx:
326
+ # Init SAM2 video predictor state
327
+ inference_state = self._video_predictor.init_state(
328
+ video_path=frame_dir,
329
+ offload_video_to_cpu=True,
330
+ async_loading_frames=True,
331
+ )
332
 
333
+ for start_idx in range(0, total_frames, step):
334
+ logging.info("Processing keyframe %d / %d", start_idx, total_frames)
335
 
336
+ img_path = os.path.join(frame_dir, frame_names[start_idx])
337
+ image = Image.open(img_path).convert("RGB")
338
 
339
+ mask_dict = MaskDictionary()
 
 
 
 
340
 
341
+ # -- Grounding DINO detection on keyframe --
342
+ inputs = gdino_processor(
343
+ images=image, text=prompt, return_tensors="pt"
344
+ )
345
+ inputs = {k: v.to(device) for k, v in inputs.items()}
346
 
347
+ with torch.no_grad():
348
+ outputs = gdino_model(**inputs)
 
 
 
 
 
349
 
350
+ # Use GDINO detector's _post_process for transformers version compat
351
+ results = self._gdino_detector._post_process(
352
+ outputs,
353
+ inputs["input_ids"],
354
+ target_sizes=[image.size[::-1]],
355
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ input_boxes = results[0]["boxes"]
358
+ det_labels = results[0].get("text_labels") or results[0].get("labels", [])
359
+ if torch.is_tensor(det_labels):
360
+ det_labels = det_labels.detach().cpu().tolist()
361
+ det_labels = [str(l) for l in det_labels]
362
+
363
+ if input_boxes.shape[0] == 0:
364
+ logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
365
+ # Fill empty results for this segment
366
+ for fi in range(start_idx, min(start_idx + step, total_frames)):
367
+ if fi not in all_results:
368
+ # Carry forward last known masks
369
+ all_results[fi] = {
370
+ k: ObjectInfo(
371
+ instance_id=v.instance_id,
372
+ mask=v.mask,
373
+ class_name=v.class_name,
374
+ x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2,
375
+ )
376
+ for k, v in sam2_masks.labels.items()
377
+ } if sam2_masks.labels else {}
378
+ continue
379
+
380
+ # -- SAM2 image predictor on keyframe --
381
  self._image_predictor.set_image(np.array(image))
382
  masks, scores, logits = self._image_predictor.predict(
383
  point_coords=None,
 
386
  multimask_output=False,
387
  )
388
 
389
+ # Normalize mask dims
390
+ if masks.ndim == 2:
391
+ masks = masks[None]
392
+ scores = scores[None]
393
+ logits = logits[None]
394
+ elif masks.ndim == 4:
395
+ masks = masks.squeeze(1)
396
+
397
+ mask_dict.add_new_frame_annotation(
398
+ mask_list=torch.tensor(masks).to(device),
399
+ box_list=input_boxes.clone() if torch.is_tensor(input_boxes) else torch.tensor(input_boxes),
400
+ label_list=det_labels,
401
+ )
402
 
403
+ # -- IoU matching to maintain persistent IDs --
404
+ objects_count = mask_dict.update_masks(
405
+ tracking_dict=sam2_masks,
406
+ iou_threshold=self.iou_threshold,
407
+ objects_count=objects_count,
408
+ )
409
 
410
+ if len(mask_dict.labels) == 0:
411
+ for fi in range(start_idx, min(start_idx + step, total_frames)):
412
+ all_results[fi] = {}
413
+ continue
414
 
415
+ # -- SAM2 video predictor: propagate masks --
 
416
  self._video_predictor.reset_state(inference_state)
417
 
418
  for obj_id, obj_info in mask_dict.labels.items():
utils/video.py CHANGED
@@ -43,7 +43,7 @@ def extract_frames_to_jpeg_dir(
43
  if not success:
44
  break
45
  fname = f"{idx:06d}.jpg"
46
- cv2.imwrite(os.path.join(output_dir, fname), frame)
47
  frame_names.append(fname)
48
  idx += 1
49
 
 
43
  if not success:
44
  break
45
  fname = f"{idx:06d}.jpg"
46
+ cv2.imwrite(os.path.join(output_dir, fname), frame, [cv2.IMWRITE_JPEG_QUALITY, 100])
47
  frame_names.append(fname)
48
  idx += 1
49