|
|
| import math |
| import torch |
| import triton |
| import triton.language as tl |
|
|
| @triton.heuristics( |
| { |
| "EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0, |
| "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
| } |
| ) |
| @triton.jit |
| def _fwd_eva_prep_kv_kernel( |
| K, |
| V, |
| PARAM_MU, |
| PARAM_PHI, |
| ChunkMask, |
| Out_RFA_K, |
| Out_RFA_V, |
| softmax_scale, |
| stride_kb, stride_kh, stride_kn, |
| stride_vb, stride_vh, stride_vn, |
| stride_mu_h, |
| stride_phi_h, |
| stride_mb, stride_mn, |
| stride_ok_b, stride_ok_h, stride_ok_c, |
| stride_ov_b, stride_ov_h, stride_ov_c, |
| nheads, |
| seqlen, |
| nchunks, |
| headdim, |
| CACHE_KEY_SEQLEN, |
| CACHE_KEY_NCHUNKS, |
| CHUNKS_PER_BLOCK: tl.constexpr, |
| CHUNK_SIZE: tl.constexpr, |
| MASK_TYPE: tl.constexpr, |
| BLOCK_HEADDIM: tl.constexpr, |
| EVEN_N: tl.constexpr, |
| EVEN_HEADDIM: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| ): |
| start_n = tl.program_id(0) |
| offs_bh = tl.program_id(1) |
| offs_h = offs_bh % nheads |
| offs_b = offs_bh // nheads |
| |
| |
| |
| offs_c = tl.arange(0, CHUNKS_PER_BLOCK) |
| offs_m = tl.arange(0, CHUNK_SIZE) |
| offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
| k_ptrs = ( |
| K + |
| offs_b * stride_kb + |
| offs_h * stride_kh + |
| ( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) * stride_kn + |
| offs_d[None, None, :] |
| ) |
| ) |
| v_ptrs = ( |
| V + |
| offs_b * stride_vb + |
| offs_h * stride_vh + |
| ( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) * stride_vn + |
| offs_d[None, None, :] |
| ) |
| ) |
| param_mu_ptrs = ( |
| PARAM_MU + |
| offs_h * stride_mu_h + |
| offs_d[None, None, :] |
| ) |
| param_phi_ptrs = ( |
| PARAM_PHI + |
| offs_h * stride_phi_h + |
| offs_d[None, None, :] |
| ) |
| log2e = 1.4426950408889634 |
| if MASK_TYPE == 1: |
| m_ptrs = ( |
| ChunkMask + |
| offs_b * stride_mb + |
| ( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None] * CHUNK_SIZE + |
| offs_m[None, :] |
| ) * stride_mn |
| ) |
| ) |
| if EVEN_N: |
| if EVEN_HEADDIM: |
| k = tl.load( |
| k_ptrs |
| ) |
| else: |
| k = tl.load( |
| k_ptrs, |
| mask=offs_d[None, None, :] < headdim, |
| other=0.0 |
| ) |
| else: |
| if EVEN_HEADDIM: |
| k = tl.load( |
| k_ptrs, |
| mask=( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) < seqlen, |
| other=0.0 |
| ) |
| else: |
| k = tl.load( |
| k_ptrs, |
| mask=( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) < seqlen |
| ) & (offs_d[None, None, :] < headdim), |
| other=0.0 |
| ) |
| |
| param_mu = tl.load(param_mu_ptrs).to(k.dtype) |
| rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| rfa_k_c_w += tl.sum(k * param_mu, axis=-1) |
| rfa_k_c_w *= log2e |
| if MASK_TYPE == 1: |
| if EVEN_N: |
| mask = tl.load( |
| m_ptrs |
| ).to(tl.float32) |
| else: |
| mask = tl.load( |
| m_ptrs, |
| mask=( |
| start_n * BLOCK_N + |
| offs_c[:, None] * CHUNK_SIZE + |
| offs_m[None, :] |
| ) < seqlen, |
| other=0.0, |
| ).to(tl.float32) |
| rfa_k_c_w = rfa_k_c_w + mask |
|
|
| rfa_k_c_w = tl.exp2(rfa_k_c_w - tl.max(rfa_k_c_w, axis=-1)[:, None]) |
| rfa_k_c_w = rfa_k_c_w / tl.sum(rfa_k_c_w, axis=-1)[:, None] |
| rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2) |
| |
| offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK) |
| out_rfa_k_ptrs = ( |
| Out_RFA_K + |
| offs_b * stride_ok_b + |
| offs_h * stride_ok_h + |
| (offs_out_c[:, None] * stride_ok_c + offs_d[None, :]) |
| ) |
|
|
| if EVEN_N: |
| if EVEN_HEADDIM: |
| tl.store( |
| out_rfa_k_ptrs, rfa_k_c |
| ) |
| else: |
| tl.store( |
| out_rfa_k_ptrs, rfa_k_c, |
| mask=offs_d[None, :] < headdim |
| ) |
| else: |
| if EVEN_HEADDIM: |
| tl.store( |
| out_rfa_k_ptrs, rfa_k_c, |
| mask=offs_out_c[:, None] < nchunks |
| ) |
| else: |
| tl.store( |
| out_rfa_k_ptrs, rfa_k_c, |
| mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim) |
| ) |
|
|
|
|
| param_phi = tl.load(param_phi_ptrs).to(k.dtype) |
| rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32) |
| rfa_v_c_w += tl.sum(k * param_phi, axis=-1) |
| rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1)) |
| rfa_v_c_w *= log2e * softmax_scale |
| if not EVEN_N: |
| rfa_v_c_w += tl.where( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None] * CHUNK_SIZE + |
| offs_m[None, :] |
| ) < seqlen, |
| 0, |
| float("-inf") |
| ) |
|
|
| if MASK_TYPE == 1: |
| rfa_v_c_w = rfa_v_c_w + mask |
|
|
| if EVEN_N: |
| if EVEN_HEADDIM: |
| v = tl.load( |
| v_ptrs |
| ) |
| else: |
| v = tl.load( |
| v_ptrs, |
| mask=offs_d[None, None, :] < headdim, |
| other=0.0 |
| ) |
| else: |
| if EVEN_HEADDIM: |
| v = tl.load( |
| v_ptrs, |
| mask=( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) < seqlen, |
| other=0.0 |
| ) |
| else: |
| v = tl.load( |
| v_ptrs, |
| mask=( |
| ( |
| start_n * BLOCK_N + |
| offs_c[:, None, None] * CHUNK_SIZE + |
| offs_m[None, :, None] |
| ) < seqlen |
| ) & (offs_d[None, None, :] < headdim), |
| other=0.0 |
| ) |
| |
| rfa_v_c_w = tl.exp2(rfa_v_c_w - tl.max(rfa_v_c_w, axis=-1)[:, None]) |
| rfa_v_c_w = rfa_v_c_w / tl.sum(rfa_v_c_w, axis=-1)[:, None] |
| rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2) |
|
|
| offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK) |
| out_rfa_v_ptrs = ( |
| Out_RFA_V + |
| offs_b * stride_ov_b + |
| offs_h * stride_ov_h + |
| (offs_out_c[:, None] * stride_ov_c + offs_d[None, :]) |
| ) |
| if EVEN_N: |
| if EVEN_HEADDIM: |
| tl.store( |
| out_rfa_v_ptrs, rfa_v_c |
| ) |
| else: |
| tl.store( |
| out_rfa_v_ptrs, rfa_v_c, |
| mask=offs_d[None, :] < headdim |
| ) |
| else: |
| if EVEN_HEADDIM: |
| tl.store( |
| out_rfa_v_ptrs, rfa_v_c, |
| mask=offs_out_c[:, None] < nchunks |
| ) |
| else: |
| tl.store( |
| out_rfa_v_ptrs, rfa_v_c, |
| mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim) |
| ) |
|
|
| def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, chunk_mask, softmax_scale, chunksize): |
| k, v, param_mu, param_phi = [ |
| x if x.stride(-1) == 1 else x.contiguous() |
| for x in [k, v, param_mu, param_phi] |
| ] |
|
|
| |
| batch, nheads, seqlen, head_dim = k.shape |
| assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize" |
| nchunks = seqlen // chunksize |
| assert k.shape == (batch, nheads, seqlen, head_dim) |
| assert v.shape == (batch, nheads, seqlen, head_dim) |
| assert param_mu.shape == (1, nheads, 1, 1, head_dim) |
| assert param_phi.shape == (1, nheads, 1, 1, head_dim) |
| assert head_dim <= 128, "We only test head dimensions up to 128" |
| assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type" |
| assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now" |
| assert k.is_cuda and v.is_cuda |
| softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
|
|
| mask_type = 0 |
| if chunk_mask is not None: |
| mask_type = 1 |
| assert chunk_mask.dtype == k.dtype |
| assert chunk_mask.is_cuda |
| assert chunk_mask.dim() == 4 |
| assert chunk_mask.shape == (batch, 1, seqlen, 1) |
| if chunk_mask.stride(-1) != 1: |
| chunk_mask = chunk_mask.contiguous() |
| mask_strides = ( |
| (chunk_mask.stride(0), chunk_mask.stride(2)) |
| if mask_type == 1 else |
| (0, 0) |
| ) |
| out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device) |
| out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device) |
|
|
| BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
| BLOCK = 128 |
| num_warps = 4 if head_dim <= 64 else 8 |
| |
| assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize" |
| chunks_per_block = BLOCK // chunksize |
|
|
| grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads) |
| _fwd_eva_prep_kv_kernel[grid]( |
| k, |
| v, |
| param_mu, |
| param_phi, |
| chunk_mask, |
| out_rfa_k, |
| out_rfa_v, |
| softmax_scale, |
| k.stride(0), k.stride(1), k.stride(2), |
| v.stride(0), v.stride(1), v.stride(2), |
| param_mu.stride(1), |
| param_phi.stride(1), |
| mask_strides[0], mask_strides[1], |
| out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2), |
| out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2), |
| nheads, |
| seqlen, |
| nchunks, |
| head_dim, |
| seqlen // 32, |
| nchunks // 32, |
| chunks_per_block, |
| chunksize, |
| mask_type, |
| BLOCK_HEADDIM, |
| BLOCK_N=BLOCK, |
| num_warps=num_warps, |
| num_stages=1, |
| ) |
| return out_rfa_k, out_rfa_v |
|
|