Instructions to use kernels-community/metal-flash-sdpa with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/metal-flash-sdpa with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/metal-flash-sdpa") - Notebooks
- Google Colab
- Kaggle
Add flash_attn_func + harden MPS dispatch for transformers compatibility
#3
by ArthurZ HF Staff - opened
torch-ext/metal_flash_sdpa/_custom_ops.py
CHANGED
|
@@ -20,7 +20,7 @@ def flash_attention_varlen(
|
|
| 20 |
) -> None:
|
| 21 |
"""
|
| 22 |
Flash Attention with variable-length sequences.
|
| 23 |
-
|
| 24 |
Args:
|
| 25 |
out: Output tensor of shape [total_q_tokens, num_heads, head_dim]
|
| 26 |
query: Query tensor of shape [total_q_tokens, num_heads, head_dim]
|
|
@@ -33,7 +33,7 @@ def flash_attention_varlen(
|
|
| 33 |
do_causal: Whether to apply causal masking
|
| 34 |
scale: Attention scale factor (default: 1/sqrt(head_dim))
|
| 35 |
softcapping: Softcapping value (default: 1.0, must be 1.0 for this implementation)
|
| 36 |
-
|
| 37 |
Note:
|
| 38 |
- cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
|
| 39 |
- Supported head dimensions: 32, 64, 72, 80, 96, 128
|
|
@@ -41,7 +41,7 @@ def flash_attention_varlen(
|
|
| 41 |
"""
|
| 42 |
if scale is None:
|
| 43 |
scale = query.shape[-1] ** -0.5
|
| 44 |
-
|
| 45 |
ops.flash_attention_varlen(
|
| 46 |
out,
|
| 47 |
query,
|
|
@@ -56,6 +56,45 @@ def flash_attention_varlen(
|
|
| 56 |
softcapping,
|
| 57 |
)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def flash_attn_varlen_func(
|
| 60 |
q: torch.Tensor,
|
| 61 |
k: torch.Tensor,
|
|
@@ -74,7 +113,7 @@ def flash_attn_varlen_func(
|
|
| 74 |
) -> torch.Tensor:
|
| 75 |
"""
|
| 76 |
Flash Attention function with API compatible with the original Flash Attention.
|
| 77 |
-
|
| 78 |
Note: This implementation does not support:
|
| 79 |
- dropout
|
| 80 |
- window attention
|
|
@@ -89,10 +128,18 @@ def flash_attn_varlen_func(
|
|
| 89 |
raise NotImplementedError("ALiBi is not supported")
|
| 90 |
if return_attn_probs:
|
| 91 |
raise NotImplementedError("Returning attention probabilities is not supported")
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
# Create output tensor
|
| 94 |
out = torch.empty_like(q)
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Call the kernel
|
| 97 |
flash_attention_varlen(
|
| 98 |
out=out,
|
|
@@ -107,7 +154,7 @@ def flash_attn_varlen_func(
|
|
| 107 |
scale=softmax_scale,
|
| 108 |
softcapping=1.0,
|
| 109 |
)
|
| 110 |
-
|
| 111 |
return out
|
| 112 |
|
| 113 |
|
|
|
|
| 20 |
) -> None:
|
| 21 |
"""
|
| 22 |
Flash Attention with variable-length sequences.
|
| 23 |
+
|
| 24 |
Args:
|
| 25 |
out: Output tensor of shape [total_q_tokens, num_heads, head_dim]
|
| 26 |
query: Query tensor of shape [total_q_tokens, num_heads, head_dim]
|
|
|
|
| 33 |
do_causal: Whether to apply causal masking
|
| 34 |
scale: Attention scale factor (default: 1/sqrt(head_dim))
|
| 35 |
softcapping: Softcapping value (default: 1.0, must be 1.0 for this implementation)
|
| 36 |
+
|
| 37 |
Note:
|
| 38 |
- cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
|
| 39 |
- Supported head dimensions: 32, 64, 72, 80, 96, 128
|
|
|
|
| 41 |
"""
|
| 42 |
if scale is None:
|
| 43 |
scale = query.shape[-1] ** -0.5
|
| 44 |
+
|
| 45 |
ops.flash_attention_varlen(
|
| 46 |
out,
|
| 47 |
query,
|
|
|
|
| 56 |
softcapping,
|
| 57 |
)
|
| 58 |
|
| 59 |
+
|
| 60 |
+
def _prepare_varlen_inputs(q, k, v, cu_seqlens_q, cu_seqlens_k):
|
| 61 |
+
"""Normalize Q/K/V and cumulative-length tensors before dispatching to the Metal kernel.
|
| 62 |
+
|
| 63 |
+
The kernel reads its inputs as flat strided buffers and expects:
|
| 64 |
+
- contiguous Q/K/V (decode-time cached K/V from PyTorch is typically a transposed view)
|
| 65 |
+
- int32 cumulative sequence lengths (Metal pointer type)
|
| 66 |
+
- distinct buffers for `cu_seqlens_q` and `cu_seqlens_k` (passing the same tensor in both
|
| 67 |
+
slots aliases the same Metal argument and yields incorrect attention scores)
|
| 68 |
+
|
| 69 |
+
Caller-provided ``cu_seqlens_q``/``cu_seqlens_k`` are preserved as-is; this function only
|
| 70 |
+
performs the dtype / contiguity / aliasing fixups the kernel needs. ``cu_seqlens`` are never
|
| 71 |
+
synthesized here — that responsibility stays with the caller (continuous batching, padding-free
|
| 72 |
+
training, etc.).
|
| 73 |
+
"""
|
| 74 |
+
needs_sync = False
|
| 75 |
+
if not q.is_contiguous():
|
| 76 |
+
q = q.contiguous()
|
| 77 |
+
needs_sync = needs_sync or q.is_mps
|
| 78 |
+
if not k.is_contiguous():
|
| 79 |
+
k = k.contiguous()
|
| 80 |
+
needs_sync = needs_sync or k.is_mps
|
| 81 |
+
if not v.is_contiguous():
|
| 82 |
+
v = v.contiguous()
|
| 83 |
+
needs_sync = needs_sync or v.is_mps
|
| 84 |
+
if cu_seqlens_q.dtype != torch.int32:
|
| 85 |
+
cu_seqlens_q = cu_seqlens_q.to(torch.int32)
|
| 86 |
+
needs_sync = needs_sync or cu_seqlens_q.is_mps
|
| 87 |
+
if cu_seqlens_k.dtype != torch.int32:
|
| 88 |
+
cu_seqlens_k = cu_seqlens_k.to(torch.int32)
|
| 89 |
+
needs_sync = needs_sync or cu_seqlens_k.is_mps
|
| 90 |
+
if cu_seqlens_k.data_ptr() == cu_seqlens_q.data_ptr():
|
| 91 |
+
cu_seqlens_k = cu_seqlens_k.clone()
|
| 92 |
+
needs_sync = needs_sync or cu_seqlens_k.is_mps
|
| 93 |
+
if needs_sync and torch.backends.mps.is_available():
|
| 94 |
+
torch.mps.synchronize()
|
| 95 |
+
return q, k, v, cu_seqlens_q, cu_seqlens_k
|
| 96 |
+
|
| 97 |
+
|
| 98 |
def flash_attn_varlen_func(
|
| 99 |
q: torch.Tensor,
|
| 100 |
k: torch.Tensor,
|
|
|
|
| 113 |
) -> torch.Tensor:
|
| 114 |
"""
|
| 115 |
Flash Attention function with API compatible with the original Flash Attention.
|
| 116 |
+
|
| 117 |
Note: This implementation does not support:
|
| 118 |
- dropout
|
| 119 |
- window attention
|
|
|
|
| 128 |
raise NotImplementedError("ALiBi is not supported")
|
| 129 |
if return_attn_probs:
|
| 130 |
raise NotImplementedError("Returning attention probabilities is not supported")
|
| 131 |
+
|
| 132 |
+
q, k, v, cu_seqlens_q, cu_seqlens_k = _prepare_varlen_inputs(q, k, v, cu_seqlens_q, cu_seqlens_k)
|
| 133 |
+
|
| 134 |
# Create output tensor
|
| 135 |
out = torch.empty_like(q)
|
| 136 |
+
|
| 137 |
+
# Flush any pending Metal encoder before launching the custom kernel; without this, a
|
| 138 |
+
# preceding op (e.g. `.contiguous()` on a transposed cache view) leaves an encoder open
|
| 139 |
+
# and the kernel trips ``"A command encoder is already encoding to this command buffer"``.
|
| 140 |
+
if q.is_mps and torch.backends.mps.is_available():
|
| 141 |
+
torch.mps.synchronize()
|
| 142 |
+
|
| 143 |
# Call the kernel
|
| 144 |
flash_attention_varlen(
|
| 145 |
out=out,
|
|
|
|
| 154 |
scale=softmax_scale,
|
| 155 |
softcapping=1.0,
|
| 156 |
)
|
| 157 |
+
|
| 158 |
return out
|
| 159 |
|
| 160 |
|