Spaces:
Paused
Paused
Zhen Ye commited on
Commit ·
21c29ae
1
Parent(s): 97b3a45
refactor(gsam2): make SAM2 detector-agnostic
Browse files- app.py +7 -7
- frontend/index.html +3 -3
- frontend/js/main.js +2 -2
- inference.py +6 -1
- jobs/background.py +3 -4
- models/segmenters/grounded_sam2.py +50 -68
- models/segmenters/model_loader.py +5 -5
- utils/roofline.py +6 -6
app.py
CHANGED
|
@@ -248,7 +248,7 @@ async def detect_endpoint(
|
|
| 248 |
mode: str = Form(...),
|
| 249 |
queries: str = Form(""),
|
| 250 |
detector: str = Form("hf_yolov8"),
|
| 251 |
-
segmenter: str = Form("
|
| 252 |
enable_depth: bool = Form(False),
|
| 253 |
enable_gpt: bool = Form(True),
|
| 254 |
):
|
|
@@ -260,7 +260,7 @@ async def detect_endpoint(
|
|
| 260 |
mode: Detection mode (object_detection, segmentation, drone_detection)
|
| 261 |
queries: Comma-separated object classes for object_detection mode
|
| 262 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 263 |
-
segmenter: Segmentation model to use (
|
| 264 |
enable_depth: Whether to run legacy depth estimation (default: False)
|
| 265 |
drone_detection uses the dedicated drone_yolo model.
|
| 266 |
|
|
@@ -302,6 +302,7 @@ async def detect_endpoint(
|
|
| 302 |
output_path,
|
| 303 |
query_list,
|
| 304 |
segmenter_name=segmenter,
|
|
|
|
| 305 |
num_maskmem=7,
|
| 306 |
)
|
| 307 |
except ValueError as exc:
|
|
@@ -402,7 +403,7 @@ async def detect_async_endpoint(
|
|
| 402 |
mode: str = Form(...),
|
| 403 |
queries: str = Form(""),
|
| 404 |
detector: str = Form("hf_yolov8"),
|
| 405 |
-
segmenter: str = Form("
|
| 406 |
depth_estimator: str = Form("depth"),
|
| 407 |
depth_scale: float = Form(25.0),
|
| 408 |
enable_depth: bool = Form(False),
|
|
@@ -491,7 +492,6 @@ async def detect_async_endpoint(
|
|
| 491 |
)
|
| 492 |
cv2.imwrite(str(first_frame_path), processed_frame)
|
| 493 |
# GPT and depth are now handled in the async pipeline (enrichment thread)
|
| 494 |
-
depth_map = None
|
| 495 |
first_frame_gpt_results = None
|
| 496 |
except Exception:
|
| 497 |
logging.exception("First-frame processing failed.")
|
|
@@ -910,7 +910,7 @@ async def chat_threat_endpoint(
|
|
| 910 |
async def benchmark_endpoint(
|
| 911 |
video: UploadFile = File(...),
|
| 912 |
queries: str = Form("person,car,truck"),
|
| 913 |
-
segmenter: str = Form("
|
| 914 |
step: int = Form(60),
|
| 915 |
num_maskmem: Optional[int] = Form(None),
|
| 916 |
):
|
|
@@ -1036,7 +1036,7 @@ async def benchmark_profile(
|
|
| 1036 |
video: UploadFile = File(...),
|
| 1037 |
mode: str = Form("detection"),
|
| 1038 |
detector: str = Form("hf_yolov8"),
|
| 1039 |
-
segmenter: str = Form("
|
| 1040 |
queries: str = Form("person,car,truck"),
|
| 1041 |
max_frames: int = Form(100),
|
| 1042 |
warmup_frames: int = Form(5),
|
|
@@ -1102,7 +1102,7 @@ async def benchmark_analysis(
|
|
| 1102 |
video: UploadFile = File(...),
|
| 1103 |
mode: str = Form("detection"),
|
| 1104 |
detector: str = Form("hf_yolov8"),
|
| 1105 |
-
segmenter: str = Form("
|
| 1106 |
queries: str = Form("person,car,truck"),
|
| 1107 |
max_frames: int = Form(100),
|
| 1108 |
warmup_frames: int = Form(5),
|
|
|
|
| 248 |
mode: str = Form(...),
|
| 249 |
queries: str = Form(""),
|
| 250 |
detector: str = Form("hf_yolov8"),
|
| 251 |
+
segmenter: str = Form("GSAM2-L"),
|
| 252 |
enable_depth: bool = Form(False),
|
| 253 |
enable_gpt: bool = Form(True),
|
| 254 |
):
|
|
|
|
| 260 |
mode: Detection mode (object_detection, segmentation, drone_detection)
|
| 261 |
queries: Comma-separated object classes for object_detection mode
|
| 262 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 263 |
+
segmenter: Segmentation model to use (GSAM2-S, GSAM2-B, GSAM2-L)
|
| 264 |
enable_depth: Whether to run legacy depth estimation (default: False)
|
| 265 |
drone_detection uses the dedicated drone_yolo model.
|
| 266 |
|
|
|
|
| 302 |
output_path,
|
| 303 |
query_list,
|
| 304 |
segmenter_name=segmenter,
|
| 305 |
+
detector_name="grounding_dino",
|
| 306 |
num_maskmem=7,
|
| 307 |
)
|
| 308 |
except ValueError as exc:
|
|
|
|
| 403 |
mode: str = Form(...),
|
| 404 |
queries: str = Form(""),
|
| 405 |
detector: str = Form("hf_yolov8"),
|
| 406 |
+
segmenter: str = Form("GSAM2-L"),
|
| 407 |
depth_estimator: str = Form("depth"),
|
| 408 |
depth_scale: float = Form(25.0),
|
| 409 |
enable_depth: bool = Form(False),
|
|
|
|
| 492 |
)
|
| 493 |
cv2.imwrite(str(first_frame_path), processed_frame)
|
| 494 |
# GPT and depth are now handled in the async pipeline (enrichment thread)
|
|
|
|
| 495 |
first_frame_gpt_results = None
|
| 496 |
except Exception:
|
| 497 |
logging.exception("First-frame processing failed.")
|
|
|
|
| 910 |
async def benchmark_endpoint(
|
| 911 |
video: UploadFile = File(...),
|
| 912 |
queries: str = Form("person,car,truck"),
|
| 913 |
+
segmenter: str = Form("GSAM2-L"),
|
| 914 |
step: int = Form(60),
|
| 915 |
num_maskmem: Optional[int] = Form(None),
|
| 916 |
):
|
|
|
|
| 1036 |
video: UploadFile = File(...),
|
| 1037 |
mode: str = Form("detection"),
|
| 1038 |
detector: str = Form("hf_yolov8"),
|
| 1039 |
+
segmenter: str = Form("GSAM2-L"),
|
| 1040 |
queries: str = Form("person,car,truck"),
|
| 1041 |
max_frames: int = Form(100),
|
| 1042 |
warmup_frames: int = Form(5),
|
|
|
|
| 1102 |
video: UploadFile = File(...),
|
| 1103 |
mode: str = Form("detection"),
|
| 1104 |
detector: str = Form("hf_yolov8"),
|
| 1105 |
+
segmenter: str = Form("GSAM2-L"),
|
| 1106 |
queries: str = Form("person,car,truck"),
|
| 1107 |
max_frames: int = Form(100),
|
| 1108 |
warmup_frames: int = Form(5),
|
frontend/index.html
CHANGED
|
@@ -75,9 +75,9 @@
|
|
| 75 |
<option value="grounding_dino" data-kind="object">Large</option>
|
| 76 |
</optgroup>
|
| 77 |
<optgroup label="Segmentation Models">
|
| 78 |
-
<option value="
|
| 79 |
-
<option value="
|
| 80 |
-
<option value="
|
| 81 |
</optgroup>
|
| 82 |
<optgroup label="Drone Detection Models">
|
| 83 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
|
|
|
| 75 |
<option value="grounding_dino" data-kind="object">Large</option>
|
| 76 |
</optgroup>
|
| 77 |
<optgroup label="Segmentation Models">
|
| 78 |
+
<option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
|
| 79 |
+
<option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
|
| 80 |
+
<option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
|
| 81 |
</optgroup>
|
| 82 |
<optgroup label="Drone Detection Models">
|
| 83 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
frontend/js/main.js
CHANGED
|
@@ -363,11 +363,11 @@ document.addEventListener("DOMContentLoaded", () => {
|
|
| 363 |
} else if (kind === "drone") {
|
| 364 |
mode = "drone_detection";
|
| 365 |
detectorParam = selectedValue;
|
| 366 |
-
segmenterParam = "
|
| 367 |
} else {
|
| 368 |
mode = "object_detection";
|
| 369 |
detectorParam = selectedValue;
|
| 370 |
-
segmenterParam = "
|
| 371 |
}
|
| 372 |
|
| 373 |
const form = new FormData();
|
|
|
|
| 363 |
} else if (kind === "drone") {
|
| 364 |
mode = "drone_detection";
|
| 365 |
detectorParam = selectedValue;
|
| 366 |
+
segmenterParam = "GSAM2-L";
|
| 367 |
} else {
|
| 368 |
mode = "object_detection";
|
| 369 |
detectorParam = selectedValue;
|
| 370 |
+
segmenterParam = "GSAM2-L";
|
| 371 |
}
|
| 372 |
|
| 373 |
const form = new FormData();
|
inference.py
CHANGED
|
@@ -1631,6 +1631,7 @@ def run_grounded_sam2_tracking(
|
|
| 1631 |
_perf_metrics: Optional[Dict[str, float]] = None,
|
| 1632 |
_perf_lock=None,
|
| 1633 |
num_maskmem: Optional[int] = None,
|
|
|
|
| 1634 |
) -> str:
|
| 1635 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1636 |
|
|
@@ -1645,7 +1646,7 @@ def run_grounded_sam2_tracking(
|
|
| 1645 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1646 |
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
|
| 1647 |
|
| 1648 |
-
active_segmenter = segmenter_name or "
|
| 1649 |
logging.info(
|
| 1650 |
"Grounded-SAM-2 tracking: segmenter=%s, queries=%s, step=%d",
|
| 1651 |
active_segmenter, queries, step,
|
|
@@ -2120,6 +2121,8 @@ def run_grounded_sam2_tracking(
|
|
| 2120 |
# ---------- Single-GPU fallback ----------
|
| 2121 |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 2122 |
_seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
|
|
|
|
|
|
| 2123 |
|
| 2124 |
if _perf_metrics is not None:
|
| 2125 |
_t_load = time.perf_counter()
|
|
@@ -2176,6 +2179,8 @@ def run_grounded_sam2_tracking(
|
|
| 2176 |
segmenters = []
|
| 2177 |
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 2178 |
_seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
|
|
|
|
|
|
| 2179 |
futs = [
|
| 2180 |
pool.submit(
|
| 2181 |
load_segmenter_on_device,
|
|
|
|
| 1631 |
_perf_metrics: Optional[Dict[str, float]] = None,
|
| 1632 |
_perf_lock=None,
|
| 1633 |
num_maskmem: Optional[int] = None,
|
| 1634 |
+
detector_name: Optional[str] = None,
|
| 1635 |
) -> str:
|
| 1636 |
"""Run Grounded-SAM-2 video tracking pipeline.
|
| 1637 |
|
|
|
|
| 1646 |
from utils.video import extract_frames_to_jpeg_dir
|
| 1647 |
from models.segmenters.grounded_sam2 import MaskDictionary, ObjectInfo, LazyFrameObjects
|
| 1648 |
|
| 1649 |
+
active_segmenter = segmenter_name or "GSAM2-L"
|
| 1650 |
logging.info(
|
| 1651 |
"Grounded-SAM-2 tracking: segmenter=%s, queries=%s, step=%d",
|
| 1652 |
active_segmenter, queries, step,
|
|
|
|
| 2121 |
# ---------- Single-GPU fallback ----------
|
| 2122 |
device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 2123 |
_seg_kw = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 2124 |
+
if detector_name is not None:
|
| 2125 |
+
_seg_kw["detector_name"] = detector_name
|
| 2126 |
|
| 2127 |
if _perf_metrics is not None:
|
| 2128 |
_t_load = time.perf_counter()
|
|
|
|
| 2179 |
segmenters = []
|
| 2180 |
with ThreadPoolExecutor(max_workers=num_gpus) as pool:
|
| 2181 |
_seg_kw_multi = {"num_maskmem": num_maskmem} if num_maskmem is not None else {}
|
| 2182 |
+
if detector_name is not None:
|
| 2183 |
+
_seg_kw_multi["detector_name"] = detector_name
|
| 2184 |
futs = [
|
| 2185 |
pool.submit(
|
| 2186 |
load_segmenter_on_device,
|
jobs/background.py
CHANGED
|
@@ -2,12 +2,10 @@ import asyncio
|
|
| 2 |
import logging
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
from jobs.models import JobStatus
|
| 8 |
-
from jobs.storage import get_job_storage
|
| 9 |
from jobs.streaming import create_stream, remove_stream
|
| 10 |
-
from inference import run_inference, run_grounded_sam2_tracking
|
| 11 |
|
| 12 |
|
| 13 |
async def process_video_async(job_id: str) -> None:
|
|
@@ -41,6 +39,7 @@ async def process_video_async(job_id: str) -> None:
|
|
| 41 |
mission_spec=job.mission_spec,
|
| 42 |
first_frame_gpt_results=job.first_frame_gpt_results,
|
| 43 |
num_maskmem=7,
|
|
|
|
| 44 |
)
|
| 45 |
else:
|
| 46 |
detections_list = None
|
|
|
|
| 2 |
import logging
|
| 3 |
from datetime import datetime
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from jobs.models import JobStatus
|
| 6 |
+
from jobs.storage import get_job_storage
|
| 7 |
from jobs.streaming import create_stream, remove_stream
|
| 8 |
+
from inference import run_inference, run_grounded_sam2_tracking
|
| 9 |
|
| 10 |
|
| 11 |
async def process_video_async(job_id: str) -> None:
|
|
|
|
| 39 |
mission_spec=job.mission_spec,
|
| 40 |
first_frame_gpt_results=job.first_frame_gpt_results,
|
| 41 |
num_maskmem=7,
|
| 42 |
+
detector_name=job.detector_name,
|
| 43 |
)
|
| 44 |
else:
|
| 45 |
detections_list = None
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""Grounded-SAM-2 segmenter with continuous-ID video tracking.
|
| 2 |
|
| 3 |
-
Combines
|
| 4 |
predictor to produce temporally consistent segmentation masks with
|
| 5 |
persistent object IDs across an entire video.
|
| 6 |
|
|
@@ -13,7 +13,7 @@ import logging
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
-
from typing import Any, Callable, Dict, List, Optional,
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
@@ -308,15 +308,26 @@ _SAM2_HF_MODELS = {
|
|
| 308 |
}
|
| 309 |
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
# ---------------------------------------------------------------------------
|
| 312 |
# Grounded-SAM-2 Segmenter
|
| 313 |
# ---------------------------------------------------------------------------
|
| 314 |
|
| 315 |
class GroundedSAM2Segmenter(Segmenter):
|
| 316 |
-
"""SAM2 video segmenter driven by
|
| 317 |
|
| 318 |
-
|
| 319 |
-
For
|
|
|
|
| 320 |
predictor for temporal mask propagation with continuous object IDs.
|
| 321 |
"""
|
| 322 |
|
|
@@ -330,12 +341,15 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 330 |
step: int = 20,
|
| 331 |
iou_threshold: float = 0.5,
|
| 332 |
num_maskmem: Optional[int] = None,
|
|
|
|
| 333 |
):
|
| 334 |
self.model_size = model_size
|
| 335 |
self.step = step
|
| 336 |
self.iou_threshold = iou_threshold
|
| 337 |
self.num_maskmem = num_maskmem # None = use default (7)
|
| 338 |
-
self.
|
|
|
|
|
|
|
| 339 |
|
| 340 |
if device:
|
| 341 |
self.device = device
|
|
@@ -345,7 +359,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 345 |
# Lazy-loaded model handles
|
| 346 |
self._video_predictor = None
|
| 347 |
self._image_predictor = None
|
| 348 |
-
self.
|
| 349 |
self._models_loaded = False
|
| 350 |
|
| 351 |
# -- Lazy loading -------------------------------------------------------
|
|
@@ -388,10 +402,11 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 388 |
self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
|
| 389 |
logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
|
| 390 |
|
| 391 |
-
#
|
| 392 |
-
from models.
|
| 393 |
|
| 394 |
-
|
|
|
|
| 395 |
|
| 396 |
self._models_loaded = True
|
| 397 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
|
@@ -476,13 +491,13 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 476 |
def predict(
|
| 477 |
self, frame: np.ndarray, text_prompts: Optional[list] = None
|
| 478 |
) -> SegmentationResult:
|
| 479 |
-
"""Run
|
| 480 |
self._ensure_models_loaded()
|
| 481 |
|
| 482 |
prompts = text_prompts or ["object"]
|
| 483 |
|
| 484 |
-
# Run
|
| 485 |
-
det = self.
|
| 486 |
if det.boxes is None or len(det.boxes) == 0:
|
| 487 |
return SegmentationResult(
|
| 488 |
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
|
|
@@ -539,11 +554,11 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 539 |
image: "Image",
|
| 540 |
text_prompts: List[str],
|
| 541 |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]:
|
| 542 |
-
"""Run
|
| 543 |
|
| 544 |
Args:
|
| 545 |
image: PIL Image in RGB mode.
|
| 546 |
-
text_prompts: Text queries for
|
| 547 |
|
| 548 |
Returns:
|
| 549 |
``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
|
|
@@ -554,26 +569,12 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 554 |
self._ensure_models_loaded()
|
| 555 |
_pm = getattr(self, '_perf_metrics', None)
|
| 556 |
|
| 557 |
-
prompt = self._gdino_detector._build_prompt(text_prompts)
|
| 558 |
-
gdino_processor = self._gdino_detector.processor
|
| 559 |
-
gdino_model = self._gdino_detector.model
|
| 560 |
-
|
| 561 |
if _pm is not None:
|
| 562 |
_t0 = time.perf_counter()
|
| 563 |
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
)
|
| 567 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 568 |
-
|
| 569 |
-
with torch.no_grad():
|
| 570 |
-
outputs = gdino_model(**inputs)
|
| 571 |
-
|
| 572 |
-
results = self._gdino_detector._post_process(
|
| 573 |
-
outputs,
|
| 574 |
-
inputs["input_ids"],
|
| 575 |
-
target_sizes=[image.size[::-1]],
|
| 576 |
-
)
|
| 577 |
|
| 578 |
if _pm is not None:
|
| 579 |
_pl = getattr(self, '_perf_lock', None)
|
|
@@ -583,21 +584,18 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 583 |
else:
|
| 584 |
_pm["gdino_total_ms"] += _d
|
| 585 |
|
| 586 |
-
|
| 587 |
-
det_labels = results[0].get("text_labels") or results[0].get("labels", [])
|
| 588 |
-
if torch.is_tensor(det_labels):
|
| 589 |
-
det_labels = det_labels.detach().cpu().tolist()
|
| 590 |
-
det_labels = [str(l) for l in det_labels]
|
| 591 |
-
|
| 592 |
-
if input_boxes.shape[0] == 0:
|
| 593 |
return None, None, []
|
| 594 |
|
|
|
|
|
|
|
|
|
|
| 595 |
# SAM2 image predictor
|
| 596 |
if _pm is not None:
|
| 597 |
_t1 = time.perf_counter()
|
| 598 |
|
| 599 |
self._image_predictor.set_image(np.array(image))
|
| 600 |
-
masks,
|
| 601 |
|
| 602 |
if _pm is not None:
|
| 603 |
_pl = getattr(self, '_perf_lock', None)
|
|
@@ -721,7 +719,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 721 |
Args:
|
| 722 |
frame_dir: Directory containing JPEG frames.
|
| 723 |
frame_names: Sorted list of frame filenames.
|
| 724 |
-
text_prompts: Text queries for
|
| 725 |
on_segment: Optional callback invoked after each segment completes.
|
| 726 |
Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
|
| 727 |
|
|
@@ -735,11 +733,6 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 735 |
|
| 736 |
device = self.device
|
| 737 |
step = self.step
|
| 738 |
-
prompt = self._gdino_detector._build_prompt(text_prompts)
|
| 739 |
-
|
| 740 |
-
# HF processor for Grounding DINO (reuse from our detector)
|
| 741 |
-
gdino_processor = self._gdino_detector.processor
|
| 742 |
-
gdino_model = self._gdino_detector.model
|
| 743 |
|
| 744 |
total_frames = len(frame_names)
|
| 745 |
logging.info(
|
|
@@ -783,24 +776,12 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 783 |
|
| 784 |
mask_dict = MaskDictionary()
|
| 785 |
|
| 786 |
-
# --
|
| 787 |
if _pm is not None:
|
| 788 |
_t_gd = time.perf_counter()
|
| 789 |
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
)
|
| 793 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 794 |
-
|
| 795 |
-
with torch.no_grad():
|
| 796 |
-
outputs = gdino_model(**inputs)
|
| 797 |
-
|
| 798 |
-
# Use GDINO detector's _post_process for transformers version compat
|
| 799 |
-
results = self._gdino_detector._post_process(
|
| 800 |
-
outputs,
|
| 801 |
-
inputs["input_ids"],
|
| 802 |
-
target_sizes=[image.size[::-1]],
|
| 803 |
-
)
|
| 804 |
|
| 805 |
if _pm is not None:
|
| 806 |
_pl = getattr(self, '_perf_lock', None)
|
|
@@ -810,13 +791,14 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 810 |
else:
|
| 811 |
_pm["gdino_total_ms"] += _d
|
| 812 |
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
|
|
|
| 818 |
|
| 819 |
-
if input_boxes
|
| 820 |
logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
|
| 821 |
# Fill empty results for this segment
|
| 822 |
seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
|
@@ -842,7 +824,7 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 842 |
_t_si = time.perf_counter()
|
| 843 |
|
| 844 |
self._image_predictor.set_image(np.array(image))
|
| 845 |
-
masks,
|
| 846 |
|
| 847 |
if _pm is not None:
|
| 848 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 1 |
"""Grounded-SAM-2 segmenter with continuous-ID video tracking.
|
| 2 |
|
| 3 |
+
Combines an object detector (open-vocabulary or closed-set) with SAM2's video
|
| 4 |
predictor to produce temporally consistent segmentation masks with
|
| 5 |
persistent object IDs across an entire video.
|
| 6 |
|
|
|
|
| 13 |
import time
|
| 14 |
from contextlib import nullcontext
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
import torch
|
|
|
|
| 308 |
}
|
| 309 |
|
| 310 |
|
| 311 |
+
def _det_label_names(det) -> List[str]:
|
| 312 |
+
"""Extract string labels from a DetectionResult, with fallback."""
|
| 313 |
+
num_boxes = len(det.boxes) if det.boxes is not None else 0
|
| 314 |
+
if det.label_names is not None and len(det.label_names) > 0:
|
| 315 |
+
return list(det.label_names)
|
| 316 |
+
if det.labels is not None and len(det.labels) > 0:
|
| 317 |
+
return [str(l) for l in det.labels]
|
| 318 |
+
return ["object"] * num_boxes
|
| 319 |
+
|
| 320 |
+
|
| 321 |
# ---------------------------------------------------------------------------
|
| 322 |
# Grounded-SAM-2 Segmenter
|
| 323 |
# ---------------------------------------------------------------------------
|
| 324 |
|
| 325 |
class GroundedSAM2Segmenter(Segmenter):
|
| 326 |
+
"""SAM2 video segmenter driven by an injected object detector.
|
| 327 |
|
| 328 |
+
Any ``ObjectDetector`` can be used (defaults to Grounding DINO).
|
| 329 |
+
For single-frame mode (``predict``), uses detector + SAM2 image predictor.
|
| 330 |
+
For video mode (``process_video``), uses detector on keyframes + SAM2 video
|
| 331 |
predictor for temporal mask propagation with continuous object IDs.
|
| 332 |
"""
|
| 333 |
|
|
|
|
| 341 |
step: int = 20,
|
| 342 |
iou_threshold: float = 0.5,
|
| 343 |
num_maskmem: Optional[int] = None,
|
| 344 |
+
detector_name: Optional[str] = None,
|
| 345 |
):
|
| 346 |
self.model_size = model_size
|
| 347 |
self.step = step
|
| 348 |
self.iou_threshold = iou_threshold
|
| 349 |
self.num_maskmem = num_maskmem # None = use default (7)
|
| 350 |
+
self._detector_name = detector_name # None = "grounding_dino"
|
| 351 |
+
_size_suffix = {"small": "S", "base": "B", "large": "L"}
|
| 352 |
+
self.name = f"GSAM2-{_size_suffix[model_size]}"
|
| 353 |
|
| 354 |
if device:
|
| 355 |
self.device = device
|
|
|
|
| 359 |
# Lazy-loaded model handles
|
| 360 |
self._video_predictor = None
|
| 361 |
self._image_predictor = None
|
| 362 |
+
self._detector = None
|
| 363 |
self._models_loaded = False
|
| 364 |
|
| 365 |
# -- Lazy loading -------------------------------------------------------
|
|
|
|
| 402 |
self._patch_num_maskmem(self._video_predictor, self.num_maskmem)
|
| 403 |
logging.info("Patched video predictor num_maskmem → %d", self.num_maskmem)
|
| 404 |
|
| 405 |
+
# Load detector by name (defaults to Grounding DINO)
|
| 406 |
+
from models.model_loader import load_detector_on_device
|
| 407 |
|
| 408 |
+
det_name = self._detector_name or "grounding_dino"
|
| 409 |
+
self._detector = load_detector_on_device(det_name, self.device)
|
| 410 |
|
| 411 |
self._models_loaded = True
|
| 412 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
|
|
|
| 491 |
def predict(
|
| 492 |
self, frame: np.ndarray, text_prompts: Optional[list] = None
|
| 493 |
) -> SegmentationResult:
|
| 494 |
+
"""Run detector + SAM2 image predictor on a single frame."""
|
| 495 |
self._ensure_models_loaded()
|
| 496 |
|
| 497 |
prompts = text_prompts or ["object"]
|
| 498 |
|
| 499 |
+
# Run detector to get boxes
|
| 500 |
+
det = self._detector.predict(frame, prompts)
|
| 501 |
if det.boxes is None or len(det.boxes) == 0:
|
| 502 |
return SegmentationResult(
|
| 503 |
masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
|
|
|
|
| 554 |
image: "Image",
|
| 555 |
text_prompts: List[str],
|
| 556 |
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], List[str]]:
|
| 557 |
+
"""Run detector + SAM2 image predictor on a single keyframe.
|
| 558 |
|
| 559 |
Args:
|
| 560 |
image: PIL Image in RGB mode.
|
| 561 |
+
text_prompts: Text queries for the detector.
|
| 562 |
|
| 563 |
Returns:
|
| 564 |
``(masks, boxes, labels)`` where *masks* is an ``(N, H, W)``
|
|
|
|
| 569 |
self._ensure_models_loaded()
|
| 570 |
_pm = getattr(self, '_perf_metrics', None)
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
if _pm is not None:
|
| 573 |
_t0 = time.perf_counter()
|
| 574 |
|
| 575 |
+
# Convert PIL RGB → numpy BGR for detector.predict()
|
| 576 |
+
frame_bgr = np.array(image)[:, :, ::-1].copy()
|
| 577 |
+
det = self._detector.predict(frame_bgr, text_prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
if _pm is not None:
|
| 580 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 584 |
else:
|
| 585 |
_pm["gdino_total_ms"] += _d
|
| 586 |
|
| 587 |
+
if det.boxes is None or len(det.boxes) == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
return None, None, []
|
| 589 |
|
| 590 |
+
input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
|
| 591 |
+
det_labels = _det_label_names(det)
|
| 592 |
+
|
| 593 |
# SAM2 image predictor
|
| 594 |
if _pm is not None:
|
| 595 |
_t1 = time.perf_counter()
|
| 596 |
|
| 597 |
self._image_predictor.set_image(np.array(image))
|
| 598 |
+
masks, _ = self._predict_masks_gpu(input_boxes)
|
| 599 |
|
| 600 |
if _pm is not None:
|
| 601 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 719 |
Args:
|
| 720 |
frame_dir: Directory containing JPEG frames.
|
| 721 |
frame_names: Sorted list of frame filenames.
|
| 722 |
+
text_prompts: Text queries for the detector.
|
| 723 |
on_segment: Optional callback invoked after each segment completes.
|
| 724 |
Receives ``{frame_idx: {obj_id: ObjectInfo}}`` for the segment.
|
| 725 |
|
|
|
|
| 733 |
|
| 734 |
device = self.device
|
| 735 |
step = self.step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
|
| 737 |
total_frames = len(frame_names)
|
| 738 |
logging.info(
|
|
|
|
| 776 |
|
| 777 |
mask_dict = MaskDictionary()
|
| 778 |
|
| 779 |
+
# -- Detector on keyframe --
|
| 780 |
if _pm is not None:
|
| 781 |
_t_gd = time.perf_counter()
|
| 782 |
|
| 783 |
+
frame_bgr = np.array(image)[:, :, ::-1].copy()
|
| 784 |
+
det = self._detector.predict(frame_bgr, text_prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
|
| 786 |
if _pm is not None:
|
| 787 |
_pl = getattr(self, '_perf_lock', None)
|
|
|
|
| 791 |
else:
|
| 792 |
_pm["gdino_total_ms"] += _d
|
| 793 |
|
| 794 |
+
if det.boxes is None or len(det.boxes) == 0:
|
| 795 |
+
input_boxes = torch.zeros((0, 4), device=device)
|
| 796 |
+
det_labels = []
|
| 797 |
+
else:
|
| 798 |
+
input_boxes = torch.tensor(det.boxes, device=device, dtype=torch.float32)
|
| 799 |
+
det_labels = _det_label_names(det)
|
| 800 |
|
| 801 |
+
if len(input_boxes) == 0:
|
| 802 |
logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
|
| 803 |
# Fill empty results for this segment
|
| 804 |
seg_results: Dict[int, Dict[int, ObjectInfo]] = {}
|
|
|
|
| 824 |
_t_si = time.perf_counter()
|
| 825 |
|
| 826 |
self._image_predictor.set_image(np.array(image))
|
| 827 |
+
masks, _ = self._predict_masks_gpu(input_boxes)
|
| 828 |
|
| 829 |
if _pm is not None:
|
| 830 |
_pl = getattr(self, '_perf_lock', None)
|
models/segmenters/model_loader.py
CHANGED
|
@@ -5,12 +5,12 @@ from typing import Callable, Dict, Optional
|
|
| 5 |
from .base import Segmenter
|
| 6 |
from .grounded_sam2 import GroundedSAM2Segmenter
|
| 7 |
|
| 8 |
-
DEFAULT_SEGMENTER = "
|
| 9 |
|
| 10 |
_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
|
| 11 |
-
"
|
| 12 |
-
"
|
| 13 |
-
"
|
| 14 |
}
|
| 15 |
|
| 16 |
|
|
@@ -37,7 +37,7 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
|
|
| 37 |
Load a segmenter by name.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
-
name: Segmenter name (default:
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
Cached segmenter instance
|
|
|
|
| 5 |
from .base import Segmenter
|
| 6 |
from .grounded_sam2 import GroundedSAM2Segmenter
|
| 7 |
|
| 8 |
+
DEFAULT_SEGMENTER = "GSAM2-L"
|
| 9 |
|
| 10 |
_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
|
| 11 |
+
"GSAM2-S": lambda **kw: GroundedSAM2Segmenter(model_size="small", **kw),
|
| 12 |
+
"GSAM2-B": lambda **kw: GroundedSAM2Segmenter(model_size="base", **kw),
|
| 13 |
+
"GSAM2-L": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
|
| 14 |
}
|
| 15 |
|
| 16 |
|
|
|
|
| 37 |
Load a segmenter by name.
|
| 38 |
|
| 39 |
Args:
|
| 40 |
+
name: Segmenter name (default: GSAM2-L)
|
| 41 |
|
| 42 |
Returns:
|
| 43 |
Cached segmenter instance
|
utils/roofline.py
CHANGED
|
@@ -21,9 +21,9 @@ _MODEL_FLOPS: Dict[str, float] = {
|
|
| 21 |
"drone_yolo": 78.9, # Same arch as YOLOv8m
|
| 22 |
|
| 23 |
# Segmentation models (GFLOPs per keyframe)
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
"gsam2_tiny": 24.0, # SAM2 tiny encoder
|
| 28 |
}
|
| 29 |
|
|
@@ -34,9 +34,9 @@ _MODEL_BYTES: Dict[str, float] = {
|
|
| 34 |
"detr_resnet50": 166.0,
|
| 35 |
"grounding_dino": 340.0,
|
| 36 |
"drone_yolo": 52.0,
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
"gsam2_tiny": 46.0,
|
| 41 |
}
|
| 42 |
|
|
|
|
| 21 |
"drone_yolo": 78.9, # Same arch as YOLOv8m
|
| 22 |
|
| 23 |
# Segmentation models (GFLOPs per keyframe)
|
| 24 |
+
"GSAM2-S": 48.0, # SAM2 small encoder
|
| 25 |
+
"GSAM2-B": 96.0, # SAM2 base encoder
|
| 26 |
+
"GSAM2-L": 200.0, # SAM2 large encoder
|
| 27 |
"gsam2_tiny": 24.0, # SAM2 tiny encoder
|
| 28 |
}
|
| 29 |
|
|
|
|
| 34 |
"detr_resnet50": 166.0,
|
| 35 |
"grounding_dino": 340.0,
|
| 36 |
"drone_yolo": 52.0,
|
| 37 |
+
"GSAM2-S": 92.0,
|
| 38 |
+
"GSAM2-B": 180.0,
|
| 39 |
+
"GSAM2-L": 400.0,
|
| 40 |
"gsam2_tiny": 46.0,
|
| 41 |
}
|
| 42 |
|