Zhen Ye Claude Opus 4.6 commited on
Commit
63684e4
·
1 Parent(s): b17bd6d

fix: manual torch.compile with graceful fallback for SAM2 components

Browse files

Replace vos_optimized=True (caused 500 error) with manual per-component
torch.compile matching Facebook's official recipe. Compiles image_encoder,
memory_encoder, memory_attention (dynamic=True), sam_prompt_encoder, and
sam_mask_decoder with max-autotune. Falls back to eager mode silently if
PyTorch version or CUDA compiler is unavailable.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. models/segmenters/grounded_sam2.py +37 -4
models/segmenters/grounded_sam2.py CHANGED
@@ -362,13 +362,17 @@ class GroundedSAM2Segmenter(Segmenter):
362
  from sam2.sam2_image_predictor import SAM2ImagePredictor
363
 
364
  # Video predictor (for process_video)
365
- # vos_optimized=True enables SAM2VideoPredictorVOS which compiles
366
- # image_encoder, memory_encoder, memory_attention, sam_prompt_encoder,
367
- # and sam_mask_decoder with torch.compile(mode="max-autotune").
368
  self._video_predictor = build_sam2_video_predictor_hf(
369
- hf_id, device=self.device, vos_optimized=True,
370
  )
371
 
 
 
 
 
 
 
 
372
  # Image predictor (for single-frame predict)
373
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
374
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
@@ -381,6 +385,35 @@ class GroundedSAM2Segmenter(Segmenter):
381
  self._models_loaded = True
382
  logging.info("Grounded-SAM-2 models loaded successfully.")
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # -- Single-frame interface (Segmenter.predict) -------------------------
385
 
386
  def predict(
 
362
  from sam2.sam2_image_predictor import SAM2ImagePredictor
363
 
364
  # Video predictor (for process_video)
 
 
 
365
  self._video_predictor = build_sam2_video_predictor_hf(
366
+ hf_id, device=self.device,
367
  )
368
 
369
+ # torch.compile individual components for fused kernels.
370
+ # memory_attention uses dynamic=True (variable memory token count).
371
+ # Wrapped in try/except: falls back to eager if PyTorch < 2.5 or
372
+ # CUDA compiler not available.
373
+ if self.device.startswith("cuda"):
374
+ self._apply_torch_compile()
375
+
376
  # Image predictor (for single-frame predict)
377
  sam2_image_model = build_sam2_hf(hf_id, device=self.device)
378
  self._image_predictor = SAM2ImagePredictor(sam2_image_model)
 
385
  self._models_loaded = True
386
  logging.info("Grounded-SAM-2 models loaded successfully.")
387
 
388
+ def _apply_torch_compile(self):
389
+ """Compile SAM2 sub-modules with torch.compile (max-autotune).
390
+
391
+ Compiles 5 components matching Facebook's official VOS recipe.
392
+ Falls back silently to eager mode on any compilation error.
393
+ """
394
+ vp = self._video_predictor
395
+ components = [
396
+ ("image_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
397
+ ("memory_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
398
+ ("memory_attention", dict(mode="max-autotune", fullgraph=True, dynamic=True)),
399
+ ("sam_prompt_encoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
400
+ ("sam_mask_decoder", dict(mode="max-autotune", fullgraph=True, dynamic=False)),
401
+ ]
402
+ compiled = []
403
+ for attr, kwargs in components:
404
+ module = getattr(vp, attr, None)
405
+ if module is None:
406
+ continue
407
+ try:
408
+ module.forward = torch.compile(module.forward, **kwargs)
409
+ compiled.append(attr)
410
+ except Exception as e:
411
+ logging.warning("torch.compile failed for %s: %s", attr, e)
412
+ if compiled:
413
+ logging.info("torch.compile applied to: %s", ", ".join(compiled))
414
+ else:
415
+ logging.info("torch.compile not available, using eager mode.")
416
+
417
  # -- Single-frame interface (Segmenter.predict) -------------------------
418
 
419
  def predict(