Update sam2/modeling/sam/transformer.py
Browse files
sam2/modeling/sam/transformer.py
CHANGED
|
@@ -16,6 +16,7 @@ from torch import Tensor, nn
|
|
| 16 |
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 17 |
from sam2.modeling.sam2_utils import MLP
|
| 18 |
from sam2.utils.misc import get_sdp_backends
|
|
|
|
| 19 |
|
| 20 |
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 21 |
# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|
|
|
|
| 16 |
from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
|
| 17 |
from sam2.modeling.sam2_utils import MLP
|
| 18 |
from sam2.utils.misc import get_sdp_backends
|
| 19 |
+
import flash_attn_interface
|
| 20 |
|
| 21 |
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 22 |
# OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
|