diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0ebe265ad98dd15f76219078e0165eebc73e93 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2668f9e2b01307da4dfcc0217eb799b18a007bd8 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8fb5e86bc8f45c1bbe1ce35d8f249a9851ad541 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39244b9091b7f88616bebcd045876b4b8f90a609 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25968d87f77e30368a1e57c832818ddb40aafe09 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcb54f213a279d7baea931d2be9a0b475132d136 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..822615381d2baddb9e1859e59c5a92067b85609f Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f4b07066c00fd89d8e0695079d8e0d758ae9d8 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ddbf724f039c5ce429ee2c0a40827854312407c Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35ee2b8dce14c0b23d32940af368b7756a245ff5 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee01c8a28b0f7e5a897dfe83b606bd8216006421 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c24b768a0e3b8b054359b9630381805eb1c5f86 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95154739165367dad2a9549d426f00fbed1e7d5b Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca1876dea55bac43debe3b500516048b1e64a60 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0c104986cb0a9beb592f8fc6ef8b07f1d3a3393 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c541b7bd0a503bc34925c4a07574edf689775fc2 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_A_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..0e2fc6773053cb204df033bd9c19a51080f6fb69 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp, gather +from ....utils import check_shared_mem, is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_intra( + q, + k, + a, + b, + gi, + ge, + dAqk, + dAqb, + dAak, + dAab, + dq, + dk, + da, + db, + dqg, + dkg, + dag, + dbg, + dgk, + dgk_offset, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32) + + if i_t * BT >= T: + return + + # offset calculation + ge += (bos*H + i_h) * K + gi += (bos*H + i_h) * K + q += (bos*H + i_h) * K + a += (bos*H + i_h) * K + b += (bos*H + i_h) * K + k += (bos*H + i_h) * K + dq += (bos*H + i_h) * K + dk += (bos*H + i_h) * K + da += (bos*H + i_h) * K + db += (bos*H + i_h) * K + dqg += (bos*H + i_h) * K + dag += (bos*H + i_h) * K + dkg += (bos*H + i_h) * K + dbg += (bos*H + i_h) * K + dgk += (bos*H + i_h) * K + dgk_offset += (bos*H + i_h) * K + dAqk += (bos*H + i_h) * BT + dAqb += (bos*H + i_h) * BT + dAak += (bos*H + i_h) * BT + dAab += (bos*H + i_h) * BT + + stride_qk = H*K + stride_A = H*BT + + p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + b_da = tl.zeros([BC, BK], dtype=tl.float32) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + b_db = tl.zeros([BC, BK], dtype=tl.float32) + # intra chunk gradient calculation + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + o_i = tl.arange(0, BC) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAab = tl.load(p_dAab, boundary_check=(0, 1)) + b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)) + b_dAak = tl.load(p_dAak, boundary_check=(0, 1)) + + # inter chunk gradient calculation + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + # intra chunk gradient calculation + for j in range(0, min(BC, T - i_t * BT)): + # trick to index the block + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + col_idx = tl.full([BC, 1], j, dtype=tl.int16) + row_idx_bc = tl.full([1, BC], j, dtype=tl.int16) + # [1, BK] + b_kj = gather(b_k, row_idx, axis=0) + b_bj = gather(b_b, row_idx, axis=0) + b_gij = gather(b_gi, row_idx, axis=0) + b_gej = gather(b_ge, row_idx, axis=0) + b_qj = gather(b_q, row_idx, axis=0) + b_aj = gather(b_a, row_idx, axis=0) + # [BC, 1] + b_dAqk_j = gather(b_dAqk, col_idx, axis=1) + b_dAab_j = gather(b_dAab, col_idx, axis=1) + b_dAqb_j = gather(b_dAqb, col_idx, axis=1) + b_dAak_j = gather(b_dAak, col_idx, axis=1) + # [1, BC] -> [BC, 1] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None] + b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None] + b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None] + else: + mask_idx = tl.arange(0, BC) == j + b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :] + b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :] + b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :] + b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :] + b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None] + b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None] + b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None] + b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None] + b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None] + b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None] + b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None] + b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None] + # [1, BK] b_qj, b_aj + b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :] + b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :] + + m_e = o_i[:, None] > j + m_i = o_i[:, None] >= j + tmp1 = exp(b_gi - b_gij) + tmp2 = exp(b_ge - b_gij) + b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.) + b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.) + b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.) + b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.) + + m_i = o_i[:, None] <= j + m_e = o_i[:, None] < j + tmp1 = exp(b_gij - b_gi) + tmp2 = exp(b_gej - b_gi) + b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.) + b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.) + b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.) + b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.) + + # post processing + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge) + b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale + tmp = exp(b_gn[None, :] - b_gi) + b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp + b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp + tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32) + b_dgk_offset = b_da * b_a + tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in [32, 64] + ], + key=['BK', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_dgk_kernel( + dgk, + dgk_offset, + dgk_last, + dgk_output, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = (i_b * NT + i_t).to(tl.int32) + bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32) + + stride_qk = H * K + dgk += (bos * H + i_h) * K + dgk_offset += (bos * H + i_h) * K + dgk_last += (i_tg * H + i_h) * K + dgk_output += (bos * H + i_h) * K + p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK + m_k = tl.arange(0, BK) + i_k * BK < K + b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) + b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1)) + # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32) + # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False) + b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True) + b_dgk_cumsum += b_dgk_last[None, :] + b_dgk_cumsum -= b_dgk_offset + p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dAqk: torch.Tensor, + dAqb: torch.Tensor, + dAak: torch.Tensor, + dAab: torch.Tensor, + dqg: torch.Tensor, + dkg: torch.Tensor, + dag: torch.Tensor, + dbg: torch.Tensor, + dgk_last: torch.Tensor, + scale: float = 1.0, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + da = torch.empty_like(a) + db = torch.empty_like(b) + dgk = torch.empty_like(gi, dtype=torch.float) + dgk_offset = torch.empty_like(gi, dtype=torch.float) + + grid = (NK, NT, B * H) + chunk_dplr_bwd_kernel_intra[grid]( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + dAqk=dAqk, + dAqb=dAqb, + dAak=dAak, + dAab=dAab, + dq=dq, + dk=dk, + dgk=dgk, + dgk_offset=dgk_offset, + dqg=dqg, + dkg=dkg, + dag=dag, + dbg=dbg, + da=da, + db=db, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BT, + BK=BK, + GATHER_SUPPORTED=is_gather_supported + ) + + dgk_output = torch.empty_like(dgk) + + def grid(meta): return (NT, triton.cdiv(K, meta['BK']), B * H) + chunk_dplr_bwd_dgk_kernel[grid]( + dgk=dgk, + dgk_offset=dgk_offset, + dgk_last=dgk_last, + dgk_output=dgk_output, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + ) + return dq, dk, da, db, dgk_output diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..86e8cec2d47980d2ff26f7e904bbe39f0697fa07 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV', "V"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dhu( + qg, + bg, + w, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)) + + mask_k = tl.arange(0, BK) < K + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + # [BT, BK] + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype)) + tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype)) + b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype)) + last_idx = min((i_t + 1) * BT, T) - 1 + bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k) + b_dh *= exp(bg_last)[:, None] + b_dh += b_dh_tmp + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dhu( + qg: torch.Tensor, + bg: torch.Tensor, + w: torch.Tensor, + gk: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *qg.shape, do.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + # H100 + if check_shared_mem('hopper', qg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', qg.device.index): # A100 + BV = 32 + BC = 32 + else: # Etc: 4090 + BV = 16 + BC = 16 + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + BC = min(BT, BC) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = qg.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.zeros_like(dv) + + grid = (NK, NV, N * H) + chunk_dplr_bwd_kernel_dhu[grid]( + qg=qg, + bg=bg, + w=w, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return dh, dh0, dv2 diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..66f5e823be6c20bfe6683d489cabde3b3816be7e --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BK_LIST + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_o( + qg, + v, + v_new, + A_qk, + A_qb, + h, + o, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_qg, b_h) + + p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) + b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) + b_Aqk = tl.where(m_s, b_Aqk, 0) + b_Aqb = tl.where(m_s, b_Aqb, 0) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_o( + qg: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + A_qk: torch.Tensor, + A_qb: torch.Tensor, + h: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *qg.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_fwd_kernel_o[grid]( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + o=o, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..49400c1f7f0f6880ef98022e01dc156c00a6d0bf --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils.op import exp +from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [16, 32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_dplr_delta_rule_fwd_kernel( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = tl.sum(b_h * b_q[None, :], axis=1) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_a += (-1 if REVERSE else 1) * H*K + p_b += (-1 if REVERSE else 1) * H*K + p_gk += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_o += (-1 if REVERSE else 1) * H*V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_dplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N * H) + fused_recurrent_dplr_delta_rule_fwd_kernel[grid]( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + REVERSE=reverse, + ) + return o, ht + + +class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + o, ht = fused_recurrent_dplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. " + "This kernel is only for inference. " + "For training, please use `chunk_dplr_delta_rule`." + ) + + +def fused_recurrent_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, T, H, K]`. + gk (torch.Tensor): + gk of shape `[B, T, H, K]`. decay term in log space! + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: 1. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (Optional[torch.Tensor]): + Cumulative sequence lengths of shape `[N + 1]` used for variable-length training, + consistent with the FlashAttention API. + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + return o, final_state diff --git a/fla3/ops/generalized_delta_rule/dplr/naive.py b/fla3/ops/generalized_delta_rule/dplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ac253673e5361a375286347253f7d4e6f7a2f3 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] + + +def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i].clone() + _beta = beta[:, :, i].clone() + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S.clone() * gk[:, :, i].exp()[..., None] + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v).to(q) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', + c=chunk_size).float(), [q, k, v, alpha, beta, gk]) + + gk_cumsum = gk.cumsum(-2) + + # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + + for i in range(chunk_size): + alpha_i = alpha[:, :, :, i, None] + q_i = q[:, :, :, i, None] + gk_i = gk_cumsum[:, :, :, i, None] + mask = (torch.arange(chunk_size) <= i).to(q.device) + attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone() + A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone() + mask = (torch.arange(chunk_size) < i).to(q.device) + # shift by one. + attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone() + A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone() + + A_ab = A_ab + for i in range(1, chunk_size): + A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2) + + A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = A_ab @ (A_ak @ v) + w = A_ab @ ((gk_cumsum-gk).exp() * alpha) + + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + v2_i = u_i + w_i @ S + + o_1 = A_qk[:, :, i] @ v_i + o_2 = A_qb[:, :, i] @ v2_i + o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S + o[:, :, i] = o_1 + o_2 + o_3 + decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp() + S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \ + (beta_i * decay).transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..6855e7bfdac154365e2faf3a91d204caf3c6f647 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + A_ab_inv, + A_ak, + ag, + v, + dw, + du, + dv, + dv0, + dag, + dAak, + dAab, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1)) + b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1)) + b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0) + b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0) + b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty) + b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v)) + b_dv0 = tl.load(p_dv0, boundary_check=(0, 1)) + b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :] + b_dA_tmp = tl.where(m_i, b_dA_tmp, 0) + b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp) + b_dA_ak = tl.where(m_i, b_dA_ak, 0) + tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1)) + b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t) + + for i_k in range(tl.cdiv(K, BK)): + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag)) + b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw) + tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1)) + + # if we know dL/dA^(-1), for dL/dA, we can use the following formula: + # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T + # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1. + # denote A = I - lower(A_ab), B = A^-1 + # in the backward pass. + # dL/dA = -(B)^T @ (dL/dB) @ B^T + # dL/dA_ab = lower(B^T @ dL/dB @ B^T) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv) + b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t) + b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0) + tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1)) + + +def chunk_dplr_bwd_wy( + A_ab_inv: torch.Tensor, + A_ak: torch.Tensor, + v: torch.Tensor, + ag: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dv0: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]) + B, T, H, K, V = *dw.shape, du.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32) + + dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float) + dA_ak = torch.empty_like(A_ak, dtype=torch.float) + dv = torch.empty_like(v) + dag = torch.empty_like(ag) + + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + ag=ag, + v=v, + dw=dw, + du=du, + dv=dv, + dv0=dv0, + dag=dag, + dAak=dA_ak, + dAab=dA_ab, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dA_ab, dA_ak, dv, dag diff --git a/fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf14bd34ebde04d9e1a46784aa80dc6d72bd4fd --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import gather +from ....utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk32( + A_ab, + A_ab_inv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, # placeholder, do not delete + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_ab = tl.load(p_Aab, boundary_check=(0, 1)) + b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i) + b_A_ab = tl.where(mask[:, None], b_a, b_A_ab) + b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk64( + A_ab, + A_ab_inv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr = is_gather_supported +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + b_A = tl.load(p_A1, boundary_check=(0, 1)) + b_A2 = tl.load(p_A2, boundary_check=(0, 1)) + b_A3 = tl.load(p_A3, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + if GATHER_SUPPORTED: + row_idx = tl.full([1, BC], i, dtype=tl.int16) + # [1, BK] -> [BK] + b_a = tl.sum(gather(b_A, row_idx, axis=0), 0) + b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0) + else: + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + mask = tl.arange(0, BC) == i + # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A) + # tl.debug_barrier() + tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # causal mask + tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def wu_fwd_kernel( + w, + u, + ag, + v, + A_ab_inv, + A_ak, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_s = tl.arange(0, BT) + + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1)) + b_Aak = tl.load(p_A_ak, boundary_check=(0, 1)) + b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0) + b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0) + # let's use tf32 here + b_Aak = tl.dot(b_Aab_inv, b_Aak) + # (SY 01/04) should be bf16 or tf32? To verify. + b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne") + b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne") + + for i_k in range(tl.cdiv(K, BK)): + p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16 + tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16 + tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def wu_fwd( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab_inv: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *ag.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + w = torch.empty_like(ag) + u = torch.empty_like(v) + wu_fwd_kernel[(NT, B * H)]( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + w=w, + u=u, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +def prepare_wy_repr_fwd( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, _ = ag.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(BT, 32) + fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32 + A_ab_inv = torch.empty_like(A_ab) + fwd_fn[(NT, B * H)]( + A_ab=A_ab, + A_ab_inv=A_ab_inv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + BC=BC, + ) + w, u = wu_fwd( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + return w, u, A_ab_inv + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +fwd_wu = wu_fwd diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78d850e19fd73c9a9e32ba2c655f7b42af0532e Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..845ef0e4275c2f45df54743e50c7018695c932a2 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8069d0acbc1791cd4ee1d5301583ffd26ada38ad Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0ce52616b8c03acb62fbfec4a069228503b9edb Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8bbc526e3c8a53c4abb1dc44fafec3847f6a81 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BK"], +) +@triton.jit +def fused_recurrent_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + a, # a [B, H, L, K] + b, # b [B, H, L, K] + o, # output [B, H, L, V] + ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0) + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + cu_seqlens, # varlen cu_seqlens + scale, # K ** -0.5 + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + p_q = q + (bos * H + i_h) * K + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + tl.arange(0, BK) + p_a = a + (bos * H + i_h) * K + tl.arange(0, BK) + p_b = b + (bos * H + i_h) * K + tl.arange(0, BK) + p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = tl.arange(0, BK) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + # to store + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v) + p_q += K*H + p_k += K*H + p_o += V*H + p_v += V*H + p_ha += V*H + p_a += K*H + p_b += K*H + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_DHT': lambda args: args['dht'] is not None, + 'USE_DH0': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=["BK", "BV"], +) +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: b_dhead + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + a, # a [B, H, L, K] + b, # b [B, H, L, K] + ha, # ha [B, H, L, V] + dht, # gradient of final state [B, H, K, V] + dh0, # gradient of initial state [B, H, K, V] + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + da, # gradient of a [NV, B, H, L, K] + db, # gradient of b [NV, B, H, L, K] + dha, # gradient of ha [NK, B, H, L, V] + h0, # initial state [B, H, K, V] + scale, # K ** -0.5 + cu_seqlens, # cu_seqlens + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0 + USE_DH0: tl.constexpr, # whether to use dh0 + USE_DHT: tl.constexpr, # whether to use dht + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + dk += i_v * B * H * K * T + db += i_v * B * H * K * T + dq += i_v * B * H * K * T + da += i_v * B * H * K * T + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + mask_k = tl.arange(0, BK) < K + mask_v = (tl.arange(0, BV) + i_v * BV) < V + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + i_v * BV + ha += (bos * H + i_h) * V + i_v * BV + a += (bos * H + i_h) * K + b += (bos * H + i_h) * K + do += (bos * H + i_h) * V + i_v * BV + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + i_v * BV + da += (bos * H + i_h) * K + db += (bos * H + i_h) * K + dha += (bos * H + i_h) * V + i_v * BV + + p_q = q + tl.arange(0, BK) + (T - 1) * H*K + p_k = k + tl.arange(0, BK) + (T - 1) * H*K + p_v = v + tl.arange(0, BV) + (T - 1) * H*V + p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V + p_a = a + tl.arange(0, BK) + (T - 1) * H*K + p_b = b + tl.arange(0, BK) + (T - 1) * H*K + p_do = do + tl.arange(0, BV) + (T - 1) * H*V + p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K + p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V + p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V + p_db = db + tl.arange(0, BK) + (T - 1) * H*K + p_da = da + tl.arange(0, BK) + (T - 1) * H*K + p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_DHT: + p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + + b_dh += b_q[:, None] * b_do[None, :] + d_k = tl.sum(b_dh * b_v[None, :], axis=1) + d_v = tl.sum(b_dh * b_k[:, None], axis=0) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v) + + b_dha = tl.sum(b_dh * b_b[:, None], axis=0) + tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v) + b_db = tl.sum(b_dh * b_ha[None, :], axis=1) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k) + + b_dh += b_dha[None, :] * b_a[:, None] + p_do -= H*V + p_q -= H*K + p_k -= H*K + p_v -= H*V + p_dk -= H*K + p_dv -= H*V + p_b -= H*K + p_db -= H*K + p_a -= H*K + p_dha -= H*V + p_ha -= H*V + + if USE_DH0: + p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) + + tl.debug_barrier() + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_k[:, None] & mask_v[None, :] + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + p_k = k + tl.arange(0, BK) + p_v = v + tl.arange(0, BV) + p_ha = ha + tl.arange(0, BV) + p_do = do + tl.arange(0, BV) + p_dha = dha + tl.arange(0, BV) + p_da = da + tl.arange(0, BK) + p_dq = dq + tl.arange(0, BK) + p_b = b + tl.arange(0, BK) + + for i in range(0, T): + b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32) + d_a = tl.sum(b_dha[None, :] * b_h, axis=1) + tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += H*K + p_do += H*V + p_v += H*V + p_da += H*K + p_dha += H*V + p_ha += H*V + p_dq += H*K + p_b += H*K + + +class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None + ): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK = triton.next_power_of_2(K) + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32) + else: + final_state = None + + ha = torch.empty_like(v, dtype=torch.float32) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + N * H + ) + o = torch.empty_like(v) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + o=o, + ha=ha, + h0=initial_state, + ht=final_state, + scale=scale, + cu_seqlens=cu_seqlens, + H=H, + T=T, + K=K, + V=V, + BK=BK, + ) + ctx.save_for_backward(q, k, v, a, b, ha, initial_state) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, a, b, ha, initial_state = ctx.saved_tensors + B, T, H, K, V = *q.shape, v.shape[-1] + N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + scale = ctx.scale + + dq = q.new_empty(NV, *q.shape) + dk = k.new_empty(NV, *k.shape) + da = a.new_empty(NV, *a.shape) + db = b.new_empty(NV, *b.shape) + dv = torch.empty_like(v) + dha = torch.empty_like(ha) + grid = (NV, N * H) + + if initial_state is not None and initial_state.requires_grad: + dh0 = torch.empty_like(initial_state, dtype=torch.float32) + else: + dh0 = None + + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + ha=ha, + dht=dht, + dh0=dh0, + do=do, + dq=dq, + dk=dk, + dv=dv, + da=da, + db=db, + dha=dha, + h0=initial_state, + scale=scale, + cu_seqlens=ctx.cu_seqlens, + B=B, + H=H, + T=T, + K=K, + V=V, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + dk = dk.sum(0) + da = da.sum(0) + db = db.sum(0) + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None + + +def fused_recurrent_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` + a (torch.Tensor): + as of shape `[B, T, H, K]` + b (torch.Tensor): + bs of shape `[B, T, H, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens + ) + return o, final_state diff --git a/fla3/ops/generalized_delta_rule/iplr/wy_fast.py b/fla3/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e895a8191b7ce6503db674c480ab7238b60ccc7b --- /dev/null +++ b/fla3/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,300 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk32( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk64( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def wu_fwd_kernel( + w, + u, + a, + k, + v, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def prepare_wy_repr_fwd( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype) + fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + ) + w, u = wu_fwd( + a=a, + v=v, + k=k, + A=A, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return w, u, A + + +def wu_fwd( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + wu_fwd_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +fwd_wu = wu_fwd diff --git a/fla3/ops/gla/__pycache__/chunk.cpython-310.pyc b/fla3/ops/gla/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7f4a8a6c42643f4786c6010d31993123de176d3 Binary files /dev/null and b/fla3/ops/gla/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc b/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13d955898952bd0b86bb9c786bfcfb37b578d43b Binary files /dev/null and b/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc b/fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ed3d5f0edd0ef581e092c81c92ccd7fd2e50aff Binary files /dev/null and b/fla3/ops/gla/__pycache__/fused_chunk.cpython-310.pyc differ diff --git a/fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc b/fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1323aab28051891d8bf687033ecd2bf9f1c394c Binary files /dev/null and b/fla3/ops/gla/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acaf4bc215d9162f5ea27b732b93d71431f979d0 Binary files /dev/null and b/fla3/ops/gla/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..762552d60666912e8d916e15b7936f4fc19feb24 Binary files /dev/null and b/fla3/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/gla/chunk.py b/fla3/ops/gla/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..d55e072435ece33364d909df81615669340f06eb --- /dev/null +++ b/fla3/ops/gla/chunk.py @@ -0,0 +1,1300 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.cumsum import chunk_local_cumsum +from fla.ops.utils.op import exp, safe_exp +from fla.utils import check_shared_mem, input_guard + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BC"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_inter( + q, + k, + g, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = b_q * exp(b_g - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[:, None] - b_gk) + # [BC, BC] using tf32 to improve precision here. + b_A += tl.dot(b_qg, b_kg) + + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["BK", "BT"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra( + q, + k, + g, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i >= j, b_A * scale, 0.) + + tl.store(A + o_A + j, b_A, mask=m_A) + p_k += H*K + p_gk += H*K + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC', 'BK'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra_split( + q, + k, + g, + A, + cu_seqlens, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_tc // NC, i_tc % NC + i_j = i_i + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_A = tl.zeros([BC], dtype=tl.float32) + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i >= j, b_A * scale, 0.) + tl.store(A + o_A + j, b_A, mask=m_A) + p_k += H*K + p_gk += H*K + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( + A, + A2, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + NK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_c * BC >= T: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(0, NK): + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + b_A += tl.load(p_A, boundary_check=(0, 1)) + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'NC', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_intra( + q, + k, + g, + dA, + dq, + dk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * exp(b_gn[None, :] - b_gk)) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg) + b_dq *= exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.) + p_kj += H*K + p_gkj += H*K + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.load(p_gq, boundary_check=(0, 1)) + b_qg = b_q * safe_exp(b_gq - b_gn[None, :]) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dk += tl.dot(b_dA, b_qg) + b_dk *= exp(b_gn[None, :] - b_gk) + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * H*BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) + p_qj += H*K + p_gqj += H*K + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BV', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_dA( + v, + do, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dA += tl.dot(b_do, b_v) + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA = tl.where(m_s, b_dA * scale, 0.) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_dv( + k, + g, + A, + do, + dh, + dv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # (SY 09/17) important to disallow tf32 here to maintain a good precision. + b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BV] + # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_inter( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK,], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_dgk *= exp(b_gn) + b_dq *= scale + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gk) + + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + # tl.debug_barrier() + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] + # Buggy due to strange triton compiler issue. + # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) + # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_intra_gk( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + A = q.new_empty(B, T, H, BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( + q, + k, + g, + A, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + ) + + grid = (NT, NC, B * H) + # load the entire [BC, K] blocks into SRAM at once + if K <= 256: + BK = triton.next_power_of_2(K) + chunk_gla_fwd_A_kernel_intra_sub_intra[grid]( + q, + k, + g, + A, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + # split then merge + else: + BK = min(128, triton.next_power_of_2(K)) + NK = triton.cdiv(K, BK) + A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( + q, + k, + g, + A_intra, + cu_seqlens, + chunk_indices, + scale, + T=T, + B=B, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + + grid = (NT, NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid]( + A_intra, + A, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + BC=BC, + NK=NK, + ) + return A + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_gla_fwd_kernel_o[grid]( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o + + +def chunk_gla_bwd_dA( + v: torch.Tensor, + do: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BV = min(64, triton.next_power_of_2(V)) + + dA = v.new_empty(B, T, H, BT, dtype=torch.float) + grid = (NT, B * H) + chunk_gla_bwd_kernel_dA[grid]( + v, + do, + dA, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dA + + +def chunk_gla_bwd_dv( + k: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_gla_bwd_kernel_dv[grid]( + k, + g, + A, + do, + dh, + dv, + cu_seqlens, + chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dv + + +def chunk_gla_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + dA: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_gla_bwd_kernel_intra[grid]( + q, + k, + g, + dA, + dq, + dk, + cu_seqlens, + chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + return dq, dk + + +def chunk_gla_bwd_dqkg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dg = torch.empty_like(g) + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) + chunk_gla_bwd_kernel_inter[grid]( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dq2, dk2, dg + + +def chunk_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + g_cumsum: Optional[torch.Tensor], + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + T = q.shape[1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + if g_cumsum is None: + g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens) + + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + states_in_fp32=False, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + + # the intra A is kept in fp32 + # the computation has very marginal effect on the entire throughput + A = chunk_gla_fwd_intra_gk( + q=q, + k=k, + g=g_cumsum, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + o = chunk_gla_fwd_o_gk( + q=q, + v=v, + g=g_cumsum, + A=A, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + return g_cumsum, A, h, ht, o + + +def chunk_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + g_cumsum: Optional[torch.Tensor], + scale: float, + initial_state: torch.Tensor, + h: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + T = q.shape[1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + if g_cumsum is None: + g_cumsum = chunk_local_cumsum(g, BT, cu_seqlens=cu_seqlens) + + if h is None: + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + h0=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=BT, + states_in_fp32=True + ) + dh, dh0 = chunk_bwd_dh( + q=q, + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + states_in_fp32=True + ) + + dv = chunk_gla_bwd_dv( + k=k, + g=g_cumsum, + A=A, + do=do, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + + # dq dk in fp32 + dA = chunk_gla_bwd_dA( + v=v, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + dq, dk = chunk_gla_bwd_dqk_intra( + q=q, + k=k, + g=g_cumsum, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + dq, dk, dg = chunk_gla_bwd_dqkg( + q=q, + k=k, + v=v, + h=h, + g=g_cumsum, + do=do, + dh=dh, + dq=dq, + dk=dk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + return dq, dk, dv, dg, dh0 + + +class ChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + cu_seqlens, + ): + T = q.shape[1] + chunk_size = min(64, max(16, triton.next_power_of_2(T))) + + g_cumsum, A, h, ht, o = chunk_gla_fwd( + q=q, + k=k, + v=v, + g=g, + g_cumsum=None, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + # recompute g_cumsum in bwd pass + if g.dtype != torch.float: + g_cumsum = None + else: + g = None + ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, ht + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors + chunk_size, scale, cu_seqlens = ctx.chunk_size, ctx.scale, ctx.cu_seqlens + dq, dk, dv, dg, dh0 = chunk_gla_bwd( + q=q, + k=k, + v=v, + g=g, + g_cumsum=g_cumsum, + scale=scale, + h=None, + A=A, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None + + +@torch.compiler.disable +def chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, T, H, K]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import chunk_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens) + return o, final_state diff --git a/fla3/ops/gla/fused_recurrent.py b/fla3/ops/gla/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..346c47b5d50acf4befefce88d85865e52a256ca2 --- /dev/null +++ b/fla3/ops/gla/fused_recurrent.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.common.fused_recurrent import fused_recurrent + + +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + gk (torch.Tensor): + Forget gates of shape `[B, T, H, K]`. + gv (torch.Tensor): + Forget gates of shape `[B, T, H, V]` applied to values. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import fused_recurrent_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=None, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return o, final_state diff --git a/fla3/ops/gla/naive.py b/fla3/ops/gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0b843a5595e0e444dd8e902c4f789b1975b958 --- /dev/null +++ b/fla3/ops/gla/naive.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False +): + dtype = q.dtype + q, k, v, gk = map(lambda x: x.transpose(1, 2).float(), (q, k, v, gk)) + B, H, T, K, V = *q.shape, v.shape[-1] + o = torch.zeros_like(v) + scale = K ** -0.5 + + h = q.new_zeros(B, H, K, V, dtype=torch.float32) + if initial_state is not None: + h += initial_state.float() + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = v[:, :, i] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o[:, :, i] = (q_i[..., None] * h).sum(-2) + + if not output_final_state: + h = None + return o.transpose(1, 2).to(dtype), h diff --git a/fla3/ops/gsa/__init__.py b/fla3/ops/gsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8a88014ddfc3143e67d3a48c38a54b75d7f3d6 --- /dev/null +++ b/fla3/ops/gsa/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gsa +from .fused_recurrent import fused_recurrent_gsa + +__all__ = [ + 'chunk_gsa', + 'fused_recurrent_gsa' +] diff --git a/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc b/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecfefc966a3e01bb6b66fc76c8c11b515b5ecd22 Binary files /dev/null and b/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc b/fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18fedd57efb28ff357a67a4b98ff08f180fb7320 Binary files /dev/null and b/fla3/ops/gsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc b/fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a571209bb748b13fe66441237c7048b94fa3c3 Binary files /dev/null and b/fla3/ops/gsa/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc b/fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27e32b225788667976ffa076384b77c0d94b677 Binary files /dev/null and b/fla3/ops/gsa/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9332ccb6310d278170b0582bbc291716e2372f Binary files /dev/null and b/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..976e990f48750dda907ba245d5ebcbe28c21d94f Binary files /dev/null and b/fla3/ops/gsa/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/gsa/chunk.py b/fla3/ops/gsa/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..c46ba51bf8ddfc4c8dc25854817c4aafcac27421 --- /dev/null +++ b/fla3/ops/gsa/chunk.py @@ -0,0 +1,1136 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h +from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.cumsum import chunk_local_cumsum +from fla.ops.utils.op import exp, safe_exp +from fla.ops.utils.softmax import softmax_bwd, softmax_fwd +from fla.utils import input_guard + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_fwd_k_kernel_inter( + q, + k, + h, + g, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BT] + b_A += tl.dot(b_q, b_k) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o = b_o * exp(b_g) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_fwd_k_kernel_intra( + v, + g, + o, + A, + cu_seqlens, + chunk_indices, + T, + HQ: tl.constexpr, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC > T: + return + + p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_vg) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_o *= exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + # [BC, BV] + b_vg = b_v[None, :] * exp(b_g - b_gv[None, :]) + # avoid 0 * inf = inf + b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.) + p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + b_o += tl.load(p_o, boundary_check=(0, 1)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8] + ], + key=["BT"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_bwd_k_kernel_dA( + v, + g, + do, + dA, + chunk_indices, + cu_seqlens, + scale, + T, + B: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC > T: + return + + p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0)) + + # [BC, BC] + b_dA = tl.zeros([BC, BC], dtype=tl.float32) + if i_i > i_j: + p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) + p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v + p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0.) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_vg) + elif i_i == i_j: + p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + m_v = o_v < V + + o_i = tl.arange(0, BC) + # [BC, BC] + m_dA = o_i[:, None] >= o_i[None, :] + for j in range(0, min(BC, T - i_t * BT - i_j * BC)): + # [BV,] + b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32) + b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32) + # [BC,] + b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1) + b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA) + + p_v += H*V + p_gv += H*V + b_dA = tl.where(m_dA, b_dA, 0.) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_bwd_k_kernel_dqkvg( + q, + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dg, + dgv, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + B: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + all = B * T + + o_i = tl.arange(0, BT) + o_t = min(i_t * BT + BT, T) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k)) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + o_v = i_v * BV + tl.arange(0, BV) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + m_v = o_v < V + + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_gv = exp(b_gn[None, :] - b_g) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_g) * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BV] + b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn) + + b_dh = b_dh.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_k.dtype)) + b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh)) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh) * b_gv + # [BV] + b_dg += tl.sum(b_dv * b_v, 0) + + if i_k == 0: + b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :] + else: + b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :] + + tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_gsa_bwd_k_kernel_intra_dvg( + v, + g, + o, + A, + do, + dv, + dg, + cu_seqlens, + chunk_indices, + T, + HQ: tl.constexpr, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr, + NG: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_v = i_v * BV + tl.arange(0, BV) + m_v = o_v < V + + if i_t * BT + i_i * BC > T: + return + + p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v + # [BV,] + b_gn = tl.load(p_gn, mask=m_v, other=0) + # [BC, BV] + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0)) + # [BC, BV] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :]) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + # [BC, BV] + b_dv += tl.dot(b_A, b_do.to(b_A.dtype)) + b_dv *= exp(b_gn[None, :] - b_gv) + + o_i = tl.arange(0, BC) + o_c = i_i * BC + tl.arange(0, BC) + + p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v + p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c + p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_A = tl.load(p_A) + # [BV,] + b_g = tl.load(p_g, mask=m_v, other=0) + b_do = tl.load(p_do, mask=m_v, other=0) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.) + + p_g += H * V + p_A += HQ * BT + p_do += HQ * V + p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0)) + + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32) + b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32) + b_dg = b_o * b_do - b_v * b_dv + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gsa_fwd_v( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float = 1., + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _, A, h, ht, o = chunk_gla_fwd( + q=q, + k=k, + v=v, + g=None, + g_cumsum=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return A, h, ht, o + + +def chunk_gsa_fwd_k( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + h0: Optional[torch.Tensor] = None, + output_final_state: bool = False, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BV = min(64, triton.next_power_of_2(V)) + HQ = q.shape[2] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NG = HQ // H + + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=None, + gv=g, + h0=h0, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=BT, + states_in_fp32=False + ) + o = v.new_empty(B, T, HQ, V) + A = q.new_empty(B, T, HQ, BT) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ) + chunk_gsa_fwd_k_kernel_inter[grid]( + q, + k, + h, + g, + o, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + ) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) + chunk_gsa_fwd_k_kernel_intra[grid]( + v, + g, + o, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + HQ=HQ, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + NG=NG, + num_warps=4, + num_stages=2 + ) + return A, h, ht, o + + +def chunk_gsa_bwd_v( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + h0: torch.Tensor, + h: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + dg: torch.Tensor, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + dq, dk, dv, dg, dh0 = chunk_gla_bwd( + q=q, + k=k, + v=v, + g=None, + g_cumsum=g, + scale=scale, + initial_state=h0, + h=h, + A=A, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return dq, dk, dv, dg, dh0 + + +def chunk_gsa_bwd_k( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + h: torch.Tensor, + h0: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + dg: torch.Tensor, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + HQ = q.shape[2] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + NG = HQ // H + + if h is None: + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=None, + gv=g, + h0=h0, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=BT, + states_in_fp32=False + ) + dh, dh0 = chunk_bwd_dh( + q=q, + k=k, + v=v, + g=None, + gk=None, + gv=g, + do=do, + h0=h0, + dht=dht, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + states_in_fp32=True + ) + dA = q.new_empty(NV, B, T, HQ, BT) + grid = (NV, NT * NC * NC, B * HQ) + chunk_gsa_bwd_k_kernel_dA[grid]( + v, + g, + do, + dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + HQ=HQ, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + NG=NG, + ) + dA = dA.sum(0, dtype=dA.dtype) + + A = do.new_empty(NK, B, T, HQ, BT) + dq = torch.empty_like(q) + dk = k.new_empty(B, T, HQ, K) + dv = v.new_empty(NK, B, T, HQ, V) + dgv = g.new_empty(NK, B, T, HQ, V, dtype=torch.float) + grid = (NK, NT, B * HQ) + chunk_gsa_bwd_k_kernel_dqkvg[grid]( + q, + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dg, + dgv, + dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + NG=NG, + ) + A = A.sum(0, dtype=A.dtype) + dv = dv.sum(0, dtype=dv.dtype) + dgv = dgv.sum(0, dtype=dgv.dtype) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) + chunk_gsa_bwd_k_kernel_intra_dvg[grid]( + v, + g, + o, + A, + do, + dv, + dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + HQ=HQ, + H=H, + V=V, + BT=BT, + BC=BC, + BV=BV, + NC=NC, + NG=NG, + num_warps=4, + num_stages=2 + ) + dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, cu_seqlens=cu_seqlens)) + + return dq, dk, dv, dg, dh0 + + +def chunk_gsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_final_state: bool = False, + scale: float = 1., + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hk0, hv0 = None, None + if initial_state is not None: + hk0, hv0 = initial_state + Ak, hk, hkt, ok = chunk_gsa_fwd_k( + q=q, + k=k, + v=s, + g=g, + h0=hk0, + output_final_state=output_final_state, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + # p is kept in fp32 for safe softmax backward + p = softmax_fwd(ok, dtype=torch.float) + + qv = p.to(q.dtype) + Av, hv, hvt, ov = chunk_gsa_fwd_v( + q=qv, + k=s, + v=v, + g=g, + scale=1., + initial_state=hv0, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return Ak, hk, hkt, ok, p, Av, hv, hvt, ov + + +def chunk_gsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + ok: torch.Tensor, + p: torch.Tensor, + A: Tuple[torch.Tensor, torch.Tensor], + h: Tuple[torch.Tensor, torch.Tensor], + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]], + scale: float, + do: torch.Tensor, + dht: Tuple[torch.Tensor, torch.Tensor], + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + hk0, hv0 = None, None + if initial_state is not None: + hk0, hv0 = initial_state + + _, Av = A + hk, hv = h + dhkt, dhvt = dht + + qv = p.to(q.dtype) + dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v( + q=qv, + k=s, + v=v, + g=g, + h0=hv0, + h=hv, + A=Av, + do=do, + dht=dhvt, + dg=None, + scale=1., + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + # softmax gradient, equivalent to: + # dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dok = softmax_bwd(p, dqv, dtype=ok.dtype) + + dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k( + q=q, + k=k, + v=s, + g=g, + h0=hk0, + h=hk, + o=ok, + do=dok, + dht=dhkt, + dg=dg, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + ds = dsv.add_(dsk) + if q.shape[1] != k.shape[1]: + dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg)) + dg = dg.to(s.dtype) + return dq, dk, dv, ds, dg, dhk0, dhv0 + + +class ChunkGSAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + scale: float, + hk0: Optional[torch.Tensor], + hv0: Optional[torch.Tensor], + output_final_state: bool, + checkpoint_level: int, + cu_seqlens: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + T = q.shape[1] + chunk_size = min(64, max(16, triton.next_power_of_2(T))) + + g_org, g = g, chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens) + Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd( + q=q, + k=k, + v=v, + s=s, + g=g, + initial_state=(hk0, hv0), + output_final_state=output_final_state, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + if checkpoint_level >= 1: + del g + g = g_org + if checkpoint_level > 1: + del hk + del hv + hk, hv = None, None + else: + hk0, hv0 = None, None + + ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv) + ctx.checkpoint_level = checkpoint_level + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + ctx.chunk_size = chunk_size + return ov, hkt, hvt + + @staticmethod + @input_guard + def backward(ctx, dov, dhkt=None, dhvt=None): + q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors + scale = ctx.scale + cu_seqlens = ctx.cu_seqlens + chunk_size = ctx.chunk_size + + if ctx.checkpoint_level >= 1: + g = chunk_local_cumsum(g, chunk_size, cu_seqlens=cu_seqlens) + dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd( + q=q, + k=k, + v=v, + s=s, + g=g, + ok=ok, + p=p, + A=(None, Av), + h=(hk, hv), + initial_state=(hk0, hv0), + scale=scale, + do=dov, + dht=(dhkt, dhvt), + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None + + +@torch.compiler.disable +def chunk_gsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False, + checkpoint_level: Optional[int] = 2, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: Optional[bool] = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is performed if `H` is not equal to `HQ`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + s (torch.Tensor): + slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, M]` applied to keys. + If not provided, this function is equivalent to vanilla ABC. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. + Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `2`: + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the fp32 cumulative values during backward. + - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (Tuple[torch.Tensor]): + Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`. + `None` otherwise. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gsa import fused_recurrent_gsa + # inputs with equal lengths + >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> s = torch.randn(B, T, H, M, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) + >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) + >>> o, (hk, hv) = chunk_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, (hk_var, hv_var) = chunk_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert hk.allclose(hk_var) + >>> assert hv.allclose(hv_var) + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, s, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, s, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}." + ) + assert checkpoint_level in [0, 1, 2] + if g is None: + # TODO: this 3 steps took huge amount of time, ought to be optimized + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 1) - z + s = torch.exp(s - z).to(k.dtype) + if scale is None: + scale = q.shape[-1] ** -0.5 + + hk0, hv0 = None, None + if initial_state is not None: + hk0, hv0 = initial_state + o, *final_state = ChunkGSAFunction.apply( + q, + k, + v, + s, + g, + scale, + hk0, + hv0, + output_final_state, + checkpoint_level, + cu_seqlens + ) + if head_first: + o = rearrange(o, 'b h t ... -> b t h ...') + return o, final_state diff --git a/fla3/ops/gsa/fused_recurrent.py b/fla3/ops/gsa/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..6febf9932b7510bf106f6e8507b32e1519813daa --- /dev/null +++ b/fla3/ops/gsa/fused_recurrent.py @@ -0,0 +1,525 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel +from fla.ops.utils import chunk_global_cumsum +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit +def fused_recurrent_gsa_inference_kernel( + q, + k, + v, + s, + g, + o, + hk0, + hv0, + hkt, + hvt, + scale, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_bh = tl.program_id(0) + i_bg = i_bh // NG + + b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = exp(b_g) + + b_ok = tl.zeros([M], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + + p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None] + # [BK,] + mask_k = o_k < K + # [M, BK] + mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :] + # [M, BK] + b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32) + # [BK,] + b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale + b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32) + b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None] + b_ok += tl.sum(b_hk * b_q[None, :], axis=1) + + if i_bh % NG == 0: + p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None] + tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk) + + b_qv = tl.softmax(b_ok) + for i_v in range(tl.cdiv(V, BV)): + o_v = i_v * BV + tl.arange(0, BV) + + p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + # [BV,] + mask_v = o_v < V + # [BV, M] + mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :] + # [BV, M] + b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32) + # [BV,] + b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32) + b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None] + b_ov = tl.sum(b_hv * b_qv[None, :], axis=1) + + tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v) + + if i_bh % NG == 0: + p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv) + + +def fused_recurrent_gsa_inference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_final_state: bool = False, + scale: float = 1., +) -> torch.Tensor: + B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + HQ = q.shape[2] + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NG = HQ // H + + if initial_state != (None, None) and initial_state is not None: + hk0, hv0 = initial_state + else: + hk0, hv0 = q.new_zeros(B, H, K, M, dtype=torch.float), q.new_zeros(B, H, M, V, dtype=torch.float) + + hkt, hvt = None, None + if output_final_state: + if NG == 1: + hkt, hvt = hk0, hv0 + else: + hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float) + + o = v.new_empty(B, T, HQ, V) + grid = (B * HQ,) + fused_recurrent_gsa_inference_kernel[grid]( + q, + k, + v, + s, + g, + o, + hk0, + hv0, + hkt, + hvt, + scale=scale, + K=K, + V=V, + M=M, + BK=BK, + BV=BV, + NG=NG + ) + return o, (hkt, hvt) + + +def fused_recurrent_gsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_final_state: bool = False, + scale: float = 1., + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + HQ = q.shape[2] + if HQ != H: + raise ValueError("GQA not supported yet.") + + BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + + hk0, hv0 = None, None + if initial_state != (None, None) and initial_state is not None: + hk0, hv0 = initial_state + hkt, hvt = None, None + if output_final_state: + hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float) + + ok = q.new_empty(NK, *s.shape, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=s, + g=None, + gk=gk, + gv=gv, + o=ok, + h0=hk0, + ht=hkt, + cu_seqlens=cu_seqlens, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=M, + BK=BK, + BV=BM, + USE_G=False, + USE_GK=False, + USE_GV=True, + REVERSE=reverse + ) + ok = ok.sum(0) + + qv = ok.softmax(-1, dtype=torch.float) + ov = q.new_empty(NM, *v.shape, dtype=torch.float) + gk, gv = g, None + grid = (NV, NM, N * H) + fused_recurrent_fwd_kernel[grid]( + q=qv, + k=s, + v=v, + g=None, + gk=gk, + gv=gv, + o=ov, + h0=hv0, + ht=hvt, + cu_seqlens=cu_seqlens, + scale=1., + B=B, + T=T, + H=H, + K=M, + V=V, + BK=BM, + BV=BV, + USE_G=False, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + ) + ov = ov.sum(0) + return ok, hkt, qv, ov, hvt + + +def fused_recurrent_gsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + qv: torch.Tensor, + hk0: Optional[torch.Tensor] = None, + hv0: Optional[torch.Tensor] = None, + ok: Optional[torch.Tensor] = None, + do: Optional[torch.Tensor] = None, + dhkt: Optional[torch.Tensor] = None, + dhvt: Optional[torch.Tensor] = None, + scale: float = 1., + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor]: + B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + + dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) + dhk0 = torch.empty_like(hk0)if hk0 is not None else None + dhv0 = torch.empty_like(hv0)if hv0 is not None else None + + gk, gv = g, None + grid = (NV, NM, N * H) + fused_recurrent_bwd_kernel[grid]( + q=qv, + k=s, + v=v, + g=None, + gk=gk, + gv=gv, + h0=hv0, + do=do, + dq=dqv, + dk=dsv, + dv=dv, + dht=dhvt, + dh0=dhv0, + cu_seqlens=cu_seqlens, + scale=1., + B=B, + T=T, + H=H, + K=M, + V=V, + BK=BM, + BV=BV, + USE_G=False, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + ) + dqv = dqv.sum(0) + dsv = dsv.sum(0) + dv = dv.sum(0) + dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + + dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=s, + g=None, + gk=gk, + gv=gv, + h0=hk0, + do=dok, + dq=dq, + dk=dk, + dv=dsk, + dht=dhkt, + dh0=dhk0, + cu_seqlens=cu_seqlens, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=M, + BK=BK, + BV=BM, + USE_G=False, + USE_GK=False, + USE_GV=True, + REVERSE=reverse, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dsk = dsk.sum(0) + + dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + + ds = dsk.add_(dsv) + dg = dgk.add_(dgv) + + return dq, dk, dv, ds, dg, dhk0, dhv0 + + +class FusedRecurrentGSAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + hk0: Optional[torch.Tensor] = None, + hv0: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + T = q.shape[1] + if T == 1 and not q.requires_grad: + o, (hkt, hvt) = fused_recurrent_gsa_inference( + q=q, + k=k, + v=v, + s=s, + g=g, + initial_state=(hk0, hv0), + output_final_state=output_final_state, + scale=scale, + ) + return o, hkt, hvt + ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd( + q=q, + k=k, + v=v, + s=s, + g=g, + initial_state=(hk0, hv0), + output_final_state=output_final_state, + scale=scale, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) + ctx.scale = scale + ctx.reverse = reverse + ctx.cu_seqlens = cu_seqlens + return ov.to(q.dtype), hkt, hvt + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dhkt=None, dhvt=None): + q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors + scale = ctx.scale + reverse = ctx.reverse + cu_seqlens = ctx.cu_seqlens + + # not supported yet. + if dhkt is not None or dhvt is not None: + if g is not None: + assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd( + q=q, + k=k, + v=v, + s=s, + g=g, + qv=qv, + hk0=hk0, + hv0=hv0, + ok=ok, + do=do, + dhkt=dhkt, + dhvt=dhvt, + scale=scale, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None + + +def fused_recurrent_gsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False, + reverse: Optional[bool] = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + s (torch.Tensor): + slot representations of shape `[B, T, H, M]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, M]` applied to keys. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`. + Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (Tuple[torch.Tensor]): + Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gsa import fused_recurrent_gsa + # inputs with equal lengths + >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> s = torch.randn(B, T, H, M, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) + >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) + >>> o, (hk, hv) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert hk.allclose(hk_var) + >>> assert hv.allclose(hv_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if initial_state is None: + initial_state = (None, None) + o, *final_state = FusedRecurrentGSAFunction.apply( + q, + k, + v, + s, + g, + scale, + *initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + return o, final_state diff --git a/fla3/ops/gsa/naive.py b/fla3/ops/gsa/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..486c4a7b569dcc804773e17b18369c9526711324 --- /dev/null +++ b/fla3/ops/gsa/naive.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import repeat + + +def naive_recurrent_gsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + q, k, v, s, g = map(lambda x: x.transpose(1, 2).contiguous().float(), (q, k, v, s, g)) + + NG = q.shape[1]//k.shape[1] + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) + if initial_state is not None: + initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) + + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) + ov = ov.transpose(1, 2).contiguous() + return ov.to(dtype), final_state diff --git a/fla3/ops/hgrn/__init__.py b/fla3/ops/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2012c3c15f125271df225ce755ed3b2dbe01a83 --- /dev/null +++ b/fla3/ops/hgrn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_hgrn +from .fused_recurrent import fused_recurrent_hgrn + +__all__ = [ + 'chunk_hgrn', + 'fused_recurrent_hgrn' +] diff --git a/fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f253916b66fb89c1ad0f445ae072c41b274f7cb1 Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/hgrn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/hgrn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64f1ab8a9891333c396da8e4b15c8862a17869a9 Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc b/fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f337bfe92207d64b61578cfebcd8045b8c8ce291 Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/hgrn/__pycache__/chunk.cpython-312.pyc b/fla3/ops/hgrn/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652edf6922baf6d52eb0e5c6bdf3d3041cb3da3d Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64b70b1fa8f59014bba4f768df4ea090eced4ca4 Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..496a22451158548ada6ad702979c555f45017a60 Binary files /dev/null and b/fla3/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/hgrn/chunk.py b/fla3/ops/hgrn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6847622ebfb071230720b7ae6669f5412a42470b --- /dev/null +++ b/fla3/ops/hgrn/chunk.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# this function implements the chunkwise form of HGRN, inspired by +# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html) +# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan + +# from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent: +# +# Performance: +# seq_len chunk recurrent chunk_bwd recurrent_bwd +# 0 128.0 0.039360 0.061056 0.312160 0.205008 +# 1 256.0 0.045824 0.123712 0.308784 0.297696 +# 2 512.0 0.058688 0.241952 0.310720 0.626528 +# 3 1024.0 0.088288 0.476992 0.313184 1.333152 +# 4 2048.0 0.169472 0.943264 0.452464 2.724864 +# 5 4096.0 0.329920 1.886144 0.881600 5.551520 +# 6 8192.0 0.647872 3.755040 1.740496 11.117184 +# 7 16384.0 1.272064 7.520576 3.446608 22.362528 + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_h( + x, + g, + gc, + o, + h0, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_b * T * D + i_t * BT * D + o_d + p_g = g + i_b * T * D + i_t * BT * D + o_d + p_gc = gc + i_b * T * D + i_t * BT * D + o_d + p_o = o + i_b * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32) + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_o( + gc, + o, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # [BT, BD] + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_o = b_o + exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_h( + g, + gc, + dx, + do, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_o( + g, + gc, + o, + dx, + dg, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + mask_t = mask & ((i_t + 1) * BT < T) + b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + + b_dx = b_dx + exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, T, D = x.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + o = torch.empty_like(x, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_fwd_kernel_h[grid]( + x, g, gc, o, initial_state, + T=T, D=D, BT=BT, + USE_INITIAL_STATE=initial_state is not None + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_fwd_kernel_o[grid]( + gc, o, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + final_state = None + if output_final_state: + final_state = o[:, -1].clone() + o = o.to(x.dtype) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, T, D = do.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + dx = torch.empty_like(o, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_bwd_kernel_h[grid]( + g, gc, dx, do, + T=T, D=D, BT=BT + ) + + dg = torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_bwd_kernel_o[grid]( + g, gc, o, dx, dg, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + if initial_state is not None: + dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype) + + return dx.to(o.dtype), dg, None, None + + +@torch.compiler.disable +def chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) diff --git a/fla3/ops/hgrn/fused_recurrent.py b/fla3/ops/hgrn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..e5857482c82fe4ed2ce83ef2928d4f384f11253a --- /dev/null +++ b/fla3/ops/hgrn/fused_recurrent.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_fwd_kernel( + x, + g, + o, + h0, + ht, + cu_seqlens, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + bos * D + o_d + p_g = g + bos * D + o_d + p_o = o + bos * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_n * D + o_d + b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) + for _ in range(0, T): + b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) + + p_x += D + p_g += D + p_o += D + + if STORE_FINAL_STATE: + p_ht = ht + i_n * D + o_d + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_bwd_kernel( + g, + o, + h0, + dx, + dg, + do, + dht, + dh0, + cu_seqlens, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_g = g + (bos + T - 1) * D + o_d + p_o = o + (bos + T - 2) * D + o_d + p_dx = dx + (bos + T - 1) * D + o_d + p_dg = dg + (bos + T - 1) * D + o_d + p_do = do + (bos + T - 1) * D + o_d + + b_dh = tl.zeros([BD], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_n * D + o_d + b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32) + + for i in range(T - 1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + if i > 0: + b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) + elif USE_INITIAL_STATE: + b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32) + else: + b_o = tl.zeros([BD], dtype=tl.float32) + + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + b_dg = b_dh * b_o + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) + + p_g -= D + p_o -= D + p_dx -= D + p_dg -= D + p_do -= D + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_n * D + o_d + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask) + + +def fused_recurrent_hgrn_fwd( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = x.shape + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + o = torch.empty_like(x) + final_state = x.new_empty(N, D) if output_final_state else None + + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_fwd_kernel[grid]( + x=x, + g=g, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + T=T, + D=D + ) + return o, final_state + + +def fused_recurrent_hgrn_bwd( + g: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor = None, + initial_state: torch.Tensor = None, + cu_seqlens: Optional[torch.LongTensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = do.shape + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + dx = torch.empty_like(o, dtype=torch.float) + dg = torch.empty_like(g, dtype=torch.float) + dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_bwd_kernel[grid]( + g=g, + o=o, + h0=initial_state, + dx=dx, + dg=dg, + do=do, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + T=T, + D=D + ) + return dx, dg, dh0 + + +class FusedRecurrentHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None + ): + o, ht = fused_recurrent_hgrn_fwd( + x=x, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens + ) + ctx.save_for_backward(g, o, initial_state) + ctx.cu_seqlens = cu_seqlens + return o, ht + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + cu_seqlens = ctx.cu_seqlens + + dx, dg, dh0 = fused_recurrent_hgrn_bwd( + g=g, + o=o, + do=do, + dht=dht, + initial_state=initial_state, + cu_seqlens=cu_seqlens + ) + return dx, dg, dh0, None, None + + +@torch.compiler.disable +def fused_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + x (torch.Tensor): + inputs of shape `[B, T, D]. + g (torch.Tensor): + Forget gates of shape `[B, T, D]`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, D]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, D]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, D]`. + final_state (torch.Tensor): + Final state of shape `[N, D]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.hgrn import fused_recurrent_hgrn + # inputs with equal lengths + >>> B, T, D = 4, 2048, 512 + >>> x = torch.randn(B, T, D, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda')) + >>> h0 = torch.randn(B, D, device='cuda') + >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + return FusedRecurrentHGRNFunction.apply( + x, + g, + initial_state, + output_final_state, + cu_seqlens + ) diff --git a/fla3/ops/hgrn/naive.py b/fla3/ops/hgrn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcddc1967b31c5181d330704c7b5ff2127e9d68 --- /dev/null +++ b/fla3/ops/hgrn/naive.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(T): + h = g[:, i].exp() * h + x[:, i] + o[:, i] = h + + if output_final_state: + final_state = h + return o.to(dtype), final_state + + +def naive_chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False, + chunk_size: int = 64 +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g) + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(0, T, chunk_size): + hp = h + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + for j in range(i, i + chunk_size): + h = g[:, j].exp() * h + x[:, j] + o[:, j] = hp * gc[:, j].exp() + h + h = o[:, j].clone() + + if output_final_state: + final_state = h + return o.to(dtype), final_state diff --git a/fla3/ops/lightning_attn/__init__.py b/fla3/ops/lightning_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c28c3af59f61d32cbb68a63926ac67fa2bb73447 --- /dev/null +++ b/fla3/ops/lightning_attn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_lightning_attn +from .fused_recurrent import fused_recurrent_lightning_attn + +__all__ = [ + 'chunk_lightning_attn', + 'fused_recurrent_lightning_attn' +] diff --git a/fla3/ops/lightning_attn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/lightning_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3586cc2c3f57196e6cb58e2ab32dc63e161ecb7d Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1720e14badf1ac6407887bfb13a3f18bae6e4ad2 Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/lightning_attn/__pycache__/chunk.cpython-310.pyc b/fla3/ops/lightning_attn/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba04e9a13024276f1ec85896b0c1b1974bedc0e5 Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/lightning_attn/__pycache__/chunk.cpython-312.pyc b/fla3/ops/lightning_attn/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..863c777423206aed854b542d2514c290e921acbb Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5123bc5c95270639922fbd4ef46bb1a5e41e8ba9 Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02148f5fad929d123235282b23b362702367561d Binary files /dev/null and b/fla3/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/lightning_attn/chunk.py b/fla3/ops/lightning_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fbea144d6c76c86ba33c362b6eddcbb083e2d7 --- /dev/null +++ b/fla3/ops/lightning_attn/chunk.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.chunk import chunk_simple_gla + + +@torch.compiler.disable +def chunk_lightning_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + num_layers: int, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + layer_idx (int): + The index of the current layer. + num_layers (int): + The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + H = q.shape[1] if head_first else q.shape[2] + s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + head_first=head_first, + cu_seqlens=cu_seqlens + ) diff --git a/fla3/ops/lightning_attn/fused_recurrent.py b/fla3/ops/lightning_attn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..8db06e33ee0650ad52b65be0c9d5c8612f929c46 --- /dev/null +++ b/fla3/ops/lightning_attn/fused_recurrent.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla + + +def fused_recurrent_lightning_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + num_layers: int, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + layer_idx (int): + The index of the current layer. + num_layers (int): + The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + H = q.shape[1] if head_first else q.shape[2] + s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) diff --git a/fla3/ops/linear_attn/__init__.py b/fla3/ops/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a981054aaf9ab98b30ac08fa525bde73e68e7e4 --- /dev/null +++ b/fla3/ops/linear_attn/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_linear_attn +from .fused_chunk import fused_chunk_linear_attn +from .fused_recurrent import fused_recurrent_linear_attn + +__all__ = [ + 'chunk_linear_attn', + 'fused_chunk_linear_attn', + 'fused_recurrent_linear_attn' +] diff --git a/fla3/ops/linear_attn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/linear_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37637925d44ce8bc99ae6d86ac458f207a436e23 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/linear_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90af510be0ae08cb04fc4d3b38107357ca47ef73 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/chunk.cpython-310.pyc b/fla3/ops/linear_attn/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a4070ee0de8e415448b9ffbf96e8b6ddc3d9697 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/chunk.cpython-312.pyc b/fla3/ops/linear_attn/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e43bf83a7e7bd19d7751010c7428cb8b035938c5 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-310.pyc b/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b5a731e4ef1cc0e471352487c8d2bd48ec86a5c Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-310.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc b/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc62a7e760ec1611ee2ce68dafa0b7078282fbd7 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..effb2559d2a9eeb88b4459c17c74d885619896ee Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a597d4ca3b2a4922a45f11dbcec4d74339a4c883 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/linear_attn/__pycache__/utils.cpython-310.pyc b/fla3/ops/linear_attn/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854e81b3ed03d07e215867304341fce6c6131958 Binary files /dev/null and b/fla3/ops/linear_attn/__pycache__/utils.cpython-310.pyc differ diff --git a/fla3/ops/linear_attn/chunk.py b/fla3/ops/linear_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..02b2d51fc237ba76e3fac5b43f240aaca86bcf20 --- /dev/null +++ b/fla3/ops/linear_attn/chunk.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch + +from fla.ops.linear_attn.utils import normalize_output +from fla.ops.simple_gla import chunk_simple_gla + + +@torch.compiler.disable +def chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + scale (Optional[int]): + Scale factor for the linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + + if scale is None: + scale = k.shape[-1] ** -0.5 + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + if not head_first: + if q.shape[1] < q.shape[2]: + raise DeprecationWarning( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + o, final_state = chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=None, + initial_state=initial_state, + output_final_state=output_final_state + ) + if normalize: + o = normalize_output(q * scale, k, o) + if head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla3/ops/linear_attn/fused_chunk.py b/fla3/ops/linear_attn/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..62054d51895dfcb2427a87e0e4869567e5a7289c --- /dev/null +++ b/fla3/ops/linear_attn/fused_chunk.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [4] + for num_stages in [1] + ], + key=['B', 'H', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_linear_attn_fwd_kernel( + q, + k, + v, + o, + h0, + ht, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + (i_b * T*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (i_b * T*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (i_b * T*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_k*B*T*H + i_b*T*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i_t == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [4] + for num_stages in [1] + ], + key=['B', 'H', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def fused_chunk_linear_attn_bwd_kernel( + q, + k, + v, + do, + dq, + dk, + dv, + h0, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + o_i = tl.arange(0, BT) + + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + (i_b * T*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (i_b * T*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (i_b * T*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_v*B*T*H+i_b*T*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [BV, BK] + if CHECK and i_t == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i_t in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + (i_b * T*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, T - i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + (i_b * T*H + i_h) * K, (T, K), (H*K, 1), (T - i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (i_b * T*H + i_h) * V, (T, V), (H*V, 1), (T - i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (i_b * T*H + i_h) * V, (T, V), (H*V, 1), (T - i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v*B*T*H+i_b*T*H+i_h) * K, (T, K), (H*K, 1), (T - i_t*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*B*T*H+i_b*T*H+i_h) * V, (T, V), (H*V, 1), (T - i_t*BT, i_v*BV), (BT, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + if CHECK and i_t == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(64, max(16, triton.next_power_of_2(T))) + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + o = q.new_empty(NK, *v.shape) + final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_linear_attn_fwd_kernel[grid]( + q, + k, + v, + o, + initial_state, + final_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + CHECK=CHECK + ) + o = o.sum(0) if NK > 1 else o[0] + + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, T, H, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = min(64, max(16, triton.next_power_of_2(T))) + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape) + dk = q.new_empty(NV, *k.shape) + dv = q.new_empty(NK, *v.shape) + grid = (NV, NK, B * H) + + fused_chunk_linear_attn_bwd_kernel[grid]( + q, + k, + v, + do, + dq, + dk, + dv, + initial_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + CHECK=ctx.CHECK + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + if not head_first: + if q.shape[1] < q.shape[2]: + raise DeprecationWarning( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + if head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla3/ops/linear_attn/fused_recurrent.py b/fla3/ops/linear_attn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..e8560d8949856ff251ad7763f49d1393248f4fa0 --- /dev/null +++ b/fla3/ops/linear_attn/fused_recurrent.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_linear_attn_fwd_kernel( + q, + k, + v, + o, + h0, + ht, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * T*V + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * T*V + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_linear_attn_bwd_kernel( + q, + k, + v, + do, + dq, + dk, + dv, + h0, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * T*V + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * T*V + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * T*K + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + b_h += b_k[:, None] * b_v[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dq += K + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_dk = dk + (i_bh + i_v * B * H) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * b_v[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + + +class FusedRecurrentLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False): + B, H, T, K = q.shape + V = v.shape[-1] + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V) if output_final_state else None + + grid = (NV, NK, B * H) + fused_recurrent_linear_attn_fwd_kernel[grid]( + q, + k, + v, + o, + initial_state, + final_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K = q.shape + V = v.shape[-1] + scale = ctx.scale + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_recurrent_linear_attn_bwd_kernel[grid]( + q, + k, + v, + do, + dq, + dk, + dv, + initial_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None, None + + +def fused_recurrent_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + if not head_first: + if q.shape[1] < q.shape[2]: + raise DeprecationWarning( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, final_state = FusedRecurrentLinearAttentionFunction.apply( + q, + k, + v, + scale, + initial_state, + output_final_state + ) + if normalize: + o = normalize_output(q * scale, k, o) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla3/ops/linear_attn/naive.py b/fla3/ops/linear_attn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..5adfb64cec3dbdc0ad443dff16d642c4c66ff358 --- /dev/null +++ b/fla3/ops/linear_attn/naive.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.linear_attn.utils import normalize_output + + +def naive_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = (( + q @ k.transpose(-1, -2)).masked_fill_( + torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), + 0 + )) @ v + o = inter + intra + if normalize: + o = normalize_output(q * scale, k, o) + return rearrange(o, 'b h n c d -> b (n c) h d') diff --git a/fla3/ops/linear_attn/utils.py b/fla3/ops/linear_attn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b444376833f5d512af6fc2db387db75a43a92e5d --- /dev/null +++ b/fla3/ops/linear_attn/utils.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +import torch + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.cumsum(-2) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-10) diff --git a/fla3/ops/nsa/__init__.py b/fla3/ops/nsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..941a1be41e1650961af0d28e64837421826ffab2 --- /dev/null +++ b/fla3/ops/nsa/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .naive import naive_nsa +from .parallel import parallel_nsa + +__all__ = [ + 'naive_nsa', + 'parallel_nsa' +] diff --git a/fla3/ops/nsa/__pycache__/__init__.cpython-310.pyc b/fla3/ops/nsa/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7290ed3604995965ccec228ad9a0503533de5e26 Binary files /dev/null and b/fla3/ops/nsa/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/nsa/__pycache__/__init__.cpython-312.pyc b/fla3/ops/nsa/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7179a80c7c06124a6d6e17e6825161d788df195 Binary files /dev/null and b/fla3/ops/nsa/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/nsa/__pycache__/compression.cpython-310.pyc b/fla3/ops/nsa/__pycache__/compression.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac0c35a9937796b9ffe8a9be8802bb43bba65f42 Binary files /dev/null and b/fla3/ops/nsa/__pycache__/compression.cpython-310.pyc differ diff --git a/fla3/ops/nsa/__pycache__/naive.cpython-310.pyc b/fla3/ops/nsa/__pycache__/naive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d1a8774c69626b4e384e306cd60106c8ab1e7ff Binary files /dev/null and b/fla3/ops/nsa/__pycache__/naive.cpython-310.pyc differ diff --git a/fla3/ops/nsa/__pycache__/naive.cpython-312.pyc b/fla3/ops/nsa/__pycache__/naive.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0352fc3456ab2da96ab5e9c7aed5a19b5d97547a Binary files /dev/null and b/fla3/ops/nsa/__pycache__/naive.cpython-312.pyc differ diff --git a/fla3/ops/nsa/__pycache__/parallel.cpython-310.pyc b/fla3/ops/nsa/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7762523482b6ebe8efceccffc2b9aad54f4ec354 Binary files /dev/null and b/fla3/ops/nsa/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/nsa/__pycache__/parallel.cpython-312.pyc b/fla3/ops/nsa/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884122c59821c60c1554e9c91c25d822d6820685 Binary files /dev/null and b/fla3/ops/nsa/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla3/ops/nsa/__pycache__/utils.cpython-310.pyc b/fla3/ops/nsa/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a87811ded73f1364bafeed57c73c74a34dd9da Binary files /dev/null and b/fla3/ops/nsa/__pycache__/utils.cpython-310.pyc differ diff --git a/fla3/ops/nsa/compression.py b/fla3/ops/nsa/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..d696d21725c178c57736d8d37ec766f25775c04e --- /dev/null +++ b/fla3/ops/nsa/compression.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.attn.parallel import parallel_attn_bwd_preprocess +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_token_indices +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_compression_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + cu_seqlens, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_o = b_o / b_acc[:, None] + b_lse = b_m + log(b_acc) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + if i_v == 0: + tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + cu_seqlens, + token_indices, + chunk_offsets, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((o_c < NC)[None, :], b_p, 0) + + # [G, BV] @ [BV, BC] -> [G, BC] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BC] @ [BC, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + cu_seqlens, + chunk_indices, + chunk_offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + + for i in range(i_c * BC * BS, T): + o_c = i_c * BC + tl.arange(0, BC) + + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BC, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0) + # [BC, G] @ [G, BV] -> [BC, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BC, BV] @ [BV, G] -> [BC, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BC, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BC, G] @ [G, BK] -> [BC, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_nsa_compression_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BS) if cu_seqlens is not None else None + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_compression_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + cu_seqlens=cu_seqlens, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_compression_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_size: int = 64, + scale: float = None, + cu_seqlens: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + if cu_seqlens is not None: + chunk_indices, chunk_offsets = prepare_chunk_indices(cu_seqlens, BS), prepare_chunk_offsets(cu_seqlens, BS) + NC = len(chunk_indices) + else: + chunk_indices, chunk_offsets = None, None + NC = triton.cdiv(triton.cdiv(T, BS), BC) + + delta = parallel_attn_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_compression_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + scale=scale, + cu_seqlens=cu_seqlens, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NC, B * H) + parallel_nsa_compression_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +class ParallelNSACompressionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + block_size, + scale, + cu_seqlens + ): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the cu_seqlens of tokens in each sequence + # for example, if the passed `cu_seqlens` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(cu_seqlens) if cu_seqlens is not None else None + + o, lse = parallel_nsa_compression_fwd( + q=q, + k=k, + v=v, + block_size=block_size, + scale=scale, + cu_seqlens=cu_seqlens, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.cu_seqlens = cu_seqlens + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype), lse + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, *args): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_compression_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_size=ctx.block_size, + scale=ctx.scale, + cu_seqlens=ctx.cu_seqlens, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None + + +def parallel_nsa_compression( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int = 64, + scale: float = None, + cu_seqlens: Optional[torch.LongTensor] = None +): + if scale is None: + scale = k.shape[-1] ** -0.5 + return ParallelNSACompressionFunction.apply( + q, + k, + v, + block_size, + scale, + cu_seqlens + ) diff --git a/fla3/ops/nsa/naive.py b/fla3/ops/nsa/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb49c4ef1621bfd43d29ed10a480314985d7465 --- /dev/null +++ b/fla3/ops/nsa/naive.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +from einops import rearrange, repeat + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_size: int = 64, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_size (int): + Selected block size. Default: 64. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if head_first: + q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, block_indices)) + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices)) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o = torch.zeros_like(v) + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([ + block_indices.new_tensor(range(0, B*T, T)), block_indices.new_tensor([B*T]) + ]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, i_b = q[i], k[i], v[i], block_indices[i] + else: + T = cu_seqlens[i+1] - cu_seqlens[i] + q_b, k_b, v_b, i_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], (q, k, v, block_indices)) + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [S*BS, HQ] + i_i = i_b[i_q] + # [S*BS, HQ, -1] + k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(i_i > i_q, float('-inf')).softmax(0) + if not varlen: + o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + else: + o[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o.to(dtype) diff --git a/fla3/ops/nsa/parallel.py b/fla3/ops/nsa/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5d7ebb4734f40a0c0cbf87c4bd0a7a2b1a91d3 --- /dev/null +++ b/fla3/ops/nsa/parallel.py @@ -0,0 +1,881 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.attn.parallel import parallel_attn_bwd_preprocess +from fla.ops.nsa.compression import parallel_nsa_compression +from fla.ops.nsa.utils import _bitonic_merge +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices +from fla.ops.utils.op import exp, log +from fla.ops.utils.pooling import mean_pooling +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK'], +) +@triton.jit +def parallel_nsa_kernel_topk( + q, + k, + lse, + scale, + block_indices, + cu_seqlens, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + S: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + ################################ + # 1. lse computation + ################################ + if lse is not None: + b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)) + else: + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_lse = b_m + log(b_acc) + + ################################ + # 2. topk selection + ################################ + # [BC] + b_i = tl.full([BC], -1, dtype=tl.float32) + o_i = tl.zeros([BC], dtype=tl.int32) + m_i = tl.arange(0, BC) < BC//2 + for i_c in range(0, i_t // BS + 1, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf')) + # [G, BC] + b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None])) + # the importance scores of the current block + # [BC] + b_i, b_ip = tl.sum(b_p, 0), b_i + o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i + + n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0]) + for i in tl.static_range(1, n_dims): + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims) + + if i_c != 0: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims) + b_i_new = b_ip * m_i + b_i * (1 - m_i) + o_i_new = o_ip * m_i + o_i * (1 - m_i) + b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims) + else: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims) + + m_top = tl.arange(0, BC//S) == 0 + b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0) + + p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,)) + tl.store(p_b, b_top.to(p_b.dtype.element_ty)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + block_indices, + block_counts, + cu_seqlens, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BS] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_s = i_hs // S, i_hs % S + + b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) + else: + b_m = b_i * BS <= i_t + + if b_i < NS and b_i >= 0: + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + block_indices, + block_counts, + cu_seqlens, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + block_mask, + cu_seqlens, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BS, BK], dtype=tl.float32) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BS, BV], dtype=tl.float32) + + for i in range(i_s * BS, T): + b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s) + if b_m: + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BS, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BS, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_nsa_topk( + q: torch.Tensor, + k: torch.Tensor, + lse: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> torch.LongTensor: + B, T, HQ, K = q.shape + H = k.shape[2] + G = HQ // H + S = block_counts if isinstance(block_counts, int) else block_counts.max().item() + S = triton.next_power_of_2(S) + # here we set BC = BS, but beware that they are actually decoupled + BC = BS = block_size + BK = triton.next_power_of_2(K) + + block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device) + token_indices = prepare_token_indices(cu_seqlens) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BS) if cu_seqlens is not None else None + grid = (T, B * H) + parallel_nsa_kernel_topk[grid]( + q=q, + k=k, + lse=lse, + scale=scale, + block_indices=block_indices, + cu_seqlens=cu_seqlens, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + S=S, + BC=BC, + BS=BS, + BK=BK + ) + return block_indices + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + cu_seqlens=cu_seqlens, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_block_mask( + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + cu_seqlens: torch.LongTensor, + block_size: int, +): + B, T, H, S = block_indices.shape + BS = block_size + if cu_seqlens is not None: + NS = triton.cdiv(prepare_lens(cu_seqlens).max().item(), BS) + else: + NS = triton.cdiv(T, BS) + block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) + + parallel_nsa_kernel_mask[(T, B, H*S)]( + block_indices=block_indices, + block_counts=block_counts, + block_mask=block_mask, + T=T, + H=H, + S=S, + BS=BS, + NS=NS + ) + return block_mask + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + cu_seqlens: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + delta = parallel_attn_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + block_indices=block_indices, + block_counts=block_counts, + cu_seqlens=cu_seqlens, + token_indices=token_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + if cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BS) + NS = len(chunk_indices) + else: + chunk_indices = None + NS = triton.cdiv(T, BS) + + # [B, T, H, M] + block_mask = parallel_nsa_block_mask(block_indices, block_counts, cu_seqlens, block_size) + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NS, B * H) + parallel_nsa_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + block_mask=block_mask, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + M=block_mask.shape[-1], + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the cu_seqlens of tokens in each sequence + # for example, if the passed `cu_seqlens` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(cu_seqlens) if cu_seqlens is not None else None + + o, lse = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + cu_seqlens=cu_seqlens, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.cu_seqlens = cu_seqlens + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + scale=ctx.scale, + cu_seqlens=ctx.cu_seqlens, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: Optional[torch.Tensor] = None, + g_slc: Optional[torch.Tensor] = None, + g_swa: Optional[torch.Tensor] = None, + block_indices: Optional[torch.LongTensor] = None, + block_counts: Union[torch.LongTensor, int] = 16, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_cmp (torch.Tensor): + Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + If `g_cmp` is provided, the passed `block_indices` will be ignored. + block_counts (Optional[Union[torch.LongTensor, int]]): + Number of selected blocks for each query. + If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`, + each query can select the same number of blocks. + If not provided, it will default to 16. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + assert block_counts is not None, "block counts must be provided for selection" + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa)) + if not isinstance(block_counts, int): + block_counts = rearrange(block_counts, 'b h t -> b t h') + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens) + o_cmp, lse_cmp = None, None + if g_cmp is not None: + o_cmp, lse_cmp = parallel_nsa_compression( + q=q, + k=k_cmp, + v=v_cmp, + block_size=block_size, + scale=scale, + cu_seqlens=cu_seqlens + ) + if block_indices is not None: + warnings.warn("`block_indices` will be ignored when `g_cmp` is provided") + block_indices = parallel_nsa_topk( + q=q, + k=k_cmp, + lse=lse_cmp, + block_counts=block_counts, + block_size=block_size, + scale=scale, + cu_seqlens=cu_seqlens + ) + o = o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens) + if g_slc is not None: + o = o_slc * g_slc.unsqueeze(-1) + if o_cmp is not None: + o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1)) + if window_size > 0: + if cu_seqlens is not None: + max_seqlen = q.shape[1] + o_swa = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(window_size-1, 0) + ).unsqueeze(0) + else: + o_swa = flash_attn_func( + q, k, v, + causal=True, + window_size=(window_size-1, 0) + ) + o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1)) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla3/ops/nsa/utils.py b/fla3/ops/nsa/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73e54138b750a280c4f8edd04ca36ffb3f58705f --- /dev/null +++ b/fla3/ops/nsa/utils.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Implements argsort based on bitonic sort. +# [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter) + +# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 + + +import triton +import triton.language as tl + +from fla.ops.utils.op import log2 + + +@triton.jit +def _compare_and_swap( + x, + ids, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = tl.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.arange(0, 2)[None, :, None] + left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype) + left = tl.reshape(left, x.shape) + right = tl.reshape(right, x.shape) + # idx + y_idx = tl.reshape(ids, shape) + left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype) + right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype) + # actual compare-and-swap + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) != flip + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge( + x, + ids, + stage: tl.constexpr, + order: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort( + x, + ids, + dim: tl.constexpr = None, + descending: tl.constexpr = tl.core.CONSTEXPR_0, +): + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids diff --git a/fla3/ops/path_attn/__init__.py b/fla3/ops/path_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9959ce2d1bb575bad4971a16dcfe8bf58f78a62f --- /dev/null +++ b/fla3/ops/path_attn/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_path_attention + +__all__ = [ + 'parallel_path_attention' +] diff --git a/fla3/ops/path_attn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92655d7328eec916f5647f131bfb9f65c00f1f03 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87d09f3fa9e9a6f4066af1a91c11fb6f86c0648f Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/cumprod_householder_bwd.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/cumprod_householder_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da760beb1af84c67e7ca1a36ab07b9a0a4f2b757 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/cumprod_householder_bwd.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/cumprod_householder_fwd.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/cumprod_householder_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d6cec08eda67e74d6cc155fa85b835ccbca9a2 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/cumprod_householder_fwd.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f32df9d58a88a8e10abb2975b43787710272c643 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_bwd_prepare.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e74419a273559657283ea92b9c35d4d8678c08 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/intra_chunk_preprocess_fwd.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/parallel.cpython-310.pyc b/fla3/ops/path_attn/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..783fae1c1cead73ac935a84e3fd2c68a5238095d Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc b/fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..104bf1270b82e45d0081fb4b49fbb0d8f0a99ed6 Binary files /dev/null and b/fla3/ops/path_attn/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla3/ops/path_attn/cumprod_householder_bwd.py b/fla3/ops/path_attn/cumprod_householder_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..959917bc97457ebafb6c999e3d8578d72bcffc31 --- /dev/null +++ b/fla3/ops/path_attn/cumprod_householder_bwd.py @@ -0,0 +1,117 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def chunk_cumprod_householder_bwd_kernel( + h, hc_suffix, dhc_whole, dh, + k, dk, dk_new, + cu_seqlens, split_indices, chunk_offsets, split_offsets, + BT: tl.constexpr, # previous small chunk size + K: tl.constexpr, + BK: tl.constexpr, + T: tl.constexpr, + S: tl.constexpr, + G: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_ss, i_hq = tl.program_id(0), tl.program_id(1) + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + boh_large = tl.load(split_offsets + i_n).to(tl.int32) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + boh = i_n * tl.cdiv(T, BT) + boh_large = i_n * tl.cdiv(T, S) + + # offset calculations + h += ((boh + tl.cdiv(i_s * S, BT)) * H + i_h) * K * K + hc_suffix += ((boh + tl.cdiv(i_s * S, BT)) * H + i_h) * K * K + k += (bos * H + i_h) * K + + dh += ((boh + tl.cdiv(i_s * S, BT)) * HQ + i_hq) * K * K + dhc_whole += ((boh_large + i_s) * HQ + i_hq) * K * K + dk += (bos * HQ + i_hq) * K + dk_new += (bos * HQ + i_hq) * K + + stride_hq = HQ * K * K + stride_h = H * K * K + NT_small = tl.cdiv(min(S, T-i_s*S), BT) + p_dhc_whole = tl.make_block_ptr(dhc_whole, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_dhc = tl.zeros([BK, BK], dtype=tl.float32) + b_dhc += tl.load(p_dhc_whole, boundary_check=(0, 1)) + + # calculate dh + for i_t_small in range(0, NT_small): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_s*S + i_t_small*BT, 0), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (HQ*K, 1), (i_s*S + i_t_small*BT, 0), (BT, BK), (1, 0)) + p_dk_new = tl.make_block_ptr(dk_new, (T, K), (HQ*K, 1), (i_s*S + i_t_small*BT, 0), (BT, BK), (1, 0)) + p_hc = tl.make_block_ptr(hc_suffix + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_t_small * stride_hq, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_hc = tl.load(p_hc, boundary_check=(0, 1)) + b_dk_new = b_dk - tl.dot(b_dk.to(b_hc.dtype), b_hc) + tl.store(p_dk_new, b_dk_new.to(dk.dtype.element_ty), boundary_check=(0, 1)) + b_dh = b_dhc - tl.dot(tl.trans(b_hc), b_dhc.to(b_hc.dtype)) + tl.store(p_dh, b_dh.to(dh.dtype.element_ty), boundary_check=(0, 1)) + b_dhc = b_dhc - tl.dot(b_dhc.to(b_h.dtype), tl.trans(b_h)) + b_dhc -= tl.dot(tl.trans(b_dk).to(b_k.dtype), b_k) + + +def chunk_cumprod_householder_bwd_fn( + h: torch.Tensor, + hc_suffix: torch.Tensor, + dhc_whole: torch.Tensor, + k: torch.Tensor, + dk: torch.Tensor, + S: int, # split size, aka large chunk size + BT: int, # small chunk size + cu_seqlens: torch.Tensor = None, +): + B, T, HQ, K = dk.shape + H = k.shape[2] + G = HQ // H + + split_indices = prepare_chunk_indices(cu_seqlens, S) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + split_offsets = prepare_chunk_offsets(cu_seqlens, S) if cu_seqlens is not None else None + + if cu_seqlens is None: + N = B + NS = N * triton.cdiv(T, S) + else: + N = len(cu_seqlens) - 1 + NS = split_offsets[-1].item() + + grid = (NS, HQ) + dh = torch.empty(hc_suffix.shape[0], HQ, K, K, device=dk.device, dtype=torch.float32) + dk_new = torch.empty_like(dk) + + chunk_cumprod_householder_bwd_kernel[grid]( + h=h, hc_suffix=hc_suffix, dhc_whole=dhc_whole, dh=dh, + k=k, dk=dk, dk_new=dk_new, + cu_seqlens=cu_seqlens, + split_indices=split_indices, chunk_offsets=chunk_offsets, split_offsets=split_offsets, + BT=BT, K=K, G=G, H=H, HQ=HQ, BK=K, + T=T, S=S + ) + return dh, dk_new diff --git a/fla3/ops/path_attn/cumprod_householder_fwd.py b/fla3/ops/path_attn/cumprod_householder_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..426905d2605b6712e7c36bf88a1d01bdbbce92c1 --- /dev/null +++ b/fla3/ops/path_attn/cumprod_householder_fwd.py @@ -0,0 +1,167 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit +def chunk_cumprod_householder_fwd_kernel( + q, + q_new, + k, + k_new, + h, + hc_suffix, + hc_prefix, + hc_whole, + cu_seqlens, + split_indices, + chunk_offsets, + split_offsets, + BT: tl.constexpr, # small chunk size + K: tl.constexpr, + G: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + BK: tl.constexpr, + T: tl.constexpr, + S: tl.constexpr, # split size, aka large chunk size + IS_VARLEN: tl.constexpr, +): + i_ss, i_hq = tl.program_id(0), tl.program_id(1) + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + boh_large = tl.load(split_offsets + i_n).to(tl.int32) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + + boh = i_n * tl.cdiv(T, BT) + boh_large = i_n * tl.cdiv(T, S) + + NT_small = tl.cdiv(min(S, T-i_s*S), BT) + stride_h = H*K*K + + # offset calculations + h += ((boh + tl.cdiv(i_s * S, BT)) * H + i_h) * K * K + hc_suffix += ((boh + tl.cdiv(i_s * S, BT)) * H + i_h) * K * K + hc_prefix += ((boh + tl.cdiv(i_s * S, BT)) * H + i_h) * K * K + hc_whole += ((boh_large + i_s) * H + i_h) * K * K + + q += (bos * HQ + i_hq) * K + q_new += (bos * HQ + i_hq) * K + k += (bos * H + i_h) * K + k_new += (bos * H + i_h) * K + + # Initialize h and load first chunk + p_h = tl.make_block_ptr(h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.zeros([BK, BK], dtype=tl.float32) + b_h += tl.load(p_h, boundary_check=(0, 1)) + # Load and store first q chunk + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_s * S, 0), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + p_q_new = tl.make_block_ptr(q_new, (T, K), (HQ*K, 1), (i_s * S, 0), (BT, BK), (1, 0)) + tl.store(p_q_new, b_q.to(q_new.dtype.element_ty), boundary_check=(0, 1)) + + p_hc_prefix = tl.make_block_ptr(hc_prefix, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_hc_prefix, tl.zeros([BK, BK], dtype=tl.float32).to(p_hc_prefix.dtype.element_ty), boundary_check=(0, 1)) + + # Process remaining chunks + + for i_t_small in range(1, NT_small): + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_s * S + i_t_small * BT, 0), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q - tl.dot(b_q, b_h.to(b_q.dtype))).to(b_q.dtype) + p_q_new = tl.make_block_ptr(q_new, (T, K), (HQ*K, 1), + (i_s * S + i_t_small * BT, 0), (BT, BK), (1, 0)) + tl.store(p_q_new, b_q.to(q_new.dtype.element_ty), boundary_check=(0, 1)) + if HQ % G == 0: + p_hc_prefix = tl.make_block_ptr(hc_prefix + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_hc_prefix, b_h.to(hc_prefix.dtype.element_ty), boundary_check=(0, 1)) + p_h_new = tl.make_block_ptr(h + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h_new = tl.load(p_h_new, boundary_check=(0, 1)) + b_h = b_h + b_h_new - tl.dot(b_h_new, b_h.to(b_h_new.dtype)) + + tl.debug_barrier() + + if HQ % G == 0: + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_s * S + (NT_small - 1) * BT, 0), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_k_new = tl.make_block_ptr(k_new, (T, K), (H*K, 1), (i_s * S + (NT_small - 1) * BT, 0), (BT, BK), (1, 0)) + tl.store(p_k_new, b_k.to(k_new.dtype.element_ty), boundary_check=(0, 1)) + p_hc_suffix = tl.make_block_ptr(hc_suffix + (NT_small - 1) * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_hc_suffix, tl.zeros([BK, BK], dtype=tl.float32).to(p_hc_suffix.dtype.element_ty), boundary_check=(0, 1)) + + p_h = tl.make_block_ptr(h + (NT_small - 1) * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.zeros([BK, BK], dtype=tl.float32) + b_h += tl.load(p_h, boundary_check=(0, 1)) + + for i_t_small in range(NT_small-2, -1, -1): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_s * S + i_t_small * BT, 0), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = (b_k - tl.dot(b_k, tl.trans(b_h).to(b_k.dtype))).to(b_k.dtype) + p_k_new = tl.make_block_ptr(k_new, (T, K), (H*K, 1), (i_s * S + i_t_small * BT, 0), (BT, BK), (1, 0)) + tl.store(p_k_new, b_k.to(k_new.dtype.element_ty), boundary_check=(0, 1)) + p_hc_suffix = tl.make_block_ptr(hc_suffix + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_hc_suffix, b_h.to(hc_suffix.dtype.element_ty), boundary_check=(0, 1)) + p_h_new = tl.make_block_ptr(h + i_t_small * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h_new = tl.load(p_h_new, boundary_check=(0, 1)) + b_h = b_h + b_h_new - tl.dot(b_h.to(b_h_new.dtype), b_h_new) + + p_hc_whole = tl.make_block_ptr(hc_whole, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_hc_whole, b_h.to(hc_whole.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_cumprod_householder_fwd_fn( + q: torch.Tensor, + k: torch.Tensor, + h: torch.Tensor, + S: int, # split size, aka large chunk size + BT: int, # small chunk size + cu_seqlens: torch.Tensor = None, +): + B, T, HQ, K = q.shape + H = k.shape[2] + G = HQ // H + + split_indices = prepare_chunk_indices(cu_seqlens, S) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + split_offsets = prepare_chunk_offsets(cu_seqlens, S) if cu_seqlens is not None else None + + if cu_seqlens is None: + N = B + NS = N * triton.cdiv(T, S) + NT = N * triton.cdiv(T, BT) + else: + N = len(cu_seqlens) - 1 + NS = split_offsets[-1] + NT = chunk_offsets[-1] + + grid = (NS, HQ) + + hc_suffix = torch.empty((NT, H, K, K), device=q.device, dtype=q.dtype) + hc_prefix = torch.empty((NT, H, K, K), device=q.device, dtype=q.dtype) + hc_whole = torch.empty((NS, H, K, K), device=q.device, dtype=q.dtype) + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + + chunk_cumprod_householder_fwd_kernel[grid]( + q=q, q_new=q_new, k=k, k_new=k_new, h=h, hc_suffix=hc_suffix, hc_prefix=hc_prefix, hc_whole=hc_whole, + cu_seqlens=cu_seqlens, + split_indices=split_indices, chunk_offsets=chunk_offsets, split_offsets=split_offsets, + BT=BT, K=K, G=G, H=H, HQ=HQ, BK=K, + T=T, S=S + ) + return q_new, k_new, hc_suffix, hc_prefix, hc_whole diff --git a/fla3/ops/path_attn/intra_chunk_preprocess_bwd.py b/fla3/ops/path_attn/intra_chunk_preprocess_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..801a5fd9805d87ad97b65ca9f95e40629234e928 --- /dev/null +++ b/fla3/ops/path_attn/intra_chunk_preprocess_bwd.py @@ -0,0 +1,140 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +# episold +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['offsets'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def intra_chunk_preprocess_bwd_kernel( + q, k, w, beta, + AT, + dA_local, dq, dq_new, dk, dk_new, dw, dbeta, dh, T, + offsets, indices, chunk_offsets, + HQ: tl.constexpr, G: tl.constexpr, H: tl.constexpr, + K: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw_beta = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_dT = tl.zeros([BT, BT], dtype=tl.float32) + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (K*H, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (K*H, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + p_T = tl.make_block_ptr(AT + (bos * H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_T = tl.load(p_T, boundary_check=(0, 1)).to(b_k.dtype) + b_w_beta = (b_w * b_beta[:, None]).to(b_w.dtype) + + o_i = tl.arange(0, BT) + b_qw = tl.where(o_i[:, None] >= o_i[None, :], tl.dot(b_q, tl.trans(b_w)), 0).to(b_q.dtype) + b_wbk = tl.where(o_i[:, None] > o_i[None, :], tl.dot(b_w_beta, tl.trans(b_k)), 0).to(b_k.dtype) + b_Twb = tl.dot(b_T, b_w_beta).to(b_w.dtype) + b_Twbk = tl.dot(b_T, b_wbk).to(b_w.dtype) + + p_dA_local = tl.make_block_ptr(dA_local + (bos * HQ + i_hq) * BT, (T, BT), (BT*HQ, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_dA_local = tl.load(p_dA_local, boundary_check=(0, 1)) + + # # Twb part qw part. + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_dq = tl.load(p_dq, boundary_check=(0, 1)).to(b_w.dtype) + + p_dh = tl.make_block_ptr(dh + ((boh + i_t) * HQ + i_hq)*K*K, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)).to(b_w.dtype) + b_dw += tl.dot(b_Twb, tl.trans(b_dh)) + b_dqw = -tl.dot(b_dA_local, tl.trans(b_Twbk)) - tl.dot(b_dq.to(b_Twb.dtype), tl.trans(b_Twb)) + b_dTwb = (-tl.dot(tl.trans(b_qw), b_dq) + tl.dot(b_w, b_dh)).to(b_w.dtype) + b_dT += tl.dot(b_dTwb, tl.trans(b_w_beta)) + b_dw_beta += tl.dot(tl.trans(b_T), b_dTwb) + + b_dqw = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dqw, 0) + b_dq += tl.dot(b_dA_local.to(b_k.dtype), b_k) + b_dq += tl.dot(b_dqw.to(b_w.dtype), b_w) + b_dw += tl.dot(tl.trans(b_dqw.to(b_q.dtype)), b_q) + p_q_new = tl.make_block_ptr(dq_new + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_q_new, b_dq.to(dq_new.dtype.element_ty), boundary_check=(0, 1)) + + # Twbk part + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_dTwbk = -tl.dot(tl.trans(b_qw), b_dA_local.to(b_qw.dtype)) - tl.dot(b_w, tl.trans(b_dk.to(b_w.dtype))) + b_dw -= tl.dot(b_Twbk, b_dk.to(b_w.dtype)) + b_dT += tl.dot(b_dTwbk.to(b_wbk.dtype), tl.trans(b_wbk)) + b_dwbk = tl.where(o_i[:, None] > o_i[None, :], tl.dot(tl.trans(b_T), b_dTwbk.to(b_T.dtype)), 0).to(b_w.dtype) + b_dw_beta += tl.dot(b_dwbk, b_k) + + b_dk += tl.dot(tl.trans(b_dwbk), b_w_beta) + b_dk += tl.dot(tl.trans(b_dA_local), b_q) + p_dk_new = tl.make_block_ptr(dk_new + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_dk_new, b_dk.to(dk_new.dtype.element_ty), boundary_check=(0, 1)) + + # matrix inverse's gradient + p_T = tl.make_block_ptr(AT + (bos * H + i_h) * BT, (BT, T), (1, BT*H), (0, i_t * BT), (BT, BT), (0, 1)) + b_Tt = tl.load(p_T, boundary_check=(0, 1)) + b_dT = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dT, 0).to(b_w.dtype) + b_dT = tl.dot(b_Tt, b_dT).to(b_w.dtype) + b_dT = tl.dot(b_dT, b_Tt) + b_dT = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dT, 0).to(b_k.dtype) + + b_dw_beta += tl.dot(b_dT, b_w) + b_dw += tl.dot(tl.trans(b_dT), b_w_beta) + b_dw += b_dw_beta * b_beta[:, None] + b_dbeta = tl.sum(b_dw_beta * b_w, axis=1) + + p_dw = tl.make_block_ptr(dw + (bos * HQ + i_hq) * K, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_dw, b_dw.to(dw.dtype.element_ty), boundary_check=(0, 1)) + p_dbeta = tl.make_block_ptr(dbeta + (bos * HQ + i_hq), (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + tl.store(p_dbeta, b_dbeta.to(dbeta.dtype.element_ty), boundary_check=(0, )) + + +def intra_chunk_preprocess_bwd_fn(q, k, w, beta, + dq, dk, dh, dA_local, + A, L, D, do, scale, cu_seqlens=None): + BT = A.shape[-1] + HQ = q.shape[-2] + B, T, H, K = k.shape + G = HQ//H + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + grid = (NT, B*HQ) + # better precision because h would be of norm smaller than 1 anyways + + dbeta = torch.empty(B, T, HQ, device=q.device, dtype=k.dtype if G == 1 else torch.float32) + dw = torch.empty(B, T, HQ, K, device=q.device, dtype=k.dtype if G == 1 else torch.float32) + dk_new = torch.empty_like(dk, dtype=k.dtype if G == 1 else torch.float32) # float32 reduction + dq_new = torch.empty_like(dq, dtype=q.dtype) + + intra_chunk_preprocess_bwd_kernel[grid]( + q=q, k=k, w=w, beta=beta, + AT=A, + dA_local=dA_local, dq=dq, dq_new=dq_new, dk=dk, dk_new=dk_new, dw=dw, dbeta=dbeta, dh=dh, T=T, + offsets=cu_seqlens, indices=indices, chunk_offsets=chunk_offsets, + HQ=HQ, G=G, H=H, + K=K, BT=BT, BK=triton.next_power_of_2(K), + ) + return dq_new, dk_new, dbeta, dw diff --git a/fla3/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py b/fla3/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2c621d135c0dc37013b311f3c5e0cc7ccb9fdc --- /dev/null +++ b/fla3/ops/path_attn/intra_chunk_preprocess_bwd_prepare.py @@ -0,0 +1,199 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + "USE_GATE": lambda args: args['g_cumsum'] is not None, + "IS_VARLEN": lambda args: args['offsets'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_transform_qk_bwd_kernel_prepare( + q, + k, + v, + w, + beta, + g_cumsum, + L, + D, + h, + q_new, + k_new, + AT, + dA_local, + dv, + do, + dg_cumsum, + scale, + indices, # varlen helper + offsets, # varlen helper + chunk_offsets, # varlen helper + T, + G: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_GATE: tl.constexpr +): + i_t, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + sm_scale = scale * 1.44269504 + # offset calculations + dA_local += (bos*HQ + i_hq) * BT + AT += (bos*H + i_h) * BT + q += (bos*HQ + i_hq) * K + q_new += (bos*HQ + i_hq) * K + k += (bos*H + i_h) * K + k_new += (bos*H + i_h) * K + w += (bos*H + i_h) * K + v += (bos*H + i_h) * V + do += (bos*HQ + i_hq) * V + dv += (bos*HQ + i_hq) * V + beta += (bos*H + i_h) + h += ((boh + i_t) * H + i_h) * K * K + if USE_GATE: + g_cumsum += (bos*HQ + i_hq) + dg_cumsum += (bos*HQ + i_hq) + L += (bos*HQ + i_hq) + D += (bos*HQ + i_hq) + + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (BK, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kt = tl.load(p_k, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_T = tl.make_block_ptr(AT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_T = tl.load(p_T, boundary_check=(0, 1)).to(b_q.dtype) + + o_i = tl.arange(0, BT) + m_t = o_i[:, None] >= o_i[None, :] + p_beta = tl.make_block_ptr(beta, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_w_beta = (b_w * b_beta[:, None]).to(b_w.dtype) + + b_Twb = tl.dot(b_T.to(b_w_beta.dtype), b_w_beta).to(b_w_beta.dtype) + + b_qw = tl.where(m_t, tl.dot(b_q, tl.trans(b_w)), 0).to(b_q.dtype) + b_qwT = tl.dot(b_qw, b_T).to(b_q.dtype) + b_wbk = tl.where(o_i[:, None] > o_i[None, :], tl.dot(b_w_beta, b_kt), 0).to(b_w.dtype) + b_A = tl.where(m_t, tl.dot(b_q, b_kt) - tl.dot(b_qwT, b_wbk), 0) + + b_q = b_q - tl.dot(b_qwT, b_w_beta) + p_q_new = tl.make_block_ptr(q_new, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, K), (1, 0)) + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + if i_hq % G == 0: + b_h = tl.dot(tl.trans(b_w), b_Twb) + p_h = tl.make_block_ptr(h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_T_wbk = tl.dot(b_T, b_wbk).to(b_w.dtype) + p_k_new = tl.make_block_ptr(k_new, (K, T), (1, K*H), (0, i_t * BT), (BK, BT), (0, 1)) + tl.store(p_k_new, (b_kt - tl.dot(tl.trans(b_w), b_T_wbk)).to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_GATE: + p_g_cumsum = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_g_cumsum = tl.load(p_g_cumsum, boundary_check=(0, )) + b_A = b_A + (b_g_cumsum[:, None] - b_g_cumsum[None, :]) + b_A = tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], b_A, float("-inf")) # avoid nan + + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_l = tl.load(p_l, boundary_check=(0, )) + p_delta = tl.make_block_ptr(D, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + delta = tl.load(p_delta, boundary_check=(0, )) + + b_A_softmax = tl.exp2(tl.where(o_i[:, None] >= o_i[None, :], b_A * sm_scale - b_l[:, None], float("-inf"))) + p_do = tl.make_block_ptr(do, (T, V), (HQ*V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(tl.trans(b_A_softmax.to(b_do.dtype)), b_do) + p_dv = tl.make_block_ptr(dv, (T, V), (HQ*V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (0, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dp = tl.dot(b_do, b_v) + + b_dA = ((b_dp - delta[:, None]) * b_A_softmax * scale) + b_dgq = tl.sum(b_dA, axis=1) - tl.sum(b_dA, axis=0) + b_dA = b_dA.to(b_v.dtype) + + if USE_GATE: + p_dg = tl.make_block_ptr(dg_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + tl.store(p_dg, b_dgq.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + p_dA = tl.make_block_ptr(dA_local, (T, BT), (BT*HQ, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +def intra_chunk_preprocess_bwd_prepare_fn(q, k, v, w, beta, g_cumsum, A, L, D, do, scale, cu_seqlens=None): + BT = A.shape[-1] + HQ = q.shape[-2] + B, T, H, K = k.shape + G = HQ//H + + V = v.shape[-1] + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + grid = (NT, B*HQ) + # better precision because h would be of norm smaller than 1 anyways + h = torch.empty(B, NT, H, K, K, dtype=q.dtype, device=q.device) + dA_local = torch.empty(B, T, HQ, BT, dtype=q.dtype, device=q.device) + dv = torch.empty(B, T, HQ, V, device=q.device, dtype=torch.float32) + dg_cumsum = torch.empty_like(g_cumsum) if g_cumsum is not None else None + + chunk_transform_qk_bwd_kernel_prepare[grid]( + q=q, + k=k, + v=v, + w=w, + beta=beta, + g_cumsum=g_cumsum, + AT=A, + dA_local=dA_local, + dv=dv, + dg_cumsum=dg_cumsum, + do=do, + L=L, + D=D, + h=h, + q_new=q_new, + k_new=k_new, + scale=scale, + offsets=cu_seqlens, + indices=indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + G=G, + HQ=HQ, + K=K, + V=V, + BK=triton.next_power_of_2(K), + BV=triton.next_power_of_2(V), + BT=BT, + ) + return q_new, k_new, h, dA_local, dv, dg_cumsum diff --git a/fla3/ops/path_attn/intra_chunk_preprocess_fwd.py b/fla3/ops/path_attn/intra_chunk_preprocess_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2ab677c31a16bf175327fd96549953d143b72d --- /dev/null +++ b/fla3/ops/path_attn/intra_chunk_preprocess_fwd.py @@ -0,0 +1,175 @@ + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + "USE_G": lambda args: args['g_cumsum'] is not None, + "IS_VARLEN": lambda args: args['offsets'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def intra_chunk_preprocess_fwd_kernel( + q, + k, + v, + w, + beta, + g_cumsum, + o, + A, + L, + M, + h, + q_new, + k_new, + # A_local, + scale, + indices, # varlen helper + offsets, # varlen helper + chunk_offsets, # varlen helper + T, + H: tl.constexpr, + G: tl.constexpr, + HQ: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr +): + i_t, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + sm_scale = scale * 1.44269504 + + # offset calculations + A += (bos*H + i_h) * BT + q += (bos*HQ + i_hq) * K + q_new += (bos*HQ + i_hq) * K + k += (bos*H + i_h) * K + k_new += (bos*H + i_h) * K + w += (bos*H + i_h) * K + v += (bos*H + i_h) * V + o += (bos*HQ + i_hq) * V + beta += (bos*H + i_h) + h += ((boh + i_t) * H + i_h) * K * K + if USE_G: + g_cumsum += (bos*HQ + i_hq) + L += (bos*HQ + i_hq) + M += (bos*HQ + i_hq) + + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_t * BT), (BK, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kt = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + p_T = tl.make_block_ptr(A, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_T = tl.load(p_T, boundary_check=(0, 1)).to(b_q.dtype) + + o_i = tl.arange(0, BT) + m_t = o_i[:, None] >= o_i[None, :] + p_beta = tl.make_block_ptr(beta, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_w_beta = (b_w * b_beta[:, None]).to(b_w.dtype) + + b_qw = tl.where(m_t, tl.dot(b_q, tl.trans(b_w)), 0).to(b_q.dtype) + b_qwT = tl.dot(b_qw, b_T).to(b_q.dtype) + b_wbk = tl.where(o_i[:, None] > o_i[None, :], tl.dot(b_w_beta, b_kt), 0).to(b_w.dtype) + b_A = tl.where(m_t, tl.dot(b_q, b_kt) - tl.dot(b_qwT, b_wbk), 0) + + b_q = b_q - tl.dot(b_qwT, b_w_beta) + p_q_new = tl.make_block_ptr(q_new, (T, K), (K*HQ, 1), (i_t * BT, 0), (BT, K), (1, 0)) + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + if i_hq % G == 0: + b_Twb = tl.dot(b_T.to(b_w_beta.dtype), b_w_beta).to(b_w_beta.dtype) + b_h = tl.dot(tl.trans(b_w), b_Twb) + p_h = tl.make_block_ptr(h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_T_wbk = tl.dot(b_T, b_wbk).to(b_w.dtype) + p_k_new = tl.make_block_ptr(k_new, (K, T), (1, K*H), (0, i_t * BT), (BK, BT), (0, 1)) + tl.store(p_k_new, (b_kt - tl.dot(tl.trans(b_w), b_T_wbk)).to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + p_g_cumsum = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_g_cumsum = tl.load(p_g_cumsum, boundary_check=(0, )) + b_A = b_A + (b_g_cumsum[:, None] - b_g_cumsum[None, :]) + b_A = tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], b_A, float("-inf")) # avoid nan + + b_qkT_softmax = tl.where(o_i[:, None] >= o_i[None, :], b_A * sm_scale, float("-inf")) + m_i = tl.max(b_qkT_softmax, 1) + b_qkT_softmax = tl.math.exp2(b_qkT_softmax - m_i[:, None]) + l_i = tl.sum(b_qkT_softmax, 1) + b_o = tl.dot(b_qkT_softmax.to(b_v.dtype), b_v) + p_o = tl.make_block_ptr(o, (T, V), (V*HQ, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + p_m = tl.make_block_ptr(M, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + tl.store(p_m, m_i.to(p_m.dtype.element_ty), boundary_check=(0,)) + tl.store(p_l, l_i.to(p_l.dtype.element_ty), boundary_check=(0,)) + + +def intra_chunk_preprocess_fwd_fn(q, k, v, w, beta, g_cumsum, A, scale, BT, cu_seqlens): + HQ = q.shape[-2] + B, T, H, K = k.shape + V = v.shape[-1] + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + o = torch.empty(B, T, HQ, V, device=q.device, dtype=q.dtype) + + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + grid = (NT, B*HQ) + L = torch.empty(B, T, HQ, dtype=torch.float32, device=q.device) + M = torch.empty(B, T, HQ, dtype=torch.float32, device=q.device) + h = torch.empty(B, NT, H, K, K, dtype=q.dtype, device=q.device) + G = HQ//H + intra_chunk_preprocess_fwd_kernel[grid]( + q=q, + k=k, + v=v, + w=w, + beta=beta, + g_cumsum=g_cumsum, + o=o, + A=A, + L=L, + M=M, + h=h, + q_new=q_new, + k_new=k_new, + scale=scale, + offsets=cu_seqlens, + indices=indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + G=G, + HQ=HQ, + K=K, + V=V, + BK=triton.next_power_of_2(K), + BV=triton.next_power_of_2(V), + BT=BT, + ) + return q_new, k_new, h, o, L, M diff --git a/fla3/ops/path_attn/parallel.py b/fla3/ops/path_attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf8f8466291458683ea04694301e53f059c577b --- /dev/null +++ b/fla3/ops/path_attn/parallel.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +from einops import reduce + +from fla.ops.attn.parallel import parallel_attn_bwd_preprocess +from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from fla.ops.path_attn.cumprod_householder_bwd import chunk_cumprod_householder_bwd_fn +from fla.ops.path_attn.cumprod_householder_fwd import chunk_cumprod_householder_fwd_fn +from fla.ops.path_attn.intra_chunk_preprocess_bwd import intra_chunk_preprocess_bwd_fn +from fla.ops.path_attn.intra_chunk_preprocess_bwd_prepare import intra_chunk_preprocess_bwd_prepare_fn +from fla.ops.path_attn.intra_chunk_preprocess_fwd import intra_chunk_preprocess_fwd_fn +from fla.ops.path_attn.parallel_path_bwd_inter_dkv import parallel_path_bwd_dkv_fn +from fla.ops.path_attn.parallel_path_bwd_inter_dqh import parallel_path_bwd_dq_fn +from fla.ops.path_attn.parallel_path_bwd_intra import parallel_path_bwd_intra_chunk_fn +from fla.ops.path_attn.parallel_path_fwd import parallel_path_fwd_fn +from fla.ops.path_attn.prepare_k_cache import prepare_k_cache_fn +from fla.ops.utils.cumsum import chunk_global_cumsum +from fla.ops.utils.solve_tril import solve_tril +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +class ParallelPATHAttentionFunction(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, beta, g, scale, cu_seqlens, use_cache=False): + g_cumsum = chunk_global_cumsum(g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) if g is not None else None + BS = 64 + BT = 64 + A, _ = chunk_scaled_dot_kkt_fwd( + k=w, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=BS, + output_dtype=torch.float32 + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + q_new, k_new, h, o, L, M = intra_chunk_preprocess_fwd_fn( + q=q, + k=k, + v=v, + w=w, + beta=beta, + g_cumsum=g_cumsum, + A=A, + scale=scale, + BT=BS, + cu_seqlens=cu_seqlens, + ) + o, L = parallel_path_fwd_fn( + q=q_new, + k=k_new, + v=v, + L=L, + h=h, + M=M, + o=o, + g_cumsum=g_cumsum, + scale=scale, + cu_seqlens=cu_seqlens, + BT=BT, + BS=BS, + ) + k_cache = prepare_k_cache_fn(k=k_new, h=h, cu_seqlens=cu_seqlens, BS=BS, use_cache=use_cache) + ctx.save_for_backward(q, k, v, w, g_cumsum, o, beta, L) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, k_cache + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dk_new): + q, k, v, w, g_cumsum, o, beta, L = ctx.saved_tensors + BT = 64 + BS = 64 + S = 512 + cu_seqlens = ctx.cu_seqlens + A, _ = chunk_scaled_dot_kkt_fwd( + k=w, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=BS, + output_dtype=torch.float32 + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + delta = parallel_attn_bwd_preprocess(o, do) + q_new, k_new, h, dA_local, dv, dg_cumsum = intra_chunk_preprocess_bwd_prepare_fn( + q=q, + k=k, + v=v, + w=w, + beta=beta, + g_cumsum=g_cumsum, + A=A, + L=L, + D=delta, + do=do, + scale=ctx.scale, + cu_seqlens=cu_seqlens, + ) + q_new_large, k_new_large, hc_suffix, hc_prefix, hc_whole = chunk_cumprod_householder_fwd_fn( + q=q_new, k=k_new, h=h, S=S, BT=BS, cu_seqlens=cu_seqlens + ) + dq, dhc_whole, dg_cumsum = parallel_path_bwd_dq_fn( + q=q_new_large, k=k_new_large, v=v, g_cumsum=g_cumsum, do=do, dg_cumsum=dg_cumsum, + hc_whole=hc_whole, scale=ctx.scale, L=L, D=delta, + cu_seqlens=cu_seqlens, + S=S, BT=BT, BS=BS + ) + dk, dv, dg_cumsum3 = parallel_path_bwd_dkv_fn( + q=q_new_large, k=k_new_large, v=v, g_cumsum=g_cumsum, do=do, dv=dv, dg_cumsum=dg_cumsum, + hc_whole=hc_whole, scale=ctx.scale, L=L, D=delta, + cu_seqlens=cu_seqlens, + S=S, BT=BT, BS=BS + ) + dh, dk = chunk_cumprod_householder_bwd_fn( + h=h, hc_suffix=hc_suffix, + k=k_new, dk=dk, dhc_whole=dhc_whole, + cu_seqlens=cu_seqlens, S=S, BT=BS + ) + dq, dk_new, dv, dh, dg_cumsum = parallel_path_bwd_intra_chunk_fn( + q=q_new, k=k_new, v=v, g_cumsum=g_cumsum, h=h, + L=L, D=delta, scale=ctx.scale, + dq=dq, dk=dk, dv=dv, dh=dh, do=do, dg_cumsum=dg_cumsum, + cu_seqlens=cu_seqlens, + S=S, BT=BT + ) + dq, dk, dbeta, dw = intra_chunk_preprocess_bwd_fn( + q=q, k=k, w=w, beta=beta, + dq=dq, dk=dk, dh=dh, dA_local=dA_local, + A=A, L=L, D=delta, do=do, scale=ctx.scale, cu_seqlens=cu_seqlens + ) + G = q.shape[-2] // k.shape[-2] + if G > 1: + assert dk.dtype == dv.dtype == dw.dtype == dbeta.dtype == torch.float32, 'reduction requires float32' + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dw = reduce(dw, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dbeta = reduce(dbeta, 'b t (h g) -> b t h', g=G, reduction='sum') + if dg_cumsum is not None: + dg_cumsum = chunk_global_cumsum(dg_cumsum, cu_seqlens=cu_seqlens, reverse=True) + return (dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dw.to(w.dtype), + dbeta.to(beta.dtype), + dg_cumsum.to(g_cumsum.dtype) if g_cumsum is not None else None, + None, None, None) + + +@torch.compiler.disable +def parallel_path_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + beta: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + use_cache: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` + w (torch.Tensor): + weights of shape `[B, T, H, K]` + beta (torch.Tensor): + beta of shape `[B, T, H]` + g (torch.Tensor): + g of shape `[B, T, HQ]` + scale (float): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + use_cache (bool): + Whether to transform and cache the key values for decoding. Default: `False`. + Returns: + o (torch.Tensor): + output of shape `[B, T, HQ, V]` + k_cache (torch.Tensor): + k_cache of shape `[B, T, H, K]` + """ + if scale is None: + scale = k.shape[-1]**-0.5 + assert q.shape[-1] in [16, 32, 64], "only support head_dim in [16, 32, 64] for now. Stay tuned!" + assert v.shape[-1] in [16, 32, 64], "only support head_dim in [16, 32, 64] for now. Stay tuned!" + assert q.shape[-1] == k.shape[-1], 'q, k should have the same head_dim.' + assert k.shape == w.shape, 'k, w should have the same shape.' + assert beta.shape[:3] == k.shape[:3], 'beta should have the same number of heads as k' + if g is not None: + assert g.shape[:3] == q.shape[:3], 'g should have the same number of heads as q' + assert q.shape[-2] % k.shape[-2] == 0, 'the number of query heads should be divisible by the number of key heads' + o, k_cache = ParallelPATHAttentionFunction.apply(q, k, v, w, beta, g, scale, cu_seqlens, use_cache) + return o, k_cache diff --git a/fla3/ops/path_attn/parallel_path_bwd_inter_dkv.py b/fla3/ops/path_attn/parallel_path_bwd_inter_dkv.py new file mode 100644 index 0000000000000000000000000000000000000000..13fc6c01c688cd088a2eaef4a664ed80068d2b12 --- /dev/null +++ b/fla3/ops/path_attn/parallel_path_bwd_inter_dkv.py @@ -0,0 +1,190 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_GATE': lambda args: args['g_cumsum'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def parallel_path_bwd_dkv_kernel( + q, k, v, g_cumsum, + hc_whole, scale, L, D, + dk, dv, do, dg_cumsum, + cu_seqlens, indices, split_offsets, + T, + G: tl.constexpr, HQ: tl.constexpr, H: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, S: tl.constexpr, + IS_VARLEN: tl.constexpr, USE_GATE: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + boh_large = tl.load(split_offsets + i_n).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + boh_large = i_n * tl.cdiv(T, S) + + # offset calculations + q += (bos * HQ + i_hq) * K + do += (bos * HQ + i_hq) * V + dk += (bos * HQ + i_hq) * K + dv += (bos * HQ + i_hq) * K + L += (bos * HQ + i_hq) + D += (bos * HQ + i_hq) + + k += (bos * H + i_h) * K # GQA when H!=HQ + v += (bos * H + i_h) * V # GQA when H!=HQ + hc_whole += (boh_large * H + i_h) * K * K + + if USE_GATE: + g_cumsum += (bos * HQ + i_hq) + dg_cumsum += (bos * HQ + i_hq) + + # constants + stride_h = H * K * K + sm_scale = scale * 1.44269504 + + # load query + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_k_origin = tl.load(p_k, boundary_check=(0, 1)) + p_v = tl.make_block_ptr(v, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + if USE_GATE: + b_g_cumsum_k = tl.zeros([BT,], dtype=tl.float32) + p_g_cumsum_k = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_g_cumsum_k += tl.load(p_g_cumsum_k, boundary_check=(0, )) + b_dg_cumsum_k = tl.zeros([BT,], dtype=tl.float32) + else: + b_g_cumsum_k = None + b_dg_cumsum_k = None + + b_dk = tl.zeros([BT, K], dtype=tl.float32) + b_dv = tl.zeros([BT, K], dtype=tl.float32) + idx_i = (i_t * BT // S).to(tl.int32) + + last_chunk_start = tl.floor(T/S).to(tl.int32) * S + + if i_t * BT < last_chunk_start: + # handle right most hand + if T % S != 0: + idx_j = (last_chunk_start // S) + b_k_accum = tl.zeros([BT, BK], dtype=tl.float32) + b_k_accum += b_k_origin + for i in range(idx_i+1, idx_j): + p_h = tl.make_block_ptr(hc_whole + i * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_k_accum = (b_k_accum - tl.dot(b_k_accum.to(b_h.dtype), tl.trans(b_h))) + b_k = b_k_accum.to(b_k_origin.dtype) + + for offset in range(tl.ceil(T/BS).to(tl.int32) * BS - BS, last_chunk_start-BS, -BS): + p_delta = tl.make_block_ptr(D, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + b_delta = tl.load(p_delta, boundary_check=(0, )) + b_l = tl.load(p_l, boundary_check=(0, )) + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (offset, 0), (BS, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_A = tl.dot(b_q, tl.trans(b_k)) + if USE_GATE: + p_g_cumsum_q = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + b_g_cumsum_q = tl.load(p_g_cumsum_q, boundary_check=(0, )) + b_A = b_A + b_g_cumsum_q[:, None] - b_g_cumsum_k[None, :] + b_A = tl.where((offset + tl.arange(0, BS) < T)[:, None], b_A, float("-inf")) # avoid nan + b_A_softmax = tl.math.exp2(b_A * sm_scale - b_l[:, None]) + p_do = tl.make_block_ptr(do, (T, V), (HQ*V, 1), (offset, 0), (BS, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(tl.trans(b_A_softmax.to(b_do.dtype)), b_do) + b_dp = tl.dot(b_do, tl.trans(b_v)) + b_dA = ((b_dp - b_delta[:, None]) * b_A_softmax * scale) + if USE_GATE: + b_dg_cumsum_k -= tl.sum(b_dA, axis=0) + b_dA = b_dA.to(b_v.dtype) + b_dk += tl.dot(tl.trans(b_dA), b_q) + + for offset_outer in range(last_chunk_start, i_t * BT + S, -S): + idx_j = (offset_outer // S) - 1 + b_k_accum = tl.zeros([BT, BK], dtype=tl.float32) + b_k_accum += b_k_origin + for i in range(idx_i+1, idx_j): + p_h = tl.make_block_ptr(hc_whole + i * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_k_accum = (b_k_accum - tl.dot(b_k_accum.to(b_h.dtype), tl.trans(b_h))) + b_k = b_k_accum.to(b_k_origin.dtype) + + p_h = tl.make_block_ptr(hc_whole + (idx_j) * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dk = b_dk - tl.dot(b_dk.to(b_h.dtype), b_h) + + for offset in range(offset_outer - BS, offset_outer-S-BS, -BS): + p_delta = tl.make_block_ptr(D, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + b_delta = tl.load(p_delta, boundary_check=(0, )) + b_l = tl.load(p_l, boundary_check=(0, )) + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (offset, 0), (BS, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_A = tl.dot(b_q, tl.trans(b_k)) + if USE_GATE: + p_g_cumsum_q = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + b_g_cumsum_q = tl.load(p_g_cumsum_q, boundary_check=(0, )) + b_A = b_A + b_g_cumsum_q[:, None] - b_g_cumsum_k[None, :] + b_A = tl.where((offset + tl.arange(0, BS) < T)[:, None], b_A, float("-inf")) # avoid nan + b_A_softmax = tl.math.exp2(b_A * sm_scale - b_l[:, None]) + p_do = tl.make_block_ptr(do, (T, V), (HQ*V, 1), (offset, 0), (BS, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(tl.trans(b_A_softmax.to(b_do.dtype)), b_do) + b_dp = tl.dot(b_do, tl.trans(b_v)) + + b_dA = ((b_dp - b_delta[:, None]) * b_A_softmax * scale) + if USE_GATE: + b_dg_cumsum_k -= tl.sum(b_dA, axis=0) + b_dA = b_dA.to(b_v.dtype) + b_dk += tl.dot(tl.trans(b_dA), b_q) + + p_dk = tl.make_block_ptr(dk, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) + tl.atomic_add(dv + (i_t * BT + tl.arange(0, BT))[:, None] * HQ*K + tl.arange(0, K)[None, :], b_dv, sem='relaxed') + if USE_GATE: + tl.atomic_add(dg_cumsum + (i_t * BT + tl.arange(0, BT)) * HQ, b_dg_cumsum_k, sem='relaxed') + + +def parallel_path_bwd_dkv_fn( + q, k, v, g_cumsum, do, dv, dg_cumsum, + hc_whole, scale, L, D, + cu_seqlens, + S, BT, BS +): + B, T, HQ, K = q.shape + V = v.shape[-1] + H = k.shape[-2] + G = HQ // H + + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + split_offsets = prepare_chunk_offsets(cu_seqlens, S) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + # should be NS + if cu_seqlens is not None: + assert split_offsets[-1] == hc_whole.shape[0] + dk = torch.empty_like(q, dtype=torch.float32) # for later reduction use + + parallel_path_bwd_dkv_kernel[(NT, B*HQ)]( + q=q, k=k, v=v, g_cumsum=g_cumsum, + hc_whole=hc_whole, scale=scale, L=L, D=D, + dk=dk, dv=dv, do=do, dg_cumsum=dg_cumsum, + cu_seqlens=cu_seqlens, indices=indices, split_offsets=split_offsets, + T=T, S=S, BT=BT, BS=BS, + G=G, HQ=HQ, H=H, K=K, V=V, + BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V), + ) + return dk, dv, dg_cumsum diff --git a/fla3/ops/path_attn/parallel_path_bwd_inter_dqh.py b/fla3/ops/path_attn/parallel_path_bwd_inter_dqh.py new file mode 100644 index 0000000000000000000000000000000000000000..a3aff629b6c2b68aa6e38d6926fbf4ae6f90379f --- /dev/null +++ b/fla3/ops/path_attn/parallel_path_bwd_inter_dqh.py @@ -0,0 +1,165 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_GATE': lambda args: args['g_cumsum'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def parallel_path_bwd_dq_kernel( + q, k, v, g_cumsum, + hc_whole, scale, L, D, + dq, do, dhc_whole, dg_cumsum, + cu_seqlens, indices, split_offsets, # varlen specific + T, + G: tl.constexpr, HQ: tl.constexpr, H: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, + S: tl.constexpr, # aka larger chunk size + IS_VARLEN: tl.constexpr, + USE_GATE: tl.constexpr, +): + i_t, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + boh_large = tl.load(split_offsets + i_n).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + boh_large = i_n * tl.cdiv(T, S) + + # offset calculations + q += (bos * HQ + i_hq) * K + dq += (bos * HQ + i_hq) * K + k += (bos * H + i_h) * K # GQA when H!=HQ + v += (bos * H + i_h) * V # GQA when H!=HQ + do += (bos * HQ + i_hq) * V + hc_whole += (boh_large * H + i_h) * K * K + dhc_whole += (boh_large * HQ + i_hq) * K * K + L += (bos * HQ + i_hq) + D += (bos * HQ + i_hq) + if USE_GATE: + g_cumsum += (bos * HQ + i_hq) + dg_cumsum += (bos * HQ + i_hq) + + # if i_t * BT < S: + # p_dq = tl.make_block_ptr(dq, (T, K), (K * HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + # tl.store(p_dq, tl.zeros([BT, BK], dtype=tl.float32).to(dq.dtype.element_ty), boundary_check=(0, 1)) + # if USE_GATE: + # p_dg = tl.make_block_ptr(dg_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + # tl.store(p_dg, tl.zeros([BT,], dtype=tl.float32).to(p_dg.dtype.element_ty), boundary_check=(0, )) + # return + + # constants + stride_h = H * K * K + stride_hq = HQ * K * K + sm_scale = scale * 1.44269504 + + # load query + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_q_origin = tl.load(p_q, boundary_check=(0, 1)) + p_do = tl.make_block_ptr(do, (T, V), (HQ*V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + p_d = tl.make_block_ptr(D, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_l = tl.load(p_l, boundary_check=(0, )) + b_delta = tl.load(p_d, boundary_check=(0, )) + + if USE_GATE: + b_g_cumsum_q = tl.zeros([BT,], dtype=tl.float32) + p_g_cumsum_q = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_g_cumsum_q += tl.load(p_g_cumsum_q, boundary_check=(0, )) + b_dg_cumsum_q = tl.zeros([BT,], dtype=tl.float32) + else: + b_g_cumsum_q = None + b_dg_cumsum_q = None + + idx_i = i_t * BT // S + curr_end = (tl.floor(i_t * BT / S).to(tl.int32) * S).to(tl.int32) + b_dq = tl.zeros([BT, K], dtype=tl.float32) + + for offset_outer in range(0, curr_end, S): + idx_j = offset_outer // S + b_q_accum = tl.zeros([BT, BK], dtype=tl.float32) + b_q_accum += b_q_origin + for i in range(idx_i-1, idx_j, -1): + p_h = tl.make_block_ptr(hc_whole + i*stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_q_accum = (b_q_accum - tl.dot(b_q_accum.to(b_h.dtype), b_h)) + b_q = b_q_accum.to(b_q_origin.dtype) + b_dh = -tl.dot(tl.trans(b_q), b_dq.to(b_q.dtype)) + + tl.atomic_add(dhc_whole + idx_j * stride_hq + tl.arange(0, K) + [:, None] * K + tl.arange(0, K)[None, :], b_dh, sem='relaxed') + p_h = tl.make_block_ptr(hc_whole + idx_j * stride_h, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dq = b_dq - tl.dot(b_dq.to(b_h.dtype), tl.trans(b_h)) + + for offset in range(offset_outer, min(offset_outer+S, i_t*BT), BS): + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (offset, 0), (BS, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A = tl.dot(b_q, tl.trans(b_k)) + if USE_GATE: + p_g_cumsum_k = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (offset, ), (BS, ), (0, )) + b_g_cumsum_k = tl.load(p_g_cumsum_k, boundary_check=(0, )) + b_A = b_A + b_g_cumsum_q[:, None] - b_g_cumsum_k[None, :] + b_A = tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], b_A, float("-inf")) # avoid nan + b_A_softmax = tl.math.exp2(b_A * sm_scale - b_l[:, None]) + p_v = tl.make_block_ptr(v, (V, T), (1, V*H), (0, offset), (BK, BS), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dp = tl.dot(b_do, b_v) + b_dA = ((b_dp - b_delta[:, None]) * b_A_softmax * scale) + b_dq += tl.dot(b_dA.to(b_k.dtype), b_k) + if USE_GATE: + b_dg_cumsum_q += tl.sum(b_dA, axis=1) + + p_dq = tl.make_block_ptr(dq, (T, K), (K * HQ, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(dq.dtype.element_ty), boundary_check=(0, 1)) + if USE_GATE: + tl.atomic_add(dg_cumsum + (i_t * BT + tl.arange(0, BT)) * HQ, b_dg_cumsum_q, sem='relaxed') + + +def parallel_path_bwd_dq_fn( + q, k, v, g_cumsum, do, dg_cumsum, + hc_whole, scale, L, D, + cu_seqlens, + S, BT, BS +): + B, T, HQ, K = q.shape + V = v.shape[-1] + H = k.shape[-2] + G = HQ // H + + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + split_offsets = prepare_chunk_offsets(cu_seqlens, S) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + + # should be NS + if cu_seqlens is not None: + assert split_offsets[-1] == hc_whole.shape[0] + dq = torch.empty_like(q, dtype=torch.float32) # for later reduction use + + # [NS, HQ, K, K] instead of [NS, H, K, K] + # atomic add must be initialized to 0 + dhc_whole = torch.zeros(hc_whole.shape[0], HQ, K, K, dtype=torch.float32, device=q.device) + + parallel_path_bwd_dq_kernel[(NT, B*HQ)]( + q=q, k=k, v=v, g_cumsum=g_cumsum, + hc_whole=hc_whole, scale=scale, L=L, D=D, + dq=dq, do=do, dhc_whole=dhc_whole, dg_cumsum=dg_cumsum, + cu_seqlens=cu_seqlens, indices=indices, split_offsets=split_offsets, + T=T, S=S, BT=BT, BS=BS, + G=G, HQ=HQ, H=H, K=K, V=V, + BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V), + ) + return dq, dhc_whole, dg_cumsum diff --git a/fla3/ops/path_attn/parallel_path_bwd_intra.py b/fla3/ops/path_attn/parallel_path_bwd_intra.py new file mode 100644 index 0000000000000000000000000000000000000000..3f44bde9bad629342642e8c66335b7e721c03059 --- /dev/null +++ b/fla3/ops/path_attn/parallel_path_bwd_intra.py @@ -0,0 +1,163 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['offsets'] is not None, + 'USE_GATE': lambda args: args['g_cumsum'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def parallel_path_bwd_intra_chunk_kernel( + q, k, v, g_cumsum, + h, L, D, + dq, dq_new, dk, dv, dh, do, dg_cumsum, + offsets, indices, chunk_offsets, + T, scale, + G: tl.constexpr, HQ: tl.constexpr, H: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + BT: tl.constexpr, S: tl.constexpr, + IS_VARLEN: tl.constexpr, USE_GATE: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + boh = i_n * tl.cdiv(T, BT) + + # offset calculations + k += (bos * H + i_h) * K # GQA when H!=HQ + v += (bos * H + i_h) * V # GQA when H!=HQ + h += (boh * H + i_h) * K * K + + q += (bos * HQ + i_hq) * K + dq += (bos * HQ + i_hq) * K + dq_new += (bos * HQ + i_hq) * K + dk += (bos * HQ + i_hq) * K + dv += (bos * HQ + i_hq) * V + do += (bos * HQ + i_hq) * V + dh += (boh * HQ + i_hq) * K * K + L += (bos * HQ + i_hq) + D += (bos * HQ + i_hq) + if USE_GATE: + g_cumsum += (bos * HQ + i_hq) + dg_cumsum += (bos * HQ + i_hq) + + # constants + sm_scale = scale * 1.44269504 + + p_do = tl.make_block_ptr(do, (T, V), (HQ*V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + p_delta = tl.make_block_ptr(D, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_delta = tl.load(p_delta, boundary_check=(0, )) + p_l = tl.make_block_ptr(L, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_l = tl.load(p_l, boundary_check=(0, )) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + p_dq = tl.make_block_ptr(dq, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + p_q = tl.make_block_ptr(q, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + + if USE_GATE: + p_gq_cumsum = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (i_t * BT, ), (BT, ), (0, )) + b_gq_cumsum = tl.load(p_gq_cumsum, boundary_check=(0, )) + b_dgq = tl.zeros([BT, ], dtype=tl.float32) + else: + b_dgq = None + + curr_start = (tl.floor(i_t * BT / S).to(tl.int32) * S).to(tl.int32) + + for offset in range(curr_start, i_t * BT, BT): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (offset, 0), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q_tmp = tl.zeros([BT, BK], dtype=tl.float32) + b_q_tmp += b_q + + for i_t_small in range(i_t * BT - BT, offset, -BT): + p_h = tl.make_block_ptr(h + tl.cdiv(i_t_small, BT) * H*K*K, (K, K), (K, 1), (0, 0), (BK, BK), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_q_tmp -= tl.dot(b_q_tmp.to(b_h.dtype), b_h) + + b_q2 = b_q_tmp.to(b_k.dtype) + b_dh = -tl.dot(tl.trans(b_q2), b_dq.to(b_q2.dtype)) + tl.atomic_add(dh + tl.cdiv(offset, BT) * HQ*K*K + tl.arange(0, K) + [:, None] * K + tl.arange(0, K)[None, :], b_dh, sem='relaxed') + + b_A = tl.dot(b_q2, tl.trans(b_k)) + if USE_GATE: + p_gk_cumsum = tl.make_block_ptr(g_cumsum, (T, ), (HQ, ), (offset, ), (BT, ), (0, )) + b_gk_cumsum = tl.load(p_gk_cumsum, boundary_check=(0, )) + b_A = b_A + b_gq_cumsum[:, None] - b_gk_cumsum[None, :] + b_A = tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], b_A, float("-inf")) # avoid nan + + b_A_softmax = tl.math.exp2(b_A * sm_scale - b_l[:, None]) + b_dv = tl.dot(tl.trans(b_A_softmax.to(b_do.dtype)), b_do) + + tl.atomic_add(dv + ((offset + tl.arange(0, BT)) * HQ * V) + [:, None] + tl.arange(0, BV)[None, :], b_dv.to(dv.dtype.element_ty), sem='relaxed') + + p_v = tl.make_block_ptr(v, (T, V), (V*H, 1), (offset, 0), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dp = tl.dot(b_do, tl.trans(b_v)) + b_dA = ((b_dp - b_delta[:, None]) * b_A_softmax * scale) + if USE_GATE: + b_dgk = -tl.sum(b_dA, axis=0) + tl.atomic_add(dg_cumsum + (offset + tl.arange(0, BT)) * HQ, b_dgk, sem='relaxed') + b_dgq += tl.sum(b_dA, axis=1) + + b_dA = b_dA.to(b_k.dtype) + p_h = tl.make_block_ptr(h + tl.cdiv(offset, BT) * H*K*K, (K, K), (1, K), (0, 0), (BK, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dk = tl.dot(tl.trans(b_dA), b_q2) + tl.atomic_add(dk + (offset + tl.arange(0, BT))[:, None] * HQ*K + tl.arange(0, + BK)[None, :], b_dk, sem='relaxed') + b_dq -= tl.dot(b_dq.to(b_h.dtype), b_h) + b_dq += tl.dot(b_dA, b_k) + + p_dq_new = tl.make_block_ptr(dq_new, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_dq_new, b_dq.to(dq_new.dtype.element_ty), boundary_check=(0, 1)) + if USE_GATE: + tl.atomic_add(dg_cumsum + (i_t * BT + tl.arange(0, BT)) * HQ, b_dgq, sem='relaxed') + + +def parallel_path_bwd_intra_chunk_fn( + q, k, v, g_cumsum, h, + dq, dk, dv, dg_cumsum, dh, do, + scale, L, D, + cu_seqlens, + S, BT +): + assert dk.dtype == dv.dtype == dh.dtype == torch.float32, 'atomic_add requires float32' + B, T, HQ, K = q.shape + assert dk.shape == dq.shape + + V = v.shape[-1] + H = k.shape[-2] + G = HQ // H + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + dq_new = torch.empty_like(dq, dtype=q.dtype) + parallel_path_bwd_intra_chunk_kernel[(NT, B*HQ)]( + q=q, k=k, v=v, g_cumsum=g_cumsum, + h=h, L=L, D=D, + dq=dq, dq_new=dq_new, dk=dk, dv=dv, dh=dh, do=do, dg_cumsum=dg_cumsum, + offsets=cu_seqlens, indices=indices, chunk_offsets=chunk_offsets, + T=T, S=S, BT=BT, scale=scale, + G=G, HQ=HQ, H=H, K=K, V=V, + BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V), + ) + return dq_new, dk, dv, dh, dg_cumsum diff --git a/fla3/ops/path_attn/prepare_k_cache.py b/fla3/ops/path_attn/prepare_k_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9dad21eff84c114eb38788a59fa0eb18368935 --- /dev/null +++ b/fla3/ops/path_attn/prepare_k_cache.py @@ -0,0 +1,76 @@ +import torch +import triton +import triton.language as tl + +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['offsets'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def parallel_path_fwd_kernel_prepare_k_cache( + k, k_new, h, + offsets, indices, chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, BK: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + NT = triton.cdiv(T, BT) + boh = i_n * NT + + # offset calculations + k += (bos * H + i_h) * K # GQA when H!=HQ + k_new += (bos * H + i_h) * K # GQA when H!=HQ + h += (boh * H + i_h) * K * K + # constants + stride_h = H * K * K + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_k = tl.zeros([BT, BK], dtype=tl.float32) + b_k += tl.load(p_k, boundary_check=(0, 1)) + for k_block_idx in range(i_t + 1, tl.cdiv(T, BT)): + p_h = tl.make_block_ptr(h + k_block_idx * stride_h, (K, K), (1, K), (0, 0), (BK, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_k_minus = tl.dot(b_k.to(b_h.dtype), b_h) + b_k = b_k - b_k_minus + p_k_new = tl.make_block_ptr(k_new, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + + +def prepare_k_cache_fn(k, h, cu_seqlens, BS, use_cache=False): + if not use_cache: + return None + else: + B, T, H, K = k.shape + k_new = torch.empty_like(k) + indices = prepare_chunk_indices(cu_seqlens, BS) if cu_seqlens is not None else None + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BS) if cu_seqlens is not None else None + NT = triton.cdiv(T, BS) if cu_seqlens is None else len(indices) + grid = (NT, B * H) + parallel_path_fwd_kernel_prepare_k_cache[grid]( + k=k, + k_new=k_new, + h=h, + offsets=cu_seqlens, + indices=indices, + chunk_offsets=chunk_offsets, + H=H, + T=T, + K=K, + BT=BS, + BK=triton.next_power_of_2(K) + ) + return k_new diff --git a/fla3/ops/rebased/__init__.py b/fla3/ops/rebased/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec6a0cb31f7f635aa528cad753d5e19196a2028 --- /dev/null +++ b/fla3/ops/rebased/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_rebased + +__all__ = [ + 'parallel_rebased' +] diff --git a/fla3/ops/rebased/__pycache__/__init__.cpython-310.pyc b/fla3/ops/rebased/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1749251d04bf4f188f8de568bbf90b1e275415e8 Binary files /dev/null and b/fla3/ops/rebased/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/rebased/__pycache__/parallel.cpython-310.pyc b/fla3/ops/rebased/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e998a0844234cde9d931c10385c7274510246f93 Binary files /dev/null and b/fla3/ops/rebased/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/rebased/naive.py b/fla3/ops/rebased/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a70242eb3c2dcc5918503af8b03a15b5740e4c2a --- /dev/null +++ b/fla3/ops/rebased/naive.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch + + +def naive_parallel_rebased( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, +) -> torch.Tensor: + if scale is None: + scale = q.shape[-1] ** -0.5 + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = attn ** 2 + attn.masked_fill_(~torch.tril(torch.ones(q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o diff --git a/fla3/ops/rebased/parallel.py b/fla3/ops/rebased/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..a1660c92a24f86751b90ee235e5927dddbabe223 --- /dev/null +++ b/fla3/ops/rebased/parallel.py @@ -0,0 +1,466 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models +# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py + + +@triton.jit(do_not_specialize=['T']) +def parallel_rebased_fwd_kernel( + q, + k, + v, + o, + z, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v*BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c*BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, i_c*BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c*BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c*BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit(do_not_specialize=['T']) +def _parallel_rebased_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dq, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k*BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c*BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c*BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c*BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, i_c*BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit(do_not_specialize=['T']) +def _parallel_rebased_bwd_dkv( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + # [BK, BTS] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BV, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BTL, BTS] + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale + b_s2 = b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit(do_not_specialize=['T']) +def parallel_rebased_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_rebased_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dq, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV + ) + tl.debug_barrier() + _parallel_rebased_bwd_dkv( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dk, + dv, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_rebased_fwd_kernel[grid]( + q, + k, + v, + o, + z, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_rebased_bwd_kernel[grid]( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +def parallel_rebased( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + eps: float = 1e-5, + use_scale: bool = True, + use_normalize: bool = True, + return_both: bool = False, + head_first: bool = False +): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = ParallelBasedFunction.apply(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla3/ops/retention/__init__.py b/fla3/ops/retention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a38ab43c9982c9751bb9db146b9d9fe05663964a --- /dev/null +++ b/fla3/ops/retention/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_retention +from .fused_chunk import fused_chunk_retention +from .fused_recurrent import fused_recurrent_retention +from .parallel import parallel_retention + +__all__ = [ + 'chunk_retention', + 'fused_chunk_retention', + 'parallel_retention', + 'fused_recurrent_retention' +] diff --git a/fla3/ops/retention/__pycache__/__init__.cpython-310.pyc b/fla3/ops/retention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ebe60dd1d508b87d4412448c360a8aca2dcdaac Binary files /dev/null and b/fla3/ops/retention/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/retention/__pycache__/__init__.cpython-312.pyc b/fla3/ops/retention/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a4ed9f877894a3b9f80535fa73fc75b8b0a4f4 Binary files /dev/null and b/fla3/ops/retention/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/retention/__pycache__/chunk.cpython-310.pyc b/fla3/ops/retention/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..672c76e023c8a21cabb3dae8e195210c7bc27fcb Binary files /dev/null and b/fla3/ops/retention/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/retention/__pycache__/chunk.cpython-312.pyc b/fla3/ops/retention/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46c525b963d8125f55d10877ed5b83b2869c38ac Binary files /dev/null and b/fla3/ops/retention/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/retention/__pycache__/fused_chunk.cpython-312.pyc b/fla3/ops/retention/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f99bd09b49f9af70dfbc483073e723cf53c33a3f Binary files /dev/null and b/fla3/ops/retention/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla3/ops/retention/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/retention/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65802967cac332cf1aefb49f98da1e26d9a7cabd Binary files /dev/null and b/fla3/ops/retention/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/retention/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/retention/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a406c9e3e3b947fbefae2498442d9cabba1bf5ba Binary files /dev/null and b/fla3/ops/retention/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/retention/__pycache__/parallel.cpython-310.pyc b/fla3/ops/retention/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..085bfbca742a7078222a254116c62b9a65647043 Binary files /dev/null and b/fla3/ops/retention/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/retention/chunk.py b/fla3/ops/retention/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5d771e61356e82803510e6b42a64b1c88de5b5 --- /dev/null +++ b/fla3/ops/retention/chunk.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.simple_gla.chunk import chunk_simple_gla + + +@torch.compiler.disable +def chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + o, final_state = chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla3/ops/retention/fused_chunk.py b/fla3/ops/retention/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb56bbe8e6be1c94ed6423c2139dea44df50e6e --- /dev/null +++ b/fla3/ops/retention/fused_chunk.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_retention_fwd_kernel( + q, + k, + v, + o, + h0, + ht, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0)) + + # d_b: overall decay for the entire chunk + # d_o: cumulative decay from the start of the chunk + # d_h: cumulative decay from the end of the chunk + d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_k*B*H+i_bh).to(tl.int64) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + NT = tl.cdiv(T, BT) + for i in range(0, NT): + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + if i == NT - 1 and (T % BT) != 0: + d_b = tl.math.exp2((T % BT) * b_b) + d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b) + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_retention_bwd_kernel( + q, + k, + v, + do, + dq, + dk, + dv, + h0, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0)) + d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b) + d_b = tl.math.exp2(BT * b_b) + + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + # [BT, K] + b_dq = tl.dot(b_ds, b_k, allow_tf32=False) + # [V, K] + if CHECK and i == 0: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + d_s = tl.trans(d_s) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H).to(tl.int64) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H).to(tl.int64) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [K, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_retention_fwd_kernel[grid]( + q, + k, + v, + o, + initial_state, + final_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = K ** -0.5 + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_retention_bwd_kernel[grid]( + q, + k, + v, + do, + dq, + dk, + dv, + initial_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if not head_first: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + o, final_state = FusedChunkRetentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla3/ops/retention/fused_recurrent.py b/fla3/ops/retention/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..b470814b31691e252d50a39b7cdc8d23fc835065 --- /dev/null +++ b/fla3/ops/retention/fused_recurrent.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla + + +def fused_recurrent_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + o, final_state = fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return o, final_state diff --git a/fla3/ops/retention/naive.py b/fla3/ops/retention/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..15611bf649779d2d956d2ab390b7d72dbb12201d --- /dev/null +++ b/fla3/ops/retention/naive.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import torch + + +def naive_retention(q, k, v): + orig_type = q.dtype + q, k, v = q.float(), k.float(), v.float() + _, n_heads, seq_len, d_head = q.shape + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() + n = q.new_tensor(range(seq_len), dtype=torch.float) + n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) + s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) + o = torch.einsum('bhqk,bhkd->bhqd', s, v) + return o.to(orig_type) diff --git a/fla3/ops/retention/parallel.py b/fla3/ops/retention/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..8e83339386de49aac5b02fab932487fb9f3eaf67 --- /dev/null +++ b/fla3/ops/retention/parallel.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.simple_gla.parallel import parallel_simple_gla + + +def parallel_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + output_attentions: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(q.shape[2]), dtype=torch.float))).log() + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]) + + o, attn = parallel_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + output_attentions=output_attentions, + cu_seqlens=cu_seqlens + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, attn diff --git a/fla3/ops/rwkv4/__init__.py b/fla3/ops/rwkv4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49de2cf83aeec67069b67e0972cfccef8a81383a --- /dev/null +++ b/fla3/ops/rwkv4/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .fused_recurrent import fused_recurrent_rwkv4 + +__all__ = [ + 'fused_recurrent_rwkv4' +] diff --git a/fla3/ops/rwkv4/fused_recurrent.py b/fla3/ops/rwkv4/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..670427619d1eb455d8e8b50cf0bb446df841cd8e --- /dev/null +++ b/fla3/ops/rwkv4/fused_recurrent.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Any, cast + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.autograd.function import Function, FunctionCtx, once_differentiable + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +def get_block_size_c(chans: int) -> int: + if chans < 32: + return 32 + if chans < 64: + return 64 + return 128 + + +@triton.jit +def fused_recurrent_rwkv4_forward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_c, + # WKV + wkv_ptr, + wkv_s_b, + wkv_s_t, + wkv_s_c, + # Output state + state_out_ptr, + state_out_s_b, + state_out_s_abe, + state_out_s_t, + state_out_s_c, + # Params + chans, + tsz, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + wkv_ptr = wkv_ptr + b_idx * wkv_s_b + alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b + beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe + eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe + + # Loads parameters. + alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) + + for t in range(tsz): + kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps) + e1a = exp(eps - tau) + e2a = exp(ukt - tau) + wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) + tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) + + w_eps = w + eps + eps = tl.maximum(w_eps, kt) + e1b = exp(w_eps - eps) + e2b = exp(kt - eps) + alpha = e1b * alpha + e2b * vt + beta = e1b * beta + e2b + tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) + tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) + tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) + + +def fused_recurrent_rwkv4_forward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, +) -> tuple[Tensor, Tensor]: + (bsz, tsz, chans) = k.shape + + # New tensors to output. + wkvs = k.new_empty(bsz, tsz, chans) + state_out = k.new_empty(bsz, 3, tsz, chans) + + # Constants. + block_size_c = get_block_size_c(chans) + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_forward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(3), + # WKV + wkvs, + wkvs.stride(0), + wkvs.stride(1), + wkvs.stride(2), + # Output state + state_out, + state_out.stride(0), + state_out.stride(1), + state_out.stride(2), + state_out.stride(3), + # Params + chans, + tsz, + BLOCK_SIZE_C=block_size_c, + ) + + state_out = torch.cat((state, state_out), dim=2) + + return wkvs, state_out + + +@triton.jit +def fused_recurrent_rwkv4_backward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_t, + state_s_c, + # WKV grad + gwkv_ptr, + gwkv_s_b, + gwkv_s_t, + gwkv_s_c, + # Output state grad + gstate_out_ptr, + gstate_out_s_b, + gstate_out_s_abe, + gstate_out_s_c, + # W grad + gw_ptr, + gw_s_b, + gw_s_c, + # U grad + gu_ptr, + gu_s_b, + gu_s_c, + # K grad + gk_ptr, + gk_s_b, + gk_s_t, + gk_s_c, + # V grad + gv_ptr, + gv_s_b, + gv_s_t, + gv_s_c, + # State grad + gstate_ptr, + gstate_s_b, + gstate_s_abe, + gstate_s_c, + # Params + tsz, + chans, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + gk_ptr = gk_ptr + b_idx * gk_s_b + gv_ptr = gv_ptr + b_idx * gv_s_b + + # Pointers to gradients which were recieved by the function. + gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b + galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe + geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe + + # Loads parameters. + galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) + + # Gradient accumulators. + gw = tl.zeros_like(w) + gu = tl.zeros_like(u) + + alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + for t in range(tsz): + tc = tsz - t - 1 + + kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) + + alpha_curr = alpha_prev + beta_curr = beta_prev + eps_curr = eps_prev + + alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps_prev) + e1 = exp(eps_prev - tau) + e2 = exp(ukt - tau) + + euke = exp(ukt + eps_prev - 2 * tau) + + denom = e1 * beta_prev + e2 + denom_sq = denom * denom + + gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) + + # Backpropagates wkv gradients. + guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq + gu += guk + gk = guk + gv = gwkvt * e2 / denom + + galpha_wkv = gwkvt * e1 / denom + gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq + geps_wkv_denom = e1 * beta_prev + e2 + geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) + + e1 = exp(w + eps_prev - eps_curr) + e2 = exp(kt - eps_curr) + + # Backpropagates alpha gradients. + galpha_we = galpha * e1 * alpha_prev + gw += galpha_we + gk += galpha * e2 * vt + gv += galpha * e2 + geps += galpha * -alpha_curr + + # Backpropagates beta gradients. + gbeta_we = gbeta * e1 * beta_prev + gw += gbeta_we + gk += gbeta * e2 + geps += gbeta * -beta_curr + + # Backpropagates epsilon gradients. + geps_mask = w + eps_prev > kt + geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) + gw += geps_we + gk += tl.where(geps_mask, tl.zeros_like(geps), geps) + + # Stores the gradients for k and v. + tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) + tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) + + # Computes new gradients for alpha and beta. + galpha = galpha * e1 + galpha_wkv + gbeta = gbeta * e1 + gbeta_wkv + geps = galpha_we + gbeta_we + geps_we + geps_wkv + + # Stores final gradients for alpha and beta. + galpha_ptr = gstate_ptr + b_idx * gstate_s_b + gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe + geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe + tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) + tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) + tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) + + # Stores final gradients for w and u. + tl.store(gw_ptr + gw_s_b * b_idx + gw_s_c * cs, gw*w, mask=cmask) + tl.store(gu_ptr + gu_s_b * b_idx + gu_s_c * cs, gu, mask=cmask) + + +def fused_recurrent_rwkv4_backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + grad_wkv: Tensor, + grad_state: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + bsz, tsz, chans = k.shape + + gw = w.new_empty(bsz, chans, dtype=torch.float) # New tensors to output. + gu = u.new_empty(bsz, chans, dtype=torch.float) + gk = torch.empty_like(k) + gv = torch.empty_like(v) + gstate = k.new_empty(bsz, 3, 1, chans) + + block_size_c = get_block_size_c(chans) # Constants. + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_backward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # WKV grad + grad_wkv, + grad_wkv.stride(0), + grad_wkv.stride(1), + grad_wkv.stride(2), + # Output state grad + grad_state, + grad_state.stride(0), + grad_state.stride(1), + grad_state.stride(3), + # W grad + gw, + gw.stride(0), + gw.stride(1), + # U grad + gu, + gu.stride(0), + gu.stride(1), + # K grad + gk, + gk.stride(0), + gk.stride(1), + gk.stride(2), + # V grad + gv, + gv.stride(0), + gv.stride(1), + gv.stride(2), + # State grad + gstate, + gstate.stride(0), + gstate.stride(1), + gstate.stride(3), + # Params + tsz, + chans, + BLOCK_SIZE_C=block_size_c, + ) + + return gw.sum(0), gu.sum(0), gk, gv, gstate + + +class FusedRecurrentRWKV4Function(Function): + + @staticmethod + @input_guard + def forward( + ctx: FunctionCtx, + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + ) -> tuple[Tensor, Tensor]: + ctx.w_dtype = w.dtype + w = -torch.exp(w.float()) + wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) + ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) + return wkv, state_out[:, :, -1:] + + @staticmethod + @once_differentiable + @input_guard + def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) + gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) + return gw.to(ctx.w_dtype), gu.to(u), gk.to(k), gv.to(v), gstate.to(state) + + +def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: + return FusedRecurrentRWKV4Function.apply(w, u, k, v, state) diff --git a/fla3/ops/rwkv6/__init__.py b/fla3/ops/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c7c218eb873a1a2115b5587530fe55f29a9d02 --- /dev/null +++ b/fla3/ops/rwkv6/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv6 +from .fused_recurrent import fused_recurrent_rwkv6 + +__all__ = [ + 'chunk_rwkv6', + 'fused_recurrent_rwkv6' +] diff --git a/fla3/ops/rwkv6/__pycache__/__init__.cpython-310.pyc b/fla3/ops/rwkv6/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1ba11c41623aaca9d4563b1ca2b5671a5a34a2 Binary files /dev/null and b/fla3/ops/rwkv6/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/rwkv6/__pycache__/__init__.cpython-312.pyc b/fla3/ops/rwkv6/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e4aaf72176874046579b1f557062454aa5fccb3 Binary files /dev/null and b/fla3/ops/rwkv6/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/rwkv6/__pycache__/chunk.cpython-310.pyc b/fla3/ops/rwkv6/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f2d3afa602cbc86f907f82ecb42f15c22020508 Binary files /dev/null and b/fla3/ops/rwkv6/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/rwkv6/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/rwkv6/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e769a51b9d2b8691b1c4f50c0974b72542c3380e Binary files /dev/null and b/fla3/ops/rwkv6/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/rwkv6/chunk.py b/fla3/ops/rwkv6/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..590807e2aa28517907bfa83ee5559721c0107abe --- /dev/null +++ b/fla3/ops/rwkv6/chunk.py @@ -0,0 +1,1304 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.common.chunk_h import chunk_fwd_h +from fla.ops.gla.chunk import chunk_gla_bwd_dA, chunk_gla_bwd_dv, chunk_gla_fwd_o_gk +from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps, num_stages=num_stages) + for BS in [16, 32, 64] + for num_warps in [4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['S', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_cumsum_kernel( + s, + oi, + oe, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32) + m_e = tl.where(o_i[:, None] > o_i[None, :], 1., 0.).to(tl.float32) + + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oi = tl.make_block_ptr(oi + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oe = tl.make_block_ptr(oe + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_oi = tl.dot(m_i, b_s) + b_oe = tl.dot(m_e, b_s) + tl.store(p_oi, b_oi.to(p_oi.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_oe, b_oe.to(p_oe.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_cumsum( + g: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + gi, ge = torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + chunk_rwkv6_fwd_cumsum_kernel[grid]( + g, + gi, + ge, + cu_seqlens, + chunk_indices, + T=T, + H=H, + S=S, + BT=BT, + ) + return gi, ge + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( + q, + k, + gi, # cumulative decay inclusive + ge, # cumulative decay exclusive + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + m_i = i_t * BT + i_i * BC + tl.arange(0, BC) < T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_i[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[:, None] - b_gk) + # [BC, BC] using tf32 to improve precision here. + b_A += tl.dot(b_qg, b_kg) + + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( + q, + k, + gi, + ge, + u, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += H*K + p_kj += H*K + p_gk += H*K + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC', 'BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( + q, + k, + gi, + ge, + u, + A, + cu_seqlens, + chunk_indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_tc // NC, i_tc % NC + i_j = i_i + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += H*K + p_kj += H*K + p_gk += H*K + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( + A, + A2, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + NK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_c * BC >= T: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(0, NK): + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + b_A += tl.load(p_A, boundary_check=(0, 1)) + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_dh( + q, + gi, + ge, + do, + dh, + dht, + dh0, + cu_seqlens, + chunk_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + p_gk = tl.make_block_ptr(ge + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gi + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk) * scale).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'NC', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_intra( + q, + k, + gi, + ge, + dA, + dq, + dk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_ge = tl.make_block_ptr(ge + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[None, :] - b_gk) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg) + b_dq *= exp(b_ge - b_gn[None, :]) + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = gi + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] > j + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_ge - b_gkj[None, :]), 0.) + p_kj += H*K + p_gkj += H*K + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = gi + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h*K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_j[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dk += tl.dot(b_dA, b_qg) + b_dk *= exp(b_gn[None, :] - b_gk) + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = ge + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * H*BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] < j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) + p_qj += H*K + p_gqj += H*K + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_inter( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_gk = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = gi + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK,], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_dgk *= exp(b_gn) + b_dq *= scale + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gi) + + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA_dig = dA + ((bos + i_t * BT + o_i) * H + i_h) * BT + o_i + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] - b_q * b_dq + # [BT,] + b_dA_dig = tl.load(p_dA_dig, mask=(i_t * BT + o_i) < T, other=0) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + # scale is already applied to b_dA_diag + b_dq += (b_dA_dig[:, None] * b_u[None, :] * b_k) + b_dk += (b_dA_dig[:, None] * b_u[None, :] * b_q) + b_du = tl.sum(b_dA_dig[:, None] * b_q * b_k, axis=0) + p_du = tl.make_block_ptr(du + (i_tg * H + i_h) * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + tl.store(p_du, b_du, boundary_check=(0,)) + + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + A = q.new_empty(B, T, H, BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_inter[grid]( + q, + k, + gi, + ge, + A, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + ) + + grid = (NT, NC, B * H) + # load the entire [BC, K] blocks into SRAM at once + if K <= 256: + BK = triton.next_power_of_2(K) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra[grid]( + q, + k, + gi, + ge, + u, + A, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + ) + # split then merge + else: + BK = min(128, triton.next_power_of_2(K)) + NK = triton.cdiv(K, BK) + A_intra = q.new_empty(NK, B, T, H, BC, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split[grid]( + q, + k, + gi, + ge, + u, + A_intra, + cu_seqlens, + chunk_indices, + scale, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + + grid = (NT, NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge[grid]( + A_intra, + A, + cu_seqlens, + chunk_indices, + B=B, + T=T, + H=H, + BT=BT, + BC=BC, + NK=NK, + ) + return A + + +def chunk_rwkv6_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.Tensor] = None, + chunk_size: int = 64, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT = len(cu_seqlens) - 1, len(chunk_indices) + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) + NG = HQ // H + + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_rwkv6_bwd_kernel_dh[grid]( + q=q, + gi=gi, + ge=ge, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + ) + return dh, dh0 + + +def chunk_rwkv6_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dA: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_rwkv6_bwd_kernel_intra[grid]( + q, + k, + gi, + ge, + dA, + dq, + dk, + cu_seqlens, + chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + return dq, dk + + +def chunk_rwkv6_bwd_dqkgu( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dA: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dg = torch.empty_like(g) + du = u.new_empty(B * NT, H, K, dtype=torch.float) + def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) + chunk_rwkv6_bwd_kernel_inter[grid]( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + du = du.sum(0) + return dq2, dk2, dg, du + + +def chunk_rwkv6_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + states_in_fp32=True + ) + # the intra A is kept in fp32 + # the computation has very marginal effect on the entire throughput + A = chunk_rwkv6_fwd_intra( + q=q, + k=k, + gi=gi, + ge=ge, + u=u, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + o = chunk_gla_fwd_o_gk( + q=q, + v=v, + g=ge, + A=A, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return A, h, ht, o + + +def chunk_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + states_in_fp32=True + ) + dh, dh0 = chunk_rwkv6_bwd_dh( + q=q, + k=k, + v=v, + gi=gi, + ge=ge, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + states_in_fp32=True + ) + + # dq dk in fp32 + dA = chunk_gla_bwd_dA( + v=v, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + dv = chunk_gla_bwd_dv( + k=k, + g=gi, + A=A, + do=do, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + dq, dk = chunk_rwkv6_bwd_dqk_intra( + q=q, + k=k, + gi=gi, + ge=ge, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + dq, dk, dg, du = chunk_rwkv6_bwd_dqkgu( + q=q, + k=k, + v=v, + h=h, + g=g, + gi=gi, + ge=ge, + u=u, + do=do, + dh=dh, + dA=dA, + dq=dq, + dk=dk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return dq, dk, dv, dg, du, dh0 + + +class ChunkRWKV6Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + g, + u, + scale, + initial_state, + output_final_state, + cu_seqlens, + ): + T = q.shape[1] + if check_shared_mem(): + chunk_size = min(32, max(32, triton.next_power_of_2(T))) + else: + chunk_size = min(64, max(32, triton.next_power_of_2(T))) + + A, h, ht, o = chunk_rwkv6_fwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + + ctx.save_for_backward(q, k, v, g, initial_state, A, u) + + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, initial_state, A, u = ctx.saved_tensors + chunk_size, scale, cu_seqlens = ctx.chunk_size, ctx.scale, ctx.cu_seqlens + dq, dk, dv, dg, du, dh0 = chunk_rwkv6_bwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + A=A, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None + + +@torch.compiler.disable +def chunk_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + w (torch.Tensor): + Forget gates of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` applied to keys. + u (torch.Tensor): + bonus representations of shape `[H]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (Optional[torch.Tensor]): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.rwkv6 import chunk_rwkv6 + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> r = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> w = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> u = torch.randn(H, K, device='cuda') + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_rwkv6( + r, k, v, w, u, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> r, k, v, w = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (r, k, v, w)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = r.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_rwkv6( + r, k, v, w, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + r, k, v, w = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, k, v, w)) + if not head_first and r.shape[1] < r.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if r.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = r.shape[-1] ** -0.5 + o, final_state = ChunkRWKV6Function.apply( + r, + k, + v, + w, + u, + scale, + initial_state, + output_final_state, + cu_seqlens, + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla3/ops/rwkv6/chunk_naive.py b/fla3/ops/rwkv6/chunk_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2ac664f5079a20eabe9b11c19c1cff6755c658 --- /dev/null +++ b/fla3/ops/rwkv6/chunk_naive.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def naive_chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + chunk_size: int = 32 +): + assert q.shape[-2] % chunk_size == 0 + orig_dtype = q.dtype + num_chunk = q.shape[-2] // chunk_size + u = u.unsqueeze(0) + + q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) + + w_cumsum = w.cumsum(-2) + + kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() + wkv = kw.transpose(-1, -2) @ v + + wkv_new = torch.zeros_like(wkv) + + for i in range(num_chunk - 1): + wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] + + o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) + + o_intra = torch.zeros_like(o_inter) + for i in range(chunk_size): + attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) + mask = (torch.arange(0, chunk_size) < i).to(attn.device) + attn.masked_fill_(~mask, 0) + intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) + intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] + o_intra[:, :, :, i] = intra_inter_o + intra_intra_o + o = o_inter + o_intra + return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) diff --git a/fla3/ops/rwkv6/fused_recurrent.py b/fla3/ops/rwkv6/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..bf61a3e0793a18e6c710644a28fdf8862bdbe96d --- /dev/null +++ b/fla3/ops/rwkv6/fused_recurrent.py @@ -0,0 +1,677 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_fwd_kernel( + q, # query [B, H, T, K]/[B, T, H, K] + k, # key [B, H, T, K]/[B, T, H, K] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] or None + u, # bonus [B, H, K] + o, # output [NK, B, H, T, V]/[NK, B, T, H, V] + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, # whether to reverse the recurrence + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_kv = b_k[:, None] * b_v[None, :] + b_o = tl.sum((b_h + b_kv * b_u[:, None]) * b_q[:, None], 0) + b_h = b_h * exp(b_w)[:, None] + b_kv + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_w += (-1 if REVERSE else 1) * H*K + p_o += (-1 if REVERSE else 1) * H*V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dq( + k, # key [B, H, T, V]/[B, T, H, V] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] + u, # bonus [B, H, K] + do, # gradient of output [B, H, T, V]/[B, T, H, V] + dq, # gradient of query [NV, B, H, T, K]/[NV, B, T, H, K] + dq1, # gradient of query_aux [NV, B, H, T, K]/[NV, B, T, H, K] + h0, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_dq1 = dq1 + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_kv = b_k[:, None] * b_v[None, :] + + b_hq = b_h * b_do[None, :] + b_dq = tl.sum(b_hq + b_kv * b_u[:, None] * b_do[None, :], 1) * scale + b_dq1 = tl.sum(b_hq, 1) + b_h = b_h * exp(b_w)[:, None] + b_h += b_kv + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + tl.store(p_dq1, b_dq1.to(p_dq1.dtype.element_ty), mask=mask_k) + + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_w += (-1 if REVERSE else 1) * H*K + p_do += (-1 if REVERSE else 1) * H*V + p_dq += (-1 if REVERSE else 1) * H*K + p_dq1 += (-1 if REVERSE else 1) * H*K + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dkv( + q, # query [B, H, T, K]/[B, T, H, K] + k, # key [B, H, T, V]/[B, T, H, V] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] + u, # bonus [B, H, K] + do, # gradient of output [B, H, T, V]/[B, T, H, V] + dk, # gradient of key [NV, B, H, T, K]/[NK, B, T, H, K] + dk1, # gradient of key_aux [NV, B, H, T, K]/[NK, B, T, H, K] + dv, # gradient of value [NK, B, H, T, V]/[NV, B, T, H, V] + dh0, # gradient of initial hidden state [N, H, K, V] + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_do = do + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_dk = dk + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dk1 = dk1 + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dv = dv + ((i_k * all + bos) + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for _ in range(T - 1, -1, -1): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dkv = b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], 1) + tl.store(p_dk1, b_dk.to(p_dk1.dtype.element_ty), mask=mask_k) + b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], 1) + b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], 0) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + b_dh *= exp(b_w)[:, None] + b_dh += b_dkv + + p_q += (-1 if not REVERSE else 1) * H*K + p_k += (-1 if not REVERSE else 1) * H*K + p_v += (-1 if not REVERSE else 1) * H*V + p_w += (-1 if not REVERSE else 1) * H*K + p_do += (-1 if not REVERSE else 1) * H*V + p_dk += (-1 if not REVERSE else 1) * H*K + p_dk1 += (-1 if not REVERSE else 1) * H*K + p_dv += (-1 if not REVERSE else 1) * H*V + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT, 'BK': BK}, num_warps=num_warps) + for BT in [16, 32, 64] + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + ], + key=['K'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dw( + q, + k, + dq, + dk, + dw, + cu_seqlens, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_k, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + NT = tl.cdiv(T, BT) + + o_i = tl.arange(0, BT) + m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) if not REVERSE else tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BK], dtype=tl.float32) + + i_t = 0 if not REVERSE else NT - 1 + for _ in range(NT): + p_q = tl.make_block_ptr(q + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_dq = tl.load(p_dq, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_dk = tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32) + b_dw = (b_q * b_dq * scale) - b_k * b_dk + b_c = b_z[None, :] + tl.dot(m_i, b_dw, allow_tf32=False) + tl.store(p_dw, b_c.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + if i_t >= 0: + b_z += tl.sum(b_dw, 0) + + i_t += (1 if not REVERSE else -1) + + +def fused_recurrent_rwkv6_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + ht = q.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + o = q.new_empty(NK, *v.shape, dtype=torch.float) + + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_fwd_kernel[grid]( + q, + k, + v, + w, + u, + o, + h0, + ht, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + ) + o = o.sum(0) + return o, ht + + +def fused_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + do: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape, dtype=torch.float) + dq1 = torch.empty_like(dq) + + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_bwd_kernel_dq[grid]( + k, + v, + w, + u, + do, + dq, + dq1, + initial_state, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + ) + dq = dq.sum(0) + dq1 = dq1.sum(0) + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dk = q.new_empty(NV, *k.shape, dtype=torch.float) + dk1 = q.new_empty(NV, *k.shape, dtype=torch.float) + dv = q.new_empty(NK, *v.shape, dtype=torch.float) + + dh0 = torch.empty_like(initial_state) if initial_state is not None else None + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_bwd_kernel_dkv[grid]( + q, + k, + v, + w, + u, + do, + dk, + dk1, + dv, + dh0, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + ) + dk = dk.sum(0) + dk1 = dk1.sum(0) + dv = dv.sum(0) + + dw = torch.empty_like(w) + def grid(meta): return (triton.cdiv(meta['K'], meta['BK']), N * H) + fused_recurrent_rwkv6_bwd_kernel_dw[grid]( + q, + k, + dq1, + dk1, + dw, + cu_seqlens, + scale, + T=T, + H=H, + K=K, + REVERSE=not reverse, + ) + du = (do.float() * v).sum(-1, True, dtype=torch.float) * q * k * scale + du = du.sum((0, 1)) + return dq, dk, dv, dw, du, dh0 + + +class FusedRecurrentRWKV6Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + o, ht = fused_recurrent_rwkv6_fwd( + q=q, + k=k, + v=v, + w=w, + u=u, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, w, u, initial_state) + ctx.scale = scale + ctx.reverse = reverse + ctx.cu_seqlens = cu_seqlens + return o.to(v), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, w, u, initial_state = ctx.saved_tensors + + dq, dk, dv, dw, du, dh0 = fused_recurrent_rwkv6_bwd( + q=q, + k=k, + v=v, + w=w, + u=u, + do=do, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + cu_seqlens=ctx.cu_seqlens, + ) + dh0 = dh0.to(initial_state) if dh0 is not None else dh0 + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0, None, None, None + + +def fused_recurrent_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + w (torch.Tensor): + data-dependent decays of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `[H, K]` + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (Optional[torch.Tensor]): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.rwkv6 import fused_recurrent_rwkv6 + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> u = torch.randn(H, K, device='cuda') + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_rwkv6( + q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + r, k, v, w = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (r, k, v, w)) + if not head_first and r.shape[1] < r.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if r.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = FusedRecurrentRWKV6Function.apply( + r, + k, + v, + w, + u, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla3/ops/rwkv6/recurrent_naive.py b/fla3/ops/rwkv6/recurrent_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2268759b5d4ce7f9be1be1f9c2e1a2f2a8e6c3 --- /dev/null +++ b/fla3/ops/rwkv6/recurrent_naive.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +): + orig_dtype = q.dtype + B, H, T, K, V = *q.shape, v.shape[-1] + q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + + if scale is None: + scale = K ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(T): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] + o[:, :, i] = o_i.sum(-2) + h = h * w_i[..., None] + kv_i + ht = h if output_final_state else None + return o.to(orig_dtype), ht + + +@torch.no_grad +@torch.jit.script +def naive_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + initial_state: Optional[torch.Tensor] = None +): + q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do)) + B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + dq = torch.zeros_like(q) + dq_aux = torch.zeros_like(q) + + if initial_state is not None: + h += initial_state + + for i in range(T): + k_i = k[:, :, i] + v_i = v[:, :, i] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h_i = (h + u[None, ..., None] * kv_i) + dq_i = (do[:, :, i, None, :] * h_i).sum(-1) + dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) + dq[:, :, i] = dq_i + dq_aux[:, :, i] = dq_aux_i + h = h * w_i[..., None] + kv_i + + du = torch.zeros_like(u) + dh = torch.zeros_like(h) + dk = torch.zeros_like(k) + dk_aux = torch.zeros_like(k) + dv = torch.zeros_like(v) + + for i in range(T - 1, -1, -1): + d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] + k_i = k[:, :, i] + v_i = v[:, :, i] + du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) + du += du_i.sum(0) + dk_i = (dh * v_i[..., None, :]).sum(-1) + dk_aux[:, :, i] = dk_i + dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) + dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) + dv_i += (dh * k_i[..., None]).sum(-2) + + dk[:, :, i] = dk_i + dv[:, :, i] = dv_i + dh = dh * w[:, :, i, :, None].exp() + d_kv_i + + # dw = q * dq_aux - k * dk_aux + dw = torch.zeros_like(w) + for i in range(T - 2, -1, -1): + dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] + + return dq, dk, dv, dw, du, dh diff --git a/fla3/ops/rwkv7/RWKV7(Goose).md b/fla3/ops/rwkv7/RWKV7(Goose).md new file mode 100644 index 0000000000000000000000000000000000000000..ba16def2b2e15687d7636fce6a3c79fdb3bda5b0 --- /dev/null +++ b/fla3/ops/rwkv7/RWKV7(Goose).md @@ -0,0 +1,567 @@ +# RWKV7 (Goose) Mechanism: Mathematical Derivation + +Zhiyuan Li + +## Introduction to RWKV-7 Architecture + +RWKV-7 employs **Dynamic State Evolution** that transcends the fundamental TC0 expressivity limitations of attention/linear attention paradigms. RWKV-7 possesses NC1 expressivity, allowing it to solve many problems that attention mechanisms cannot. + +In simple terms, traditional attention mechanisms (like Transformer's QKV-softmax-attention) store multiple {k,v} (key and value vector pairs), matching queries (q alias named r in RWKV) against keys to retrieve corresponding values. + +RWKV-7 takes a different approach - rather than directly storing {k,v} pairs, it dynamically updates a state by learning relationships between keys and values from context. This updated state then processes new input queries (q, or r in RWKV terminology) to produce outputs[^1]. + +[^1]: For a more detailed explanation of this approach, see the original article by the RWKV author: https://mp.weixin.qq.com/s/kC_Z3vuQ5B4PiRwZVeIvHQ + +Specifically, RWKV-7 maintains an internal model $v≈kS^⊤$. It aims to fit a simple objective: for given vector sequences {kt} and {vt}, use state S to transform ki into vi, making the output v as close as possible to the target v. + +To achieve this, during inference with an L2 loss function $L=½‖v−kS^⊤‖²$, RWKV-7 automatically simulates dynamic gradient descent to continuously train its internal model $v≈kS^⊤$. + +The gradient is: **$∂L/∂S = S_k^T k - v^T k$** + +Therefore, the gradient descent update (with weight decay factors $d_t = \exp(-\exp(w_t))$ and learning rate parameters) is: $$S_t = S_{t-1} \cdot \text{Diag}(d_t) - \eta_t \cdot (k_t^T k_t S_{t-1} - k_t^T v_t)$$ This simplifies to: + +$$S_t = S_{t-1} \cdot \text{Diag}(d_t) - \eta_t \cdot k_t^T k_t \cdot S_{t-1} + \eta_t \cdot k_t^T v_t$$ + +$$S_t = S_{t-1} \cdot (\text{Diag}(d_t) - \eta_t \cdot k_t^T k_t) + \eta_t \cdot k_t^T v_t$$ + +In the full RWKV-7 implementation, this gradient descent update is generalized by replacing the terms as follows: + +- $\text{Diag}(d_t)$ becomes $D_t$ (the diagonal decay matrix) +- The term $-\eta_t \cdot k_t^T k_t$ is generalized to $\alpha_t \beta_t^T$, where: + - $\alpha_t$ can be initialized as $-\eta_t \cdot k_t$ + - $\beta_t$ can be initialized as $k_t$ +- The term $\eta_t \cdot k_t^T v_t$ becomes $v_t k_t^T$ with appropriate scaling of $k_t$ + +This leads to the final recurrence equation[^2]: + +[^2]: For a more detailed explanation, see the triton codes. Note: In the optimized Triton implementation, `w` is already the log of the decay factor, so there's only one exponential operation needed. https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/rwkv7/fused_recurrent.py#L94 + +$$S_t = S_{t-1} \cdot D_t + S_{t-1} \cdot \alpha_t \beta_t^T + v_t k_t^T \in \mathbb{R}^{d_v \times d_k}$$ + +This formulation allows more flexibility in how the state evolves while maintaining the core gradient descent learning dynamics. + +The output at each timestep is computed as: + +$o_t = S_t \cdot q_t$ + +Where $q_t \in \mathbb{R}^{d_k}$ is the query vector (named $r$ in RWKV terminology), typically scaled by a factor of $\frac{1}{\sqrt{d_k}}$. This formulation allows RWKV-7 to continuously adapt its internal representation based on context, transcending the limitations of traditional attention mechanisms. + +## 1. Forward Pass Recurrence Equation + +In the implementation, the state update is defined as: + +For each batch (bi) and head (hi), at time step t: + +```python +w_t = torch.exp(-torch.exp(w[bi, hi, t])) # shape [K] +sa = (state[bi, hi] * a_t[None, :]).sum(dim=1) # shape [V] +state[bi, hi] = w_t[None, :] * state[bi, hi] + sa[:, None] * b_t[None, :] + k_t[None, :] * v_t[:, None] +``` + +Where state[bi, hi] has shape [V, K], representing a state matrix that maps from K-dimensional keys to V-dimensional values. + +## 2. Backward Pass Derivation + +### 2.1 Gradient of Loss w.r.t. State + +For time step t, if L is the loss function, dstate_curr = ∂L/∂state[bi, hi, t+1] is the gradient of the current state: + +``` +dstate_curr = dstate[bi, hi] + q_t[None, :] * doutput[bi, hi, t][:, None] +``` + +This includes gradients propagated from future time steps dstate[bi, hi] and gradients from the current output. + +### 2.2 Gradient w.r.t. Query q_t + +``` +dq[bi, hi, t] = torch.matmul(doutput[bi, hi, t], curr_state) * scale +``` + +### 2.3 Gradient w.r.t. Decay Parameter w_t + +For the gradient of w_t, we need to consider how it affects the state update: + +1. For the `w_t[None, :] * state[bi, hi]` component of the state update: + +First, compute the derivative of L with respect to w_t: +``` +∂L/∂w_t[k] = ∑_v (dstate_curr[v,k] * prev_state[v,k]) +``` + +This equation sums over the v dimension for each position k, resulting in a vector of shape [K]. + +Then, compute the derivative of w_t with respect to w: +``` +∂w_t[k]/∂w[k] = -exp(w[k]) * exp(-exp(w[k])) = -exp(w[k]) * w_t[k] +``` + +Finally, apply the chain rule: +``` +∂L/∂w[k] = ∂L/∂w_t[k] * ∂w_t[k]/∂w[k] + = (∑_v dstate_curr[v,k] * prev_state[v,k]) * (-exp(w[k]) * w_t[k]) +``` + +In code, this is expressed as: +```python +dw[bi, hi, t] += -torch.sum(dstate_curr * prev_state, dim=0) * torch.exp(w[bi, hi, t]) * w_t +``` + +Or equivalently: +```python +dw[bi, hi, t] += -torch.sum(dstate_curr * prev_state, dim=0) * torch.exp(w[bi, hi, t]) * torch.exp(-torch.exp(w[bi, hi, t])) +``` + +### 2.4 Gradient w.r.t. k_t and v_t + +For the `k_t[None, :] * v_t[:, None]` component: + +```python +dk[bi, hi, t] += torch.sum(dstate_curr * v_t[:, None], dim=0) +dv[bi, hi, t] += torch.sum(dstate_curr * k_t[None, :], dim=1) +``` + +### 2.5 Gradient w.r.t. α_t and β_t (a_t and b_t in code) + +For the `sa[:, None] * b_t[None, :]` component, where `sa = (state[bi, hi] * a_t[None, :]).sum(dim=1)`: + +```python +db[bi, hi, t] += torch.sum(dstate_curr * sa[:, None], dim=0) +dsa = torch.sum(dstate_curr * b_t[None, :], dim=1) +da[bi, hi, t] += torch.sum(prev_state * dsa[:, None], dim=0) +``` + +### 2.6 Gradient w.r.t. Previous State S_{t-1} + +Finally, we compute the gradient of the previous state for backpropagation: + +```python +dstate_from_sa = a_t[None, :] * dsa[:, None] +dstate_from_decay = dstate_curr * w_t[None, :] +dstate[bi, hi] = dstate_from_sa + dstate_from_decay +``` + +```python +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def naive_recurrent_rwkv7( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + a: torch.Tensor, # Dynamic learning rate modulator + b: torch.Tensor, # State update modulator + scale: float = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = True, +): + """ + Naive recurrent implementation of RWKV-7 (Goose) attention mechanism. + Modified from bo's code. + https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo.py#L170 + + Args: + q, k, v: Query, Key, and Value tensors + w: Time decay weights + a: Dynamic learning rate modulator, influences the in-context learning rate + b: State update modulator, directly participates in state update calculation + scale: Scaling factor for attention scores + initial_state: Initial state for the recurrent computation + output_final_state: Whether to output the final state + + Returns: + Attention output and optionally the final state + """ + torch_dtype = q.dtype if q.dtype in [torch.float64, torch.float] else torch.float + orig_dtype = q.dtype + B, H, L, N, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + q, k, v, w, a, b = (x.to(dtype=torch_dtype) for x in (q, k, v, w, a, b)) + # q, k, v, a, b, w, + # shape: (B, H, L, D), (B, H, L, D), (B, H, T, V), (B, H, L, D), (B, H, L, D), (B, H, L, D) + state = torch.zeros(B, H, V, N, dtype=torch_dtype, device=q.device) + o = torch.zeros_like(v) + + if scale == -1.0: + scale = N ** -0.5 + + if initial_state is not None: + state += initial_state.to(dtype=torch_dtype) + + for t in range(L): + q_t = q[:, :, t] * scale + k_t = k[:, :, t] + v_t = v[:, :, t] + a_t = a[:, :, t] + b_t = b[:, :, t] + + # from bo's code + sab = torch.einsum('bhik,bhk,bhj->bhij', state, a_t, b_t) + state = state * torch.exp(-torch.exp(w[:, :, t, None, :])) + sab + torch.einsum('bhj,bhi->bhij', k_t, v_t) + o[:, :, t] = torch.einsum('bhj,bhij->bhi', q_t, state) + + if not output_final_state: + ht = None + elif initial_state is not None: + ht = state.to(initial_state.dtype) + else: + ht = state.to(orig_dtype) + + return o.to(orig_dtype), ht + + +def naive_recurrent_rwkv7_2( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + a: torch.Tensor, # Dynamic learning rate modulator + b: torch.Tensor, # State update modulator + scale: float = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = True, +): + """ + Naive recurrent implementation of RWKV-7 (Goose) attention mechanism. + + Args: + q, k, v: Query, Key, and Value tensors + w: Time decay weights + a: Dynamic learning rate modulator, influences the in-context learning rate + b: State update modulator, directly participates in state update calculation + scale: Scaling factor for attention scores + initial_state: Initial state for the recurrent computation + output_final_state: Whether to output the final state + + Returns: + Attention output and optionally the final state + """ + torch_dtype = q.dtype if q.dtype in [torch.float64, torch.float] else torch.float + orig_dtype = q.dtype + B, H, L, N, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + q, k, v, w, a, b = (x.to(dtype=torch_dtype) for x in (q, k, v, w, a, b)) + # q, k, v, a, b, w, + # shape: (B, H, L, D), (B, H, L, D), (B, H, T, V), (B, H, L, D), (B, H, L, D), (B, H, L, D) + state = torch.zeros(B, H, V, N, dtype=torch_dtype, device=q.device) + o = torch.zeros_like(v) + + if scale == -1.0: + scale = N ** -0.5 + + if initial_state is not None: + state += initial_state.to(dtype=torch_dtype) + + for t in range(L): + for bi in range(B): + for hi in range(H): + q_t = q[bi, hi, t] * scale + k_t = k[bi, hi, t] + v_t = v[bi, hi, t] + a_t = a[bi, hi, t] + b_t = b[bi, hi, t] + w_t = torch.exp(-torch.exp(w[bi, hi, t])) + + # h: [V, K], a_t [K] -> [1, K] + # sa: [V] + sa = (state[bi, hi] * a_t[None, :]).sum(dim=1) + + state[bi, hi] = w_t[None, :] * state[bi, hi] + sa[:, None] * b_t[None, :] + k_t[None, :] * v_t[:, None] + y = (state[bi, hi] * q_t[None, :]).sum(dim=1) + + o[bi, hi, t] = y + + ht = state if output_final_state else None + return o.to(orig_dtype), ht + + +@torch.no_grad() +def naive_recurrent_rwkv7_2_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + doutput: torch.Tensor, + dh_t: Optional[torch.Tensor] = None, + scale: float = 1.0, + dtype: Optional[torch.dtype] = None +): + """ + Backward pass for the naive_recurrent_rwkv7_2 implementation. + + Args: + q, k, v, w, a, b: Original forward pass inputs + doutput: Gradient of the loss with respect to the output + dh_t: Gradient of the loss with respect to the final state (if any) + scale: Scaling factor used in the forward pass + dtype: Optional dtype for computation + + Returns: + Gradients with respect to all inputs + """ + torch_dtype = q.dtype if q.dtype in [torch.float64, torch.float] else torch.float + q, k, v, w, a, b, doutput = (x.to(dtype=torch_dtype) for x in (q, k, v, w, a, b, doutput)) + if dh_t is not None: + dh_t = dh_t.to(dtype=torch_dtype) + + B, H, L, N, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + + # Initialize gradients + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + dw = torch.zeros_like(w) + da = torch.zeros_like(a) + db = torch.zeros_like(b) + + # Initialize state gradients + dstate = torch.zeros(B, H, V, N, dtype=torch_dtype, device=q.device) + if dh_t is not None: + dstate += dh_t + + if scale == -1.0: + scale = N ** -0.5 + + # First rebuild all states from forward pass + states = [] + state = torch.zeros(B, H, V, N, dtype=torch_dtype, device=q.device) + states.append(state.clone()) + + # In practice, we don't recompute all states from the beginning. + # Instead, we use checkpointing: we save states at regular intervals (e.g., every 16 tokens) + # during the forward pass, then reconstruct intermediate states during the backward pass + # by working backwards from the nearest checkpoint. + # + # For example, to get state[t-1] from state[t]: + # state[t-1] = (state[t] - (sa * b_t + k_t * v_t)) / w_t + # + # This approach balances memory usage and computational efficiency: + # - Reduces memory by not storing every state + # - Maintains numerical stability by limiting the number of backward steps from each checkpoint + # - Allows efficient gradient computation without recomputing the entire sequence + for t in range(L): + for bi in range(B): + for hi in range(H): + q_t = q[bi, hi, t] * scale + k_t = k[bi, hi, t] + v_t = v[bi, hi, t] + a_t = a[bi, hi, t] + b_t = b[bi, hi, t] + w_t = torch.exp(-torch.exp(w[bi, hi, t])) + + sa = (state[bi, hi] * a_t[None, :]).sum(dim=1) + + state[bi, hi] = w_t[None, :] * state[bi, hi] + sa[:, None] * b_t[None, :] + k_t[None, :] * v_t[:, None] + states.append(state.clone()) + + # Backward pass through time + for t in range(L-1, -1, -1): + for bi in range(B): + for hi in range(H): + q_t = q[bi, hi, t] * scale + k_t = k[bi, hi, t] + v_t = v[bi, hi, t] + a_t = a[bi, hi, t] + b_t = b[bi, hi, t] + w_scalar = w[bi, hi, t] + w_exp = torch.exp(w_scalar) + w_t = torch.exp(-w_exp) + + curr_state = states[t+1][bi, hi] # State after update [V, K] + prev_state = states[t][bi, hi] # State before update [V, K] + + dq[bi, hi, t] += torch.matmul(doutput[bi, hi, t], curr_state) * scale + + dstate_from_out = q_t[None, :] * doutput[bi, hi, t][:, None] # [V, K] + + dstate_curr = dstate[bi, hi] + dstate_from_out + + sa = (prev_state * a_t[None, :]).sum(dim=1) # [V] + + # state[bi, hi] = w_t[None, :] * prev_state + ... + dw[bi, hi, t] += -torch.sum(dstate_curr * prev_state, dim=0) * \ + w_t * w_exp + + # k_t[None, :] * v_t[:, None] -> [V, K] + dk[bi, hi, t] += torch.sum(dstate_curr * v_t[:, None], dim=0) + dv[bi, hi, t] += torch.sum(dstate_curr * k_t[None, :], dim=1) + + # sa[:, None] * b_t[None, :] -> [V, K] + db[bi, hi, t] += torch.sum(dstate_curr * sa[:, None], dim=0) + dsa = torch.sum(dstate_curr * b_t[None, :], dim=1) # [V] + + # sa = (prev_state * a_t[None, :]).sum(dim=1) + da[bi, hi, t] += torch.sum(prev_state * dsa[:, None], dim=0) + dstate_from_sa = a_t[None, :] * dsa[:, None] # [V, K] + + # w_t[None, :] * prev_state + dstate_from_decay = dstate_curr * w_t[None, :] # [V, K] + + dstate[bi, hi] = dstate_from_sa + dstate_from_decay + + return dq, dk, dv, dw, da, db, dstate + + +class NativeRecurrentRWKV7Function(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, a, b, scale, initial_state, + training: bool = True, dtype: Optional[torch.dtype] = None, + state_ckpt_interval: int = 16): + o, ht = naive_recurrent_rwkv7_2(q, k, v, w, a, b, scale=scale, initial_state=initial_state) + if training: + ctx.save_for_backward(q, k, v, w, a, b) + ctx.scale = scale + ctx.dtype = dtype + ctx.ckpt_interval = state_ckpt_interval + ctx.use_initial_state = initial_state is not None + return o, ht + + @staticmethod + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, w, a, b = ctx.saved_tensors + dq, dk, dv, dw, da, db, dh = naive_recurrent_rwkv7_2_bwd( + q, k, v, w, a, b, do, dht, ctx.scale, dtype=ctx.dtype) + dh = dh if ctx.use_initial_state else None + return dq, dk, dv, dw, da, db, None, dh, None, None + + +def recurrent_rwkv7( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = 1.0, + initial_state: torch.Tensor = None, + output_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + r (torch.Tensor): + r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + w (torch.Tensor): + decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`, kernel + will apply log_w = -torch.exp(w) + log_w (torch.Tensor): + log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (float): + scale of the attention. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (bool): + whether to use head first. Recommended to be False to avoid extra transposes. + """ + assert cu_seqlens is None + assert head_first is True + assert w is not None + if scale == -1.0: + scale = q.shape[-1] ** -0.5 + o, final_state = NativeRecurrentRWKV7Function.apply(q, k, v, w, a, b, scale, initial_state) + + return o, final_state + + +def test_autograd_function(): + """Test the custom autograd function implementation""" + # Set random seed for reproducibility + torch.manual_seed(42) + + # Define test dimensions + B, H, T, D = 1, 1, 64, 64 + device = 'cpu' + dtype = torch.float64 + + # Create random test inputs + + q = torch.empty(B, H, T, D, device=device).uniform_(-1, 1).to(dtype=dtype).requires_grad_(True) + k = torch.empty(B, H, T, D, device=device).uniform_(-1, 1).to(dtype=dtype).requires_grad_(True) + v = torch.empty(B, H, T, D, device=device).uniform_(-1, 1).to(dtype=dtype).requires_grad_(True) + w = torch.empty(B, H, T, D, device=device).uniform_(-8, -6).to(dtype=dtype).requires_grad_(True) + + kk = torch.empty(B, H, T, D, device=device).uniform_(-1, 1) + kk = torch.nn.functional.normalize(kk, dim=-1).to(dtype=dtype) + + a = -kk.clone().requires_grad_(True) # -kk + a_scale = torch.empty(B, H, T, D, device=device).uniform_(0, 0.1).to(dtype=dtype) + b = (kk * a_scale).requires_grad_(True) # kk*a + + # Create initial state + initial_state = torch.zeros(B, H, V, N).to(torch.float64) + + # Clone inputs for the two paths we're testing + q1, k1, v1, w1, a1, b1 = q.clone().detach().requires_grad_(True), k.clone().detach().requires_grad_(True), v.clone().detach().requires_grad_( + True), w.clone().detach().requires_grad_(True), a.clone().detach().requires_grad_(True), b.clone().detach().requires_grad_(True) + q2, k2, v2, w2, a2, b2 = q.clone().detach().requires_grad_(True), k.clone().detach().requires_grad_(True), v.clone().detach().requires_grad_( + True), w.clone().detach().requires_grad_(True), a.clone().detach().requires_grad_(True), b.clone().detach().requires_grad_(True) + + # Path 1: Using naive implementation with autograd + + output1, state1 = naive_recurrent_rwkv7(q1, k1, v1, w1, a1, b1, initial_state=initial_state.clone()) + + output2, state2 = recurrent_rwkv7(q2, k2, v2, w2, a2, b2, 1.0, initial_state.clone()) + + # Check forward pass equivalence + output_diff = torch.max(torch.abs(output1 - output2)).item() + state_diff = torch.max(torch.abs(state1 - state2)).item() + + print(f"\nAutograd Function test (forward):") + print(f" Max output difference: {output_diff:.6e}") + print(f" Max state difference: {state_diff:.6e}") + + # Create loss function to test backward pass + def compute_loss(output, state): + return output.sum() # + state.sum() + + # # # Compute loss and gradients for both paths + loss1 = compute_loss(output1, state1) + loss1.backward() + + loss2 = compute_loss(output2, state2) + loss2.backward() + + # # Compare gradients + grad_diffs = { + 'q': torch.max(torch.abs(q1.grad - q2.grad)).item(), + 'k': torch.max(torch.abs(k1.grad - k2.grad)).item(), + 'v': torch.max(torch.abs(v1.grad - v2.grad)).item(), + 'w': torch.max(torch.abs(w1.grad - w2.grad)).item(), + 'a': torch.max(torch.abs(a1.grad - a2.grad)).item(), + 'b': torch.max(torch.abs(b1.grad - b2.grad)).item(), + } + + print(f"\nAutograd Function test (backward):") + for param, diff in grad_diffs.items(): + print(f" Max {param} gradient difference: {diff:.6e}") + + +test_autograd_function() + +``` diff --git a/fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc b/fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..953d374b2e7857f02e3c197841d2197dd8033ed1 Binary files /dev/null and b/fla3/ops/rwkv7/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc b/fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9389c4706df4130e0e304c2b7a6d67c840aed167 Binary files /dev/null and b/fla3/ops/rwkv7/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc b/fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b762b4bc1c937035d9a6810b46db3870054d61fb Binary files /dev/null and b/fla3/ops/rwkv7/__pycache__/fused_addcmul.cpython-310.pyc differ diff --git a/fla3/ops/rwkv7/channel_mixing.py b/fla3/ops/rwkv7/channel_mixing.py new file mode 100644 index 0000000000000000000000000000000000000000..991ea426f086859b8eaf0623f5acf059f9bddc5c --- /dev/null +++ b/fla3/ops/rwkv7/channel_mixing.py @@ -0,0 +1,323 @@ +import logging + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_pytorch_version, input_guard, use_cuda_graph + +logger = logging.getLogger(__name__) + +if not check_pytorch_version('2.4'): + logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def rwkv_seq_mix_kernel( + x_ptr, + x_prev_ptr, + mix_k_ptr, + output_ptr, + batch_size: tl.constexpr, + token_length, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr +): + block_start = tl.program_id(0) * BLOCK_SIZE + block_idx = block_start + tl.arange(0, BLOCK_SIZE)[:] + + total_seq_dim = token_length * hidden_dim + batch_idx = block_idx // total_seq_dim + seq_and_feat = block_idx % total_seq_dim + seq_idx = seq_and_feat // hidden_dim + feat_idx = seq_and_feat % hidden_dim + + is_valid = (batch_idx < batch_size) & (seq_idx < token_length) + + x_idx = batch_idx * total_seq_dim + seq_idx * hidden_dim + feat_idx + + curr_x = tl.load(x_ptr + x_idx, mask=is_valid, other=0.0).to(tl.float32) + k_value = tl.load(mix_k_ptr + feat_idx).to(tl.float32) + + is_first = seq_idx < 1 + prev_state_idx = batch_idx * hidden_dim + feat_idx + prev_state = tl.load(x_prev_ptr + prev_state_idx, + mask=(is_first & is_valid), + other=0.0).to(tl.float32) + + prev_x_idx = x_idx - hidden_dim + prev_x = tl.load(x_ptr + prev_x_idx, + mask=(~is_first & is_valid), + other=0.0).to(tl.float32) + + prev_value = tl.where(is_first, prev_state, prev_x) + state_diff = prev_value - curr_x + mixed = state_diff * k_value + result = tl.cast(curr_x + mixed, dtype=output_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(output_ptr + x_idx, result, mask=is_valid) + + +@triton.jit +def rwkv_channel_mixing_pow_and_relu( + in_ptr, + out_ptr, + BLOCK_SIZE: tl.constexpr +): + """Fused ReLU and Power operation: x = ReLU(x)^2""" + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE) + x0 = xindex + x = tl.load(in_ptr + (x0), None) + x = tl.maximum(x, 0.0).to(tl.float32) + x = tl.cast(x * x, dtype=out_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(out_ptr + (x0), x, None) + + +def rwkv_mix_torch(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor): + if x_prev.dim() == 2: + x_prev = x_prev.unsqueeze(1) # (batch_size, 1, hidden_dim) + xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x + k = x.addcmul(xx, x_k) + return k + + +def rwkv_relu_and_square_torch(x: torch.Tensor): + return torch.relu(x) ** 2 + + +def rwkv_mix_fwd(x, x_prev, x_k): + has_batch = x.dim() == 3 + + if has_batch: + batch_size, token_length, hidden_dim = x.shape + else: + token_length, hidden_dim = x.shape + batch_size = 1 + x = x.unsqueeze(0) + x_prev = x_prev.unsqueeze(0) + + token_length = x.shape[1] + hidden_dim = x.shape[2] + total_elements = batch_size * token_length * hidden_dim + + output = torch.empty_like(x) + + def grid(meta): return ( + (total_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0 + 1, # grid_1 + 1 # grid_2 + ) + + rwkv_seq_mix_kernel[grid]( + x.contiguous(), + x_prev.contiguous(), + x_k.squeeze(), + output, + batch_size=batch_size, + token_length=token_length, + hidden_dim=hidden_dim, + ) + if not has_batch: + output = output.squeeze(0) + return output + + +def rwkv_relu_and_square_fwd(x: torch.Tensor, inplace: bool = True): + """ + Triton implementation of RWKV's ReLU and square operation + Args: + x: Input tensor + Returns: + Tensor after ReLU and square operations + """ + x = x.contiguous() + output = x if inplace else torch.empty_like(x) + + def grid(meta): return ( + (output.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0 + 1, # grid_1 + 1 # grid_2 + ) + rwkv_channel_mixing_pow_and_relu[grid]( + x, + output, + BLOCK_SIZE=4096, + ) + + return output + + +@triton.jit +def relu_square_bwd_kernel( + out_ptr, + forward_input_ptr, + BLOCK_SIZE: tl.constexpr +): + """ReLU(x)^2 backward kernel + grad_input = grad_output * 2 * x if x > 0 else 0 + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + x = tl.load(forward_input_ptr + offsets).to(tl.float32) + grad = tl.load(out_ptr + offsets).to(tl.float32) + + x = tl.maximum(x, 0.0) + + grad_input = grad * 2 * x + + tl.store(out_ptr + offsets, grad_input.to(out_ptr.dtype.element_ty)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def rwkv_mix_bwd_kenel( + dk1_ptr0, + xk_ptr, + dx_ptr, + dx_prev_ptr, + batch_size, + token_length, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + batch_idx = offsets // (token_length * hidden_dim) + seq_feat = offsets % (token_length * hidden_dim) + seq_idx = seq_feat // hidden_dim + feat_idx = seq_feat % hidden_dim + + is_valid = offsets < (batch_size * token_length * hidden_dim) + + dk1 = tl.load(dk1_ptr0 + offsets, mask=is_valid) + xk = tl.load(xk_ptr + feat_idx, mask=is_valid) + prod = dk1 * xk + + mask_next = seq_idx < (token_length - 1) + next_offset = offsets + hidden_dim + dk1_next = tl.load(dk1_ptr0 + next_offset, mask=mask_next & is_valid, other=0.0) + prod_next = dk1_next * xk + dx_val = dk1 - prod + tl.where(mask_next, prod_next, 0.0) + dx_val = tl.cast(dx_val, dtype=dx_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(dx_ptr + offsets, dx_val, mask=is_valid) + + dx_prev_offset = batch_idx * hidden_dim + feat_idx + is_first_step = seq_idx == 0 + + tl.store( + dx_prev_ptr + dx_prev_offset, + tl.cast(prod, dtype=dx_prev_ptr.dtype.element_ty), + mask=is_first_step + ) + + +@torch.compile(fullgraph=True) +def compute_x_k_grad(dk1, x, x_prev): + """ + Args: + dk1: (batch*seq_len, hidden_dim) + x: (batch, seq_len, hidden_dim) + x_prev: (batch, hidden_dim) or (batch, 1, hidden_dim) + """ + + if x_prev.dim() == 2: + x_prev = x_prev.unsqueeze(1) # (batch, 1, hidden_dim) + xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x # (batch, seq_len, hidden_dim) + + # (hidden_dim,) --> (1, 1, hidden_dim) + grad_x_k = (dk1 * xx.reshape(-1, x.shape[2])).sum(dim=0).view(1, 1, -1) + return grad_x_k + + +def rwkv_channel_mixing_bwd(grad_output, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, inplace=True): + batch_size = x.shape[0] if x.dim() == 3 else 1 + seq_len, n_embd = x.shape[-2], x.shape[-1] + + dV = k.transpose(-2, -1) @ grad_output + dk = grad_output @ value_weight.transpose(-2, -1) + + BLOCK_SIZE = 4096 + grid = ((dk.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE,) + relu_square_bwd_kernel[grid]( + dk, + k1_K, + BLOCK_SIZE=BLOCK_SIZE + ) + + dK = k1.transpose(-2, -1) @ dk + dk1 = dk @ key_weight.transpose(-2, -1) + dk1 = dk1.view(-1, n_embd).contiguous() + + dk_reduced = compute_x_k_grad(dk1, x, x_prev) + dx_prev = torch.empty_like(x_prev) if not inplace else x_prev + dx = torch.empty_like(x) if not inplace else x + + def grid(meta): return ((batch_size * seq_len * n_embd + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], 1, 1) + rwkv_mix_bwd_kenel[grid]( + dk1, + x_k.squeeze(), + dx, + dx_prev, + batch_size, + seq_len, + n_embd, + ) + # dx_prev.shape batch_size, seq_len, n_embd + return dx, dx_prev, dk_reduced, dK, dV + + +class Rwkv7ChannelMixing(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, x, x_prev, x_k, key_weight, value_weight, inplace: bool = True): + k1 = rwkv_mix_fwd(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_fwd(k1_K, inplace=True) + ctx.save_for_backward(x, x_prev, x_k, key_weight, value_weight) + ctx.inplace = inplace + return k @ value_weight + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, dkv): + x, x_prev, x_k, key_weight, value_weight = ctx.saved_tensors + k1 = rwkv_mix_fwd(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_fwd(k1_K, inplace=False) + dx, dx_prev, dk_reduced, dK, dV = rwkv_channel_mixing_bwd( + dkv, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, ctx.inplace) + return dx, dx_prev, dk_reduced, dK, dV, None + + +def channel_mixing_rwkv7(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor, + key_weight: torch.Tensor, value_weight: torch.Tensor, inplace: bool = True): + assert x.dim() == 3 + + return Rwkv7ChannelMixing.apply(x, x_prev, x_k, key_weight, value_weight, inplace), x[-1, :] + + +def channel_mixing_rwkv7_torch(x, x_prev, x_k, key_weight, value_weight): + k1 = rwkv_mix_torch(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_torch(k1_K) + return k @ value_weight, x[-1, :] diff --git a/fla3/ops/simple_gla/__pycache__/__init__.cpython-310.pyc b/fla3/ops/simple_gla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abe1f9cc98109caabd5b5040bc854735eacc9fbd Binary files /dev/null and b/fla3/ops/simple_gla/__pycache__/__init__.cpython-310.pyc differ