Spaces:
Paused
Paused
Zhen Ye Claude Opus 4.6 commited on
Commit ·
5749bd6
1
Parent(s): 10eb3c6
revert: remove torch.compile — runtime failures on HF Space
Browse filesReverts all torch.compile changes (b17bd6d..10eb3c6). The HF Space
container lacks Triton/inductor support, causing 500 errors. The
GPU-resident tensor pipeline from 5aec47c is retained.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -363,16 +363,9 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 363 |
|
| 364 |
# Video predictor (for process_video)
|
| 365 |
self._video_predictor = build_sam2_video_predictor_hf(
|
| 366 |
-
hf_id, device=self.device
|
| 367 |
)
|
| 368 |
|
| 369 |
-
# torch.compile individual components for fused kernels.
|
| 370 |
-
# memory_attention uses dynamic=True (variable memory token count).
|
| 371 |
-
# Wrapped in try/except: falls back to eager if PyTorch < 2.5 or
|
| 372 |
-
# CUDA compiler not available.
|
| 373 |
-
if self.device.startswith("cuda"):
|
| 374 |
-
self._apply_torch_compile()
|
| 375 |
-
|
| 376 |
# Image predictor (for single-frame predict)
|
| 377 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 378 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
|
@@ -385,52 +378,6 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 385 |
self._models_loaded = True
|
| 386 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 387 |
|
| 388 |
-
def _apply_torch_compile(self):
|
| 389 |
-
"""Compile SAM2 sub-modules with torch.compile (max-autotune).
|
| 390 |
-
|
| 391 |
-
Compiles 5 components matching Facebook's official VOS recipe.
|
| 392 |
-
torch.compile wraps succeed immediately; actual Triton/inductor
|
| 393 |
-
compilation happens lazily on first forward pass. We store
|
| 394 |
-
original forwards so propagate_segment can fall back on error.
|
| 395 |
-
"""
|
| 396 |
-
vp = self._video_predictor
|
| 397 |
-
components = [
|
| 398 |
-
("image_encoder", dict(mode="max-autotune", dynamic=False)),
|
| 399 |
-
("memory_encoder", dict(mode="max-autotune", dynamic=False)),
|
| 400 |
-
("memory_attention", dict(mode="max-autotune", dynamic=True)),
|
| 401 |
-
("sam_prompt_encoder", dict(mode="max-autotune", dynamic=False)),
|
| 402 |
-
("sam_mask_decoder", dict(mode="max-autotune", dynamic=False)),
|
| 403 |
-
]
|
| 404 |
-
self._original_forwards: Dict[str, Any] = {}
|
| 405 |
-
compiled = []
|
| 406 |
-
for attr, kwargs in components:
|
| 407 |
-
module = getattr(vp, attr, None)
|
| 408 |
-
if module is None:
|
| 409 |
-
continue
|
| 410 |
-
try:
|
| 411 |
-
self._original_forwards[attr] = module.forward
|
| 412 |
-
module.forward = torch.compile(module.forward, **kwargs)
|
| 413 |
-
compiled.append(attr)
|
| 414 |
-
except Exception as e:
|
| 415 |
-
logging.warning("torch.compile wrapping failed for %s: %s", attr, e)
|
| 416 |
-
if compiled:
|
| 417 |
-
logging.info("torch.compile applied to: %s", ", ".join(compiled))
|
| 418 |
-
self._torch_compiled = True
|
| 419 |
-
else:
|
| 420 |
-
logging.info("torch.compile not available, using eager mode.")
|
| 421 |
-
self._torch_compiled = False
|
| 422 |
-
|
| 423 |
-
def _revert_torch_compile(self):
|
| 424 |
-
"""Revert compiled forwards back to eager originals."""
|
| 425 |
-
vp = self._video_predictor
|
| 426 |
-
for attr, orig_fwd in self._original_forwards.items():
|
| 427 |
-
module = getattr(vp, attr, None)
|
| 428 |
-
if module is not None:
|
| 429 |
-
module.forward = orig_fwd
|
| 430 |
-
self._original_forwards.clear()
|
| 431 |
-
self._torch_compiled = False
|
| 432 |
-
logging.warning("Reverted torch.compile — falling back to eager mode.")
|
| 433 |
-
|
| 434 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 435 |
|
| 436 |
def predict(
|
|
@@ -613,32 +560,9 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 613 |
class_names_list: List[str] = []
|
| 614 |
cursor = 0
|
| 615 |
|
| 616 |
-
|
| 617 |
-
# If inductor/triton fails, revert to eager and restart propagation.
|
| 618 |
-
_generator = self._video_predictor.propagate_in_video(
|
| 619 |
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 620 |
-
)
|
| 621 |
-
if getattr(self, '_torch_compiled', False) and not getattr(self, '_compile_verified', False):
|
| 622 |
-
try:
|
| 623 |
-
_first = next(_generator)
|
| 624 |
-
except Exception as e:
|
| 625 |
-
logging.warning("torch.compile runtime error, reverting to eager: %s", e)
|
| 626 |
-
self._revert_torch_compile()
|
| 627 |
-
# Re-init propagation with eager forwards
|
| 628 |
-
self._video_predictor.reset_state(inference_state)
|
| 629 |
-
for obj_id, obj_info in mask_dict.labels.items():
|
| 630 |
-
self._video_predictor.add_new_mask(
|
| 631 |
-
inference_state, start_idx, obj_id, obj_info.mask,
|
| 632 |
-
)
|
| 633 |
-
_generator = self._video_predictor.propagate_in_video(
|
| 634 |
-
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 635 |
-
)
|
| 636 |
-
_first = next(_generator)
|
| 637 |
-
import itertools
|
| 638 |
-
self._compile_verified = True
|
| 639 |
-
_generator = itertools.chain([_first], _generator)
|
| 640 |
-
|
| 641 |
-
for out_frame_idx, out_obj_ids, out_mask_logits in _generator:
|
| 642 |
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
|
| 643 |
n = bool_masks.shape[0]
|
| 644 |
|
|
|
|
| 363 |
|
| 364 |
# Video predictor (for process_video)
|
| 365 |
self._video_predictor = build_sam2_video_predictor_hf(
|
| 366 |
+
hf_id, device=self.device
|
| 367 |
)
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
# Image predictor (for single-frame predict)
|
| 370 |
sam2_image_model = build_sam2_hf(hf_id, device=self.device)
|
| 371 |
self._image_predictor = SAM2ImagePredictor(sam2_image_model)
|
|
|
|
| 378 |
self._models_loaded = True
|
| 379 |
logging.info("Grounded-SAM-2 models loaded successfully.")
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 382 |
|
| 383 |
def predict(
|
|
|
|
| 560 |
class_names_list: List[str] = []
|
| 561 |
cursor = 0
|
| 562 |
|
| 563 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
|
|
|
|
|
|
|
| 564 |
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 565 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
|
| 567 |
n = bool_masks.shape[0]
|
| 568 |
|