Add flash_attn_func + harden MPS dispatch for transformers compatibility
#3
by ArthurZ HF Staff - opened
Add a dense flash_attn_func API and harden flash_attn_varlen_func for MPS:
flash_attn_func: dense(B, S, H, D)Q/K/V β(B, S, H, D)out, synthesizescu_seqlensand dispatches to varlen. Letstransformers(and any FA-style caller)
use the kernel without writing padding-free boilerplate..contiguous()Q/K/V β decode-time cached K/V is a transposed view; without
this the kernel reads garbage and produces nonsense tokens.- Auto-cast
cu_seqlens_*toint32and defensively clonecu_seqlens_kwhen it
aliasescu_seqlens_q(Metal treats them as the same argument otherwise). torch.mps.synchronize()after prep ops scheduled anything β otherwise the
open compute encoder tripscommand encoder is already encoding to this command bufferwhen the kernel launches its own.- Build dense
cu_seqlenson CPU then copy to device (thearange * seq_len
on-device multiply is the same encoder trigger as above).
Verified vs SDPA on MPS/fp16:
- prefill (Q=K=20, B=2, GQA): max abs diff 6.1e-5
- decode (Q=1, K=21, non-contig K/V): max abs diff 1.5e-5
- aliased cu_seqlens: bit-exact
End-to-end with Qwen2.5-0.5B-Instruct greedy generation on MPS: output token ids
match SDPA exactly with no caller-side workarounds.
Companion transformers PR: https://github.com/huggingface/transformers/pull/45974
(transformers fallback is still useful for varlen-only third-party kernels, but
with this change it is no longer required for metal-flash-sdpa specifically.)