flash-mla / tests /ref.py
medmekk's picture
Upload folder using huggingface_hub
ccef021 verified
from typing import Optional, Tuple
import torch
from lib import TestParam, Testcase, TestcaseForDecode, KVScope
def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor:
if lse1 is None:
return lse0
else:
return torch.logsumexp(
torch.stack([
lse0.view(s_q, h_q),
lse1.broadcast_to(s_q, h_q)
], dim=0),
dim=0
)
def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
- o: [s_q, h_q, dv]
- o_fp32: [s_q, h_q, dv]
- max_logits: [s_q, h_q]
- lse: [s_q, h_q]
"""
indices = t.indices.clone().squeeze(1)
if t.topk_length is not None:
mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1) # [s_q, topk]
indices[mask] = -1
invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk]
indices[invalid_mask] = 0
q = t.q.float()
gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float() # [s_q, topk, d_qk]
P = (q @ gathered_kv.transpose(1, 2)) # [s_q, h_q, topk]
P *= t.sm_scale
P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf")
orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q]
max_logits = P.max(dim=-1).values # [s_q, h_q]
lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q)
if not torch.is_inference_mode_enabled():
lse_for_o = lse_for_o.clone()
lse_for_o[lse_for_o == float("-inf")] = float("+inf") # So that corresponding O will be 0
s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))
out = s_for_o @ gathered_kv[..., :p.d_v] # [s_q, h_q, dv]
lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q]
orig_lse[lonely_q_mask] = float("+inf")
return (out.to(torch.bfloat16), out, max_logits, orig_lse)
def ref_sparse_attn_decode(
p: TestParam,
t: TestcaseForDecode
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation of sparse decoding attention in PyTorch
"""
assert p.h_kv == 1
assert p.decode is not None
b = p.decode.b
def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]:
assert kv_scope.indices_in_kvcache is not None
topk = kv_scope.indices_in_kvcache.size(-1)
indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain
gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk) # [b, s_q, topk, d]
invalid_mask = kv_scope.indices_in_kvcache == -1
if kv_scope.topk_length is not None:
invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1)
return gathered_kv, invalid_mask
gathered_kv, invalid_mask = process_kv_scope(t.kv_scope)
if t.extra_kv_scope is not None:
gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope)
gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2) # [b, s_q, topk+extra_topk, d]
invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2) # [b, s_q, topk+extra_topk]
gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float()
gathered_kv[gathered_kv != gathered_kv] = 0.0
q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk)
attn_weight = q @ gathered_kv.transpose(-1, -2) # [t.b*t.s_q, t.h_q, topk+extra_topk]
attn_weight *= t.sm_scale
attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf")
lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q]
attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1))
output = attn_weight @ gathered_kv[..., :p.d_v] # [t.b*t.s_q, t.h_q, t.dv]
output = output.view(b, p.s_q, p.h_q, p.d_v)
lse = lse.view(b, p.s_q, p.h_q)
# Attention sink
if t.attn_sink is not None:
output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1)
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output.to(torch.bfloat16), lse.transpose(1, 2)