# LA Flash Utils This folder contains the sparse attention utilities used by `LA_FLASH_ATTN=la_flash`. The release path is implemented with FlashAttention varlen over LocateAnything range plans. It does not include or build a local C++/CUDA extension. ## Features - Supports batched LocateAnything hybrid MTP inference on A100, RTX 4090, and H100. - Consumes Magi-style `q_ranges`, `k_ranges`, `segment_offsets`, and `attn_type_map` plans generated by `batch_utils.hybrid_runtime`. - Uses FlashAttention varlen for packed causal/full plans. - Packs LocateAnything MTP full-window key segments before calling FlashAttention, avoiding dense `[B,H,Q,K]` masks. - Supports log-sum-exp merging for compatible non-packed multi-segment plans. ## Attention Types The release path intentionally supports only FlashAttention-compatible plan types: | Value | Meaning | | --- | --- | | `0` | Full attention over the listed key segment or packed key segments. | | `1` | Bottom-right causal attention. | ## How It Works `batch_utils.hybrid_runtime` builds sparse range plans for the text decoder. Each plan describes which query token intervals attend to which key/value token intervals. `kernel_utils.range_attention` executes those plans with FlashAttention instead of materializing dense SDPA masks. The runtime follows three paths: - **Packed simple plans:** when each query range maps to one contiguous key/value range, LA Flash flattens the selected ranges, builds FlashAttention `cu_seqlens_q` / `cu_seqlens_k`, and calls `flash_attn_varlen_func` directly. - **Packed MTP full-window plans:** for hybrid MTP decode, multiple full key/value windows for the same query block are concatenated into one packed key/value sequence before the FlashAttention call. This keeps the sparse memory profile without constructing a `[B,H,Q,K]` attention mask. - **Compatible multi-segment plans:** when a query range attends to multiple segments that cannot be packed as one sequence, each segment is evaluated with FlashAttention and the partial outputs are merged with the standard log-sum-exp softmax composition. The output tensor shape and dtype match the decoder attention output expected by the model. This path is inference-oriented and depends on FlashAttention's forward kernels; it is not a custom autograd training backend. ## Runtime Knobs | Variable | Default | Meaning | | --- | --- | --- | | `LA_FLASH_ATTN` | `sdpa` | Set to `la_flash` to enable this backend through `batch_utils`. | | `LA_FLASH_FASTPATH` | `auto` | Use FlashAttention varlen for packed simple plans. | | `LA_FLASH_SEGMENT_FASTPATH` | `auto` | Use FlashAttention varlen for multi-segment sparse plans. Full segments are packed first; other compatible segments use LSE merging. | | `LA_FLASH_PLAN_STATS` | `0` | Record sparse plan statistics in inference summaries. | ## Notes Dense prefill and stock worker-style generation should keep `LA_FLASH_DENSE_BACKEND=sdpa`; LA Flash is used for sparse range plans produced by `batch_utils`. This package is for inference and evaluation. Training remains on the MagiAttention backend; the batched sparse-plan decode runtime does not support the `labels` training path. ## Source Layout - `range_attention.py`: FlashAttention varlen dispatch, sparse KV packing, LSE merge fallback, and availability checks.