Zhen Ye commited on
Commit
3fde4e4
·
1 Parent(s): 04d4562

Remove SAM3 and standardize segmentation on Grounded-SAM2

Browse files
Dockerfile CHANGED
@@ -19,7 +19,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
19
  && rm -rf /var/lib/apt/lists/* \
20
  && pip install --no-cache-dir --upgrade pip \
21
  && pip install --no-cache-dir -r requirements.txt \
22
- && python -c "import transformers; print('transformers', transformers.__version__); print('has Sam3Model', hasattr(transformers, 'Sam3Model'))"
23
 
24
  COPY --chown=user . .
25
 
 
19
  && rm -rf /var/lib/apt/lists/* \
20
  && pip install --no-cache-dir --upgrade pip \
21
  && pip install --no-cache-dir -r requirements.txt \
22
+ && python -c "import transformers; print('transformers', transformers.__version__); print('has Sam2Model', hasattr(transformers, 'Sam2Model'))"
23
 
24
  COPY --chown=user . .
25
 
app.py CHANGED
@@ -41,7 +41,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, Redirect
41
  from fastapi.staticfiles import StaticFiles
42
  import uvicorn
43
 
44
- from inference import process_first_frame, run_inference, run_segmentation
45
  from models.depth_estimators.model_loader import list_depth_estimators
46
  from jobs.background import process_video_async
47
  from jobs.models import JobInfo, JobStatus
@@ -268,7 +268,7 @@ async def detect_endpoint(
268
  mode: str = Form(...),
269
  queries: str = Form(""),
270
  detector: str = Form("hf_yolov8"),
271
- segmenter: str = Form("sam3"),
272
  enable_depth: bool = Form(False),
273
  enable_gpt: bool = Form(True),
274
  ):
@@ -280,7 +280,7 @@ async def detect_endpoint(
280
  mode: Detection mode (object_detection, segmentation, drone_detection)
281
  queries: Comma-separated object classes for object_detection mode
282
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
283
- segmenter: Segmentation model to use (sam3)
284
  enable_depth: Whether to run legacy depth estimation (default: False)
285
  drone_detection uses the dedicated drone_yolo model.
286
 
@@ -317,7 +317,7 @@ async def detect_endpoint(
317
  query_list = ["object"]
318
 
319
  try:
320
- output_path = run_segmentation(
321
  input_path,
322
  output_path,
323
  query_list,
@@ -421,7 +421,7 @@ async def detect_async_endpoint(
421
  mode: str = Form(...),
422
  queries: str = Form(""),
423
  detector: str = Form("hf_yolov8"),
424
- segmenter: str = Form("sam3"),
425
  depth_estimator: str = Form("depth"),
426
  depth_scale: float = Form(25.0),
427
  enable_depth: bool = Form(False),
 
41
  from fastapi.staticfiles import StaticFiles
42
  import uvicorn
43
 
44
+ from inference import process_first_frame, run_inference, run_grounded_sam2_tracking
45
  from models.depth_estimators.model_loader import list_depth_estimators
46
  from jobs.background import process_video_async
47
  from jobs.models import JobInfo, JobStatus
 
268
  mode: str = Form(...),
269
  queries: str = Form(""),
270
  detector: str = Form("hf_yolov8"),
271
+ segmenter: str = Form("gsam2_large"),
272
  enable_depth: bool = Form(False),
273
  enable_gpt: bool = Form(True),
274
  ):
 
280
  mode: Detection mode (object_detection, segmentation, drone_detection)
281
  queries: Comma-separated object classes for object_detection mode
282
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
283
+ segmenter: Segmentation model to use (gsam2_small, gsam2_base, gsam2_large)
284
  enable_depth: Whether to run legacy depth estimation (default: False)
285
  drone_detection uses the dedicated drone_yolo model.
286
 
 
317
  query_list = ["object"]
318
 
319
  try:
320
+ output_path = run_grounded_sam2_tracking(
321
  input_path,
322
  output_path,
323
  query_list,
 
421
  mode: str = Form(...),
422
  queries: str = Form(""),
423
  detector: str = Form("hf_yolov8"),
424
+ segmenter: str = Form("gsam2_large"),
425
  depth_estimator: str = Form("depth"),
426
  depth_scale: float = Form(25.0),
427
  enable_depth: bool = Form(False),
frontend/index.html CHANGED
@@ -75,7 +75,9 @@
75
  <option value="grounding_dino" data-kind="object">Large</option>
76
  </optgroup>
77
  <optgroup label="Segmentation Models">
78
- <option value="sam3" data-kind="segmentation">Segmentor</option>
 
 
79
  </optgroup>
80
  <optgroup label="Drone Detection Models">
81
  <option value="drone_yolo" data-kind="drone">Drone</option>
@@ -293,4 +295,4 @@
293
 
294
  </body>
295
 
296
- </html>
 
75
  <option value="grounding_dino" data-kind="object">Large</option>
76
  </optgroup>
77
  <optgroup label="Segmentation Models">
78
+ <option value="gsam2_large" data-kind="segmentation">SAM2 Large</option>
79
+ <option value="gsam2_base" data-kind="segmentation">SAM2 Base+</option>
80
+ <option value="gsam2_small" data-kind="segmentation">SAM2 Small</option>
81
  </optgroup>
82
  <optgroup label="Drone Detection Models">
83
  <option value="drone_yolo" data-kind="drone">Drone</option>
 
295
 
296
  </body>
297
 
298
+ </html>
frontend/js/LaserPerception_original.js CHANGED
@@ -701,7 +701,9 @@
701
  "hf_yolov8",
702
  "detr_resnet50",
703
  "grounding_dino",
704
- "sam3",
 
 
705
  "drone_yolo",
706
 
707
  ]);
@@ -900,7 +902,7 @@
900
  form.append("detector", detector);
901
  }
902
  if (mode === "segmentation") {
903
- form.append("segmenter", "sam3");
904
  }
905
  // drone_detection uses drone_yolo automatically
906
 
 
701
  "hf_yolov8",
702
  "detr_resnet50",
703
  "grounding_dino",
704
+ "gsam2_small",
705
+ "gsam2_base",
706
+ "gsam2_large",
707
  "drone_yolo",
708
 
709
  ]);
 
902
  form.append("detector", detector);
903
  }
904
  if (mode === "segmentation") {
905
+ form.append("segmenter", detector || "gsam2_large");
906
  }
907
  // drone_detection uses drone_yolo automatically
908
 
frontend/js/main.js CHANGED
@@ -339,16 +339,35 @@ document.addEventListener("DOMContentLoaded", () => {
339
  }
340
 
341
  try {
342
- const mode = detectorSelect ? detectorSelect.value : "hf_yolov8";
 
 
343
  const queries = missionText ? missionText.value.trim() : "";
344
  const enableGPT = $("#enableGPTToggle")?.checked || false;
345
  const enableDepth = false; // depth mode disabled
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  const form = new FormData();
348
  form.append("video", state.videoFile);
349
- form.append("mode", "object_detection");
350
  if (queries) form.append("queries", queries);
351
- form.append("detector", mode);
 
352
  form.append("enable_gpt", enableGPT ? "true" : "false");
353
  form.append("enable_depth", enableDepth ? "true" : "false");
354
 
 
339
  }
340
 
341
  try {
342
+ const selectedOption = detectorSelect ? detectorSelect.options[detectorSelect.selectedIndex] : null;
343
+ const selectedValue = detectorSelect ? detectorSelect.value : "hf_yolov8";
344
+ const kind = selectedOption ? selectedOption.getAttribute("data-kind") : "object";
345
  const queries = missionText ? missionText.value.trim() : "";
346
  const enableGPT = $("#enableGPTToggle")?.checked || false;
347
  const enableDepth = false; // depth mode disabled
348
 
349
+ // Determine mode and model parameter from data-kind attribute
350
+ let mode, detectorParam, segmenterParam;
351
+ if (kind === "segmentation") {
352
+ mode = "segmentation";
353
+ segmenterParam = selectedValue;
354
+ detectorParam = "hf_yolov8"; // default, unused for segmentation
355
+ } else if (kind === "drone") {
356
+ mode = "drone_detection";
357
+ detectorParam = selectedValue;
358
+ segmenterParam = "gsam2_large";
359
+ } else {
360
+ mode = "object_detection";
361
+ detectorParam = selectedValue;
362
+ segmenterParam = "gsam2_large";
363
+ }
364
+
365
  const form = new FormData();
366
  form.append("video", state.videoFile);
367
+ form.append("mode", mode);
368
  if (queries) form.append("queries", queries);
369
+ form.append("detector", detectorParam);
370
+ form.append("segmenter", segmenterParam);
371
  form.append("enable_gpt", enableGPT ? "true" : "false");
372
  form.append("enable_depth", enableDepth ? "true" : "false");
373
 
inference.py CHANGED
@@ -1380,7 +1380,7 @@ def run_segmentation(
1380
  if max_frames is not None:
1381
  total_frames = min(total_frames, max_frames)
1382
 
1383
- active_segmenter = segmenter_name or "sam3"
1384
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
1385
 
1386
  # 2. Load Segmenters (Parallel)
@@ -1586,6 +1586,134 @@ def run_segmentation(
1586
 
1587
 
1588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1589
  def run_depth_inference(
1590
  input_video_path: str,
1591
  output_video_path: str,
 
1380
  if max_frames is not None:
1381
  total_frames = min(total_frames, max_frames)
1382
 
1383
+ active_segmenter = segmenter_name or "gsam2_large"
1384
  logging.info("Using segmenter: %s with queries: %s", active_segmenter, queries)
1385
 
1386
  # 2. Load Segmenters (Parallel)
 
1586
 
1587
 
1588
 
1589
+ def run_grounded_sam2_tracking(
1590
+ input_video_path: str,
1591
+ output_video_path: str,
1592
+ queries: List[str],
1593
+ max_frames: Optional[int] = None,
1594
+ segmenter_name: Optional[str] = None,
1595
+ job_id: Optional[str] = None,
1596
+ stream_queue: Optional[Queue] = None,
1597
+ step: int = 20,
1598
+ ) -> str:
1599
+ """Run Grounded-SAM-2 video tracking pipeline.
1600
+
1601
+ Unlike per-frame segmentation, this extracts all frames to JPEG,
1602
+ runs SAM2 video predictor for temporal mask propagation, then
1603
+ renders the results back into a video.
1604
+ """
1605
+ import shutil
1606
+
1607
+ from utils.video import extract_frames_to_jpeg_dir
1608
+ from models.segmenters.model_loader import load_segmenter as _load_seg
1609
+
1610
+ active_segmenter = segmenter_name or "gsam2_large"
1611
+ logging.info(
1612
+ "Grounded-SAM-2 tracking: segmenter=%s, queries=%s, step=%d",
1613
+ active_segmenter, queries, step,
1614
+ )
1615
+
1616
+ # 1. Extract frames to JPEG directory
1617
+ frame_dir = tempfile.mkdtemp(prefix="gsam2_frames_")
1618
+ try:
1619
+ frame_names, fps, width, height = extract_frames_to_jpeg_dir(
1620
+ input_video_path, frame_dir, max_frames=max_frames,
1621
+ )
1622
+ total_frames = len(frame_names)
1623
+ logging.info("Extracted %d frames to %s", total_frames, frame_dir)
1624
+
1625
+ # 2. Load segmenter
1626
+ segmenter = _load_seg(active_segmenter)
1627
+
1628
+ # 3. Run tracking pipeline
1629
+ _check_cancellation(job_id)
1630
+ tracking_results = segmenter.process_video(frame_dir, frame_names, queries)
1631
+
1632
+ # 4. Render results into output video
1633
+ _check_cancellation(job_id)
1634
+ import os as _os
1635
+
1636
+ with StreamingVideoWriter(output_video_path, fps, width, height) as writer:
1637
+ for frame_idx in range(total_frames):
1638
+ _check_cancellation(job_id)
1639
+
1640
+ # Read original frame
1641
+ frame_path = _os.path.join(frame_dir, frame_names[frame_idx])
1642
+ frame = cv2.imread(frame_path)
1643
+ if frame is None:
1644
+ logging.warning("Failed to read frame %d, skipping", frame_idx)
1645
+ continue
1646
+
1647
+ frame_objects = tracking_results.get(frame_idx, {})
1648
+
1649
+ if frame_objects:
1650
+ # Collect masks, boxes, and labels for rendering
1651
+ masks_list = []
1652
+ boxes_list = []
1653
+ label_list = []
1654
+
1655
+ for obj_id, obj_info in frame_objects.items():
1656
+ mask = obj_info.mask
1657
+ if mask is not None:
1658
+ if isinstance(mask, torch.Tensor):
1659
+ mask_np = mask.cpu().numpy().astype(bool)
1660
+ else:
1661
+ mask_np = np.asarray(mask).astype(bool)
1662
+ # Resize mask if needed
1663
+ if mask_np.shape[:2] != (height, width):
1664
+ mask_np = cv2.resize(
1665
+ mask_np.astype(np.uint8),
1666
+ (width, height),
1667
+ interpolation=cv2.INTER_NEAREST,
1668
+ ).astype(bool)
1669
+ masks_list.append(mask_np)
1670
+
1671
+ label = f"{obj_info.instance_id} {obj_info.class_name}"
1672
+ label_list.append(label)
1673
+
1674
+ if obj_info.x1 or obj_info.y1 or obj_info.x2 or obj_info.y2:
1675
+ boxes_list.append([obj_info.x1, obj_info.y1, obj_info.x2, obj_info.y2])
1676
+
1677
+ # Draw masks
1678
+ if masks_list:
1679
+ masks_array = np.stack(masks_list)
1680
+ frame = draw_masks(frame, masks_array, labels=label_list)
1681
+
1682
+ # Draw boxes
1683
+ if boxes_list:
1684
+ boxes_array = np.array(boxes_list)
1685
+ frame = draw_boxes(frame, boxes_array, label_names=label_list)
1686
+
1687
+ writer.write(frame)
1688
+
1689
+ # Stream frame if requested
1690
+ if stream_queue:
1691
+ try:
1692
+ from jobs.streaming import publish_frame as _pub
1693
+ if job_id:
1694
+ _pub(job_id, frame)
1695
+ else:
1696
+ stream_queue.put(frame, timeout=0.01)
1697
+ except Exception:
1698
+ pass
1699
+
1700
+ if frame_idx % 30 == 0:
1701
+ logging.info(
1702
+ "Rendered frame %d / %d", frame_idx, total_frames
1703
+ )
1704
+
1705
+ logging.info("Grounded-SAM-2 output written to: %s", output_video_path)
1706
+ return output_video_path
1707
+
1708
+ finally:
1709
+ # Cleanup temp frame directory
1710
+ try:
1711
+ shutil.rmtree(frame_dir)
1712
+ logging.info("Cleaned up temp frame dir: %s", frame_dir)
1713
+ except Exception:
1714
+ logging.warning("Failed to clean up temp frame dir: %s", frame_dir)
1715
+
1716
+
1717
  def run_depth_inference(
1718
  input_video_path: str,
1719
  output_video_path: str,
jobs/background.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  from jobs.models import JobStatus
8
  from jobs.storage import get_job_storage, get_depth_output_path, get_first_frame_depth_path
9
  from jobs.streaming import create_stream, remove_stream
10
- from inference import run_inference, run_segmentation, run_depth_inference
11
 
12
 
13
  async def process_video_async(job_id: str) -> None:
@@ -28,7 +28,7 @@ async def process_video_async(job_id: str) -> None:
28
  # Run detection or segmentation first
29
  if job.mode == "segmentation":
30
  detection_path = await asyncio.to_thread(
31
- run_segmentation,
32
  job.input_video_path,
33
  job.output_video_path,
34
  job.queries,
 
7
  from jobs.models import JobStatus
8
  from jobs.storage import get_job_storage, get_depth_output_path, get_first_frame_depth_path
9
  from jobs.streaming import create_stream, remove_stream
10
+ from inference import run_inference, run_grounded_sam2_tracking, run_depth_inference
11
 
12
 
13
  async def process_video_async(job_id: str) -> None:
 
28
  # Run detection or segmentation first
29
  if job.mode == "segmentation":
30
  detection_path = await asyncio.to_thread(
31
+ run_grounded_sam2_tracking,
32
  job.input_video_path,
33
  job.output_video_path,
34
  job.queries,
models/segmenters/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
  from .base import Segmenter, SegmentationResult
2
  from .model_loader import load_segmenter
3
- from .sam3 import SAM3Segmenter
4
 
5
  __all__ = [
6
  "Segmenter",
7
  "SegmentationResult",
8
  "load_segmenter",
9
- "SAM3Segmenter",
10
  ]
 
1
  from .base import Segmenter, SegmentationResult
2
  from .model_loader import load_segmenter
3
+ from .grounded_sam2 import GroundedSAM2Segmenter
4
 
5
  __all__ = [
6
  "Segmenter",
7
  "SegmentationResult",
8
  "load_segmenter",
9
+ "GroundedSAM2Segmenter",
10
  ]
models/segmenters/grounded_sam2.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Grounded-SAM-2 segmenter with continuous-ID video tracking.
2
+
3
+ Combines Grounding DINO (open-vocabulary detection) with SAM2's video
4
+ predictor to produce temporally consistent segmentation masks with
5
+ persistent object IDs across an entire video.
6
+
7
+ Reference implementation:
8
+ Grounded-SAM-2/grounded_sam2_tracking_demo_with_continuous_id.py
9
+ """
10
+
11
+ import copy
12
+ import logging
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ from PIL import Image
19
+
20
+ from .base import Segmenter, SegmentationResult
21
+
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Data structures (mirrors Grounded-SAM-2 reference utilities)
25
+ # ---------------------------------------------------------------------------
26
+
27
+ @dataclass
28
+ class ObjectInfo:
29
+ """Per-object tracking info for a single frame."""
30
+ instance_id: int = 0
31
+ mask: Any = None # torch.Tensor bool (H, W)
32
+ class_name: str = ""
33
+ x1: int = 0
34
+ y1: int = 0
35
+ x2: int = 0
36
+ y2: int = 0
37
+
38
+ def update_box(self):
39
+ """Derive bounding box from mask."""
40
+ if self.mask is None:
41
+ return
42
+ nonzero = torch.nonzero(self.mask)
43
+ if nonzero.size(0) == 0:
44
+ return
45
+ y_min, x_min = torch.min(nonzero, dim=0)[0]
46
+ y_max, x_max = torch.max(nonzero, dim=0)[0]
47
+ self.x1 = x_min.item()
48
+ self.y1 = y_min.item()
49
+ self.x2 = x_max.item()
50
+ self.y2 = y_max.item()
51
+
52
+
53
+ @dataclass
54
+ class MaskDictionary:
55
+ """Tracks object masks across frames with IoU-based ID matching."""
56
+ mask_height: int = 0
57
+ mask_width: int = 0
58
+ labels: Dict[int, ObjectInfo] = field(default_factory=dict)
59
+
60
+ def add_new_frame_annotation(
61
+ self,
62
+ mask_list: torch.Tensor,
63
+ box_list: torch.Tensor,
64
+ label_list: list,
65
+ ):
66
+ mask_img = torch.zeros(mask_list.shape[-2:])
67
+ anno = {}
68
+ for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
69
+ final_index = idx + 1
70
+ mask_img[mask == True] = final_index # noqa: E712
71
+ anno[final_index] = ObjectInfo(
72
+ instance_id=final_index,
73
+ mask=mask,
74
+ class_name=str(label),
75
+ x1=int(box[0]),
76
+ y1=int(box[1]),
77
+ x2=int(box[2]),
78
+ y2=int(box[3]),
79
+ )
80
+ self.mask_height = mask_img.shape[0]
81
+ self.mask_width = mask_img.shape[1]
82
+ self.labels = anno
83
+
84
+ def update_masks(
85
+ self,
86
+ tracking_dict: "MaskDictionary",
87
+ iou_threshold: float = 0.8,
88
+ objects_count: int = 0,
89
+ ) -> int:
90
+ """Match current detections against tracked objects via IoU."""
91
+ updated = {}
92
+ for _seg_id, seg_info in self.labels.items():
93
+ if seg_info.mask is None or seg_info.mask.sum() == 0:
94
+ continue
95
+ matched_id = 0
96
+ for _obj_id, obj_info in tracking_dict.labels.items():
97
+ iou = self._iou(seg_info.mask, obj_info.mask)
98
+ if iou > iou_threshold:
99
+ matched_id = obj_info.instance_id
100
+ break
101
+ if not matched_id:
102
+ objects_count += 1
103
+ matched_id = objects_count
104
+ new_info = ObjectInfo(
105
+ instance_id=matched_id,
106
+ mask=seg_info.mask,
107
+ class_name=seg_info.class_name,
108
+ )
109
+ updated[matched_id] = new_info
110
+ self.labels = updated
111
+ return objects_count
112
+
113
+ def get_target_class_name(self, instance_id: int) -> str:
114
+ info = self.labels.get(instance_id)
115
+ return info.class_name if info else ""
116
+
117
+ @staticmethod
118
+ def _iou(m1: torch.Tensor, m2: torch.Tensor) -> float:
119
+ m1f = m1.to(torch.float32)
120
+ m2f = m2.to(torch.float32)
121
+ inter = (m1f * m2f).sum()
122
+ union = m1f.sum() + m2f.sum() - inter
123
+ if union == 0:
124
+ return 0.0
125
+ return float(inter / union)
126
+
127
+
128
+ # ---------------------------------------------------------------------------
129
+ # SAM2 HuggingFace model IDs per size
130
+ # ---------------------------------------------------------------------------
131
+
132
+ _SAM2_HF_MODELS = {
133
+ "small": "facebook/sam2.1-hiera-small",
134
+ "base": "facebook/sam2.1-hiera-base-plus",
135
+ "large": "facebook/sam2.1-hiera-large",
136
+ }
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Grounded-SAM-2 Segmenter
141
+ # ---------------------------------------------------------------------------
142
+
143
+ class GroundedSAM2Segmenter(Segmenter):
144
+ """SAM2 video segmenter driven by Grounding DINO detections.
145
+
146
+ For single-frame mode (``predict``), uses GDINO + SAM2 image predictor.
147
+ For video mode (``process_video``), uses GDINO on keyframes + SAM2 video
148
+ predictor for temporal mask propagation with continuous object IDs.
149
+ """
150
+
151
+ supports_batch = False
152
+ max_batch_size = 1
153
+
154
+ def __init__(
155
+ self,
156
+ model_size: str = "large",
157
+ device: Optional[str] = None,
158
+ step: int = 20,
159
+ iou_threshold: float = 0.8,
160
+ ):
161
+ self.model_size = model_size
162
+ self.step = step
163
+ self.iou_threshold = iou_threshold
164
+ self.name = f"gsam2_{model_size}"
165
+
166
+ if device:
167
+ self.device = device
168
+ else:
169
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
170
+
171
+ # Lazy-loaded model handles
172
+ self._video_predictor = None
173
+ self._image_predictor = None
174
+ self._gdino_detector = None
175
+ self._models_loaded = False
176
+
177
+ # -- Lazy loading -------------------------------------------------------
178
+
179
+ def _ensure_models_loaded(self):
180
+ if self._models_loaded:
181
+ return
182
+
183
+ hf_id = _SAM2_HF_MODELS[self.model_size]
184
+ logging.info(
185
+ "Loading Grounded-SAM-2 (%s) on device %s ...", hf_id, self.device
186
+ )
187
+
188
+ # Enable TF32 on Ampere+ GPUs
189
+ if torch.cuda.is_available():
190
+ try:
191
+ props = torch.cuda.get_device_properties(
192
+ int(self.device.split(":")[-1]) if ":" in self.device else 0
193
+ )
194
+ if props.major >= 8:
195
+ torch.backends.cuda.matmul.allow_tf32 = True
196
+ torch.backends.cudnn.allow_tf32 = True
197
+ except Exception:
198
+ pass
199
+
200
+ from sam2.build_sam import build_sam2_hf, build_sam2_video_predictor_hf
201
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
202
+
203
+ # Video predictor (for process_video)
204
+ self._video_predictor = build_sam2_video_predictor_hf(
205
+ hf_id, device=self.device
206
+ )
207
+
208
+ # Image predictor (for single-frame predict)
209
+ sam2_image_model = build_sam2_hf(hf_id, device=self.device)
210
+ self._image_predictor = SAM2ImagePredictor(sam2_image_model)
211
+
212
+ # Reuse existing Grounding DINO detector from our codebase
213
+ from models.detectors.grounding_dino import GroundingDinoDetector
214
+
215
+ self._gdino_detector = GroundingDinoDetector(device=self.device)
216
+
217
+ self._models_loaded = True
218
+ logging.info("Grounded-SAM-2 models loaded successfully.")
219
+
220
+ # -- Single-frame interface (Segmenter.predict) -------------------------
221
+
222
+ def predict(
223
+ self, frame: np.ndarray, text_prompts: Optional[list] = None
224
+ ) -> SegmentationResult:
225
+ """Run GDINO + SAM2 image predictor on a single frame."""
226
+ self._ensure_models_loaded()
227
+
228
+ prompts = text_prompts or ["object"]
229
+
230
+ # Run Grounding DINO to get boxes
231
+ det = self._gdino_detector.predict(frame, prompts)
232
+ if det.boxes is None or len(det.boxes) == 0:
233
+ return SegmentationResult(
234
+ masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
235
+ scores=None,
236
+ boxes=None,
237
+ )
238
+
239
+ # SAM2 image predictor expects RGB
240
+ import cv2 as _cv2
241
+ frame_rgb = _cv2.cvtColor(frame, _cv2.COLOR_BGR2RGB)
242
+
243
+ with torch.autocast(device_type=self.device.split(":")[0], dtype=torch.bfloat16):
244
+ self._image_predictor.set_image(frame_rgb)
245
+ input_boxes = torch.tensor(det.boxes, device=self.device, dtype=torch.float32)
246
+ masks, scores, _ = self._image_predictor.predict(
247
+ point_coords=None,
248
+ point_labels=None,
249
+ box=input_boxes,
250
+ multimask_output=False,
251
+ )
252
+
253
+ # Normalize mask shape to (N, H, W)
254
+ if masks.ndim == 2:
255
+ masks = masks[None]
256
+ elif masks.ndim == 4:
257
+ masks = masks.squeeze(1)
258
+
259
+ if isinstance(masks, torch.Tensor):
260
+ masks_np = masks.cpu().numpy().astype(bool)
261
+ else:
262
+ masks_np = np.asarray(masks).astype(bool)
263
+
264
+ scores_np = None
265
+ if scores is not None:
266
+ if isinstance(scores, torch.Tensor):
267
+ scores_np = scores.cpu().numpy().flatten()
268
+ else:
269
+ scores_np = np.asarray(scores).flatten()
270
+
271
+ return SegmentationResult(
272
+ masks=masks_np,
273
+ scores=scores_np,
274
+ boxes=det.boxes,
275
+ )
276
+
277
+ # -- Video-level tracking interface -------------------------------------
278
+
279
+ def process_video(
280
+ self,
281
+ frame_dir: str,
282
+ frame_names: List[str],
283
+ text_prompts: List[str],
284
+ ) -> Dict[int, Dict[int, ObjectInfo]]:
285
+ """Run full Grounded-SAM-2 tracking pipeline on extracted JPEG frames.
286
+
287
+ Args:
288
+ frame_dir: Directory containing JPEG frames.
289
+ frame_names: Sorted list of frame filenames.
290
+ text_prompts: Text queries for Grounding DINO.
291
+
292
+ Returns:
293
+ Dict mapping frame_idx -> {obj_id: ObjectInfo} with masks,
294
+ bboxes, and class names for every frame.
295
+ """
296
+ import os
297
+
298
+ self._ensure_models_loaded()
299
+
300
+ device = self.device
301
+ step = self.step
302
+ prompt = self._gdino_detector._build_prompt(text_prompts)
303
+
304
+ # HF processor for Grounding DINO (reuse from our detector)
305
+ gdino_processor = self._gdino_detector.processor
306
+ gdino_model = self._gdino_detector.model
307
+
308
+ total_frames = len(frame_names)
309
+ logging.info(
310
+ "Grounded-SAM-2 tracking: %d frames, step=%d, queries=%s",
311
+ total_frames, step, text_prompts,
312
+ )
313
+
314
+ # Init SAM2 video predictor state
315
+ with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
316
+ inference_state = self._video_predictor.init_state(
317
+ video_path=frame_dir,
318
+ offload_video_to_cpu=True,
319
+ async_loading_frames=True,
320
+ )
321
+
322
+ sam2_masks = MaskDictionary()
323
+ objects_count = 0
324
+ all_results: Dict[int, Dict[int, ObjectInfo]] = {}
325
+
326
+ for start_idx in range(0, total_frames, step):
327
+ logging.info("Processing keyframe %d / %d", start_idx, total_frames)
328
+
329
+ img_path = os.path.join(frame_dir, frame_names[start_idx])
330
+ image = Image.open(img_path).convert("RGB")
331
+
332
+ mask_dict = MaskDictionary()
333
+
334
+ # -- Grounding DINO detection on keyframe --
335
+ inputs = gdino_processor(
336
+ images=image, text=prompt, return_tensors="pt"
337
+ )
338
+ inputs = {k: v.to(device) for k, v in inputs.items()}
339
+
340
+ with torch.no_grad():
341
+ outputs = gdino_model(**inputs)
342
+
343
+ results = gdino_processor.post_process_grounded_object_detection(
344
+ outputs,
345
+ inputs["input_ids"],
346
+ threshold=0.25,
347
+ text_threshold=0.25,
348
+ target_sizes=[image.size[::-1]],
349
+ )
350
+
351
+ input_boxes = results[0]["boxes"]
352
+ det_labels = results[0].get("text_labels") or results[0].get("labels", [])
353
+ if torch.is_tensor(det_labels):
354
+ det_labels = det_labels.detach().cpu().tolist()
355
+ det_labels = [str(l) for l in det_labels]
356
+
357
+ if input_boxes.shape[0] == 0:
358
+ logging.info("No detections on keyframe %d, propagating previous masks", start_idx)
359
+ # Fill empty results for this segment
360
+ for fi in range(start_idx, min(start_idx + step, total_frames)):
361
+ if fi not in all_results:
362
+ # Carry forward last known masks
363
+ all_results[fi] = {
364
+ k: ObjectInfo(
365
+ instance_id=v.instance_id,
366
+ mask=v.mask,
367
+ class_name=v.class_name,
368
+ x1=v.x1, y1=v.y1, x2=v.x2, y2=v.y2,
369
+ )
370
+ for k, v in sam2_masks.labels.items()
371
+ } if sam2_masks.labels else {}
372
+ continue
373
+
374
+ # -- SAM2 image predictor on keyframe --
375
+ with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
376
+ self._image_predictor.set_image(np.array(image))
377
+ masks, scores, logits = self._image_predictor.predict(
378
+ point_coords=None,
379
+ point_labels=None,
380
+ box=input_boxes,
381
+ multimask_output=False,
382
+ )
383
+
384
+ # Normalize mask dims
385
+ if masks.ndim == 2:
386
+ masks = masks[None]
387
+ scores = scores[None]
388
+ logits = logits[None]
389
+ elif masks.ndim == 4:
390
+ masks = masks.squeeze(1)
391
+
392
+ mask_dict.add_new_frame_annotation(
393
+ mask_list=torch.tensor(masks).to(device),
394
+ box_list=torch.tensor(input_boxes.cpu().numpy() if torch.is_tensor(input_boxes) else input_boxes),
395
+ label_list=det_labels,
396
+ )
397
+
398
+ # -- IoU matching to maintain persistent IDs --
399
+ objects_count = mask_dict.update_masks(
400
+ tracking_dict=sam2_masks,
401
+ iou_threshold=self.iou_threshold,
402
+ objects_count=objects_count,
403
+ )
404
+
405
+ if len(mask_dict.labels) == 0:
406
+ for fi in range(start_idx, min(start_idx + step, total_frames)):
407
+ all_results[fi] = {}
408
+ continue
409
+
410
+ # -- SAM2 video predictor: propagate masks --
411
+ with torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16):
412
+ self._video_predictor.reset_state(inference_state)
413
+
414
+ for obj_id, obj_info in mask_dict.labels.items():
415
+ self._video_predictor.add_new_mask(
416
+ inference_state,
417
+ start_idx,
418
+ obj_id,
419
+ obj_info.mask,
420
+ )
421
+
422
+ for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
423
+ inference_state,
424
+ max_frame_num_to_track=step,
425
+ start_frame_idx=start_idx,
426
+ ):
427
+ frame_objects: Dict[int, ObjectInfo] = {}
428
+ for i, out_obj_id in enumerate(out_obj_ids):
429
+ out_mask = (out_mask_logits[i] > 0.0)
430
+ info = ObjectInfo(
431
+ instance_id=out_obj_id,
432
+ mask=out_mask[0],
433
+ class_name=mask_dict.get_target_class_name(out_obj_id),
434
+ )
435
+ info.update_box()
436
+ frame_objects[out_obj_id] = info
437
+
438
+ all_results[out_frame_idx] = frame_objects
439
+ # Keep latest frame masks for next segment's IoU matching
440
+ sam2_masks = MaskDictionary()
441
+ sam2_masks.labels = copy.deepcopy(frame_objects)
442
+ if frame_objects:
443
+ first_info = next(iter(frame_objects.values()))
444
+ if first_info.mask is not None:
445
+ sam2_masks.mask_height = first_info.mask.shape[-2] if first_info.mask.ndim >= 2 else 0
446
+ sam2_masks.mask_width = first_info.mask.shape[-1] if first_info.mask.ndim >= 2 else 0
447
+
448
+ logging.info(
449
+ "Grounded-SAM-2 tracking complete: %d frames, %d tracked objects",
450
+ len(all_results), objects_count,
451
+ )
452
+ return all_results
models/segmenters/model_loader.py CHANGED
@@ -3,12 +3,14 @@ from functools import lru_cache
3
  from typing import Callable, Dict, Optional
4
 
5
  from .base import Segmenter
6
- from .sam3 import SAM3Segmenter
7
 
8
- DEFAULT_SEGMENTER = "sam3"
9
 
10
- _REGISTRY: Dict[str, Callable[[], Segmenter]] = {
11
- "sam3": SAM3Segmenter,
 
 
12
  }
13
 
14
 
@@ -35,7 +37,7 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
35
  Load a segmenter by name.
36
 
37
  Args:
38
- name: Segmenter name (default: sam3)
39
 
40
  Returns:
41
  Cached segmenter instance
@@ -46,7 +48,4 @@ def load_segmenter(name: Optional[str] = None) -> Segmenter:
46
 
47
  def load_segmenter_on_device(name: str, device: str) -> Segmenter:
48
  """Create a new segmenter instance on the specified device (no caching)."""
49
- # bypass cache by calling private creator directly
50
- # Note: _create_segmenter calls factory() which needs to accept device now.
51
- # We need to update _create_segmenter to pass kwargs too.
52
  return _create_segmenter(name, device=device)
 
3
  from typing import Callable, Dict, Optional
4
 
5
  from .base import Segmenter
6
+ from .grounded_sam2 import GroundedSAM2Segmenter
7
 
8
+ DEFAULT_SEGMENTER = "gsam2_large"
9
 
10
+ _REGISTRY: Dict[str, Callable[..., Segmenter]] = {
11
+ "gsam2_small": lambda **kw: GroundedSAM2Segmenter(model_size="small", **kw),
12
+ "gsam2_base": lambda **kw: GroundedSAM2Segmenter(model_size="base", **kw),
13
+ "gsam2_large": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
14
  }
15
 
16
 
 
37
  Load a segmenter by name.
38
 
39
  Args:
40
+ name: Segmenter name (default: gsam2_large)
41
 
42
  Returns:
43
  Cached segmenter instance
 
48
 
49
  def load_segmenter_on_device(name: str, device: str) -> Segmenter:
50
  """Create a new segmenter instance on the specified device (no caching)."""
 
 
 
51
  return _create_segmenter(name, device=device)
models/segmenters/sam3.py DELETED
@@ -1,284 +0,0 @@
1
- import logging
2
- from typing import Optional, Sequence
3
-
4
- import numpy as np
5
- import torch
6
- from PIL import Image
7
- from transformers import Sam3Model, Sam3Processor
8
-
9
- from .base import Segmenter, SegmentationResult
10
-
11
-
12
- class SAM3Segmenter(Segmenter):
13
- """
14
- SAM3 (Segment Anything Model 3) segmenter.
15
-
16
- Performs automatic instance segmentation on images without prompts.
17
- Uses facebook/sam3 model from HuggingFace.
18
- """
19
-
20
- name = "sam3"
21
-
22
- def __init__(
23
- self,
24
- model_id: str = "facebook/sam3",
25
- device: Optional[str] = None,
26
- threshold: float = 0.5,
27
- mask_threshold: float = 0.5,
28
- ):
29
- """
30
- Initialize SAM3 segmenter.
31
-
32
- Args:
33
- model_id: HuggingFace model ID
34
- device: Device to run on (cuda/cpu), auto-detected if None
35
- threshold: Confidence threshold for filtering instances
36
- mask_threshold: Threshold for binarizing masks
37
- """
38
- self.device = device or (
39
- "cuda" if torch.cuda.is_available() else "cpu"
40
- )
41
- self.threshold = threshold
42
- self.mask_threshold = mask_threshold
43
-
44
- logging.info(
45
- "Loading SAM3 model %s on device %s", model_id, self.device
46
- )
47
-
48
- try:
49
- self.model = Sam3Model.from_pretrained(model_id).to(self.device)
50
- self.processor = Sam3Processor.from_pretrained(model_id)
51
- self.model.eval()
52
- except Exception:
53
- logging.exception("Failed to load SAM3 model")
54
- raise
55
-
56
- logging.info("SAM3 model loaded successfully")
57
-
58
- supports_batch = True
59
- max_batch_size = 8
60
-
61
- def _parse_single_result(self, results, frame_shape) -> SegmentationResult:
62
- # Extract results
63
- masks = results.get("masks", [])
64
- scores = results.get("scores", None)
65
- boxes = results.get("boxes", None)
66
-
67
- # Convert to numpy arrays
68
- if len(masks) > 0:
69
- # Stack masks: list of (H, W) -> (N, H, W)
70
- masks_array = np.stack([m.cpu().numpy() for m in masks])
71
- else:
72
- # No objects detected
73
- masks_array = np.zeros(
74
- (0, frame_shape[0], frame_shape[1]), dtype=bool
75
- )
76
-
77
- scores_array = (
78
- scores.cpu().numpy() if scores is not None else None
79
- )
80
- boxes_array = (
81
- boxes.cpu().numpy() if boxes is not None else None
82
- )
83
-
84
- return SegmentationResult(
85
- masks=masks_array,
86
- scores=scores_array,
87
- boxes=boxes_array,
88
- )
89
-
90
- def _expand_inputs_if_needed(self, inputs):
91
- """
92
- Helper to expand vision inputs (pixel_values or vision_embeds) to match text prompts.
93
- Handles:
94
- 1. 1 image, N texts (Expand 1 -> N)
95
- 2. N images, N*M texts (Expand N -> N*M)
96
- """
97
- pixel_values = inputs.get("pixel_values")
98
- input_ids = inputs.get("input_ids")
99
-
100
- if (
101
- pixel_values is not None
102
- and input_ids is not None
103
- ):
104
- img_batch = pixel_values.shape[0]
105
- text_batch = input_ids.shape[0]
106
-
107
- should_expand = False
108
- expansion_factor = 1
109
-
110
- if img_batch == 1 and text_batch > 1:
111
- should_expand = True
112
- expansion_factor = text_batch
113
- elif img_batch > 1 and text_batch > img_batch and text_batch % img_batch == 0:
114
- should_expand = True
115
- expansion_factor = text_batch // img_batch
116
-
117
- if should_expand:
118
- logging.debug(f"Expanding SAM3 vision inputs from {img_batch} to {text_batch} (factor {expansion_factor}) using embeddings reuse.")
119
-
120
- # 1. Compute vision embeddings once for original images
121
- with torch.no_grad():
122
- vision_outputs = self.model.get_vision_features(
123
- pixel_values=pixel_values
124
- )
125
-
126
-
127
- # Iterate over keys to expand
128
- keys_to_expand = list(vision_outputs.keys())
129
- for key in keys_to_expand:
130
- value = getattr(vision_outputs, key, None)
131
- if value is None:
132
- # Try getItem
133
- try:
134
- value = vision_outputs[key]
135
- except:
136
- continue
137
-
138
- new_value = None
139
- if isinstance(value, torch.Tensor):
140
- # Ensure we only expand the batch dimension (dim 0)
141
- if value.shape[0] == img_batch:
142
- new_value = value.repeat_interleave(expansion_factor, dim=0)
143
- elif isinstance(value, (list, tuple)):
144
- new_list = []
145
- valid_expansion = False
146
- for i, v in enumerate(value):
147
- if isinstance(v, torch.Tensor) and v.shape[0] == img_batch:
148
- new_list.append(v.repeat_interleave(expansion_factor, dim=0))
149
- valid_expansion = True
150
- else:
151
- new_list.append(v)
152
-
153
- if valid_expansion:
154
- # Preserve type
155
- new_value = type(value)(new_list)
156
-
157
- if new_value is not None:
158
- # Update dict item if possible
159
- try:
160
- vision_outputs[key] = new_value
161
- except:
162
- pass
163
- # Update attribute explicitly if it exists
164
- if hasattr(vision_outputs, key):
165
- setattr(vision_outputs, key, new_value)
166
-
167
-
168
- # 3. Update inputs for model call
169
- inputs["vision_embeds"] = vision_outputs
170
- del inputs["pixel_values"] # Mutually exclusive with vision_embeds
171
-
172
- # 4. Expand other metadata
173
- if "original_sizes" in inputs and inputs["original_sizes"].shape[0] == img_batch:
174
- inputs["original_sizes"] = inputs["original_sizes"].repeat_interleave(expansion_factor, dim=0)
175
-
176
- if "reshape_input_sizes" in inputs and inputs["reshape_input_sizes"].shape[0] == img_batch:
177
- inputs["reshape_input_sizes"] = inputs["reshape_input_sizes"].repeat_interleave(expansion_factor, dim=0)
178
-
179
- def predict(self, frame: np.ndarray, text_prompts: Optional[list] = None) -> SegmentationResult:
180
- """
181
- Run SAM3 segmentation on a frame.
182
-
183
- Args:
184
- frame: Input image (HxWx3 numpy array in RGB)
185
- text_prompts: List of text prompts for segmentation
186
-
187
- Returns:
188
- SegmentationResult with instance masks
189
- """
190
- # Convert numpy array to PIL Image
191
- if frame.dtype == np.uint8:
192
- pil_image = Image.fromarray(frame)
193
- else:
194
- # Normalize to 0-255 if needed
195
- frame_uint8 = (frame * 255).astype(np.uint8)
196
- pil_image = Image.fromarray(frame_uint8)
197
-
198
- # Use default prompts if none provided
199
- if not text_prompts:
200
- text_prompts = ["object"]
201
-
202
- # Process image with text prompts
203
- inputs = self.processor(
204
- images=pil_image, text=text_prompts, return_tensors="pt"
205
- ).to(self.device)
206
-
207
- # Handle batch expansion
208
- self._expand_inputs_if_needed(inputs)
209
-
210
-
211
- # Run inference
212
- try:
213
- if "pixel_values" in inputs:
214
- logging.debug(f"SAM3 Input pixel_values shape: {inputs['pixel_values'].shape}")
215
- with torch.no_grad():
216
- outputs = self.model(**inputs)
217
- except RuntimeError as e:
218
- logging.error(f"RuntimeError during SAM3 inference: {e}")
219
- logging.error(f"Input keys: {inputs.keys()}")
220
- if 'pixel_values' in inputs:
221
- logging.error(f"Pixel values shape: {inputs['pixel_values'].shape}")
222
- # Re-raise to let user know
223
- raise
224
-
225
- # Post-process to get instance masks
226
- try:
227
- results = self.processor.post_process_instance_segmentation(
228
- outputs,
229
- threshold=self.threshold,
230
- mask_threshold=self.mask_threshold,
231
- target_sizes=inputs.get("original_sizes").tolist(),
232
- )[0]
233
- return self._parse_single_result(results, frame.shape)
234
-
235
- except Exception:
236
- logging.exception("SAM3 post-processing failed")
237
- # Return empty result
238
- return SegmentationResult(
239
- masks=np.zeros((0, frame.shape[0], frame.shape[1]), dtype=bool),
240
- scores=None,
241
- boxes=None,
242
- )
243
-
244
- def predict_batch(self, frames: Sequence[np.ndarray], text_prompts: Optional[list] = None) -> Sequence[SegmentationResult]:
245
- pil_images = []
246
- for f in frames:
247
- if f.dtype == np.uint8:
248
- pil_images.append(Image.fromarray(f))
249
- else:
250
- f_uint8 = (f * 255).astype(np.uint8)
251
- pil_images.append(Image.fromarray(f_uint8))
252
-
253
- prompts = text_prompts or ["object"]
254
-
255
- # Flatten prompts for all images: [img1_p1, img1_p2, img2_p1, img2_p2, ...]
256
- flattened_prompts = []
257
- for _ in frames:
258
- flattened_prompts.extend(prompts)
259
-
260
- inputs = self.processor(images=pil_images, text=flattened_prompts, return_tensors="pt").to(self.device)
261
-
262
- # Handle batch expansion
263
- self._expand_inputs_if_needed(inputs)
264
-
265
- with torch.no_grad():
266
- outputs = self.model(**inputs)
267
-
268
- try:
269
- results_list = self.processor.post_process_instance_segmentation(
270
- outputs,
271
- threshold=self.threshold,
272
- mask_threshold=self.mask_threshold,
273
- target_sizes=inputs.get("original_sizes").tolist(),
274
- )
275
- return [self._parse_single_result(r, f.shape) for r, f in zip(results_list, frames)]
276
- except Exception:
277
- logging.exception("SAM3 batch post-processing failed")
278
- return [
279
- SegmentationResult(
280
- masks=np.zeros((0, f.shape[0], f.shape[1]), dtype=bool),
281
- scores=None,
282
- boxes=None
283
- ) for f in frames
284
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -10,3 +10,6 @@ ultralytics
10
  python-dotenv
11
  einops
12
  sentence-transformers
 
 
 
 
10
  python-dotenv
11
  einops
12
  sentence-transformers
13
+ SAM-2 @ git+https://github.com/facebookresearch/sam2.git
14
+ hydra-core>=1.3.2
15
+ iopath>=0.1.10
utils/video.py CHANGED
@@ -9,6 +9,51 @@ import cv2
9
  import numpy as np
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def extract_frames(video_path: str) -> Tuple[List[np.ndarray], float, int, int]:
13
  cap = cv2.VideoCapture(video_path)
14
  if not cap.isOpened():
 
9
  import numpy as np
10
 
11
 
12
+ def extract_frames_to_jpeg_dir(
13
+ video_path: str,
14
+ output_dir: str,
15
+ max_frames: int = None,
16
+ ) -> Tuple[List[str], float, int, int]:
17
+ """Extract video frames as numbered JPEG files for SAM2 video predictor.
18
+
19
+ Args:
20
+ video_path: Path to input video.
21
+ output_dir: Directory to write JPEG files into.
22
+ max_frames: Optional cap on number of frames to extract.
23
+
24
+ Returns:
25
+ (frame_names, fps, width, height) where *frame_names* is a sorted
26
+ list of filenames like ``000000.jpg``, ``000001.jpg``, etc.
27
+ """
28
+ os.makedirs(output_dir, exist_ok=True)
29
+ cap = cv2.VideoCapture(video_path)
30
+ if not cap.isOpened():
31
+ raise ValueError(f"Unable to open video: {video_path}")
32
+
33
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+
37
+ frame_names: List[str] = []
38
+ idx = 0
39
+ while True:
40
+ if max_frames is not None and idx >= max_frames:
41
+ break
42
+ success, frame = cap.read()
43
+ if not success:
44
+ break
45
+ fname = f"{idx:06d}.jpg"
46
+ cv2.imwrite(os.path.join(output_dir, fname), frame)
47
+ frame_names.append(fname)
48
+ idx += 1
49
+
50
+ cap.release()
51
+ if not frame_names:
52
+ raise ValueError("Video decode produced zero frames.")
53
+
54
+ return frame_names, fps, width, height
55
+
56
+
57
  def extract_frames(video_path: str) -> Tuple[List[np.ndarray], float, int, int]:
58
  cap = cv2.VideoCapture(video_path)
59
  if not cap.isOpened():