Zhen Ye Claude Opus 4.6 commited on
Commit
1c6c619
·
1 Parent(s): 5749bd6

feat: add num_maskmem parameter to /benchmark endpoint

Browse files

Thread num_maskmem override through the full pipeline so benchmark
runs can test reduced memory bank sizes (e.g. 3 instead of default 7).
Patches the SAM2 video predictor's maskmem_tpos_enc at runtime.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

app.py CHANGED
@@ -32,6 +32,7 @@ import uuid
32
  from contextlib import asynccontextmanager
33
  from datetime import timedelta
34
  from pathlib import Path
 
35
 
36
  import cv2
37
  import numpy as np
@@ -857,6 +858,7 @@ async def benchmark_endpoint(
857
  queries: str = Form("person,car,truck"),
858
  segmenter: str = Form("gsam2_large"),
859
  step: int = Form(20),
 
860
  ):
861
  """Run instrumented GSAM2 pipeline and return latency breakdown JSON.
862
 
@@ -898,6 +900,7 @@ async def benchmark_endpoint(
898
  enable_gpt=False,
899
  _perf_metrics=metrics,
900
  _perf_lock=lock,
 
901
  )
902
 
903
  # Read frame count and fps from output video
@@ -915,6 +918,7 @@ async def benchmark_endpoint(
915
  "total_frames": total_frames,
916
  "fps": fps,
917
  "num_gpus": num_gpus,
 
918
  "metrics": metrics,
919
  })
920
 
 
32
  from contextlib import asynccontextmanager
33
  from datetime import timedelta
34
  from pathlib import Path
35
+ from typing import Optional
36
 
37
  import cv2
38
  import numpy as np
 
858
  queries: str = Form("person,car,truck"),
859
  segmenter: str = Form("gsam2_large"),
860
  step: int = Form(20),
861
+ num_maskmem: Optional[int] = Form(None),
862
  ):
863
  """Run instrumented GSAM2 pipeline and return latency breakdown JSON.
864
 
 
900
  enable_gpt=False,
901
  _perf_metrics=metrics,
902
  _perf_lock=lock,
903
+ num_maskmem=num_maskmem,
904
  )
905
 
906
  # Read frame count and fps from output video
 
918
  "total_frames": total_frames,
919
  "fps": fps,
920
  "num_gpus": num_gpus,
921
+ "num_maskmem": num_maskmem if num_maskmem is not None else 7,
922
  "metrics": metrics,
923
  })
924
 
inference.py CHANGED
@@ -1631,6 +1631,7 @@ def run_grounded_sam2_tracking(
1631
  first_frame_gpt_results: Optional[Dict[str, Any]] = None,
1632
  _perf_metrics: Optional[Dict[str, float]] = None,
1633
  _perf_lock=None,
 
1634
  ) -> str:
1635
  """Run Grounded-SAM-2 video tracking pipeline.
1636
 
@@ -1679,7 +1680,8 @@ def run_grounded_sam2_tracking(
1679
  if num_gpus <= 1:
1680
  # ---------- Single-GPU fallback ----------
1681
  device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
1682
- segmenter = load_segmenter_on_device(active_segmenter, device_str)
 
1683
  _check_cancellation(job_id)
1684
 
1685
  if _perf_metrics is not None:
@@ -1710,11 +1712,13 @@ def run_grounded_sam2_tracking(
1710
  # Phase 1: Load one segmenter per GPU (parallel)
1711
  segmenters = []
1712
  with ThreadPoolExecutor(max_workers=num_gpus) as pool:
 
1713
  futs = [
1714
  pool.submit(
1715
  load_segmenter_on_device,
1716
  active_segmenter,
1717
  f"cuda:{i}",
 
1718
  )
1719
  for i in range(num_gpus)
1720
  ]
 
1631
  first_frame_gpt_results: Optional[Dict[str, Any]] = None,
1632
  _perf_metrics: Optional[Dict[str, float]] = None,
1633
  _perf_lock=None,
1634
+ num_maskmem: Optional[int] = None,
1635
  ) -> str:
1636
  """Run Grounded-SAM-2 video tracking pipeline.
1637
 
 
1680
  if num_gpus <= 1:
1681
  # ---------- Single-GPU fallback ----------
1682
  device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
1683
+ _seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
1684
+ segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
1685
  _check_cancellation(job_id)
1686
 
1687
  if _perf_metrics is not None:
 
1712
  # Phase 1: Load one segmenter per GPU (parallel)
1713
  segmenters = []
1714
  with ThreadPoolExecutor(max_workers=num_gpus) as pool:
