Zhen Ye commited on
Commit
6268ac2
·
1 Parent(s): d74c718

feat(segmentation): add YSAM2 variants and mission-aware detector mapping

Browse files
app.py CHANGED
@@ -61,6 +61,7 @@ from utils.threat_chat import chat_about_threats
61
  from utils.relevance import evaluate_relevance
62
  from utils.enrichment import run_enrichment
63
  from utils.schemas import AssessmentStatus
 
64
  from utils.mission_parser import parse_mission_text, build_broad_queries, MissionParseError
65
 
66
  logging.basicConfig(level=logging.INFO)
@@ -260,7 +261,7 @@ async def detect_endpoint(
260
  mode: Detection mode (object_detection, segmentation, drone_detection)
261
  queries: Comma-separated object classes for object_detection mode
262
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
263
- segmenter: Segmentation model to use (GSAM2-S, GSAM2-B, GSAM2-L)
264
  enable_depth: Whether to run legacy depth estimation (default: False)
265
  drone_detection uses the dedicated drone_yolo model.
266
 
@@ -302,7 +303,6 @@ async def detect_endpoint(
302
  output_path,
303
  query_list,
304
  segmenter_name=segmenter,
305
- detector_name="grounding_dino",
306
  num_maskmem=7,
307
  )
308
  except ValueError as exc:
@@ -439,16 +439,23 @@ async def detect_async_endpoint(
439
  mission_mode = "LEGACY"
440
 
441
  detector_name = detector
 
442
  if mode == "drone_detection":
443
  detector_name = "drone_yolo"
 
444
  elif mode == "segmentation":
445
- # Grounded-SAM2 uses Grounding DINO (open-vocabulary) for keyframe prompts.
446
- detector_name = "grounding_dino"
 
 
 
 
 
447
 
448
  if queries.strip():
449
  try:
450
- mission_spec = parse_mission_text(queries.strip(), detector_name, video_path=str(input_path))
451
- query_list = build_broad_queries(detector_name, mission_spec)
452
  mission_mode = "MISSION"
453
  logging.info(
454
  "Mission parsed: mode=%s classes=%s broad_queries=%s domain=%s(%s)",
 
61
  from utils.relevance import evaluate_relevance
62
  from utils.enrichment import run_enrichment
63
  from utils.schemas import AssessmentStatus
64
+ from models.segmenters.model_loader import get_segmenter_detector
65
  from utils.mission_parser import parse_mission_text, build_broad_queries, MissionParseError
66
 
67
  logging.basicConfig(level=logging.INFO)
 
261
  mode: Detection mode (object_detection, segmentation, drone_detection)
262
  queries: Comma-separated object classes for object_detection mode
263
  detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
264
+ segmenter: Segmentation model to use (GSAM2-S/B/L, YSAM2-S/B/L)
265
  enable_depth: Whether to run legacy depth estimation (default: False)
266
  drone_detection uses the dedicated drone_yolo model.
267
 
 
303
  output_path,
304
  query_list,
305
  segmenter_name=segmenter,
 
306
  num_maskmem=7,
307
  )
308
  except ValueError as exc:
 
439
  mission_mode = "LEGACY"
440
 
441
  detector_name = detector
442
+ mission_detector = detector # detector key used for mission query parsing
443
  if mode == "drone_detection":
444
  detector_name = "drone_yolo"
445
+ mission_detector = "drone_yolo"
446
  elif mode == "segmentation":
447
+ # Segmenter registry owns detector selection (GSAM2→GDINO, YSAM2→YOLO).
448
+ # detector_name=None so the job doesn't forward it (avoids duplicate kwarg).
449
+ try:
450
+ mission_detector = get_segmenter_detector(segmenter)
451
+ except ValueError as exc:
452
+ raise HTTPException(status_code=400, detail=str(exc))
453
+ detector_name = None
454
 
455
  if queries.strip():
456
  try:
457
+ mission_spec = parse_mission_text(queries.strip(), mission_detector, video_path=str(input_path))
458
+ query_list = build_broad_queries(mission_detector, mission_spec)
459
  mission_mode = "MISSION"
460
  logging.info(
461
  "Mission parsed: mode=%s classes=%s broad_queries=%s domain=%s(%s)",
frontend/index.html CHANGED
@@ -78,6 +78,9 @@
78
  <option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
79
  <option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
80
  <option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
 
 
 
81
  </optgroup>
82
  <optgroup label="Drone Detection Models">
83
  <option value="drone_yolo" data-kind="drone">Drone</option>
 
78
  <option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
79
  <option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
80
  <option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
81
+ <option value="YSAM2-L" data-kind="segmentation">YSAM2-L (Fast)</option>
82
+ <option value="YSAM2-B" data-kind="segmentation">YSAM2-B (Fast)</option>
83
+ <option value="YSAM2-S" data-kind="segmentation">YSAM2-S (Fast)</option>
84
  </optgroup>
85
  <optgroup label="Drone Detection Models">
86
  <option value="drone_yolo" data-kind="drone">Drone</option>
models/segmenters/grounded_sam2.py CHANGED
@@ -349,7 +349,9 @@ class GroundedSAM2Segmenter(Segmenter):
349
  self.num_maskmem = num_maskmem # None = use default (7)
350
  self._detector_name = detector_name # None = "grounding_dino"
351
  _size_suffix = {"small": "S", "base": "B", "large": "L"}
352
- self.name = f"GSAM2-{_size_suffix[model_size]}"
 
 
353
 
354
  if device:
355
  self.device = device
 
349
  self.num_maskmem = num_maskmem # None = use default (7)
350
  self._detector_name = detector_name # None = "grounding_dino"
351
  _size_suffix = {"small": "S", "base": "B", "large": "L"}
352
+ _det_prefix = {"hf_yolov8": "YSAM2"}
353
+ _prefix = _det_prefix.get(detector_name, "GSAM2")
354
+ self.name = f"{_prefix}-{_size_suffix[model_size]}"
355
 
356
  if device:
357
  self.device = device
models/segmenters/model_loader.py CHANGED
@@ -1,19 +1,56 @@
1
  import os
2
  from functools import lru_cache
3
- from typing import Callable, Dict, Optional
4
 
5
  from .base import Segmenter
6
  from .grounded_sam2 import GroundedSAM2Segmenter
7
 
8
  DEFAULT_SEGMENTER = "GSAM2-L"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  _REGISTRY: Dict[str, Callable[..., Segmenter]] = {
11
- "GSAM2-S": lambda **kw: GroundedSAM2Segmenter(model_size="small", **kw),
12
- "GSAM2-B": lambda **kw: GroundedSAM2Segmenter(model_size="base", **kw),
13
- "GSAM2-L": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
14
  }
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def _create_segmenter(name: str, **kwargs) -> Segmenter:
18
  """Create a new segmenter instance."""
19
  try:
 
1
  import os
2
  from functools import lru_cache
3
+ from typing import Callable, Dict, Optional, Tuple
4
 
5
  from .base import Segmenter
6
  from .grounded_sam2 import GroundedSAM2Segmenter
7
 
8
  DEFAULT_SEGMENTER = "GSAM2-L"
9
+ _DEFAULT_SEGMENTER_DETECTOR = "grounding_dino"
10
+
11
+ _SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = {
12
+ "GSAM2-S": ("small", None),
13
+ "GSAM2-B": ("base", None),
14
+ "GSAM2-L": ("large", None),
15
+ "YSAM2-S": ("small", "hf_yolov8"),
16
+ "YSAM2-B": ("base", "hf_yolov8"),
17
+ "YSAM2-L": ("large", "hf_yolov8"),
18
+ }
19
+
20
+
21
+ def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]:
22
+ def _factory(**kw) -> Segmenter:
23
+ if detector_name is None:
24
+ return GroundedSAM2Segmenter(model_size=model_size, **kw)
25
+ # YSAM2 keys own detector selection; drop accidental caller overrides.
26
+ kw.pop("detector_name", None)
27
+ return GroundedSAM2Segmenter(
28
+ model_size=model_size,
29
+ detector_name=detector_name,
30
+ **kw,
31
+ )
32
+
33
+ return _factory
34
+
35
 
36
  _REGISTRY: Dict[str, Callable[..., Segmenter]] = {
37
+ name: _build_factory(model_size, detector_name)
38
+ for name, (model_size, detector_name) in _SEGMENTER_SPECS.items()
 
39
  }
40
 
41
 
42
+ def get_segmenter_detector(segmenter_name: str) -> str:
43
+ """Return the detector key associated with a segmenter (for mission parsing)."""
44
+ spec = _SEGMENTER_SPECS.get(segmenter_name)
45
+ if spec is None:
46
+ available = ", ".join(sorted(_REGISTRY))
47
+ raise ValueError(
48
+ f"Unknown segmenter '{segmenter_name}'. Available: {available}"
49
+ )
50
+ detector_name = spec[1]
51
+ return detector_name or _DEFAULT_SEGMENTER_DETECTOR
52
+
53
+
54
  def _create_segmenter(name: str, **kwargs) -> Segmenter:
55
  """Create a new segmenter instance."""
56
  try:
utils/roofline.py CHANGED
@@ -24,6 +24,10 @@ _MODEL_FLOPS: Dict[str, float] = {
24
  "GSAM2-S": 48.0, # SAM2 small encoder
25
  "GSAM2-B": 96.0, # SAM2 base encoder
26
  "GSAM2-L": 200.0, # SAM2 large encoder
 
 
 
 
27
  "gsam2_tiny": 24.0, # SAM2 tiny encoder
28
  }
29
 
@@ -37,6 +41,9 @@ _MODEL_BYTES: Dict[str, float] = {
37
  "GSAM2-S": 92.0,
38
  "GSAM2-B": 180.0,
39
  "GSAM2-L": 400.0,
 
 
 
40
  "gsam2_tiny": 46.0,
41
  }
42
 
 
24
  "GSAM2-S": 48.0, # SAM2 small encoder
25
  "GSAM2-B": 96.0, # SAM2 base encoder
26
  "GSAM2-L": 200.0, # SAM2 large encoder
27
+ # YSAM2 uses the same SAM2 backbone; detector differences are reflected in timing.
28
+ "YSAM2-S": 48.0,
29
+ "YSAM2-B": 96.0,
30
+ "YSAM2-L": 200.0,
31
  "gsam2_tiny": 24.0, # SAM2 tiny encoder
32
  }
33
 
 
41
  "GSAM2-S": 92.0,
42
  "GSAM2-B": 180.0,
43
  "GSAM2-L": 400.0,
44
+ "YSAM2-S": 92.0,
45
+ "YSAM2-B": 180.0,
46
+ "YSAM2-L": 400.0,
47
  "gsam2_tiny": 46.0,
48
  }
49