Zhen Ye commited on
Commit
21c29ae
·
1 Parent(s): 97b3a45

refactor(gsam2): make SAM2 detector-agnostic

Browse files
app.py CHANGED
@@ -248,7 +248,7 @@ async def detect_endpoint(
248
  mode: str = Form(...),
249
  queries: str = Form(""),
250
  detector: str = Form("hf_yolov8"),
251
- segmenter: str = Form("gsam2_large"),
252
  enable_depth: bool = Form(False),
253
  enable_gpt: bool = Form(True),
254
  ):
@@ -260,7 +260,7 @@ async def detect_endpoint(
260
  mode: Detection mode (object_detection, segmentation, drone_detection)
261
  queries: Comma-separated object classes for object_detection mode
262
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
263
- segmenter: Segmentation model to use (gsam2_small, gsam2_base, gsam2_large)
264
  enable_depth: Whether to run legacy depth estimation (default: False)
265
  drone_detection uses the dedicated drone_yolo model.
266
 
@@ -302,6 +302,7 @@ async def detect_endpoint(
302
  output_path,
303
  query_list,
304
  segmenter_name=segmenter,
 
305
  num_maskmem=7,
306
  )
307
  except ValueError as exc:
@@ -402,7 +403,7 @@ async def detect_async_endpoint(
402
  mode: str = Form(...),
403
  queries: str = Form(""),
404
  detector: str = Form("hf_yolov8"),
405
- segmenter: str = Form("gsam2_large"),
406
  depth_estimator: str = Form("depth"),
407
  depth_scale: float = Form(25.0),
408
  enable_depth: bool = Form(False),
@@ -491,7 +492,6 @@ async def detect_async_endpoint(
491
  )
492
  cv2.imwrite(str(first_frame_path), processed_frame)
493
  # GPT and depth are now handled in the async pipeline (enrichment thread)
494
- depth_map = None
495
  first_frame_gpt_results = None
496
  except Exception:
497
  logging.exception("First-frame processing failed.")
@@ -910,7 +910,7 @@ async def chat_threat_endpoint(
910
  async def benchmark_endpoint(
911
  video: UploadFile = File(...),
912
  queries: str = Form("person,car,truck"),
913
- segmenter: str = Form("gsam2_large"),
914
  step: int = Form(60),
915
  num_maskmem: Optional[int] = Form(None),
916
  ):
@@ -1036,7 +1036,7 @@ async def benchmark_profile(
1036
  video: UploadFile = File(...),
1037
  mode: str = Form("detection"),
1038
  detector: str = Form("hf_yolov8"),
1039
- segmenter: str = Form("gsam2_large"),
1040
  queries: str = Form("person,car,truck"),
1041
  max_frames: int = Form(100),
1042
  warmup_frames: int = Form(5),
@@ -1102,7 +1102,7 @@ async def benchmark_analysis(
1102
  video: UploadFile = File(...),
1103
  mode: str = Form("detection"),
1104
  detector: str = Form("hf_yolov8"),
1105
- segmenter: str = Form("gsam2_large"),
1106
  queries: str = Form("person,car,truck"),
1107
  max_frames: int = Form(100),
1108
  warmup_frames: int = Form(5),
 
248
  mode: str = Form(...),
249
  queries: str = Form(""),
250
  detector: str = Form("hf_yolov8"),
251
+ segmenter: str = Form("GSAM2-L"),
252
  enable_depth: bool = Form(False),
253
  enable_gpt: bool = Form(True),
254
  ):
 
260
  mode: Detection mode (object_detection, segmentation, drone_detection)
261
  queries: Comma-separated object classes for object_detection mode
262
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
263
+ segmenter: Segmentation model to use (GSAM2-S, GSAM2-B, GSAM2-L)
264
  enable_depth: Whether to run legacy depth estimation (default: False)
265
  drone_detection uses the dedicated drone_yolo model.
266
 
 
302
  output_path,
303
  query_list,
304
  segmenter_name=segmenter,
305
+ detector_name="grounding_dino",
306
  num_maskmem=7,
307
  )
308
  except ValueError as exc:
 
403
  mode: str = Form(...),
404
  queries: str = Form(""),
405
  detector: str = Form("hf_yolov8"),
406
+ segmenter: str = Form("GSAM2-L"),
407
  depth_estimator: str = Form("depth"),
408
  depth_scale: float = Form(25.0),
409
  enable_depth: bool = Form(False),
 
492
  )
493
  cv2.imwrite(str(first_frame_path), processed_frame)
494
  # GPT and depth are now handled in the async pipeline (enrichment thread)
 
495
  first_frame_gpt_results = None
496
  except Exception:
497
  logging.exception("First-frame processing failed.")
 
910
  async def benchmark_endpoint(
911
  video: UploadFile = File(...),
912
  queries: str = Form("person,car,truck"),
913
+ segmenter: str = Form("GSAM2-L"),
914
  step: int = Form(60),
915
  num_maskmem: Optional[int] = Form(None),
916
  ):
 
1036
  video: UploadFile = File(...),
1037
  mode: str = Form("detection"),
1038
  detector: str = Form("hf_yolov8"),
1039
+ segmenter: str = Form("GSAM2-L"),
1040
  queries: str = Form("person,car,truck"),
1041
  max_frames: int = Form(100),
1042
  warmup_frames: int = Form(5),
 
1102
  video: UploadFile = File(...),
1103
  mode: str = Form("detection"),
1104
  detector: str = Form("hf_yolov8"),
1105
+ segmenter: str = Form("GSAM2-L"),
1106
  queries: str = Form("person,car,truck"),
1107
  max_frames: int = Form(100),
1108
  warmup_frames: int = Form(5),
frontend/index.html CHANGED
@@ -75,9 +75,9 @@
75
  <option value="grounding_dino" data-kind="object">Large</option>
76
  </optgroup>
77
  <optgroup label="Segmentation Models">
78
- <option value="gsam2_large" data-kind="segmentation">SAM2 Large</option>
79
- <option value="gsam2_base" data-kind="segmentation">SAM2 Base+</option>
80
- <option value="gsam2_small" data-kind="segmentation">SAM2 Small</option>
81
  </optgroup>
82
  <optgroup label="Drone Detection Models">
83
  <option value="drone_yolo" data-kind="drone">Drone</option>
 
75
  <option value="grounding_dino" data-kind="object">Large</option>
76
  </optgroup>
77
  <optgroup label="Segmentation Models">
78
+ <option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
79
+ <option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
80
+ <option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
81
  </optgroup>
82
  <optgroup label="Drone Detection Models">
83
  <option value="drone_yolo" data-kind="drone">Drone</option>
frontend/js/main.js CHANGED
@@ -363,11 +363,11 @@ document.addEventListener("DOMContentLoaded", () => {
363
  } else if (kind === "drone") {
364
  mode = "drone_detection";
365
  detectorParam = selectedValue;
366
- segmenterParam = "gsam2_large";
367
  } else {
368
  mode = "object_detection";
369
  detectorParam = selectedValue;
370
- segmenterParam = "gsam2_large";
371
  }
372
 
373
  const form = new FormData();
 
363
  } else if (kind === "drone") {
364
  mode = "drone_detection";
365
  detectorParam = selectedValue;
366
+ segmenterParam = "GSAM2-L";
367
  } else {
368
  mode = "object_detection";
369
  detectorParam = selectedValue;
370
+ segmenterParam = "GSAM2-L";
371
  }
372
 
373
  const form = new FormData();
inference.py CHANGED
@@ -1631,6 +1631,7 @@ def run_grounded_sam2_tracking(
1631
  _perf_metrics: Optional[Dict[str, float]] = None,
1632
  _perf_lock=None,
1633
  num_maskmem: Optional[int] = None,
 
1634
  ) -> str:
1635
  """Run Grounded-SAM-2 video tracking pipeline.
1636
 
@@ -1645,7 +1646,7 @@ def run_grounded_sam2_tracking(
1645
  from utils.video import extract_frames_to_jpeg_dir
1646
  from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
1647
 
1648
- active_segmenter = segmenter_name or "gsam2_large"
1649
  logging.info(
1650
  "Grounded-SAM-2 tracking: segmenter=%s, queries=%s, step=%d",
1651
  active_segmenter, queries, step,
@@ -2120,6 +2121,8 @@ def run_grounded_sam2_tracking(
2120
  # ---------- Single-GPU fallback ----------
2121
  device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
2122
  _seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
 
 
2123
 
2124
  if _perf_metrics is not None:
2125
  _t_load = time.perf_counter()
@@ -2176,6 +2179,8 @@ def run_grounded_sam2_tracking(
2176
  segmenters = []
2177
  with ThreadPoolExecutor(max_workers=num_gpus) as pool:
2178
  _seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
 
 
2179
  futs = [
2180
  pool.submit(
2181
  load_segmenter_on_device,
 
1631
  _perf_metrics: Optional[Dict[str, float]] = None,
1632
  _perf_lock=None,
1633
  num_maskmem: Optional[int] = None,
1634
+ detector_name: Optional[str] = None,
1635
  ) -> str:
1636
  """Run Grounded-SAM-2 video tracking pipeline.
1637
 
 
1646
  from utils.video import extract_frames_to_jpeg_dir
1647
  from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
1648
 
1649
+ active_segmenter = segmenter_name or "GSAM2-L"
1650
  logging.info(
1651
  "Grounded-SAM-2 tracking: segmenter=%s, queries=%s, step=%d",
1652
  active_segmenter, queries, step,
 
2121
  # ---------- Single-GPU fallback ----------
2122
  device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
2123
  _seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
2124
+ if detector_name is not None:
2125
+ _seg_kw["detector_name"] = detector_name
2126
 
2127
  if _perf_metrics is not None:
2128
  _t_load = time.perf_counter()
 
2179
  segmenters = []
2180
  with ThreadPoolExecutor(max_workers=num_gpus) as pool:
2181
  _seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
2182
+ if detector_name is not None:
2183
+ _seg_kw_multi["detector_name"] = detector_name
2184
  futs = [
2185
  pool.submit(
2186
  load_segmenter_on_device,
jobs/background.py CHANGED
@@ -2,12 +2,10 @@ import asyncio
2
  import logging
3
  from datetime import datetime
4
 
5
- import torch
6
-
7
  from jobs.models import JobStatus
8
- from jobs.storage import get_job_storage, get_depth_output_path, get_first_frame_depth_path
9
  from jobs.streaming import create_stream, remove_stream
10
- from inference import run_inference, run_grounded_sam2_tracking, run_depth_inference
11
 
12
 
13
  async def process_video_async(job_id: str) -> None:
@@ -41,6 +39,7 @@ async def process_video_async(job_id: str) -> None:
41
  mission_spec=job.mission_spec,
42
  first_frame_gpt_results=job.first_frame_gpt_results,
43
  num_maskmem=7,
 
44
  )
45
  else:
46
  detections_list = None
 
2
  import logging
3
  from datetime import datetime
4
 
 
 
5
  from jobs.models import JobStatus
6
+ from jobs.storage import get_job_storage
7
  from jobs.streaming import create_stream, remove_stream
8
+ from inference import run_inference, run_grounded_sam2_tracking
9
 
10
 
11
  async def process_video_async(job_id: str) -> None:
 
39
  mission_spec=job.mission_spec,
40
  first_frame_gpt_results=job.first_frame_gpt_results,
41
  num_maskmem=7,
42
+ detector_name=job.detector_name,
43
  )
44
  else:
45
  detections_list = None
models/segmenters/grounded_sam2.py CHANGED
@@ -1,6 +1,6 @@
1
  """Grounded-SAM-2 segmenter with continuous-ID video tracking.
2
 
3
- Combines Grounding DINO (open-vocabulary detection) with SAM2's video
4
  predictor to produce temporally consistent segmentation masks with
5
  persistent object IDs across an entire video.
6
 
@@ -13,7 +13,7 @@ import logging
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
17
 
18
  import numpy as np
19
  import torch
@@ -308,15 +308,26 @@ _SAM2_HF_MODELS = {
308
  }
309
 
310
 
 
 
 
 
 
 
 
 
 
 
311
  # ---------------------------------------------------------------------------
312
  # Grounded-SAM-2 Segmenter
313
  # ---------------------------------------------------------------------------
314
 
315
  class GroundedSAM2Segmenter(Segmenter):
316
- """SAM2 video segmenter driven by Grounding DINO detections.
317
 
318
- For single-frame mode (``predict``), uses GDINO + SAM2 image predictor.
319
- For video mode (``process_video``), uses GDINO on keyframes + SAM2 video
 
320
  predictor for temporal mask propagation with continuous object IDs.
321
  """
322
 
@@ -330,12 +341,15 @@ class GroundedSAM2Segmenter(Segmenter):
330
  step: int = 20,
331
  iou_threshold: float = 0.5,
332
  num_maskmem: Optional[int] = None,
 
333
  ):
334
  self.model_size = model_size
335
  self.step = step
336
  self.iou_threshold = iou_threshold
337
  self.num_maskmem = num_maskmem # None = use default (7)
338
- self.name = f"gsam2_{model_size}"
 
 
339
 
340
  if device:
341
  self.device = device
@@ -345,7 +359,7 @@ class GroundedSAM2Segmenter(Segmenter):
345
  # Lazy-loaded model handles
346
  self._video_predictor = None
347
  self._image_predictor = None
348
- self._gdino_detector = None
349
  self._models_loaded = False
350
 
351
  # -- Lazy loading -------------------------------------------------------
@@ -388,10 +402,11 @@ class GroundedSAM2Segmenter(Segmenter):
388
  self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
389
  logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
390
 
391
- # Reuse existing Grounding DINO detector from our codebase
392
- from models.detectors.grounding_dino import GroundingDinoDetector
393
 
394
- self._gdino_detector = GroundingDinoDetector(device=self.device)
 
395
 
396
  self._models_loaded = True
397
  logging.info("Grounded-SAM-2 models loaded successfully.")
@@ -476,13 +491,13 @@ class GroundedSAM2Segmenter(Segmenter):
476
  def predict(
477
  self, frame: np.ndarray, text_prompts: Optional[list] = None
478
  ) -> SegmentationResult:
479
- """Run GDINO + SAM2 image predictor on a single frame."""
480
  self._ensure_models_loaded()
481
 
482
  prompts = text_prompts or ["object"]
483
 
484
- # Run Grounding DINO to get boxes
485
- det = self._gdino_detector.predict(frame, prompts)
486
  if det.boxes is None or len(det.boxes) == 0:
487
  return SegmentationResult(
488
  masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
@@ -539,11 +554,11 @@ class GroundedSAM2Segmenter(Segmenter):
539
  image: "Image",
540
  text_prompts: List[str],
541
  ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]:
542
- """Run GDINO + SAM2 image predictor on a single keyframe.
543
 
544
  Args:
545
  image: PIL Image in RGB mode.
546
- text_prompts: Text queries for Grounding DINO.
547
 
548
  Returns:
549
  ``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
@@ -554,26 +569,12 @@ class GroundedSAM2Segmenter(Segmenter):
554
  self._ensure_models_loaded()
555
  _pm = getattr(self, '_perf_metrics', None)
556
 
557
- prompt = self._gdino_detector._build_prompt(text_prompts)
558
- gdino_processor = self._gdino_detector.processor
559
- gdino_model = self._gdino_detector.model
560
-
561
  if _pm is not None:
562
  _t0 = time.perf_counter()
563
 
564
- inputs = gdino_processor(
565
- images=image, text=prompt, return_tensors="pt"
566
- )
567
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
568
-
569
- with torch.no_grad():
570
- outputs = gdino_model(**inputs)
571
-
572
- results = self._gdino_detector._post_process(
573
- outputs,
574
- inputs["input_ids"],
575
- target_sizes=[image.size[::-1]],
576
- )
577
 
578
  if _pm is not None:
579
  _pl = getattr(self, '_perf_lock', None)
@@ -583,21 +584,18 @@ class GroundedSAM2Segmenter(Segmenter):
583
  else:
584
  _pm["gdino_total_ms"] += _d
585
 
586
- input_boxes = results[0]["boxes"]
587
- det_labels = results[0].get("text_labels") or results[0].get("labels", [])
588
- if torch.is_tensor(det_labels):
589
- det_labels = det_labels.detach().cpu().tolist()
590
- det_labels = [str(l) for l in det_labels]
591
-
592
- if input_boxes.shape[0] == 0:
593
  return None, None, []
594
 
 
 
 
595
  # SAM2 image predictor
596
  if _pm is not None:
597
  _t1 = time.perf_counter()
598
 
599
  self._image_predictor.set_image(np.array(image))
600
- masks, scores = self._predict_masks_gpu(input_boxes)
601
 
602
  if _pm is not None:
603
  _pl = getattr(self, '_perf_lock', None)
@@ -721,7 +719,7 @@ class GroundedSAM2Segmenter(Segmenter):
721
  Args:
722
  frame_dir: Directory containing JPEG frames.
723
  frame_names: Sorted list of frame filenames.
724
- text_prompts: Text queries for Grounding DINO.
725
  on_segment: Optional callback invoked after each segment completes.
726
  Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
727
 
@@ -735,11 +733,6 @@ class GroundedSAM2Segmenter(Segmenter):
735
 
736
  device = self.device
737
  step = self.step
738
- prompt = self._gdino_detector._build_prompt(text_prompts)
739
-
740
- # HF processor for Grounding DINO (reuse from our detector)
741
- gdino_processor = self._gdino_detector.processor
742
- gdino_model = self._gdino_detector.model
743
 
744
  total_frames = len(frame_names)
745
  logging.info(
@@ -783,24 +776,12 @@ class GroundedSAM2Segmenter(Segmenter):
783
 
784
  mask_dict = MaskDictionary()
785
 
786
- # -- Grounding DINO detection on keyframe --
787
  if _pm is not None:
788
  _t_gd = time.perf_counter()
789
 
790
- inputs = gdino_processor(
791
- images=image, text=prompt, return_tensors="pt"
792
- )
793
- inputs = {k: v.to(device) for k, v in inputs.items()}
794
-
795
- with torch.no_grad():
796
- outputs = gdino_model(**inputs)
797
-
798
- # Use GDINO detector's _post_process for transformers version compat
799
- results = self._gdino_detector._post_process(
800
- outputs,
801
- inputs["input_ids"],
802
- target_sizes=[image.size[::-1]],
803
- )
804
 
805
  if _pm is not None:
806
  _pl = getattr(self, '_perf_lock', None)
@@ -810,13 +791,14 @@ class GroundedSAM2Segmenter(Segmenter):
810
  else:
811
  _pm["gdino_total_ms"] += _d
812
 
813
- input_boxes = results[0]["boxes"]
814
- det_labels = results[0].get("text_labels") or results[0].get("labels", [])
815
- if torch.is_tensor(det_labels):
816
- det_labels = det_labels.detach().cpu().tolist()
817
- det_labels = [str(l) for l in det_labels]
 
818
 
819
- if input_boxes.shape[0] == 0:
820
  logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
821
  # Fill empty results for this segment
822
  seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
@@ -842,7 +824,7 @@ class GroundedSAM2Segmenter(Segmenter):
842
  _t_si = time.perf_counter()
843
 
844
  self._image_predictor.set_image(np.array(image))
845
- masks, scores = self._predict_masks_gpu(input_boxes)
846
 
847
  if _pm is not None:
848
  _pl = getattr(self, '_perf_lock', None)
 
1
  """Grounded-SAM-2 segmenter with continuous-ID video tracking.
2
 
3
+ Combines an object detector (open-vocabulary or closed-set) with SAM2's video
4
  predictor to produce temporally consistent segmentation masks with
5
  persistent object IDs across an entire video.
6
 
 
13
  import time
14
  from contextlib import nullcontext
15
  from dataclasses import dataclass, field
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple
17
 
18
  import numpy as np
19
  import torch
 
308
  }
309
 
310
 
311
+ def _det_label_names(det) -> List[str]:
312
+ """Extract string labels from a DetectionResult, with fallback."""
313
+ num_boxes = len(det.boxes) if det.boxes is not None else 0
314
+ if det.label_names is not None and len(det.label_names) > 0:
315
+ return list(det.label_names)
316
+ if det.labels is not None and len(det.labels) > 0:
317
+ return [str(l) for l in det.labels]
318
+ return ["object"] * num_boxes
319
+
320
+
321
  # ---------------------------------------------------------------------------
322
  # Grounded-SAM-2 Segmenter
323
  # ---------------------------------------------------------------------------
324
 
325
  class GroundedSAM2Segmenter(Segmenter):
326
+ """SAM2 video segmenter driven by an injected object detector.
327
 
328
+ Any ``ObjectDetector`` can be used (defaults to Grounding DINO).
329
+ For single-frame mode (``predict``), uses detector + SAM2 image predictor.
330
+ For video mode (``process_video``), uses detector on keyframes + SAM2 video
331
  predictor for temporal mask propagation with continuous object IDs.
332
  """
333
 
 
341
  step: int = 20,
342
  iou_threshold: float = 0.5,
343
  num_maskmem: Optional[int] = None,
344
+ detector_name: Optional[str] = None,
345
  ):
346
  self.model_size = model_size
347
  self.step = step
348
  self.iou_threshold = iou_threshold
349
  self.num_maskmem = num_maskmem # None = use default (7)
350
+ self._detector_name = detector_name # None = "grounding_dino"
351
+ _size_suffix = {"small": "S", "base": "B", "large": "L"}
352
+ self.name = f"GSAM2-{_size_suffix[model_size]}"
353
 
354
  if device:
355
  self.device = device
 
359
  # Lazy-loaded model handles
360
  self._video_predictor = None
361
  self._image_predictor = None
362
+ self._detector = None
363
  self._models_loaded = False
364
 
365
  # -- Lazy loading -------------------------------------------------------
 
402
  self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
403
  logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
404
 
405
+ # Load detector by name (defaults to Grounding DINO)
406
+ from models.model_loader import load_detector_on_device
407
 
408
+ det_name = self._detector_name or "grounding_dino"
409
+ self._detector = load_detector_on_device(det_name, self.device)
410
 
411
  self._models_loaded = True
412
  logging.info("Grounded-SAM-2 models loaded successfully.")
 
491
  def predict(
492
  self, frame: np.ndarray, text_prompts: Optional[list] = None
493
  ) -> SegmentationResult:
494
+ """Run detector + SAM2 image predictor on a single frame."""
495
  self._ensure_models_loaded()
496
 
497
  prompts = text_prompts or ["object"]
498
 
499
+ # Run detector to get boxes
500
+ det = self._detector.predict(frame, prompts)
501
  if det.boxes is None or len(det.boxes) == 0:
502
  return SegmentationResult(
503
  masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
 
554
  image: "Image",
555
  text_prompts: List[str],
556
  ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]:
557
+ """Run detector + SAM2 image predictor on a single keyframe.
558
 
559
  Args:
560
  image: PIL Image in RGB mode.
561
+ text_prompts: Text queries for the detector.
562
 
563
  Returns:
564
  ``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
 
569
  self._ensure_models_loaded()
570
  _pm = getattr(self, '_perf_metrics', None)
571
 
 
 
 
 
572
  if _pm is not None:
573
  _t0 = time.perf_counter()
574
 
575
+ # Convert PIL RGB → numpy BGR for detector.predict()
576
+ frame_bgr = np.array(image)[:, :, ::-1].copy()
577
+ det = self._detector.predict(frame_bgr, text_prompts)
 
 
 
 
 
 
 
 
 
 
578
 
579
  if _pm is not None:
580
  _pl = getattr(self, '_perf_lock', None)
 
584
  else:
585
  _pm["gdino_total_ms"] += _d
586
 
587
+ if det.boxes is None or len(det.boxes) == 0:
 
 
 
 
 
 
588
  return None, None, []
589
 
590
+ input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
591
+ det_labels = _det_label_names(det)
592
+
593
  # SAM2 image predictor
594
  if _pm is not None:
595
  _t1 = time.perf_counter()
596
 
597
  self._image_predictor.set_image(np.array(image))
598
+ masks, _ = self._predict_masks_gpu(input_boxes)
599
 
600
  if _pm is not None:
601
  _pl = getattr(self, '_perf_lock', None)
 
719
  Args:
720
  frame_dir: Directory containing JPEG frames.
721
  frame_names: Sorted list of frame filenames.
722
+ text_prompts: Text queries for the detector.
723
  on_segment: Optional callback invoked after each segment completes.
724
  Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
725
 
 
733
 
734
  device = self.device
735
  step = self.step
 
 
 
 
 
736
 
737
  total_frames = len(frame_names)
738
  logging.info(
 
776
 
777
  mask_dict = MaskDictionary()
778
 
779
+ # -- Detector on keyframe --
780
  if _pm is not None:
781
  _t_gd = time.perf_counter()
782
 
783
+ frame_bgr = np.array(image)[:, :, ::-1].copy()
784
+ det = self._detector.predict(frame_bgr, text_prompts)
 
 
 
 
 
 
 
 
 
 
 
 
785
 
786
  if _pm is not None:
787
  _pl = getattr(self, '_perf_lock', None)
 
791
  else:
792
  _pm["gdino_total_ms"] += _d
793
 
794
+ if det.boxes is None or len(det.boxes) == 0:
795
+ input_boxes = torch.zeros((0, 4), device=device)
796
+ det_labels = []
797
+ else:
798
+ input_boxes = torch.tensor(det.boxes, device=device, dtype=torch.float32)
799
+ det_labels = _det_label_names(det)
800
 
801
+ if len(input_boxes) == 0:
802
  logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
803
  # Fill empty results for this segment
804
  seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
 
824
  _t_si = time.perf_counter()
825
 
826
  self._image_predictor.set_image(np.array(image))
827
+ masks, _ = self._predict_masks_gpu(input_boxes)
828
 
829
  if _pm is not None:
830
  _pl = getattr(self, '_perf_lock', None)
models/segmenters/model_loader.py CHANGED
@@ -5,12 +5,12 @@ from typing import Callable, Dict, Optional
5
  from .base import Segmenter
6
  from .grounded_sam2 import GroundedSAM2Segmenter
7
 
8
- DEFAULT_SEGMENTER = "gsam2_large"
9
 
10
  _REGISTRY: Dict[str, Callable[..., Segmenter]] = {
11
- "gsam2_small": lambda **kw: GroundedSAM2Segmenter(model_size="small", **kw),
12
- "gsam2_base": lambda **kw: GroundedSAM2Segmenter(model_size="base", **kw),
13
- "gsam2_large": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
14
  }
15
 
16
 
@@ -37,7 +37,7 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
37
  Load a segmenter by name.
38
 
39
  Args:
40
- name: Segmenter name (default: gsam2_large)
41
 
42
  Returns:
43
  Cached segmenter instance
 
5
  from .base import Segmenter
6
  from .grounded_sam2 import GroundedSAM2Segmenter
7
 
8
+ DEFAULT_SEGMENTER = "GSAM2-L"
9
 
10
  _REGISTRY: Dict[str, Callable[..., Segmenter]] = {
11
+ "GSAM2-S": lambda **kw: GroundedSAM2Segmenter(model_size="small", **kw),
12
+ "GSAM2-B": lambda **kw: GroundedSAM2Segmenter(model_size="base", **kw),
13
+ "GSAM2-L": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
14
  }
15
 
16
 
 
37
  Load a segmenter by name.
38
 
39
  Args:
40
+ name: Segmenter name (default: GSAM2-L)
41
 
42
  Returns:
43
  Cached segmenter instance
utils/roofline.py CHANGED
@@ -21,9 +21,9 @@ _MODEL_FLOPS: Dict[str, float] = {
21
  "drone_yolo": 78.9, # Same arch as YOLOv8m
22
 
23
  # Segmentation models (GFLOPs per keyframe)
24
- "gsam2_small": 48.0, # SAM2 small encoder
25
- "gsam2_base": 96.0, # SAM2 base encoder
26
- "gsam2_large": 200.0, # SAM2 large encoder
27
  "gsam2_tiny": 24.0, # SAM2 tiny encoder
28
  }
29
 
@@ -34,9 +34,9 @@ _MODEL_BYTES: Dict[str, float] = {
34
  "detr_resnet50": 166.0,
35
  "grounding_dino": 340.0,
36
  "drone_yolo": 52.0,
37
- "gsam2_small": 92.0,
38
- "gsam2_base": 180.0,
39
- "gsam2_large": 400.0,
40
  "gsam2_tiny": 46.0,
41
  }
42
 
 
21
  "drone_yolo": 78.9, # Same arch as YOLOv8m
22
 
23
  # Segmentation models (GFLOPs per keyframe)
24
+ "GSAM2-S": 48.0, # SAM2 small encoder
25
+ "GSAM2-B": 96.0, # SAM2 base encoder
26
+ "GSAM2-L": 200.0, # SAM2 large encoder
27
  "gsam2_tiny": 24.0, # SAM2 tiny encoder
28
  }
29
 
 
34
  "detr_resnet50": 166.0,
35
  "grounding_dino": 340.0,
36
  "drone_yolo": 52.0,
37
+ "GSAM2-S": 92.0,
38
+ "GSAM2-B": 180.0,
39
+ "GSAM2-L": 400.0,
40
  "gsam2_tiny": 46.0,
41
  }
42