LocateAnything-3B / kernel_utils /range_attention.py
ShihaoW's picture
Refine LA Flash plan documentation
fb9f8b1 verified
Raw
History Blame
14.4 kB
"""Sparse LocateAnything attention implemented with FlashAttention varlen.
The public API accepts flattened query/key/value tensors:
q: [total_q, num_q_heads, head_dim]
k: [total_k, num_kv_heads, head_dim]
v: [total_k, num_kv_heads, head_dim]
and a Magi-style range plan:
q_ranges: [num_ranges, 2]
k_ranges: [num_key_segments, 2]
segment_offsets: [num_query_groups + 1]
attn_type_map:
0 = full attention over the listed key segment(s)
1 = bottom-right causal attention
For LocateAnything hybrid MTP decode, batch_utils represents the window as a
causal prefix plus full-attention sparse window segments. This module packs
those visible KV segments and calls FlashAttention varlen, avoiding dense masks.
"""
from __future__ import annotations
import os
from typing import Optional
import torch
_FLASH_ATTN_VARLEN = None
_FLASH_ATTN_ERROR: Optional[BaseException] = None
def _env_enabled(name: str, default: str = "auto") -> bool:
value = os.environ.get(name, default).strip().lower()
return value in {"", "auto", "1", "on", "true", "yes", "force"}
def is_available() -> bool:
try:
_load_flash_attn_varlen()
return True
except Exception:
return False
def _flash_fastpath_enabled() -> bool:
return _env_enabled("LA_FLASH_FASTPATH", "auto")
def _flash_segment_fastpath_enabled() -> bool:
return _env_enabled("LA_FLASH_SEGMENT_FASTPATH", "auto")
def _load_flash_attn_varlen():
global _FLASH_ATTN_VARLEN, _FLASH_ATTN_ERROR
if _FLASH_ATTN_VARLEN is not None:
return _FLASH_ATTN_VARLEN
if _FLASH_ATTN_ERROR is not None:
raise _FLASH_ATTN_ERROR
try:
from flash_attn import flash_attn_varlen_func
_FLASH_ATTN_VARLEN = flash_attn_varlen_func
return _FLASH_ATTN_VARLEN
except BaseException as exc:
_FLASH_ATTN_ERROR = exc
raise
def _coalesce_query_groups(q_ranges, k_ranges, attn_type_map):
"""Group consecutive entries that share the same query span and mask type."""
if q_ranges.numel() == 0:
segment_offsets = torch.zeros((1,), dtype=torch.int32, device=q_ranges.device)
return q_ranges, k_ranges, segment_offsets, attn_type_map, 0, 0
q_cpu = q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
t_cpu = attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous()
grouped_q = []
grouped_t = []
offsets = [0]
max_q_len = 0
last_q = None
last_t = None
for idx, (qr, attn_type) in enumerate(zip(q_cpu.tolist(), t_cpu.tolist())):
key = (int(qr[0]), int(qr[1]))
attn_type = int(attn_type)
if attn_type not in (0, 1):
raise RuntimeError(
"LA Flash path only supports FlashAttention-compatible attn_type 0/1. "
f"Got attn_type={attn_type}; regenerate a type 0/1 range plan."
)
if last_q is None:
grouped_q.append([key[0], key[1]])
grouped_t.append(attn_type)
max_q_len = max(max_q_len, key[1] - key[0])
last_q = key
last_t = attn_type
continue
if key == last_q and attn_type == last_t:
continue
offsets.append(idx)
grouped_q.append([key[0], key[1]])
grouped_t.append(attn_type)
max_q_len = max(max_q_len, key[1] - key[0])
last_q = key
last_t = attn_type
offsets.append(int(q_ranges.shape[0]))
k_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
max_k_len = max((int(end) - int(start) for start, end in k_cpu.tolist()), default=0)
return (
torch.tensor(grouped_q, dtype=torch.int32, device=q_ranges.device).contiguous(),
k_ranges,
torch.tensor(offsets, dtype=torch.int32, device=q_ranges.device).contiguous(),
torch.tensor(grouped_t, dtype=torch.int32, device=q_ranges.device).contiguous(),
int(max_q_len),
int(max_k_len),
)
def _flash_lse_to_tq_h(lse, total_q, q_lengths=None):
if lse is None:
return None
if lse.dim() != 2:
if lse.dim() == 3 and q_lengths is not None and lse.shape[0] == len(q_lengths):
chunks = []
for idx, q_len in enumerate(q_lengths):
q_len = int(q_len)
if lse.shape[1] == 0 or q_len > lse.shape[2]:
return None
chunks.append(lse[idx, :, :q_len].transpose(0, 1).contiguous())
merged = torch.cat(chunks, dim=0).float()
return merged if merged.shape[0] == total_q else None
return None
if lse.shape[0] == total_q:
return lse.float()
if lse.shape[1] == total_q:
return lse.transpose(0, 1).contiguous().float()
return None
def _make_cu_seqlens(lengths, device):
return torch.tensor([0] + list(torch.tensor(lengths).cumsum(0).tolist()), device=device, dtype=torch.int32)
def _try_flash_segment_merge(
q,
k,
v,
k_ranges,
segment_offsets,
group_q_ranges,
group_attn_type_map,
softmax_scale,
):
if not _flash_segment_fastpath_enabled():
return None
if q.dtype not in (torch.float16, torch.bfloat16) or k.dtype != q.dtype or v.dtype != q.dtype:
return None
if group_q_ranges is None or segment_offsets is None or group_attn_type_map is None:
return None
flash_attn_varlen = _load_flash_attn_varlen()
gq_cpu = group_q_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
kr_cpu = k_ranges.detach().to(device="cpu", dtype=torch.int32).contiguous()
seg_cpu = segment_offsets.detach().to(device="cpu", dtype=torch.int32).contiguous()
type_cpu = group_attn_type_map.detach().to(device="cpu", dtype=torch.int32).contiguous()
groups = []
max_segments = 0
for group_idx, (q_start, q_end) in enumerate(gq_cpu.tolist()):
attn_type = int(type_cpu[group_idx].item())
if attn_type not in (0, 1):
return None
seg_start = int(seg_cpu[group_idx].item())
seg_end = int(seg_cpu[group_idx + 1].item())
if seg_end <= seg_start or q_end <= q_start:
return None
segments = kr_cpu[seg_start:seg_end].tolist()
max_segments = max(max_segments, len(segments))
groups.append((int(q_start), int(q_end), attn_type, [(int(a), int(b)) for a, b in segments]))
if not groups or max_segments == 0:
return None
can_pack_full_groups = all(attn_type == 0 or len(segments) == 1 for _, _, attn_type, segments in groups)
if can_pack_full_groups:
merged = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool)
for attn_type in (0, 1):
q_slices = []
k_slices = []
v_slices = []
q_lengths = []
k_lengths = []
targets = []
for q_start, q_end, group_type, segments in groups:
if group_type != attn_type:
continue
q_slices.append(q[q_start:q_end])
if attn_type == 0 and len(segments) > 1:
k_slices.append(torch.cat([k[start:end] for start, end in segments], dim=0))
v_slices.append(torch.cat([v[start:end] for start, end in segments], dim=0))
k_lengths.append(sum(end - start for start, end in segments))
else:
k_start, k_end = segments[0]
k_slices.append(k[k_start:k_end])
v_slices.append(v[k_start:k_end])
k_lengths.append(k_end - k_start)
q_lengths.append(q_end - q_start)
targets.append((q_start, q_end))
if not q_slices:
continue
out_pass = flash_attn_varlen(
torch.cat(q_slices, dim=0).contiguous(),
torch.cat(k_slices, dim=0).contiguous(),
torch.cat(v_slices, dim=0).contiguous(),
_make_cu_seqlens(q_lengths, q.device),
_make_cu_seqlens(k_lengths, q.device),
int(max(q_lengths)),
int(max(k_lengths)),
dropout_p=0.0,
softmax_scale=float(softmax_scale),
causal=bool(attn_type == 1),
)
if isinstance(out_pass, tuple):
out_pass = out_pass[0]
cursor = 0
for q_start, q_end in targets:
q_len = q_end - q_start
merged[q_start:q_end] = out_pass[cursor:cursor + q_len]
covered[q_start:q_end] = True
cursor += q_len
if bool(covered.all().item()):
return merged
merged = torch.zeros((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
merged_lse = torch.full((q.shape[0], q.shape[1]), -float("inf"), device=q.device, dtype=torch.float32)
covered = torch.zeros((q.shape[0],), device=q.device, dtype=torch.bool)
for segment_idx in range(max_segments):
for attn_type in (0, 1):
q_slices = []
k_slices = []
v_slices = []
q_lengths = []
k_lengths = []
targets = []
for q_start, q_end, group_type, segments in groups:
if group_type != attn_type or segment_idx >= len(segments):
continue
k_start, k_end = segments[segment_idx]
if k_end <= k_start:
continue
q_slices.append(q[q_start:q_end])
k_slices.append(k[k_start:k_end])
v_slices.append(v[k_start:k_end])
q_lengths.append(q_end - q_start)
k_lengths.append(k_end - k_start)
targets.append((q_start, q_end))
if not q_slices:
continue
result = flash_attn_varlen(
torch.cat(q_slices, dim=0).contiguous(),
torch.cat(k_slices, dim=0).contiguous(),
torch.cat(v_slices, dim=0).contiguous(),
_make_cu_seqlens(q_lengths, q.device),
_make_cu_seqlens(k_lengths, q.device),
int(max(q_lengths)),
int(max(k_lengths)),
dropout_p=0.0,
softmax_scale=float(softmax_scale),
causal=bool(attn_type == 1),
return_attn_probs=True,
)
if not isinstance(result, tuple) or len(result) < 2:
return None
out_pass = result[0]
lse_pass = _flash_lse_to_tq_h(result[1], out_pass.shape[0], q_lengths)
if lse_pass is None:
return None
cursor = 0
for q_start, q_end in targets:
q_len = q_end - q_start
out_seg = out_pass[cursor:cursor + q_len].float()
lse_seg = lse_pass[cursor:cursor + q_len]
old_lse = merged_lse[q_start:q_end]
new_lse = torch.maximum(old_lse, lse_seg)
old_w = torch.exp(old_lse - new_lse)
seg_w = torch.exp(lse_seg - new_lse)
denom = (old_w + seg_w).clamp_min(1e-20)
merged[q_start:q_end] = (
merged[q_start:q_end] * old_w.unsqueeze(-1)
+ out_seg * seg_w.unsqueeze(-1)
) / denom.unsqueeze(-1)
merged_lse[q_start:q_end] = new_lse + torch.log(denom)
covered[q_start:q_end] = True
cursor += q_len
if not bool(covered.all().item()):
return None
return merged.to(dtype=q.dtype)
def range_attention(
q,
k,
v,
q_ranges,
k_ranges,
attn_type_map,
softmax_scale: float,
*,
segment_offsets=None,
group_q_ranges=None,
group_attn_type_map=None,
max_q_len=None,
max_k_len=None,
flash_cu_seqlens_q=None,
flash_cu_seqlens_k=None,
flash_causal=None,
disjoint_q_ranges=None,
):
"""Run sparse range attention through FlashAttention varlen."""
del disjoint_q_ranges
if not q.is_cuda:
raise RuntimeError("LA Flash range_attention requires CUDA tensors")
if segment_offsets is None or group_q_ranges is None or group_attn_type_map is None:
(
group_q_ranges,
k_ranges,
segment_offsets,
group_attn_type_map,
computed_max_q_len,
computed_max_k_len,
) = _coalesce_query_groups(q_ranges, k_ranges, attn_type_map)
if max_q_len is None:
max_q_len = computed_max_q_len
if max_k_len is None:
max_k_len = computed_max_k_len
elif max_q_len is None:
lengths = (group_q_ranges[:, 1] - group_q_ranges[:, 0]).detach().to(device="cpu")
max_q_len = int(lengths.max().item()) if lengths.numel() else 0
if max_k_len is None:
k_lengths = (k_ranges[:, 1] - k_ranges[:, 0]).detach().to(device="cpu")
max_k_len = int(k_lengths.max().item()) if k_lengths.numel() else 0
if (
flash_cu_seqlens_q is not None
and flash_cu_seqlens_k is not None
and flash_causal is not None
and _flash_fastpath_enabled()
and q.dtype in (torch.float16, torch.bfloat16)
and k.dtype == q.dtype
and v.dtype == q.dtype
):
flash_attn_varlen = _load_flash_attn_varlen()
return flash_attn_varlen(
q.contiguous(),
k.contiguous(),
v.contiguous(),
flash_cu_seqlens_q.contiguous().to(device=q.device, dtype=torch.int32),
flash_cu_seqlens_k.contiguous().to(device=q.device, dtype=torch.int32),
int(max_q_len),
int(max_k_len),
dropout_p=0.0,
softmax_scale=float(softmax_scale),
causal=bool(flash_causal),
)
segment_out = _try_flash_segment_merge(
q,
k,
v,
k_ranges,
segment_offsets,
group_q_ranges,
group_attn_type_map,
softmax_scale,
)
if segment_out is not None:
return segment_out
raise RuntimeError(
"LA Flash could not express this range plan with FlashAttention varlen. "
"Only attn_type 0/1 range plans are supported in the release path."
)