alex
commited on
Commit
·
ae726e0
1
Parent(s):
a8172af
simplify
Browse files
sam2/modeling/sam/transformer.py
CHANGED
|
@@ -27,8 +27,10 @@ def _can_use_flash_attn(q: torch.Tensor) -> bool:
|
|
| 27 |
# FlashAttention works on CUDA with fp16/bf16 and (usually) Ampere+ GPUs
|
| 28 |
if not q.is_cuda:
|
| 29 |
return False
|
| 30 |
-
major, _ = torch.cuda.get_device_capability(q.device)
|
| 31 |
-
return q.dtype in (torch.float16, torch.bfloat16) and major >= 8 # A100/RTX30+ typically
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
class TwoWayTransformer(nn.Module):
|
|
|
|
| 27 |
# FlashAttention works on CUDA with fp16/bf16 and (usually) Ampere+ GPUs
|
| 28 |
if not q.is_cuda:
|
| 29 |
return False
|
| 30 |
+
# major, _ = torch.cuda.get_device_capability(q.device)
|
| 31 |
+
# return q.dtype in (torch.float16, torch.bfloat16) and major >= 8 # A100/RTX30+ typically
|
| 32 |
+
|
| 33 |
+
return True
|
| 34 |
|
| 35 |
|
| 36 |
class TwoWayTransformer(nn.Module):
|