Spaces:
Paused
Paused
Zhen Ye Claude Opus 4.6 commited on
Commit ·
4c36e1e
1
Parent(s): 63684e4
fix: add runtime fallback for torch.compile inductor/triton failures
Browse filestorch.compile wrapping succeeds immediately but actual compilation
happens lazily on first forward pass. Now catches runtime errors on
the first propagate_in_video yield, reverts all compiled forwards
to eager originals, and retries the propagation transparently.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
models/segmenters/grounded_sam2.py
CHANGED
|
@@ -389,7 +389,9 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 389 |
"""Compile SAM2 sub-modules with torch.compile (max-autotune).
|
| 390 |
|
| 391 |
Compiles 5 components matching Facebook's official VOS recipe.
|
| 392 |
-
|
|
|
|
|
|
|
| 393 |
"""
|
| 394 |
vp = self._video_predictor
|
| 395 |
components = [
|
|
@@ -399,20 +401,35 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 399 |
("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 400 |
("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 401 |
]
|
|
|
|
| 402 |
compiled = []
|
| 403 |
for attr, kwargs in components:
|
| 404 |
module = getattr(vp, attr, None)
|
| 405 |
if module is None:
|
| 406 |
continue
|
| 407 |
try:
|
|
|
|
| 408 |
module.forward = torch.compile(module.forward, **kwargs)
|
| 409 |
compiled.append(attr)
|
| 410 |
except Exception as e:
|
| 411 |
-
logging.warning("torch.compile failed for %s: %s", attr, e)
|
| 412 |
if compiled:
|
| 413 |
logging.info("torch.compile applied to: %s", ", ".join(compiled))
|
|
|
|
| 414 |
else:
|
| 415 |
logging.info("torch.compile not available, using eager mode.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 418 |
|
|
@@ -596,9 +613,32 @@ class GroundedSAM2Segmenter(Segmenter):
|
|
| 596 |
class_names_list: List[str] = []
|
| 597 |
cursor = 0
|
| 598 |
|
| 599 |
-
|
|
|
|
|
|
|
| 600 |
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 601 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
|
| 603 |
n = bool_masks.shape[0]
|
| 604 |
|
|
|
|
| 389 |
"""Compile SAM2 sub-modules with torch.compile (max-autotune).
|
| 390 |
|
| 391 |
Compiles 5 components matching Facebook's official VOS recipe.
|
| 392 |
+
torch.compile wraps succeed immediately; actual Triton/inductor
|
| 393 |
+
compilation happens lazily on first forward pass. We store
|
| 394 |
+
original forwards so propagate_segment can fall back on error.
|
| 395 |
"""
|
| 396 |
vp = self._video_predictor
|
| 397 |
components = [
|
|
|
|
| 401 |
("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 402 |
("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
|
| 403 |
]
|
| 404 |
+
self._original_forwards: Dict[str, Any] = {}
|
| 405 |
compiled = []
|
| 406 |
for attr, kwargs in components:
|
| 407 |
module = getattr(vp, attr, None)
|
| 408 |
if module is None:
|
| 409 |
continue
|
| 410 |
try:
|
| 411 |
+
self._original_forwards[attr] = module.forward
|
| 412 |
module.forward = torch.compile(module.forward, **kwargs)
|
| 413 |
compiled.append(attr)
|
| 414 |
except Exception as e:
|
| 415 |
+
logging.warning("torch.compile wrapping failed for %s: %s", attr, e)
|
| 416 |
if compiled:
|
| 417 |
logging.info("torch.compile applied to: %s", ", ".join(compiled))
|
| 418 |
+
self._torch_compiled = True
|
| 419 |
else:
|
| 420 |
logging.info("torch.compile not available, using eager mode.")
|
| 421 |
+
self._torch_compiled = False
|
| 422 |
+
|
| 423 |
+
def _revert_torch_compile(self):
|
| 424 |
+
"""Revert compiled forwards back to eager originals."""
|
| 425 |
+
vp = self._video_predictor
|
| 426 |
+
for attr, orig_fwd in self._original_forwards.items():
|
| 427 |
+
module = getattr(vp, attr, None)
|
| 428 |
+
if module is not None:
|
| 429 |
+
module.forward = orig_fwd
|
| 430 |
+
self._original_forwards.clear()
|
| 431 |
+
self._torch_compiled = False
|
| 432 |
+
logging.warning("Reverted torch.compile — falling back to eager mode.")
|
| 433 |
|
| 434 |
# -- Single-frame interface (Segmenter.predict) -------------------------
|
| 435 |
|
|
|
|
| 613 |
class_names_list: List[str] = []
|
| 614 |
cursor = 0
|
| 615 |
|
| 616 |
+
# Wrap generator to catch torch.compile runtime failures on first frame.
|
| 617 |
+
# If inductor/triton fails, revert to eager and restart propagation.
|
| 618 |
+
_generator = self._video_predictor.propagate_in_video(
|
| 619 |
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 620 |
+
)
|
| 621 |
+
if getattr(self, '_torch_compiled', False) and not getattr(self, '_compile_verified', False):
|
| 622 |
+
try:
|
| 623 |
+
_first = next(_generator)
|
| 624 |
+
except Exception as e:
|
| 625 |
+
logging.warning("torch.compile runtime error, reverting to eager: %s", e)
|
| 626 |
+
self._revert_torch_compile()
|
| 627 |
+
# Re-init propagation with eager forwards
|
| 628 |
+
self._video_predictor.reset_state(inference_state)
|
| 629 |
+
for obj_id, obj_info in mask_dict.labels.items():
|
| 630 |
+
self._video_predictor.add_new_mask(
|
| 631 |
+
inference_state, start_idx, obj_id, obj_info.mask,
|
| 632 |
+
)
|
| 633 |
+
_generator = self._video_predictor.propagate_in_video(
|
| 634 |
+
inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
|
| 635 |
+
)
|
| 636 |
+
_first = next(_generator)
|
| 637 |
+
import itertools
|
| 638 |
+
self._compile_verified = True
|
| 639 |
+
_generator = itertools.chain([_first], _generator)
|
| 640 |
+
|
| 641 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in _generator:
|
| 642 |
bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
|
| 643 |
n = bool_masks.shape[0]
|
| 644 |
|