File size: 4,037 Bytes
7c2c068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import triton
import triton.language as tl
import math
import pdb

@triton.jit
def _fwd_kernel(
    Q, K, V, Out, Lse, 
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_ob, stride_oh, stride_om,
    nheads,
    seqlen_q, seqlen_k, seqlen_q_rounded,
    headdim,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1).to(tl.int64)
    off_b = off_hb // nheads
    off_h = off_hb % nheads

    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[None, :] * stride_kn + offs_d[:, None])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])

    l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)

    qk_scale = softmax_scale
    # qk_scale *= 1.44269504

    # defined as the original paper
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # pdb.set_trace()
        k = tl.load(k_ptrs, mask=(start_n + offs_n)[None, :] < seqlen_k, other=0.0)
        qk = tl.dot(q, k).to(tl.float32)
    
        qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
        if IS_CAUSAL:
            qk = qk * qk_scale + tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
            m_ij = tl.max(qk, 1)
            m_i_new = tl.maximum(m_ij, m_i)
            qk -= m_i_new[:, None]
        else:
            m_ij = tl.max(qk * qk_scale, 1)
            m_i_new = tl.maximum(m_ij, m_i)
            qk = qk * qk_scale - m_i_new[:, None]

        p = tl.exp(qk)
        l_ij = tl.sum(p, 1)

        alpha = tl.exp(m_i - m_i_new)
        # alpha2 = tl.exp(m_ij - m_i_new)
        l_i_new = l_i * alpha + l_ij # * alpha2
        acc = acc * alpha[:, None] # * l_i[:, None] / l_i_new[:, None]

        v = tl.load(v_ptrs, mask = offs_n[:, None] < (seqlen_k - start_n))
        p = p.to(tl.float16)
        acc += tl.dot(p, v) # * alpha2[:, None]
        # pdb.set_trace()
        m_i = m_i_new
        l_i = l_i_new

        k_ptrs += BLOCK_N * stride_kn
        v_ptrs += BLOCK_N * stride_vn
    
    acc = acc / l_i[:, None]
    
    out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
    tl.store(out_ptrs, acc.to(Out.type.element_ty), mask=offs_m[:, None] < seqlen_q)


def flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
    batch, nheads, seqlen_q, d = q.shape
    _, _, seqlen_k, _ = k.shape
    softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
    
    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
    o = torch.empty_like(q)

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
    BLOCK = 128
    num_warps = 4 if d <= 64 else 8
    grid = (triton.cdiv(seqlen_q, BLOCK), batch * nheads)
    _fwd_kernel[grid](
        q, k, v, o, lse,
        softmax_scale,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        o.stride(0), o.stride(1), o.stride(2),
        nheads, seqlen_q, seqlen_k, 
        seqlen_q_rounded, d,
        causal,
        BLOCK_HEADDIM,
        BLOCK_M=BLOCK,
        BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
    )
    return o