alex commited on
Commit
ae726e0
·
1 Parent(s): a8172af
Files changed (1) hide show
  1. sam2/modeling/sam/transformer.py +4 -2
sam2/modeling/sam/transformer.py CHANGED
@@ -27,8 +27,10 @@ def _can_use_flash_attn(q: torch.Tensor) -> bool:
27
  # FlashAttention works on CUDA with fp16/bf16 and (usually) Ampere+ GPUs
28
  if not q.is_cuda:
29
  return False
30
- major, _ = torch.cuda.get_device_capability(q.device)
31
- return q.dtype in (torch.float16, torch.bfloat16) and major >= 8 # A100/RTX30+ typically
 
 
32
 
33
 
34
  class TwoWayTransformer(nn.Module):
 
27
  # FlashAttention works on CUDA with fp16/bf16 and (usually) Ampere+ GPUs
28
  if not q.is_cuda:
29
  return False
30
+ # major, _ = torch.cuda.get_device_capability(q.device)
31
+ # return q.dtype in (torch.float16, torch.bfloat16) and major >= 8 # A100/RTX30+ typically
32
+
33
+ return True
34
 
35
 
36
  class TwoWayTransformer(nn.Module):