Spaces:
Running
Running
Zhen Ye
Claude Opus 4.6
commited on
Commit
·
1c6c619
1
Parent(s):
5749bd6
feat: add num_maskmem parameter to /benchmark endpoint
Browse filesThread num_maskmem override through the full pipeline so benchmark
runs can test reduced memory bank sizes (e.g. 3 instead of default 7).
Patches the SAM2 video predictor's maskmem_tpos_enc at runtime.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app.py +4 -0
- inference.py +5 -1
- models/segmenters/grounded_sam2.py +43 -0
- models/segmenters/model_loader.py +2 -2
app.py
CHANGED
|
@@ -32,6 +32,7 @@ import uuid
|
|
| 32 |
from contextlib import asynccontextmanager
|
| 33 |
from datetime import timedelta
|
| 34 |
from pathlib import Path
|
|
|
|
| 35 |
|
| 36 |
import cv2
|
| 37 |
import numpy as np
|
|
@@ -857,6 +858,7 @@ async def benchmark_endpoint(
|
|
| 857 |
queries: str = Form("person,car,truck"),
|
| 858 |
segmenter: str = Form("gsam2_large"),
|
| 859 |
step: int = Form(20),
|
|
|
|
| 860 |
):
|
| 861 |
"""Run instrumented GSAM2 pipeline and return latency breakdown JSON.
|
| 862 |
|
|
@@ -898,6 +900,7 @@ async def benchmark_endpoint(
|
|
| 898 |
enable_gpt=False,
|
| 899 |
_perf_metrics=metrics,
|
| 900 |
_perf_lock=lock,
|
|
|
|
| 901 |
)
|
| 902 |
|
| 903 |
# Read frame count and fps from output video
|
|
@@ -915,6 +918,7 @@ async def benchmark_endpoint(
|
|
| 915 |
"total_frames": total_frames,
|
| 916 |
"fps": fps,
|
| 917 |
"num_gpus": num_gpus,
|
|
|
|
| 918 |
"metrics": metrics,
|
| 919 |
})
|
| 920 |
|
|
|
|
| 32 |
from contextlib import asynccontextmanager
|
| 33 |
from datetime import timedelta
|
| 34 |
from pathlib import Path
|
| 35 |
+
from typing import Optional
|
| 36 |
|
| 37 |
import cv2
|
| 38 |
import numpy as np
|
|
|
|
| 858 |
queries: str = Form("person,car,truck"),
|
| 859 |
segmenter: str = Form("gsam2_large"),
|
| 860 |
step: int = Form(20),
|
| 861 |
+
num_maskmem: Optional[int] = Form(None),
|
| 862 |
):
|
| 863 |
"""Run instrumented GSAM2 pipeline and return latency breakdown JSON.
|
| 864 |
|
|
|
|
| 900 |
enable_gpt=False,
|
| 901 |
_perf_metrics=metrics,
|
| 902 |
_perf_lock=lock,
|
| 903 |
+
num_maskmem=num_maskmem,
|
| 904 |
)
|
| 905 |
|
| 906 |
# Read frame count and fps from output video
|
|
|
|
| 918 |
"total_frames": total_frames,
|
| 919 |
"fps": fps,
|
| 920 |
"num_gpus": num_gpus,
|
| 921 |
+
"num_maskmem": num_maskmem if num_maskmem is not None else 7,
|
| 922 |
"metrics": metrics,
|
| 923 |
})
|
| 924 |
|
inference.py
CHANGED
|
@@ -1631,6 +1631,7 @@ def run_grounded_sam2_tracking(
|
|
| 1631 |
first_frame_gpt_results: Optional[Dict[str, Any]] = None,
|
| 1632 |
_perf_metrics: Optional[Dict[str, float]] = None,
|
| 1633 |
_perf_lock=None,
|
|
|
|
| 1634 |
) -> str:
|
| 1635 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1636 |
|
|
@@ -1679,7 +1680,8 @@ def run_grounded_sam2_tracking(
|
|
| 1679 |
if num_gpus <= 1:
|
| 1680 |
# ---------- Single-GPU fallback ----------
|
| 1681 |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1682 |
-
|
|
|
|
| 1683 |
_check_cancellation(job_id)
|
| 1684 |
|
| 1685 |
if _perf_metrics is not None:
|
|
@@ -1710,11 +1712,13 @@ def run_grounded_sam2_tracking(
|
|
| 1710 |
# Phase 1: Load one segmenter per GPU (parallel)
|
| 1711 |
segmenters = []
|
| 1712 |
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
|
|
|
| 1713 |
futs = [
|
| 1714 |
pool.submit(
|
| 1715 |
load_segmenter_on_device,
|
| 1716 |
active_segmenter,
|
| 1717 |
f"cuda:{i}",
|
|
|
|
| 1718 |
)
|
| 1719 |
for i in range(num_gpus)
|
| 1720 |
]
|
|
|
|
| 1631 |
first_frame_gpt_results: Optional[Dict[str, Any]] = None,
|
| 1632 |
_perf_metrics: Optional[Dict[str, float]] = None,
|
| 1633 |
_perf_lock=None,
|
| 1634 |
+
num_maskmem: Optional[int] = None,
|
| 1635 |
) -> str:
|
| 1636 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1637 |
|
|
|
|
| 1680 |
if num_gpus <= 1:
|
| 1681 |
# ---------- Single-GPU fallback ----------
|
| 1682 |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1683 |
+
_seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 1684 |
+
segmenter = load_segmenter_on_device(active_segmenter, device_str, **_seg_kw)
|
| 1685 |
_check_cancellation(job_id)
|
| 1686 |
|
| 1687 |
if _perf_metrics is not None:
|
|
|
|
| 1712 |
# Phase 1: Load one segmenter per GPU (parallel)
|
| 1713 |
segmenters = []
|
| 1714 |
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 1715 |
+
_seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 1716 |
futs = [
|
| 1717 |
pool.submit(
|
| 1718 |
load_segmenter_on_device,
|
| 1719 |
active_segmenter,
|
| 1720 |
f"cuda:{i}",
|
| 1721 |
+
**_seg_kw_multi,
|
| 1722 |
)
|
| 1723 |
for i in range(num_gpus)
|
| 1724 |
]
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -318,10 +318,12 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 318 |
device: Optional[str] = None,
|
| 319 |
step: int = 20,
|
| 320 |
iou_threshold: float = 0.5,
|
|
|
|
| 321 |
):
|
| 322 |
self.model_size = model_size
|
| 323 |
self.step = step
|
| 324 |
self.iou_threshold = iou_threshold
|
|
|
|
| 325 |
self.name = f"gsam2_{model_size}"
|
| 326 |
|
| 327 |
if device:
|
|
@@ -370,6 +372,11 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 370 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 371 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# Reuse existing Grounding DINO detector from our codebase
|
| 374 |
from models.detectors.grounding_dino import GroundingDinoDetector
|
| 375 |
|
|
@@ -378,6 +385,42 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 378 |
self._models_loaded = True
|
| 379 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 382 |
|
| 383 |
def predict(
|
|
|
|
| 318 |
device: Optional[str] = None,
|
| 319 |
step: int = 20,
|
| 320 |
iou_threshold: float = 0.5,
|
| 321 |
+
num_maskmem: Optional[int] = None,
|
| 322 |
):
|
| 323 |
self.model_size = model_size
|
| 324 |
self.step = step
|
| 325 |
self.iou_threshold = iou_threshold
|
| 326 |
+
self.num_maskmem = num_maskmem # None = use default (7)
|
| 327 |
self.name = f"gsam2_{model_size}"
|
| 328 |
|
| 329 |
if device:
|
|
|
|
| 372 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 373 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
| 374 |
|
| 375 |
+
# Override num_maskmem if requested
|
| 376 |
+
if self.num_maskmem is not None:
|
| 377 |
+
self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
|
| 378 |
+
logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
|
| 379 |
+
|
| 380 |
# Reuse existing Grounding DINO detector from our codebase
|
| 381 |
from models.detectors.grounding_dino import GroundingDinoDetector
|
| 382 |
|
|
|
|
| 385 |
self._models_loaded = True
|
| 386 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 387 |
|
| 388 |
+
@staticmethod
|
| 389 |
+
def _patch_num_maskmem(predictor, num_maskmem: int):
|
| 390 |
+
"""Override num_maskmem on a loaded SAM2 video predictor at runtime.
|
| 391 |
+
|
| 392 |
+
Slices the temporal positional encoding parameter to match the new
|
| 393 |
+
memory size so the model runs without shape mismatches.
|
| 394 |
+
"""
|
| 395 |
+
import torch.nn as nn
|
| 396 |
+
|
| 397 |
+
# The underlying model may be predictor itself or predictor.model
|
| 398 |
+
model = getattr(predictor, "model", predictor)
|
| 399 |
+
old = getattr(model, "num_maskmem", None)
|
| 400 |
+
if old is None:
|
| 401 |
+
logging.warning("Cannot patch num_maskmem: attribute not found on model")
|
| 402 |
+
return
|
| 403 |
+
if num_maskmem == old:
|
| 404 |
+
return
|
| 405 |
+
model.num_maskmem = num_maskmem
|
| 406 |
+
# Slice or pad maskmem_tpos_enc (shape: [num_maskmem, 1, 1, mem_dim])
|
| 407 |
+
if hasattr(model, "maskmem_tpos_enc") and model.maskmem_tpos_enc is not None:
|
| 408 |
+
old_enc = model.maskmem_tpos_enc
|
| 409 |
+
if num_maskmem <= old_enc.shape[0]:
|
| 410 |
+
model.maskmem_tpos_enc = nn.Parameter(
|
| 411 |
+
old_enc[:num_maskmem].clone()
|
| 412 |
+
)
|
| 413 |
+
else:
|
| 414 |
+
# Pad with zeros for the extra slots
|
| 415 |
+
pad = torch.zeros(
|
| 416 |
+
num_maskmem - old_enc.shape[0], *old_enc.shape[1:],
|
| 417 |
+
device=old_enc.device, dtype=old_enc.dtype,
|
| 418 |
+
)
|
| 419 |
+
model.maskmem_tpos_enc = nn.Parameter(
|
| 420 |
+
torch.cat([old_enc, pad], dim=0)
|
| 421 |
+
)
|
| 422 |
+
logging.info("num_maskmem changed from %d to %d", old, num_maskmem)
|
| 423 |
+
|
| 424 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 425 |
|
| 426 |
def predict(
|
models/segmenters/model_loader.py
CHANGED
|
@@ -46,6 +46,6 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
|
|
| 46 |
return _get_cached_segmenter(segmenter_name)
|
| 47 |
|
| 48 |
|
| 49 |
-
def load_segmenter_on_device(name: str, device: str) -> Segmenter:
|
| 50 |
"""Create a new segmenter instance on the specified device (no caching)."""
|
| 51 |
-
return _create_segmenter(name, device=device)
|
|
|
|
| 46 |
return _get_cached_segmenter(segmenter_name)
|
| 47 |
|
| 48 |
|
| 49 |
+
def load_segmenter_on_device(name: str, device: str, **kwargs) -> Segmenter:
|
| 50 |
"""Create a new segmenter instance on the specified device (no caching)."""
|
| 51 |
+
return _create_segmenter(name, device=device, **kwargs)
|