File size: 4,556 Bytes
ccef021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import Optional, Tuple

import torch

from lib import TestParam, Testcase, TestcaseForDecode, KVScope

def _merge_two_lse(lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int) -> torch.Tensor:
    if lse1 is None:
        return lse0
    else:
        return torch.logsumexp(
            torch.stack([
                lse0.view(s_q, h_q),
                lse1.broadcast_to(s_q, h_q)
            ], dim=0),
            dim=0
        )
        
def ref_sparse_attn_fwd(p: TestParam, t: Testcase) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Returns:
    - o: [s_q, h_q, dv]
    - o_fp32: [s_q, h_q, dv]
    - max_logits: [s_q, h_q]
    - lse: [s_q, h_q]
    """
    indices = t.indices.clone().squeeze(1)
    if t.topk_length is not None:
        mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze(0).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze(1)   # [s_q, topk]
        indices[mask] = -1
    invalid_mask = (indices < 0) | (indices >= p.s_kv)    # [s_q, topk]
    indices[invalid_mask] = 0

    q = t.q.float()
    gathered_kv = t.kv.index_select(dim=0, index=indices.flatten()).reshape(p.s_q, p.topk, p.d_qk).float()   # [s_q, topk, d_qk]
    P = (q @ gathered_kv.transpose(1, 2))   # [s_q, h_q, topk]
    P *= t.sm_scale
    P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf")

    orig_lse = torch.logsumexp(P, dim=-1)   # [s_q, h_q]
    max_logits = P.max(dim=-1).values   # [s_q, h_q]

    lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q)
    if not torch.is_inference_mode_enabled():
        lse_for_o = lse_for_o.clone()
    lse_for_o[lse_for_o == float("-inf")] = float("+inf")   # So that corresponding O will be 0
    s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1))
    out = s_for_o @ gathered_kv[..., :p.d_v]   # [s_q, h_q, dv]

    lonely_q_mask = orig_lse == float("-inf")   # [s_q, h_q]
    orig_lse[lonely_q_mask] = float("+inf")
    return (out.to(torch.bfloat16), out, max_logits, orig_lse)


def ref_sparse_attn_decode(
    p: TestParam,
    t: TestcaseForDecode
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    A reference implementation of sparse decoding attention in PyTorch
    """
    assert p.h_kv == 1
    assert p.decode is not None
    b = p.decode.b

    def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]:
        assert kv_scope.indices_in_kvcache is not None
        topk = kv_scope.indices_in_kvcache.size(-1)
        indices_in_kv_cache_fixed = torch.clamp_min(kv_scope.indices_in_kvcache, 0) # Otherwise torch.index_select will complain
        gathered_kv = kv_scope.blocked_k.view(-1, p.d_qk).index_select(0, indices_in_kv_cache_fixed.view(-1)).view(b, p.s_q, topk, p.d_qk)  # [b, s_q, topk, d]
        invalid_mask = kv_scope.indices_in_kvcache == -1
        if kv_scope.topk_length is not None:
            invalid_mask |= torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1)
        return gathered_kv, invalid_mask
    
    gathered_kv, invalid_mask = process_kv_scope(t.kv_scope)
    if t.extra_kv_scope is not None:
        gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope)
        gathered_kv = torch.cat([gathered_kv, gathered_kv1], dim=2)  # [b, s_q, topk+extra_topk, d]
        invalid_mask = torch.cat([invalid_mask, invalid_mask1], dim=2)   # [b, s_q, topk+extra_topk]

    gathered_kv = gathered_kv.view(b*p.s_q, -1, p.d_qk).float()
    gathered_kv[gathered_kv != gathered_kv] = 0.0
    q = t.q.float().view(b*p.s_q, p.h_q, p.d_qk)
    attn_weight = q @ gathered_kv.transpose(-1, -2)  # [t.b*t.s_q, t.h_q, topk+extra_topk]
    attn_weight *= t.sm_scale
    attn_weight[invalid_mask.view(b*p.s_q, 1, -1).broadcast_to(b*p.s_q, p.h_q, invalid_mask.size(-1))] = float("-inf")
    lse = attn_weight.logsumexp(dim=-1)  # [t.b*t.s_q, t.h_q]
    attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1))
    output = attn_weight @ gathered_kv[..., :p.d_v]    # [t.b*t.s_q, t.h_q, t.dv]
    output = output.view(b, p.s_q, p.h_q, p.d_v)
    lse = lse.view(b, p.s_q, p.h_q)

    # Attention sink
    if t.attn_sink is not None:
        output *= (1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse))).unsqueeze(-1)

    # Correct for q tokens which has no attendable k
    lonely_q_mask = (lse == float("-inf"))
    output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0
    lse[lonely_q_mask] = float("+inf")

    return output.to(torch.bfloat16), lse.transpose(1, 2)