ShihaoW's picture
Document LA Flash implementation details
c32291c verified
|
Raw
History Blame
3.35 kB

LA Flash Utils

This folder contains the sparse attention utilities used by LA_FLASH_ATTN=la_flash. The release path is implemented with FlashAttention varlen over LocateAnything range plans. It does not include or build a local C++/CUDA extension.

Features

  • Supports batched LocateAnything hybrid MTP inference on A100, RTX 4090, and H100.
  • Consumes Magi-style q_ranges, k_ranges, segment_offsets, and attn_type_map plans generated by batch_utils.hybrid_runtime.
  • Uses FlashAttention varlen for packed causal/full plans.
  • Packs LocateAnything MTP full-window key segments before calling FlashAttention, avoiding dense [B,H,Q,K] masks.
  • Supports log-sum-exp merging for compatible non-packed multi-segment plans.

Attention Types

The release path intentionally supports only FlashAttention-compatible plan types:

Value Meaning
0 Full attention over the listed key segment or packed key segments.
1 Bottom-right causal attention.

How It Works

batch_utils.hybrid_runtime builds sparse range plans for the text decoder. Each plan describes which query token intervals attend to which key/value token intervals. kernel_utils.range_attention executes those plans with FlashAttention instead of materializing dense SDPA masks.

The runtime follows three paths:

  • Packed simple plans: when each query range maps to one contiguous key/value range, LA Flash flattens the selected ranges, builds FlashAttention cu_seqlens_q / cu_seqlens_k, and calls flash_attn_varlen_func directly.
  • Packed MTP full-window plans: for hybrid MTP decode, multiple full key/value windows for the same query block are concatenated into one packed key/value sequence before the FlashAttention call. This keeps the sparse memory profile without constructing a [B,H,Q,K] attention mask.
  • Compatible multi-segment plans: when a query range attends to multiple segments that cannot be packed as one sequence, each segment is evaluated with FlashAttention and the partial outputs are merged with the standard log-sum-exp softmax composition.

The output tensor shape and dtype match the decoder attention output expected by the model. This path is inference-oriented and depends on FlashAttention's forward kernels; it is not a custom autograd training backend.

Runtime Knobs

Variable Default Meaning
LA_FLASH_ATTN sdpa Set to la_flash to enable this backend through batch_utils.
LA_FLASH_FASTPATH auto Use FlashAttention varlen for packed simple plans.
LA_FLASH_SEGMENT_FASTPATH auto Use FlashAttention varlen for multi-segment sparse plans. Full segments are packed first; other compatible segments use LSE merging.
LA_FLASH_PLAN_STATS 0 Record sparse plan statistics in inference summaries.

Notes

Dense prefill and stock worker-style generation should keep LA_FLASH_DENSE_BACKEND=sdpa; LA Flash is used for sparse range plans produced by batch_utils.

This package is for inference and evaluation. Training remains on the MagiAttention backend; the batched sparse-plan decode runtime does not support the labels training path.

Source Layout

  • range_attention.py: FlashAttention varlen dispatch, sparse KV packing, LSE merge fallback, and availability checks.