Zhen Ye Claude Opus 4.6 commited on
Commit
4c36e1e
·
1 Parent(s): 63684e4

fix: add runtime fallback for torch.compile inductor/triton failures

Browse files

torch.compile wrapping succeeds immediately but actual compilation
happens lazily on first forward pass. Now catches runtime errors on
the first propagate_in_video yield, reverts all compiled forwards
to eager originals, and retries the propagation transparently.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. models/segmenters/grounded_sam2.py +44 -4
models/segmenters/grounded_sam2.py CHANGED
@@ -389,7 +389,9 @@ class GroundedSAM2Segmenter(Segmenter):
389
  """Compile SAM2 sub-modules with torch.compile (max-autotune).
390
 
391
  Compiles 5 components matching Facebook's official VOS recipe.
392
- Falls back silently to eager mode on any compilation error.
 
 
393
  """
394
  vp = self._video_predictor
395
  components = [
@@ -399,20 +401,35 @@ class GroundedSAM2Segmenter(Segmenter):
399
  ("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
400
  ("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
401
  ]
 
402
  compiled = []
403
  for attr, kwargs in components:
404
  module = getattr(vp, attr, None)
405
  if module is None:
406
  continue
407
  try:
 
408
  module.forward = torch.compile(module.forward, **kwargs)
409
  compiled.append(attr)
410
  except Exception as e:
411
- logging.warning("torch.compile failed for %s: %s", attr, e)
412
  if compiled:
413
  logging.info("torch.compile applied to: %s", ", ".join(compiled))
 
414
  else:
415
  logging.info("torch.compile not available, using eager mode.")
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  # -- Single-frame interface (Segmenter.predict) -------------------------
418
 
@@ -596,9 +613,32 @@ class GroundedSAM2Segmenter(Segmenter):
596
  class_names_list: List[str] = []
597
  cursor = 0
598
 
599
- for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
 
 
600
  inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
601
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
603
  n = bool_masks.shape[0]
604
 
 
389
  """Compile SAM2 sub-modules with torch.compile (max-autotune).
390
 
391
  Compiles 5 components matching Facebook's official VOS recipe.
392
+ torch.compile wraps succeed immediately; actual Triton/inductor
393
+ compilation happens lazily on first forward pass. We store
394
+ original forwards so propagate_segment can fall back on error.
395
  """
396
  vp = self._video_predictor
397
  components = [
 
401
  ("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
402
  ("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
403
  ]
404
+ self._original_forwards: Dict[str, Any] = {}
405
  compiled = []
406
  for attr, kwargs in components:
407
  module = getattr(vp, attr, None)
408
  if module is None:
409
  continue
410
  try:
411
+ self._original_forwards[attr] = module.forward
412
  module.forward = torch.compile(module.forward, **kwargs)
413
  compiled.append(attr)
414
  except Exception as e:
415
+ logging.warning("torch.compile wrapping failed for %s: %s", attr, e)
416
  if compiled:
417
  logging.info("torch.compile applied to: %s", ", ".join(compiled))
418
+ self._torch_compiled = True
419
  else:
420
  logging.info("torch.compile not available, using eager mode.")
421
+ self._torch_compiled = False
422
+
423
+ def _revert_torch_compile(self):
424
+ """Revert compiled forwards back to eager originals."""
425
+ vp = self._video_predictor
426
+ for attr, orig_fwd in self._original_forwards.items():
427
+ module = getattr(vp, attr, None)
428
+ if module is not None:
429
+ module.forward = orig_fwd
430
+ self._original_forwards.clear()
431
+ self._torch_compiled = False
432
+ logging.warning("Reverted torch.compile — falling back to eager mode.")
433
 
434
  # -- Single-frame interface (Segmenter.predict) -------------------------
435
 
 
613
  class_names_list: List[str] = []
614
  cursor = 0
615
 
616
+ # Wrap generator to catch torch.compile runtime failures on first frame.
617
+ # If inductor/triton fails, revert to eager and restart propagation.
618
+ _generator = self._video_predictor.propagate_in_video(
619
  inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
620
+ )
621
+ if getattr(self, '_torch_compiled', False) and not getattr(self, '_compile_verified', False):
622
+ try:
623
+ _first = next(_generator)
624
+ except Exception as e:
625
+ logging.warning("torch.compile runtime error, reverting to eager: %s", e)
626
+ self._revert_torch_compile()
627
+ # Re-init propagation with eager forwards
628
+ self._video_predictor.reset_state(inference_state)
629
+ for obj_id, obj_info in mask_dict.labels.items():
630
+ self._video_predictor.add_new_mask(
631
+ inference_state, start_idx, obj_id, obj_info.mask,
632
+ )
633
+ _generator = self._video_predictor.propagate_in_video(
634
+ inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
635
+ )
636
+ _first = next(_generator)
637
+ import itertools
638
+ self._compile_verified = True
639
+ _generator = itertools.chain([_first], _generator)
640
+
641
+ for out_frame_idx, out_obj_ids, out_mask_logits in _generator:
642
  bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
643
  n = bool_masks.shape[0]
644