alexnasa commited on
Commit
629614d
·
verified ·
1 Parent(s): 6e73bb7

Update sam2/modeling/sam/transformer.py

Browse files
Files changed (1) hide show
  1. sam2/modeling/sam/transformer.py +1 -0
sam2/modeling/sam/transformer.py CHANGED
@@ -16,6 +16,7 @@ from torch import Tensor, nn
16
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
  from sam2.modeling.sam2_utils import MLP
18
  from sam2.utils.misc import get_sdp_backends
 
19
 
20
  warnings.simplefilter(action="ignore", category=FutureWarning)
21
  # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
 
16
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
  from sam2.modeling.sam2_utils import MLP
18
  from sam2.utils.misc import get_sdp_backends
19
+ import flash_attn_interface
20
 
21
  warnings.simplefilter(action="ignore", category=FutureWarning)
22
  # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()