Spaces:
Paused
Paused
Zhen Ye Claude Opus 4.6 commited on
Commit ·
63684e4
1
Parent(s): b17bd6d
fix: manual torch.compile with graceful fallback for SAM2 components
Browse filesReplace vos_optimized=True (caused 500 error) with manual per-component
torch.compile matching Facebook's official recipe. Compiles image_encoder,
memory_encoder, memory_attention (dynamic=True), sam_prompt_encoder, and
sam_mask_decoder with max-autotune. Falls back to eager mode silently if
PyTorch version or CUDA compiler is unavailable.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -362,13 +362,17 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 362 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 363 |
|
| 364 |
# Video predictor (for process_video)
|
| 365 |
-
# vos_optimized=True enables SAM2VideoPredictorVOS which compiles
|
| 366 |
-
# image_encoder, memory_encoder, memory_attention, sam_prompt_encoder,
|
| 367 |
-
# and sam_mask_decoder with torch.compile(mode="max-autotune").
|
| 368 |
self._video_predictor = build_sam2_video_predictor_hf(
|
| 369 |
-
hf_id, device=self.device,
|
| 370 |
)
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
# Image predictor (for single-frame predict)
|
| 373 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 374 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
|
@@ -381,6 +385,35 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 381 |
self._models_loaded = True
|
| 382 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 385 |
|
| 386 |
def predict(
|
|
|
|
| 362 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 363 |
|
| 364 |
# Video predictor (for process_video)
|
|
|
|
|
|
|
|
|
|
| 365 |
self._video_predictor = build_sam2_video_predictor_hf(
|
| 366 |
+
hf_id, device=self.device,
|
| 367 |
)
|
| 368 |
|
| 369 |
+
# torch.compile individual components for fused kernels.
|
| 370 |
+
# memory_attention uses dynamic=True (variable memory token count).
|
| 371 |
+
# Wrapped in try/except: falls back to eager if PyTorch < 2.5 or
|
| 372 |
+
# CUDA compiler not available.
|
| 373 |
+
if self.device.startswith("cuda"):
|
| 374 |
+
self._apply_torch_compile()
|
| 375 |
+
|
| 376 |
# Image predictor (for single-frame predict)
|
| 377 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 378 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
|
|
|
| 385 |
self._models_loaded = True
|
| 386 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 387 |
|
| 388 |
+
def _apply_torch_compile(self):
|
| 389 |
+
"""Compile SAM2 sub-modules with torch.compile (max-autotune).
|
| 390 |
+
|
| 391 |
+
Compiles 5 components matching Facebook's official VOS recipe.
|
| 392 |
+
Falls back silently to eager mode on any compilation error.
|
| 393 |
+
"""
|
| 394 |
+
vp = self._video_predictor
|
| 395 |
+
components = [
|
| 396 |
+
("image_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 397 |
+
("memory_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 398 |
+
("memory_attention", dict(mode="max-autotune", fullgraph=True, dynamic=True)),
|
| 399 |
+
("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 400 |
+
("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 401 |
+
]
|
| 402 |
+
compiled = []
|
| 403 |
+
for attr, kwargs in components:
|
| 404 |
+
module = getattr(vp, attr, None)
|
| 405 |
+
if module is None:
|
| 406 |
+
continue
|
| 407 |
+
try:
|
| 408 |
+
module.forward = torch.compile(module.forward, **kwargs)
|
| 409 |
+
compiled.append(attr)
|
| 410 |
+
except Exception as e:
|
| 411 |
+
logging.warning("torch.compile failed for %s: %s", attr, e)
|
| 412 |
+
if compiled:
|
| 413 |
+
logging.info("torch.compile applied to: %s", ", ".join(compiled))
|
| 414 |
+
else:
|
| 415 |
+
logging.info("torch.compile not available, using eager mode.")
|
| 416 |
+
|
| 417 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 418 |
|
| 419 |
def predict(
|