Kernels

Add flash_attn_func + harden MPS dispatch for transformers compatibility

#3
by ArthurZ HF Staff - opened
torch-ext/metal_flash_sdpa/_custom_ops.py CHANGED
@@ -20,7 +20,7 @@ def flash_attention_varlen(
20
  ) -> None:
21
  """
22
  Flash Attention with variable-length sequences.
23
-
24
  Args:
25
  out: Output tensor of shape [total_q_tokens, num_heads, head_dim]
26
  query: Query tensor of shape [total_q_tokens, num_heads, head_dim]
@@ -33,7 +33,7 @@ def flash_attention_varlen(
33
  do_causal: Whether to apply causal masking
34
  scale: Attention scale factor (default: 1/sqrt(head_dim))
35
  softcapping: Softcapping value (default: 1.0, must be 1.0 for this implementation)
36
-
37
  Note:
38
  - cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
39
  - Supported head dimensions: 32, 64, 72, 80, 96, 128
@@ -41,7 +41,7 @@ def flash_attention_varlen(
41
  """
42
  if scale is None:
43
  scale = query.shape[-1] ** -0.5
44
-
45
  ops.flash_attention_varlen(
46
  out,
47
  query,
@@ -56,6 +56,45 @@ def flash_attention_varlen(
56
  softcapping,
57
  )
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def flash_attn_varlen_func(
60
  q: torch.Tensor,
61
  k: torch.Tensor,
@@ -74,7 +113,7 @@ def flash_attn_varlen_func(
74
  ) -> torch.Tensor:
75
  """
76
  Flash Attention function with API compatible with the original Flash Attention.
77
-
78
  Note: This implementation does not support:
79
  - dropout
80
  - window attention
@@ -89,10 +128,18 @@ def flash_attn_varlen_func(
89
  raise NotImplementedError("ALiBi is not supported")
90
  if return_attn_probs:
91
  raise NotImplementedError("Returning attention probabilities is not supported")
92
-
 
 
93
  # Create output tensor
94
  out = torch.empty_like(q)
95
-
 
 
 
 
 
 
96
  # Call the kernel
97
  flash_attention_varlen(
98
  out=out,
@@ -107,7 +154,7 @@ def flash_attn_varlen_func(
107
  scale=softmax_scale,
108
  softcapping=1.0,
109
  )
110
-
111
  return out
112
 
113
 
 
20
  ) -> None:
21
  """
22
  Flash Attention with variable-length sequences.
23
+
24
  Args:
25
  out: Output tensor of shape [total_q_tokens, num_heads, head_dim]
26
  query: Query tensor of shape [total_q_tokens, num_heads, head_dim]
 
33
  do_causal: Whether to apply causal masking
34
  scale: Attention scale factor (default: 1/sqrt(head_dim))
35
  softcapping: Softcapping value (default: 1.0, must be 1.0 for this implementation)
36
+
37
  Note:
38
  - cu_seqlens_q and cu_seqlens_k must have dtype torch.int32 for Metal compatibility
39
  - Supported head dimensions: 32, 64, 72, 80, 96, 128
 
41
  """
42
  if scale is None:
43
  scale = query.shape[-1] ** -0.5
44
+
45
  ops.flash_attention_varlen(
46
  out,
47
  query,
 
56
  softcapping,
57
  )
58
 
59
+
60
+ def _prepare_varlen_inputs(q, k, v, cu_seqlens_q, cu_seqlens_k):
61
+ """Normalize Q/K/V and cumulative-length tensors before dispatching to the Metal kernel.
62
+
63
+ The kernel reads its inputs as flat strided buffers and expects:
64
+ - contiguous Q/K/V (decode-time cached K/V from PyTorch is typically a transposed view)
65
+ - int32 cumulative sequence lengths (Metal pointer type)
66
+ - distinct buffers for `cu_seqlens_q` and `cu_seqlens_k` (passing the same tensor in both
67
+ slots aliases the same Metal argument and yields incorrect attention scores)
68
+
69
+ Caller-provided ``cu_seqlens_q``/``cu_seqlens_k`` are preserved as-is; this function only
70
+ performs the dtype / contiguity / aliasing fixups the kernel needs. ``cu_seqlens`` are never
71
+ synthesized here — that responsibility stays with the caller (continuous batching, padding-free
72
+ training, etc.).
73
+ """
74
+ needs_sync = False
75
+ if not q.is_contiguous():
76
+ q = q.contiguous()
77
+ needs_sync = needs_sync or q.is_mps
78
+ if not k.is_contiguous():
79
+ k = k.contiguous()
80
+ needs_sync = needs_sync or k.is_mps
81
+ if not v.is_contiguous():
82
+ v = v.contiguous()
83
+ needs_sync = needs_sync or v.is_mps
84
+ if cu_seqlens_q.dtype != torch.int32:
85
+ cu_seqlens_q = cu_seqlens_q.to(torch.int32)
86
+ needs_sync = needs_sync or cu_seqlens_q.is_mps
87
+ if cu_seqlens_k.dtype != torch.int32:
88
+ cu_seqlens_k = cu_seqlens_k.to(torch.int32)
89
+ needs_sync = needs_sync or cu_seqlens_k.is_mps
90
+ if cu_seqlens_k.data_ptr() == cu_seqlens_q.data_ptr():
91
+ cu_seqlens_k = cu_seqlens_k.clone()
92
+ needs_sync = needs_sync or cu_seqlens_k.is_mps
93
+ if needs_sync and torch.backends.mps.is_available():
94
+ torch.mps.synchronize()
95
+ return q, k, v, cu_seqlens_q, cu_seqlens_k
96
+
97
+
98
  def flash_attn_varlen_func(
99
  q: torch.Tensor,
100
  k: torch.Tensor,
 
113
  ) -> torch.Tensor:
114
  """
115
  Flash Attention function with API compatible with the original Flash Attention.
116
+
117
  Note: This implementation does not support:
118
  - dropout
119
  - window attention
 
128
  raise NotImplementedError("ALiBi is not supported")
129
  if return_attn_probs:
130
  raise NotImplementedError("Returning attention probabilities is not supported")
131
+
132
+ q, k, v, cu_seqlens_q, cu_seqlens_k = _prepare_varlen_inputs(q, k, v, cu_seqlens_q, cu_seqlens_k)
133
+
134
  # Create output tensor
135
  out = torch.empty_like(q)
136
+
137
+ # Flush any pending Metal encoder before launching the custom kernel; without this, a
138
+ # preceding op (e.g. `.contiguous()` on a transposed cache view) leaves an encoder open
139
+ # and the kernel trips ``"A command encoder is already encoding to this command buffer"``.
140
+ if q.is_mps and torch.backends.mps.is_available():
141
+ torch.mps.synchronize()
142
+
143
  # Call the kernel
144
  flash_attention_varlen(
145
  out=out,
 
154
  scale=softmax_scale,
155
  softcapping=1.0,
156
  )
157
+
158
  return out
159
 
160