Spaces:
Paused
Paused
Zhen Ye commited on
Commit ·
6268ac2
1
Parent(s): d74c718
feat(segmentation): add YSAM2 variants and mission-aware detector mapping
Browse files- app.py +13 -6
- frontend/index.html +3 -0
- models/segmenters/grounded_sam2.py +3 -1
- models/segmenters/model_loader.py +41 -4
- utils/roofline.py +7 -0
app.py
CHANGED
|
@@ -61,6 +61,7 @@ from utils.threat_chat import chat_about_threats
|
|
| 61 |
from utils.relevance import evaluate_relevance
|
| 62 |
from utils.enrichment import run_enrichment
|
| 63 |
from utils.schemas import AssessmentStatus
|
|
|
|
| 64 |
from utils.mission_parser import parse_mission_text, build_broad_queries, MissionParseError
|
| 65 |
|
| 66 |
logging.basicConfig(level=logging.INFO)
|
|
@@ -260,7 +261,7 @@ async def detect_endpoint(
|
|
| 260 |
mode: Detection mode (object_detection, segmentation, drone_detection)
|
| 261 |
queries: Comma-separated object classes for object_detection mode
|
| 262 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 263 |
-
segmenter: Segmentation model to use (GSAM2-S
|
| 264 |
enable_depth: Whether to run legacy depth estimation (default: False)
|
| 265 |
drone_detection uses the dedicated drone_yolo model.
|
| 266 |
|
|
@@ -302,7 +303,6 @@ async def detect_endpoint(
|
|
| 302 |
output_path,
|
| 303 |
query_list,
|
| 304 |
segmenter_name=segmenter,
|
| 305 |
-
detector_name="grounding_dino",
|
| 306 |
num_maskmem=7,
|
| 307 |
)
|
| 308 |
except ValueError as exc:
|
|
@@ -439,16 +439,23 @@ async def detect_async_endpoint(
|
|
| 439 |
mission_mode = "LEGACY"
|
| 440 |
|
| 441 |
detector_name = detector
|
|
|
|
| 442 |
if mode == "drone_detection":
|
| 443 |
detector_name = "drone_yolo"
|
|
|
|
| 444 |
elif mode == "segmentation":
|
| 445 |
-
#
|
| 446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
|
| 448 |
if queries.strip():
|
| 449 |
try:
|
| 450 |
-
mission_spec = parse_mission_text(queries.strip(),
|
| 451 |
-
query_list = build_broad_queries(
|
| 452 |
mission_mode = "MISSION"
|
| 453 |
logging.info(
|
| 454 |
"Mission parsed: mode=%s classes=%s broad_queries=%s domain=%s(%s)",
|
|
|
|
| 61 |
from utils.relevance import evaluate_relevance
|
| 62 |
from utils.enrichment import run_enrichment
|
| 63 |
from utils.schemas import AssessmentStatus
|
| 64 |
+
from models.segmenters.model_loader import get_segmenter_detector
|
| 65 |
from utils.mission_parser import parse_mission_text, build_broad_queries, MissionParseError
|
| 66 |
|
| 67 |
logging.basicConfig(level=logging.INFO)
|
|
|
|
| 261 |
mode: Detection mode (object_detection, segmentation, drone_detection)
|
| 262 |
queries: Comma-separated object classes for object_detection mode
|
| 263 |
detector: Model to use (hf_yolov8, detr_resnet50, grounding_dino)
|
| 264 |
+
segmenter: Segmentation model to use (GSAM2-S/B/L, YSAM2-S/B/L)
|
| 265 |
enable_depth: Whether to run legacy depth estimation (default: False)
|
| 266 |
drone_detection uses the dedicated drone_yolo model.
|
| 267 |
|
|
|
|
| 303 |
output_path,
|
| 304 |
query_list,
|
| 305 |
segmenter_name=segmenter,
|
|
|
|
| 306 |
num_maskmem=7,
|
| 307 |
)
|
| 308 |
except ValueError as exc:
|
|
|
|
| 439 |
mission_mode = "LEGACY"
|
| 440 |
|
| 441 |
detector_name = detector
|
| 442 |
+
mission_detector = detector # detector key used for mission query parsing
|
| 443 |
if mode == "drone_detection":
|
| 444 |
detector_name = "drone_yolo"
|
| 445 |
+
mission_detector = "drone_yolo"
|
| 446 |
elif mode == "segmentation":
|
| 447 |
+
# Segmenter registry owns detector selection (GSAM2→GDINO, YSAM2→YOLO).
|
| 448 |
+
# detector_name=None so the job doesn't forward it (avoids duplicate kwarg).
|
| 449 |
+
try:
|
| 450 |
+
mission_detector = get_segmenter_detector(segmenter)
|
| 451 |
+
except ValueError as exc:
|
| 452 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 453 |
+
detector_name = None
|
| 454 |
|
| 455 |
if queries.strip():
|
| 456 |
try:
|
| 457 |
+
mission_spec = parse_mission_text(queries.strip(), mission_detector, video_path=str(input_path))
|
| 458 |
+
query_list = build_broad_queries(mission_detector, mission_spec)
|
| 459 |
mission_mode = "MISSION"
|
| 460 |
logging.info(
|
| 461 |
"Mission parsed: mode=%s classes=%s broad_queries=%s domain=%s(%s)",
|
frontend/index.html
CHANGED
|
@@ -78,6 +78,9 @@
|
|
| 78 |
<option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
|
| 79 |
<option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
|
| 80 |
<option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
|
|
|
|
|
|
|
|
|
|
| 81 |
</optgroup>
|
| 82 |
<optgroup label="Drone Detection Models">
|
| 83 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
|
|
|
| 78 |
<option value="GSAM2-L" data-kind="segmentation">GSAM2-L</option>
|
| 79 |
<option value="GSAM2-B" data-kind="segmentation">GSAM2-B</option>
|
| 80 |
<option value="GSAM2-S" data-kind="segmentation">GSAM2-S</option>
|
| 81 |
+
<option value="YSAM2-L" data-kind="segmentation">YSAM2-L (Fast)</option>
|
| 82 |
+
<option value="YSAM2-B" data-kind="segmentation">YSAM2-B (Fast)</option>
|
| 83 |
+
<option value="YSAM2-S" data-kind="segmentation">YSAM2-S (Fast)</option>
|
| 84 |
</optgroup>
|
| 85 |
<optgroup label="Drone Detection Models">
|
| 86 |
<option value="drone_yolo" data-kind="drone">Drone</option>
|
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -349,7 +349,9 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 349 |
self.num_maskmem = num_maskmem # None = use default (7)
|
| 350 |
self._detector_name = detector_name # None = "grounding_dino"
|
| 351 |
_size_suffix = {"small": "S", "base": "B", "large": "L"}
|
| 352 |
-
|
|
|
|
|
|
|
| 353 |
|
| 354 |
if device:
|
| 355 |
self.device = device
|
|
|
|
| 349 |
self.num_maskmem = num_maskmem # None = use default (7)
|
| 350 |
self._detector_name = detector_name # None = "grounding_dino"
|
| 351 |
_size_suffix = {"small": "S", "base": "B", "large": "L"}
|
| 352 |
+
_det_prefix = {"hf_yolov8": "YSAM2"}
|
| 353 |
+
_prefix = _det_prefix.get(detector_name, "GSAM2")
|
| 354 |
+
self.name = f"{_prefix}-{_size_suffix[model_size]}"
|
| 355 |
|
| 356 |
if device:
|
| 357 |
self.device = device
|
models/segmenters/model_loader.py
CHANGED
|
@@ -1,19 +1,56 @@
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
-
from typing import Callable, Dict, Optional
|
| 4 |
|
| 5 |
from .base import Segmenter
|
| 6 |
from .grounded_sam2 import GroundedSAM2Segmenter
|
| 7 |
|
| 8 |
DEFAULT_SEGMENTER = "GSAM2-L"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
"GSAM2-L": lambda **kw: GroundedSAM2Segmenter(model_size="large", **kw),
|
| 14 |
}
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
def _create_segmenter(name: str, **kwargs) -> Segmenter:
|
| 18 |
"""Create a new segmenter instance."""
|
| 19 |
try:
|
|
|
|
| 1 |
import os
|
| 2 |
from functools import lru_cache
|
| 3 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 4 |
|
| 5 |
from .base import Segmenter
|
| 6 |
from .grounded_sam2 import GroundedSAM2Segmenter
|
| 7 |
|
| 8 |
DEFAULT_SEGMENTER = "GSAM2-L"
|
| 9 |
+
_DEFAULT_SEGMENTER_DETECTOR = "grounding_dino"
|
| 10 |
+
|
| 11 |
+
_SEGMENTER_SPECS: Dict[str, Tuple[str, Optional[str]]] = {
|
| 12 |
+
"GSAM2-S": ("small", None),
|
| 13 |
+
"GSAM2-B": ("base", None),
|
| 14 |
+
"GSAM2-L": ("large", None),
|
| 15 |
+
"YSAM2-S": ("small", "hf_yolov8"),
|
| 16 |
+
"YSAM2-B": ("base", "hf_yolov8"),
|
| 17 |
+
"YSAM2-L": ("large", "hf_yolov8"),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _build_factory(model_size: str, detector_name: Optional[str]) -> Callable[..., Segmenter]:
|
| 22 |
+
def _factory(**kw) -> Segmenter:
|
| 23 |
+
if detector_name is None:
|
| 24 |
+
return GroundedSAM2Segmenter(model_size=model_size, **kw)
|
| 25 |
+
# YSAM2 keys own detector selection; drop accidental caller overrides.
|
| 26 |
+
kw.pop("detector_name", None)
|
| 27 |
+
return GroundedSAM2Segmenter(
|
| 28 |
+
model_size=model_size,
|
| 29 |
+
detector_name=detector_name,
|
| 30 |
+
**kw,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
return _factory
|
| 34 |
+
|
| 35 |
|
| 36 |
_REGISTRY: Dict[str, Callable[..., Segmenter]] = {
|
| 37 |
+
name: _build_factory(model_size, detector_name)
|
| 38 |
+
for name, (model_size, detector_name) in _SEGMENTER_SPECS.items()
|
|
|
|
| 39 |
}
|
| 40 |
|
| 41 |
|
| 42 |
+
def get_segmenter_detector(segmenter_name: str) -> str:
|
| 43 |
+
"""Return the detector key associated with a segmenter (for mission parsing)."""
|
| 44 |
+
spec = _SEGMENTER_SPECS.get(segmenter_name)
|
| 45 |
+
if spec is None:
|
| 46 |
+
available = ", ".join(sorted(_REGISTRY))
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Unknown segmenter '{segmenter_name}'. Available: {available}"
|
| 49 |
+
)
|
| 50 |
+
detector_name = spec[1]
|
| 51 |
+
return detector_name or _DEFAULT_SEGMENTER_DETECTOR
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def _create_segmenter(name: str, **kwargs) -> Segmenter:
|
| 55 |
"""Create a new segmenter instance."""
|
| 56 |
try:
|
utils/roofline.py
CHANGED
|
@@ -24,6 +24,10 @@ _MODEL_FLOPS: Dict[str, float] = {
|
|
| 24 |
"GSAM2-S": 48.0, # SAM2 small encoder
|
| 25 |
"GSAM2-B": 96.0, # SAM2 base encoder
|
| 26 |
"GSAM2-L": 200.0, # SAM2 large encoder
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"gsam2_tiny": 24.0, # SAM2 tiny encoder
|
| 28 |
}
|
| 29 |
|
|
@@ -37,6 +41,9 @@ _MODEL_BYTES: Dict[str, float] = {
|
|
| 37 |
"GSAM2-S": 92.0,
|
| 38 |
"GSAM2-B": 180.0,
|
| 39 |
"GSAM2-L": 400.0,
|
|
|
|
|
|
|
|
|
|
| 40 |
"gsam2_tiny": 46.0,
|
| 41 |
}
|
| 42 |
|
|
|
|
| 24 |
"GSAM2-S": 48.0, # SAM2 small encoder
|
| 25 |
"GSAM2-B": 96.0, # SAM2 base encoder
|
| 26 |
"GSAM2-L": 200.0, # SAM2 large encoder
|
| 27 |
+
# YSAM2 uses the same SAM2 backbone; detector differences are reflected in timing.
|
| 28 |
+
"YSAM2-S": 48.0,
|
| 29 |
+
"YSAM2-B": 96.0,
|
| 30 |
+
"YSAM2-L": 200.0,
|
| 31 |
"gsam2_tiny": 24.0, # SAM2 tiny encoder
|
| 32 |
}
|
| 33 |
|
|
|
|
| 41 |
"GSAM2-S": 92.0,
|
| 42 |
"GSAM2-B": 180.0,
|
| 43 |
"GSAM2-L": 400.0,
|
| 44 |
+
"YSAM2-S": 92.0,
|
| 45 |
+
"YSAM2-B": 180.0,
|
| 46 |
+
"YSAM2-L": 400.0,
|
| 47 |
"gsam2_tiny": 46.0,
|
| 48 |
}
|
| 49 |
|