Zhen Ye Claude Opus 4.6 commited on
Commit
5749bd6
·
1 Parent(s): 10eb3c6

revert: remove torch.compile — runtime failures on HF Space

Browse files

Reverts all torch.compile changes (b17bd6d..10eb3c6). The HF Space
container lacks Triton/inductor support, causing 500 errors. The
GPU-resident tensor pipeline from 5aec47c is retained.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. models/segmenters/grounded_sam2.py +3 -79
models/segmenters/grounded_sam2.py CHANGED
@@ -363,16 +363,9 @@ class GroundedSAM2Segmenter(Segmenter):
363
 
364
  # Video predictor (for process_video)
365
  self._video_predictor = build_sam2_video_predictor_hf(
366
- hf_id, device=self.device,
367
  )
368
 
369
- # torch.compile individual components for fused kernels.
370
- # memory_attention uses dynamic=True (variable memory token count).
371
- # Wrapped in try/except: falls back to eager if PyTorch < 2.5 or
372
- # CUDA compiler not available.
373
- if self.device.startswith("cuda"):
374
- self._apply_torch_compile()
375
-
376
  # Image predictor (for single-frame predict)
377
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
378
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
@@ -385,52 +378,6 @@ class GroundedSAM2Segmenter(Segmenter):
385
  self._models_loaded = True
386
  logging.info("Grounded-SAM-2 models loaded successfully.")
387
 
388
- def _apply_torch_compile(self):
389
- """Compile SAM2 sub-modules with torch.compile (max-autotune).
390
-
391
- Compiles 5 components matching Facebook's official VOS recipe.
392
- torch.compile wraps succeed immediately; actual Triton/inductor
393
- compilation happens lazily on first forward pass. We store
394
- original forwards so propagate_segment can fall back on error.
395
- """
396
- vp = self._video_predictor
397
- components = [
398
- ("image_encoder", dict(mode="max-autotune", dynamic=False)),
399
- ("memory_encoder", dict(mode="max-autotune", dynamic=False)),
400
- ("memory_attention", dict(mode="max-autotune", dynamic=True)),
401
- ("sam_prompt_encoder", dict(mode="max-autotune", dynamic=False)),
402
- ("sam_mask_decoder", dict(mode="max-autotune", dynamic=False)),
403
- ]
404
- self._original_forwards: Dict[str, Any] = {}
405
- compiled = []
406
- for attr, kwargs in components:
407
- module = getattr(vp, attr, None)
408
- if module is None:
409
- continue
410
- try:
411
- self._original_forwards[attr] = module.forward
412
- module.forward = torch.compile(module.forward, **kwargs)
413
- compiled.append(attr)
414
- except Exception as e:
415
- logging.warning("torch.compile wrapping failed for %s: %s", attr, e)
416
- if compiled:
417
- logging.info("torch.compile applied to: %s", ", ".join(compiled))
418
- self._torch_compiled = True
419
- else:
420
- logging.info("torch.compile not available, using eager mode.")
421
- self._torch_compiled = False
422
-
423
- def _revert_torch_compile(self):
424
- """Revert compiled forwards back to eager originals."""
425
- vp = self._video_predictor
426
- for attr, orig_fwd in self._original_forwards.items():
427
- module = getattr(vp, attr, None)
428
- if module is not None:
429
- module.forward = orig_fwd
430
- self._original_forwards.clear()
431
- self._torch_compiled = False
432
- logging.warning("Reverted torch.compile — falling back to eager mode.")
433
-
434
  # -- Single-frame interface (Segmenter.predict) -------------------------
435
 
436
  def predict(
@@ -613,32 +560,9 @@ class GroundedSAM2Segmenter(Segmenter):
613
  class_names_list: List[str] = []
614
  cursor = 0
615
 
616
- # Wrap generator to catch torch.compile runtime failures on first frame.
617
- # If inductor/triton fails, revert to eager and restart propagation.
618
- _generator = self._video_predictor.propagate_in_video(
619
  inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
620
- )
621
- if getattr(self, '_torch_compiled', False) and not getattr(self, '_compile_verified', False):
622
- try:
623
- _first = next(_generator)
624
- except Exception as e:
625
- logging.warning("torch.compile runtime error, reverting to eager: %s", e)
626
- self._revert_torch_compile()
627
- # Re-init propagation with eager forwards
628
- self._video_predictor.reset_state(inference_state)
629
- for obj_id, obj_info in mask_dict.labels.items():
630
- self._video_predictor.add_new_mask(
631
- inference_state, start_idx, obj_id, obj_info.mask,
632
- )
633
- _generator = self._video_predictor.propagate_in_video(
634
- inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
635
- )
636
- _first = next(_generator)
637
- import itertools
638
- self._compile_verified = True
639
- _generator = itertools.chain([_first], _generator)
640
-
641
- for out_frame_idx, out_obj_ids, out_mask_logits in _generator:
642
  bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
643
  n = bool_masks.shape[0]
644
 
 
363
 
364
  # Video predictor (for process_video)
365
  self._video_predictor = build_sam2_video_predictor_hf(
366
+ hf_id, device=self.device
367
  )
368
 
 
 
 
 
 
 
 
369
  # Image predictor (for single-frame predict)
370
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
371
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
 
378
  self._models_loaded = True
379
  logging.info("Grounded-SAM-2 models loaded successfully.")
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  # -- Single-frame interface (Segmenter.predict) -------------------------
382
 
383
  def predict(
 
560
  class_names_list: List[str] = []
561
  cursor = 0
562
 
563
+ for out_frame_idx, out_obj_ids, out_mask_logits in self._video_predictor.propagate_in_video(
 
 
564
  inference_state, max_frame_num_to_track=step, start_frame_idx=start_idx,
565
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  bool_masks = (out_mask_logits[:, 0] > 0.0) # (N, H, W) GPU async
567
  n = bool_masks.shape[0]
568