Kernels

Add flash_attn_func + harden MPS dispatch for transformers compatibility

#3
by ArthurZ HF Staff - opened

Add a dense flash_attn_func API and harden flash_attn_varlen_func for MPS:

  • flash_attn_func: dense (B, S, H, D) Q/K/V β†’ (B, S, H, D) out, synthesizes
    cu_seqlens and dispatches to varlen. Lets transformers (and any FA-style caller)
    use the kernel without writing padding-free boilerplate.
  • .contiguous() Q/K/V β€” decode-time cached K/V is a transposed view; without
    this the kernel reads garbage and produces nonsense tokens.
  • Auto-cast cu_seqlens_* to int32 and defensively clone cu_seqlens_k when it
    aliases cu_seqlens_q (Metal treats them as the same argument otherwise).
  • torch.mps.synchronize() after prep ops scheduled anything β€” otherwise the
    open compute encoder trips command encoder is already encoding to this command buffer when the kernel launches its own.
  • Build dense cu_seqlens on CPU then copy to device (the arange * seq_len
    on-device multiply is the same encoder trigger as above).

Verified vs SDPA on MPS/fp16:

  • prefill (Q=K=20, B=2, GQA): max abs diff 6.1e-5
  • decode (Q=1, K=21, non-contig K/V): max abs diff 1.5e-5
  • aliased cu_seqlens: bit-exact

End-to-end with Qwen2.5-0.5B-Instruct greedy generation on MPS: output token ids
match SDPA exactly with no caller-side workarounds.

Companion transformers PR: https://github.com/huggingface/transformers/pull/45974
(transformers fallback is still useful for varlen-only third-party kernels, but
with this change it is no longer required for metal-flash-sdpa specifically.)

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment