ShihaoW's picture
Document LA Flash implementation details
c32291c verified
|
Raw
History Blame
3.35 kB
# LA Flash Utils
This folder contains the sparse attention utilities used by
`LA_FLASH_ATTN=la_flash`. The release path is implemented with
FlashAttention varlen over LocateAnything range plans. It does not include or
build a local C++/CUDA extension.
## Features
- Supports batched LocateAnything hybrid MTP inference on A100, RTX 4090, and H100.
- Consumes Magi-style `q_ranges`, `k_ranges`, `segment_offsets`, and
`attn_type_map` plans generated by `batch_utils.hybrid_runtime`.
- Uses FlashAttention varlen for packed causal/full plans.
- Packs LocateAnything MTP full-window key segments before calling
FlashAttention, avoiding dense `[B,H,Q,K]` masks.
- Supports log-sum-exp merging for compatible non-packed multi-segment plans.
## Attention Types
The release path intentionally supports only FlashAttention-compatible plan
types:
| Value | Meaning |
| --- | --- |
| `0` | Full attention over the listed key segment or packed key segments. |
| `1` | Bottom-right causal attention. |
## How It Works
`batch_utils.hybrid_runtime` builds sparse range plans for the text decoder.
Each plan describes which query token intervals attend to which key/value token
intervals. `kernel_utils.range_attention` executes those plans with
FlashAttention instead of materializing dense SDPA masks.
The runtime follows three paths:
- **Packed simple plans:** when each query range maps to one contiguous
key/value range, LA Flash flattens the selected ranges, builds FlashAttention
`cu_seqlens_q` / `cu_seqlens_k`, and calls `flash_attn_varlen_func` directly.
- **Packed MTP full-window plans:** for hybrid MTP decode, multiple full
key/value windows for the same query block are concatenated into one packed
key/value sequence before the FlashAttention call. This keeps the sparse
memory profile without constructing a `[B,H,Q,K]` attention mask.
- **Compatible multi-segment plans:** when a query range attends to multiple
segments that cannot be packed as one sequence, each segment is evaluated with
FlashAttention and the partial outputs are merged with the standard
log-sum-exp softmax composition.
The output tensor shape and dtype match the decoder attention output expected
by the model. This path is inference-oriented and depends on FlashAttention's
forward kernels; it is not a custom autograd training backend.
## Runtime Knobs
| Variable | Default | Meaning |
| --- | --- | --- |
| `LA_FLASH_ATTN` | `sdpa` | Set to `la_flash` to enable this backend through `batch_utils`. |
| `LA_FLASH_FASTPATH` | `auto` | Use FlashAttention varlen for packed simple plans. |
| `LA_FLASH_SEGMENT_FASTPATH` | `auto` | Use FlashAttention varlen for multi-segment sparse plans. Full segments are packed first; other compatible segments use LSE merging. |
| `LA_FLASH_PLAN_STATS` | `0` | Record sparse plan statistics in inference summaries. |
## Notes
Dense prefill and stock worker-style generation should keep
`LA_FLASH_DENSE_BACKEND=sdpa`; LA Flash is used for sparse range plans
produced by `batch_utils`.
This package is for inference and evaluation. Training remains on the
MagiAttention backend; the batched sparse-plan decode runtime does not support
the `labels` training path.
## Source Layout
- `range_attention.py`: FlashAttention varlen dispatch, sparse KV packing, LSE
merge fallback, and availability checks.