1715
+ _seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
1716
  futs = [
1717
  pool.submit(
1718
  load_segmenter_on_device,
1719
  active_segmenter,
1720
  f"cuda:{i}",
1721
+ **_seg_kw_multi,
1722
  )
1723
  for i in range(num_gpus)
1724
  ]
models/segmenters/grounded_sam2.py CHANGED
@@ -318,10 +318,12 @@ class GroundedSAM2Segmenter(Segmenter):
318
  device: Optional[str] = None,
319
  step: int = 20,
320
  iou_threshold: float = 0.5,
 
321
  ):
322
  self.model_size = model_size
323
  self.step = step
324
  self.iou_threshold = iou_threshold
 
325
  self.name = f"gsam2_{model_size}"
326
 
327
  if device:
@@ -370,6 +372,11 @@ class GroundedSAM2Segmenter(Segmenter):
370
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
371
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
372
 
 
 
 
 
 
373
  # Reuse existing Grounding DINO detector from our codebase
374
  from models.detectors.grounding_dino import GroundingDinoDetector
375
 
@@ -378,6 +385,42 @@ class GroundedSAM2Segmenter(Segmenter):
378
  self._models_loaded = True
379
  logging.info("Grounded-SAM-2 models loaded successfully.")
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  # -- Single-frame interface (Segmenter.predict) -------------------------
382
 
383
  def predict(
 
318
  device: Optional[str] = None,
319
  step: int = 20,
320
  iou_threshold: float = 0.5,
321
+ num_maskmem: Optional[int] = None,
322
  ):
323
  self.model_size = model_size
324
  self.step = step
325
  self.iou_threshold = iou_threshold
326
+ self.num_maskmem = num_maskmem # None = use default (7)
327
  self.name = f"gsam2_{model_size}"
328
 
329
  if device:
 
372
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
373
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
374
 
375
+ # Override num_maskmem if requested
376
+ if self.num_maskmem is not None:
377
+ self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
378
+ logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
379
+
380
  # Reuse existing Grounding DINO detector from our codebase
381
  from models.detectors.grounding_dino import GroundingDinoDetector
382
 
 
385
  self._models_loaded = True
386
  logging.info("Grounded-SAM-2 models loaded successfully.")
387
 
388
+ @staticmethod
389
+ def _patch_num_maskmem(predictor, num_maskmem: int):
390
+ """Override num_maskmem on a loaded SAM2 video predictor at runtime.
391
+
392
+ Slices the temporal positional encoding parameter to match the new
393
+ memory size so the model runs without shape mismatches.
394
+ """
395
+ import torch.nn as nn
396
+
397
+ # The underlying model may be predictor itself or predictor.model
398
+ model = getattr(predictor, "model", predictor)
399
+ old = getattr(model, "num_maskmem", None)
400
+ if old is None:
401
+ logging.warning("Cannot patch num_maskmem: attribute not found on model")
402
+ return
403
+ if num_maskmem == old:
404
+ return
405
+ model.num_maskmem = num_maskmem
406
+ # Slice or pad maskmem_tpos_enc (shape: [num_maskmem, 1, 1, mem_dim])
407
+ if hasattr(model, "maskmem_tpos_enc") and model.maskmem_tpos_enc is not None:
408
+ old_enc = model.maskmem_tpos_enc
409
+ if num_maskmem <= old_enc.shape[0]:
410
+ model.maskmem_tpos_enc = nn.Parameter(
411
+ old_enc[:num_maskmem].clone()
412
+ )
413
+ else:
414
+ # Pad with zeros for the extra slots
415
+ pad = torch.zeros(
416
+ num_maskmem - old_enc.shape[0], *old_enc.shape[1:],
417
+ device=old_enc.device, dtype=old_enc.dtype,
418
+ )
419
+ model.maskmem_tpos_enc = nn.Parameter(
420
+ torch.cat([old_enc, pad], dim=0)
421
+ )
422
+ logging.info("num_maskmem changed from %d to %d", old, num_maskmem)
423
+
424
  # -- Single-frame interface (Segmenter.predict) -------------------------
425
 
426
  def predict(
models/segmenters/model_loader.py CHANGED
@@ -46,6 +46,6 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
46
  return _get_cached_segmenter(segmenter_name)
47
 
48
 
49
- def load_segmenter_on_device(name: str, device: str) -> Segmenter:
50
  """Create a new segmenter instance on the specified device (no caching)."""
51
- return _create_segmenter(name, device=device)
 
46
  return _get_cached_segmenter(segmenter_name)
47
 
48
 
49
+ def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter:
50
  """Create a new segmenter instance on the specified device (no caching)."""
51
+ return _create_segmenter(name, device=device, **kwargs)