diff --git a/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b44600d477eb80e539b2ac1c1cf04d9e62e9a38c Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d51499ae48dfdaf08ca2ce10f9bfd10798df118 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2772e24703a9b30e4dbd4b1a68091d0e697531f3 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2221c886354106639fd7dce216b716de2fcdae7 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..951b5ac370e3cdeda083b7b975acda7c9ca52995 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0946f9dc4d11677c55ba6b44b31c49e77c488f37 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b536a03a72a5479f17259d76a0ef178de938f3 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64d893465447c326a1330a509a95b70e84820903 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160748c27631da670ad7ccb185697d3df7a5c868 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716a83f4fb81d700207e44909d70ffb4edd9f75f Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/wy_fast_non.cpython-312.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-312.pyc b/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eefffb66c1e6dff9076ccacc6369b44f5ba8e247 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ead442494be045c80966f523e19d5b42cf2113a Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-312.pyc b/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b75385df30c15ac1f10f3e01331a3b06040c8d02 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule_t/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..182b3f34e981063688a12c3975b306bdbef968f2 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/wy_fast.py b/fla2/ops/mask_gated_delta_rule_t/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..c051b0e879fd5e66a0fb34db6e2b9f743cfab5ae --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/wy_fast.py @@ -0,0 +1,541 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +from typing import Optional +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + Aw, + Au, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + b_Aw = None + p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK","r"], +) +@triton.jit +def gated_chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + mask_ij, + A, + Ag, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 #BT [r,r] + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)#BT BT + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:]#BT r r + + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + + p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = safe_exp(b_g_diff) + + b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT + p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor, + g_cumsum:Optional[torch.Tensor] = None, + BT:int = 32, + output_dtype: torch.dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] #B H T r r + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, g_cumsum, mask, A,Ag, + T*K, K, 1, + T, K, r, BT, BK + ) + return A,Ag + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, Aw,Au, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + + + + +# class WYRepresentationPrepration(torch.autograd.Function): +# @staticmethod +# @contiguous +# @autocast_custom_fwd +# def forward(ctx, k, v, beta,mask,chunk_size=64): +# ctx.BT = chunk_size +# w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) +# ctx.save_for_backward(k, v, beta,mask,A) +# return w, u +# @staticmethod +# @contiguous +# @autocast_custom_bwd +# def backward(ctx, dw, du): +# k, v, beta,mask, A = ctx.saved_tensors +# BT = ctx.BT +# dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) +# return dk, dv, dbeta, dmask, None + +# prepare_wy_repr = WYRepresentationPrepration.apply + + +# def naive(k, v, beta,maskij,chunk_size): +# l_org = k.shape[2] +# l_new = triton.next_power_of_2(l_org) +# k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) +# v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) +# beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) +# k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) +# beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + +# b,h,nt,BT,dk = k.shape +# dv = v.shape[-1] +# r = maskij.shape[-1] +# k_beta = k * beta[..., None] +# k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) +# k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) +# k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org +# v_beta = v * beta[..., None] +# v_beta = v_beta +# v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) +# ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + +# attn = (ki @ ki.transpose(-1, -2)) +# attn = torch.tril(attn, diagonal=-1)#bhnr cc +# attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc +# attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + +# o = torch.zeros_like(k_beta) +# o2 = torch.zeros_like(v_beta) + +# o[..., 0, :,:] = k_beta[..., 0,:,:].clone() +# o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() +# for i in range(1, chunk_size): +# o_i = (o[..., :i,:,:]).clone()#bhn :t cc +# o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) +# o2_i = (o2[..., :i,:,:]).clone()#少一个维度 +# o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) +# return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +# if __name__ == "__main__": +# #all compute here +# import sys +# sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') +# torch.set_default_dtype(torch.bfloat16) +# seq_len = 32 +# b = 2 +# h = 2 +# k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 +# v = torch.randn(b, h, seq_len, 128) +# beta = torch.rand(b, h, seq_len).sigmoid() +# require_grad = True +# BT = 16 +# k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) +# r = 4 +# # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() +# mask = torch.randn([r,r]) +# mask = mask.cuda().requires_grad_(require_grad).contiguous() +# # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) +# # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) +# # from einops import rearrange + +# k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) +# b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) +# a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt +# qq = torch.tril(a1,diagonal=-1) +# qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) +# sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') +# sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + +# # #长条对角线 +# i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) +# s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() +# s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') +# s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + +# # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r +# # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.float32) +# # s = rearrange(s,'b h n a c->(b h) (n a) c') +# # print(Ad) +# # print(s) +# # print((Ad-s).abs().max()) + +# w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, 16) +# As = rearrange(As,'b h (n t) l->(b h n) t l',t =BT*r) +# # print((As-s).abs().max()) +# # B*H*NT,BT*r,16*r +# # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) +# # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) +# # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') +# # wc = s_copy@k_exp + +# # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) +# # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) +# # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) +# # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') +# # uc = s_copy@v_exp +# # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) +# # do = torch.rand_like(wc) +# # do2 = torch.rand_like(uc)#b h n t t +# # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 +# # do = torch.rand_like(o1) +# # do2 = torch.rand_like(o2)#b h n t t +# # if require_grad: +# # o1.backward(do, retain_graph=True) +# # o2.backward(do2, retain_graph=True) +# # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + +# # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) +# # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + +# # print((o1-w0).abs().max()) +# # print((o2-u0).abs().max()) +# # print((k_grad-k_grad2).abs().max()) +# # print((v_grad-v_grad2).abs().max()) +# # print((beta_grad-beta_grad2).abs().max()) +# # print((mask_grad-mask_grad2).abs().max()) +# # print(mask_grad) +# # print(mask_grad2) + + diff --git a/fla2/ops/mask_gated_delta_rule_t/wy_fast_test.py b/fla2/ops/mask_gated_delta_rule_t/wy_fast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22aba7278db186f6b7139b33d446813078728861 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/wy_fast_test.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + b_A21 = tl.permute(b_A21,(0,2,1,3)) + b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr2(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + torch.manual_seed(42) + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 128 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 32 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = BT,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = BT) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.bfloat16) + # s = rearrange(s,'b h n a c->(b h n) a c') + # print(Ad.shape) + # print(s.shape) + + w,u,As = fwd_prepare_wy_repr2(k, v, beta,mask, BT) + # w2,u2,Ad2 = fwd_prepare_wy_repr(k, v, beta,mask, BT) + + # print((w2-w).abs().max()) + # print((u2-u).abs().max()) + # print((As-Ad2).abs().max()) + + # print((Ad-s).abs().max()) + # print(Ad-s) + + # print((As-s).abs().max()) + # print(As-s) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/retention/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a9c80e71944e39ac61f995c66e4996c9d86241 Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/retention/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9781cd72fe33c600a1078d7cf5ed1c0d5f8207ce Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/retention/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e21c0c7b8f8358bcfd6e30ce4dc1f7e7426bdb Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/retention/__pycache__/parallel.cpython-312.pyc b/fla2/ops/retention/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..681e9b059f242550487eb77195555add091e2a6c Binary files /dev/null and b/fla2/ops/retention/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla2/ops/retention/__pycache__/parallel.cpython-38.pyc b/fla2/ops/retention/__pycache__/parallel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcaa3037a1cad8233f28c087e264675b4ffa5902 Binary files /dev/null and b/fla2/ops/retention/__pycache__/parallel.cpython-38.pyc differ diff --git a/fla2/ops/retention/__pycache__/parallel.cpython-39.pyc b/fla2/ops/retention/__pycache__/parallel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd997563348147c44d88cbcade3cf0b20dd3a1c3 Binary files /dev/null and b/fla2/ops/retention/__pycache__/parallel.cpython-39.pyc differ diff --git a/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520873bd35d776e178fc272b8520f6bc5ac624a8 Binary files /dev/null and b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2ed1e6ef976898f60a430c8332cf3bee0e4f29 Binary files /dev/null and b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb72eff4f436be0c4c00c16abec3d11086bcb160 Binary files /dev/null and b/fla2/ops/retention/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/__init__.cpython-38.pyc b/fla2/ops/rwkv6/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28d8737cb88fe14e2eb865b3224dd7a3bbb46efa Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/__init__.cpython-39.pyc b/fla2/ops/rwkv6/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe743e90ec8409820933ef0a4c580f8029134f6d Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/chunk.cpython-312.pyc b/fla2/ops/rwkv6/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4c638cd8619f8690e6be80123f8bf26e2bd1a4 Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/chunk.cpython-38.pyc b/fla2/ops/rwkv6/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fff5fcea8ed37d034dc7e451304fe7f3b7f594c1 Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/chunk.cpython-39.pyc b/fla2/ops/rwkv6/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34c1de025d43bec138c1d7eaa6b45624d2b01864 Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc6948d44abb3c6dff19c94a0ecbb336ca3ed70d Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29f95d97f58121b5257e3f2a91f131b22f1cae14 Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32857d28fe07a2777b23d2e077bcbc095bcf140d Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/rwkv6/chunk.py b/fla2/ops/rwkv6/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..a5533c076f1220df4175dab91ada5f62ccbf942c --- /dev/null +++ b/fla2/ops/rwkv6/chunk.py @@ -0,0 +1,931 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_reversed_cumsum +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BS': 16}, num_warps=2), + triton.Config({'BS': 16}, num_warps=4), + triton.Config({'BS': 16}, num_warps=8), + triton.Config({'BS': 32}, num_warps=2), + triton.Config({'BS': 32}, num_warps=4), + triton.Config({'BS': 32}, num_warps=8), + triton.Config({'BS': 64}, num_warps=2), + triton.Config({'BS': 64}, num_warps=4), + triton.Config({'BS': 64}, num_warps=8), + ], + key=['S'] +) +@triton.jit +def chunk_rwkv6_fwd_kernel_cum( + s, + o, + o_minus_s, + s_s_h, + s_s_t, + s_s_d, + T: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o_minus_s = tl.make_block_ptr(o_minus_s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_o_minus_s, (b_o - b_s).to(p_o_minus_s.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def post_process_grad( + q, + k, + v, + u, + do, + dk, + dq, + du, + scale, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + H, + T: tl.constexpr, + BT: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_h = i_bh % H + + # Note that BK = tl.next_power_of_2(K), BV = tl.next_power_of_2(V) + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, 0), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_u = tl.load(p_u, boundary_check=(0,)) + + b_vdo = tl.sum(b_v * b_do, axis=1) + b_du = b_vdo[:, None] * b_k * b_q * scale + b_dq = b_vdo[:, None] * b_k * b_u[None, :] * scale + b_dk = b_vdo[:, None] * b_q * b_u[None, :] * scale + + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + tl.store(p_du, b_du.to(p_du.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_h( + k, + v, + g, + h, + h0, + ht, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + for i_t in range(NT): + o_t = min(i_t * BT + BT, T) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BT] + b_g = tl.load(p_g, boundary_check=(0, 1)) + if i_t < NT - 1: + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + else: + b_gn = tl.min(b_g, axis=1) + b_h *= tl.exp(b_gn)[:, None] + b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype) + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_intra( + q, + k, + g, + gs, + u, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + H, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + DK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + i_h = i_bh % H + n_bh = tl.num_programs(2) + + o_k = i_k * BK + tl.arange(0, BK) + o_q = i_t * BT + i_i * BC + m_k = o_k < K + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BK,] + b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_gs - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_q_u = tl.make_block_ptr(q + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + o_i = tl.arange(0, BC) + o_g = i_bh * T * K + (i_t * BT + i_j * BC) * K + o_k + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + p_u = tl.make_block_ptr(u + i_h * DK, (DK,), (1,), (i_k * BK), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(g + o_g + j * K, mask=(m_k & ((i_t * BT + i_j * BC + j) < T)), other=0).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_gs - b_gk[None, :]) * scale, 1) + b_A = tl.where(o_i > j, b_A, 0.) + # self + b_q_u = tl.load(p_q_u, boundary_check=(0,)).to(tl.float32) + b_A_u = tl.sum(b_q_u * b_k * b_u * scale, axis=0) + m_u = tl.arange(0, BC) == j + b_A = tl.where(m_u, b_A_u, b_A) + tl.store(A + o_A + j, b_A.to(A.dtype.element_ty), mask=m_A) + p_k = tl.advance(p_k, (K,)) + p_q_u = tl.advance(p_q_u, (K,)) + + +@triton.jit +def chunk_rwkv6_fwd_kernel_inter( + q, + v, + gs, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_gs)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_dh( + q, + g, + gs, + do, + dh, + dh0, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + o_t = min(i_t * BT + BT, T) + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BK, BV] + b_dh *= tl.exp(b_gn)[:, None] + # [BK, BT] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_q = (b_q * tl.exp(b_gs)).to(b_q.dtype) + + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_inter( + k, + v, + h, + g, + gs, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_t = min(i_t * BT + BT, T) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gq = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((o_t - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gq = tl.load(p_gq, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + + b_dq = b_dq * tl.exp(b_gq) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] > o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_rwkv6_bwd_kernel_intra( + q, + k, + g, + gs, + dA, + dq, + dk, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_k = i_k * BK + tl.arange(0, BK) + o_q = i_t * BT + i_i * BC + m_k = o_k < K + + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BK,] + b_gn = tl.load(g + i_bh * T * K + (o_q - 1) * K + o_k, mask=(m_k & (i_i > 0) & (o_q <= T)), other=0) + # [BC, BK] + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_gs - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(g + i_bh * T * K + (o_q + j) * K + o_k, mask=(m_k & ((o_q + j) < T)), other=0) + # [BC, BK] + m_i = o_i[:, None] > j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_gs - b_gkj[None, :]), 0.) + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gs = tl.make_block_ptr(gs + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gs = tl.load(p_gs, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_gs - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(gs + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] < j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkRWKV6Function(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level): + q = r # alias + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_rwkv6_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float) + + g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g_org = g_org.view(B, H, NT, BT, -1) + # g = g_org.cumsum(-2).view(B, H, T, -1) + # gs = g - g_org + chunk_rwkv6_fwd_kernel_cum[grid]( + g_org, g, gs, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=final_state if final_state is not None else None + ) + A = q.new_zeros(NK, B, H, T, BT) + grid = (NK, NT * NC * NC, B * H) + chunk_rwkv6_fwd_kernel_intra[grid]( + q, k, g, gs, u, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + H=H, T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, DK=K, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + o = torch.empty_like(v) + + grid = (NV, NT, B * H) + chunk_rwkv6_fwd_kernel_inter[grid]( + q, v, gs, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + if checkpoint_level > 1: + del h + h, initial_state = None, None + del g, gs + ctx.save_for_backward(q, k, v, g_org, u, h, initial_state, A) + ctx.BT = BT + ctx.scale = scale + ctx.checkpoint_level = checkpoint_level + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + q, k, v, g, u, h, initial_state, A = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, NT, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_rwkv6_fwd_kernel_h[grid]( + k, v, g, h, h0, ht, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + def bwd_inner(q, g, gs, h0, do, B, H, T, K, V, BT, BK, BV, NT, scale): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + dh0 = torch.empty_like(h0) if h0 is not None else None + grid = (NK, NV, B * H) + chunk_rwkv6_bwd_kernel_dh[grid]( + q, g, gs, do, dh, dh0, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), dh.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=h0 is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return dh, dh0 + + # recompute cumulative log decays. + g_org, g, gs = g, torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return ((triton.cdiv(meta['S'], meta['BS']), NT, B * H)) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_rwkv6_fwd_kernel_cum[grid]( + g_org, g, gs, + g.stride(1), g.stride(2), g.stride(3), + T=T, S=K, BT=BT + ) + + # rerun the forward pass to get h if checkpoint_level >= 1 + if ctx.checkpoint_level == 1: + h = fwd_inner( + q=q, k=k, v=v, g=g, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + h0=initial_state if initial_state is not None else None, + ht=None + ) + + scale = ctx.scale + # g, gs: torch.float32 + dh, dh0 = bwd_inner( + q.to(torch.float), g, gs, initial_state, do.to(torch.float), + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + scale=scale + ) + dh = dh.to(q) + if initial_state is not None: + dh0 = dh0.to(q) + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dv = v.new_empty(NK, *v.shape) + dA = q.new_zeros(B, H, T, BT) + grid = (NK, NT, B * H) + chunk_rwkv6_bwd_kernel_inter[grid]( + k, v, h, g, gs, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=dv.dtype) + grid = (NK, NT * NC, B * H) + chunk_rwkv6_bwd_kernel_intra[grid]( + q, k, g, gs, dA, dq, dk, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + + # TODO: fuse? + dg = (dq * q)[:, :, 1:] - (dk * k)[:, :, 0:-1] + dg = torch.nn.functional.pad(dg, (0, 0, 0, 1, 0, 0, 0, 0), value=0) + dg = chunk_global_reversed_cumsum(dg).to(g) + # equivalent to the following pytorch code. + # du = ((do * v).sum(-1)[..., None] * k * q * scale).sum(-2).to(u) + # dq += ((do * v).sum(-1)[..., None] * k * scale * u[:, :, None, :]) + # dk += ((do * v).sum(-1)[..., None] * q * scale * u[:, :, None, :]) + BT = 64 + grid = (triton.cdiv(T, BT), B * H) + du = torch.empty_like(g, dtype=torch.float) + post_process_grad[grid]( + q, k, v, u, do, dk, dq, du, scale, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), H=H, + T=T, BT=BT, K=K, V=V, BK=triton.next_power_of_2(K), BV=triton.next_power_of_2(V), + num_warps=4 + ) + du = du.sum([0, 2]) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None + + +def chunk_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + checkpoint_level: Optional[int] = 0 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `(B, H, T, K)`. Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `(H, K)` + scale (Optional[int]): + Scale factor for the RWKV6 attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `0`: + - Level `0`: store forward hidden states for backprop. + - Level `1`: recompute the forward hidden states during backward. + """ + assert checkpoint_level in [0, 1] + if scale is None: + scale = r.shape[-1] ** -0.5 + o, final_state = ChunkRWKV6Function.apply(r, k, v, g, u, scale, initial_state, output_final_state, checkpoint_level) + return o, final_state + + +if __name__ == "__main__": + import torch.nn.functional as F + + from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6 + B = 8 + H = 4 + L = 1024 + K = 100 + V = 120 + + torch.manual_seed(0) + dtype = torch.float + q = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True) + k = torch.randn(B, H, L, K).cuda().to(dtype).requires_grad_(True) + v = torch.randn(B, H, L, V).cuda().to(dtype).requires_grad_(True) + w = (-torch.randn(B, H, L, K).exp()).cuda().requires_grad_(True) + u = torch.randn(H, K).cuda().to(dtype).requires_grad_(True) + h0 = torch.randn(B, H, K, V).cuda().to(dtype).requires_grad_(True) + do = torch.rand_like(v).cuda() + o, ht = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True) + o.backward(do) + dq, q.grad = q.grad.clone(), None + dk, k.grad = k.grad.clone(), None + dv, v.grad = v.grad.clone(), None + dw, w.grad = w.grad.clone(), None + du, u.grad = u.grad.clone(), None + dh0, h0.grad = h0.grad.clone(), None + o2, ht2 = chunk_rwkv6(q, k, v, w, u, initial_state=h0, output_final_state=True) + o2.backward(do) + torch.testing.assert_close(o, o2, rtol=0, atol=1e-4) + torch.testing.assert_close(ht, ht2, rtol=0, atol=1e-4) + torch.testing.assert_close(q.grad, dq, rtol=0, atol=1e-4) + torch.testing.assert_close(k.grad, dk, rtol=0, atol=1e-4) + torch.testing.assert_close(v.grad, dv, rtol=0, atol=1e-4) + torch.testing.assert_close(w.grad, dw, rtol=0, atol=1e-4) + torch.testing.assert_close(u.grad, du, rtol=0, atol=2e-4) + torch.testing.assert_close(h0.grad, dh0, rtol=0, atol=2e-4) + + print("All tests passed!") + + @triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['T'], + # different possible values for `x_name` + x_vals=[128 * 2 ** i for i in range(0, 8)], + # argument name whose value corresponds to a different line in the plot + line_arg='provider', + # possible values for `line_arg`` + line_vals=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'], + # label name for the lines + line_names=['recurrent', 'chunk', 'recurrent_bwd', 'chunk_bwd'], + # line styles + styles=[('green', '-'), ('blue', '--'), ('red', '-.'), ('cyan', ':'), ('yellow', 'dotted'), ('black', 'dashed')], + ylabel="Execution Time (ms)", # label name for the y-axis + # name for the plot. Used also as a file name for saving the plot. + plot_name="Performance", + args={}, + ) + ) + def benchmark(T, provider): + device = 'cuda' + dtype = torch.bfloat16 + requires_grad = True + B, H, K = 16, 4, 128 + + q = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + k = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + v = torch.randn(B, H, T, K, device=device, requires_grad=requires_grad, dtype=dtype) + w = F.logsigmoid(torch.randn(B, H, T, K)).to(dtype=dtype, device=device).requires_grad_(True) + u = torch.randn(H, K, device=device, requires_grad=requires_grad, dtype=dtype) + + do = torch.ones_like(q, dtype=dtype) + quantiles = [0.5, 0.2, 0.8] + results = 0, 0, 0 + if provider == 'recurrent': + results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u), quantiles=quantiles) + if provider == 'chunk': + results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u), quantiles=quantiles) + if provider == 'recurrent_bwd': + results = triton.testing.do_bench(lambda: fused_recurrent_rwkv6(q, k, v, w, u) + [0].backward(do), quantiles=quantiles) + if provider == 'chunk_bwd': + results = triton.testing.do_bench(lambda: chunk_rwkv6(q, k, v, w, u)[0].backward(do), quantiles=quantiles) + return results + benchmark.run(print_data=True) diff --git a/fla2/ops/rwkv6/chunk_naive.py b/fla2/ops/rwkv6/chunk_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2ac664f5079a20eabe9b11c19c1cff6755c658 --- /dev/null +++ b/fla2/ops/rwkv6/chunk_naive.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def naive_chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + chunk_size: int = 32 +): + assert q.shape[-2] % chunk_size == 0 + orig_dtype = q.dtype + num_chunk = q.shape[-2] // chunk_size + u = u.unsqueeze(0) + + q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) + + w_cumsum = w.cumsum(-2) + + kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() + wkv = kw.transpose(-1, -2) @ v + + wkv_new = torch.zeros_like(wkv) + + for i in range(num_chunk - 1): + wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] + + o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) + + o_intra = torch.zeros_like(o_inter) + for i in range(chunk_size): + attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) + mask = (torch.arange(0, chunk_size) < i).to(attn.device) + attn.masked_fill_(~mask, 0) + intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) + intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] + o_intra[:, :, :, i] = intra_inter_o + intra_intra_o + o = o_inter + o_intra + return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) diff --git a/fla2/ops/rwkv6/recurrent_fuse.py b/fla2/ops/rwkv6/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..baa61ae4e27683d7625a8ca06becbdabd4559688 --- /dev/null +++ b/fla2/ops/rwkv6/recurrent_fuse.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_reversed_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def fused_recurrent_rwkv6_fwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, K] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + o, # output [B, H, T, V] + # initial hidden state initialization [B, H, K, V] + h0, + ht, # final hidden state [B, H, K, V] + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + scale, # K ** -0.5 + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bv[:, None] & mask_bk[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_w = tl.exp(b_w) + b_kv = b_k[None, :] * b_v[:, None] + b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + b_h = b_h * b_w[None, :] + b_h += b_kv + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + p_q += -K if REVERSE else K + p_k += -K if REVERSE else K + p_o += -V if REVERSE else V + p_v += -V if REVERSE else V + p_w += -K if REVERSE else K + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_rwkv6_bwd_kernel_dq( + # B: B, H: H, T: T, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + k, # key [B, H, T, V] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + + do, # gradient of output [B, H, T, V] + dq, # gradient of query [NV, B, H, T, K] + dq_aux, # gradient of query_aux [NV, B, H, T, K] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0) + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_dq_aux = dq_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0) + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bv[:, None] & mask_bk[None, :] + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_kv = b_k[None, :] * b_v[:, None] + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_w = tl.exp(b_w) + h_q = b_h * b_do[:, None] + b_dq = tl.sum(h_q + b_kv * b_u[None, :] * b_do[:, None], axis=0) + b_dq *= scale + b_dq_aux = tl.sum(h_q, axis=0) + b_h = b_h * b_w[None, :] + b_h += b_kv + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_bk) + tl.store(p_dq_aux, b_dq_aux.to(p_dq_aux.dtype.element_ty), mask=mask_bk) + p_k += -K if REVERSE else K + p_do += -V if REVERSE else V + p_v += -V if REVERSE else V + p_w += -K if REVERSE else K + p_dq += -K if REVERSE else K + p_dq_aux += -K if REVERSE else K + + +@triton.jit +def fused_recurrent_rwkv6_bwd_kernel_dkv( + # B: B, H: H, T: T, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + w, # log gate [B, H, T, K] + u, # bonus [B, H, K] + + do, # gradient of output [B, H, T, V] + dk, + dk_aux, + dv, + dh0, + + # initial hidden state initialization [B, H, K, V] + s_k_h, # stride size: T * K + s_v_h, # stride size: T * V + + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + REVERSE: tl.constexpr, # whether to do autoregressive modeling in the reverse direction +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dk_aux = dk_aux + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T - 1) * V if not REVERSE else 0) + p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T - 1) * K if not REVERSE else 0) + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + mask_kv = mask_bk[:, None] & mask_bv[None, :] + + p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK + b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32) + + for _ in range(T-1, -1, -1): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_dkv = b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + tl.store(p_dk_aux, b_dk.to(p_dk_aux.dtype.element_ty), mask=mask_bk) + b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], axis=1) + b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], axis=0) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + b_dh *= tl.exp(b_w)[:, None] + b_dh += b_dkv + + p_q += K if REVERSE else -K + p_k += K if REVERSE else -K + p_v += V if REVERSE else -V + p_w += K if REVERSE else -K + p_do += V if REVERSE else -V + p_dk += K if REVERSE else -K + p_dk_aux += K if REVERSE else -K + p_dv += V if REVERSE else -V + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_kv) + + +class FusedRecurrentRWKV6Function(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False): + q = r + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + + final_state = q.new_empty(B, H, K, V) if output_final_state else None + + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + grid = (NV, NK, B * H) + fused_recurrent_rwkv6_fwd_kernel[grid]( + q, k, v, w, u, o, initial_state, final_state, + k.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + REVERSE=reverse, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, w, u, initial_state) + ctx.scale = scale + ctx.reverse = reverse + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, w, u, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + + BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + dq = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dq_aux = torch.empty_like(dq) + grid = (NV, NK, B * H) + + fused_recurrent_rwkv6_bwd_kernel_dq[grid]( + k, v, w, u, do, dq, dq_aux, initial_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0).to(q) + dq_aux = dq_aux.sum(0) + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dk = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dk_aux = q.new_empty(NV, B, H, T, K, dtype=torch.float32) + dv = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + dh0 = initial_state.new_empty(B, H, K, V) if initial_state is not None else None + grid = (NV, NK, B * H) + fused_recurrent_rwkv6_bwd_kernel_dkv[grid]( + q, k, v, w, u, do, dk, dk_aux, dv, dh0, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + REVERSE=ctx.reverse, + ) + dk = dk.sum(0).to(k) + dv = dv.sum(0).to(v) + dk_aux = dk_aux.sum(0) + + dw = (dq_aux * q * scale)[:, :, 1:] - (dk_aux * k)[:, :, 0:-1] + dw = torch.nn.functional.pad(dw, (0, 0, 0, 1, 0, 0, 0, 0), value=0) + dw = chunk_global_reversed_cumsum(dw).to(w) + + du = ((do * v).sum(-1)[..., None] * k * q * scale).sum([0, -2]).to(u) + return dq, dk, dv, dw, du, None, dh0, None, None + + +def fused_recurrent_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `(B, H, T, K)`. Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `(H, K)` + scale (Optional[int]): + Scale factor for the RWKV6 attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + """ + if scale == -1: + scale = r.shape[-1] ** -0.5 + o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/rwkv6/recurrent_naive.py b/fla2/ops/rwkv6/recurrent_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2268759b5d4ce7f9be1be1f9c2e1a2f2a8e6c3 --- /dev/null +++ b/fla2/ops/rwkv6/recurrent_naive.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +): + orig_dtype = q.dtype + B, H, T, K, V = *q.shape, v.shape[-1] + q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + + if scale is None: + scale = K ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(T): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] + o[:, :, i] = o_i.sum(-2) + h = h * w_i[..., None] + kv_i + ht = h if output_final_state else None + return o.to(orig_dtype), ht + + +@torch.no_grad +@torch.jit.script +def naive_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + initial_state: Optional[torch.Tensor] = None +): + q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do)) + B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + dq = torch.zeros_like(q) + dq_aux = torch.zeros_like(q) + + if initial_state is not None: + h += initial_state + + for i in range(T): + k_i = k[:, :, i] + v_i = v[:, :, i] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h_i = (h + u[None, ..., None] * kv_i) + dq_i = (do[:, :, i, None, :] * h_i).sum(-1) + dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) + dq[:, :, i] = dq_i + dq_aux[:, :, i] = dq_aux_i + h = h * w_i[..., None] + kv_i + + du = torch.zeros_like(u) + dh = torch.zeros_like(h) + dk = torch.zeros_like(k) + dk_aux = torch.zeros_like(k) + dv = torch.zeros_like(v) + + for i in range(T - 1, -1, -1): + d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] + k_i = k[:, :, i] + v_i = v[:, :, i] + du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) + du += du_i.sum(0) + dk_i = (dh * v_i[..., None, :]).sum(-1) + dk_aux[:, :, i] = dk_i + dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) + dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) + dv_i += (dh * k_i[..., None]).sum(-2) + + dk[:, :, i] = dk_i + dv[:, :, i] = dv_i + dh = dh * w[:, :, i, :, None].exp() + d_kv_i + + # dw = q * dq_aux - k * dk_aux + dw = torch.zeros_like(w) + for i in range(T - 2, -1, -1): + dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] + + return dq, dk, dv, dw, du, dh diff --git a/fla2/ops/simple_gla/README.md b/fla2/ops/simple_gla/README.md new file mode 100644 index 0000000000000000000000000000000000000000..72e710a3aa837e4d3543a62fb93de61a714cbe1d --- /dev/null +++ b/fla2/ops/simple_gla/README.md @@ -0,0 +1,5 @@ +- Simple GLA + +Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. + +$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. \ No newline at end of file diff --git a/fla2/ops/simple_gla/__init__.py b/fla2/ops/simple_gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ac1b8af4c89e76c81e1622842ab6c879881be0de --- /dev/null +++ b/fla2/ops/simple_gla/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_simple_gla + +__all__ = [ + 'chunk_simple_gla' +] diff --git a/fla2/ops/simple_gla/chunk.py b/fla2/ops/simple_gla/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ab9e0615250a3154da38303c7753bc13c6cda2 --- /dev/null +++ b/fla2/ops/simple_gla/chunk.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.ops.utils import chunk_local_cumsum, chunk_global_reversed_cumsum +from fla.ops.common.chunk_h import chunk_fwd_h_fn, chunk_bwd_dh_fn + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_simple_gla_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * tl.exp(b_g)[:, None] + b_s = b_s * tl.exp(b_g[:, None] - b_g[None, :]) + b_s = tl.where(m_s, b_s, 0) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_simple_gla_bwd_kernel_dqkvg( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dv, + dg, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if i_t < NT - 1: + b_g_last = tl.load(g + i_bh * T + i_t * BT + BT - 1) + else: + b_g_last = tl.load(g + i_bh * T + T - 1) + mask = tl.exp(b_g[None, :] - b_g[:, None]) + mask = tl.where(o_i[:, None] <= o_i[None, :], mask * scale, 0) + b_s = b_s * mask + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * tl.exp(-b_g + b_g_last)[:, None] + b_dv += tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = b_dq * tl.exp(b_g)[:, None] + b_dk = b_dk * tl.exp(-b_g + b_g_last)[:, None] + b_ds = b_ds * tl.trans(mask) + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_ds = None + b_s = None + b_q = None + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_dg = tl.sum(b_dq * b_q - b_dk * b_k.to(tl.float32), axis=1) + p_dg = tl.make_block_ptr(dg + (i_k*n_bh + i_bh) * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +def chunk_fwd_o_fn(h, q, k, v, g, BT, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + o = torch.empty_like(v) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + chunk_simple_gla_fwd_kernel_o[grid]( + q, k, v, h, g, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV + ) + return o + + +def chunk_bwd_dqkvg_fn(do, q, k, v, g, h, dh, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT, NK = triton.cdiv(T, BT), triton.cdiv(K, BK) + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + dg = torch.empty(NK, B, H, T, dtype=torch.float32, device=g.device) + chunk_simple_gla_bwd_kernel_dqkvg[grid]( + q, k, v, h, g, do, dh, dq, dk, dv, dg, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT + ) + dv = dv.sum(0) + dg = dg.sum(0) + dg = chunk_global_reversed_cumsum(dg) + return dq, dk, dv, dg + + + + +class SimpleGLAFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level=1): + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + g = chunk_local_cumsum(g, BT) + h, final_state = chunk_fwd_h_fn(k=k, v=v, g=g, gk=None, gv=None, BT=BT, h0=initial_state, output_final_state=output_final_state) + o = chunk_fwd_o_fn(h, q, k, v, g, BT, scale) + if checkpoint_level == 1: + h = None + ctx.save_for_backward(q, k, v, h, g, initial_state) + ctx.scale = scale + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht): + BT, scale = ctx.BT, ctx.scale + q, k, v, h, g, initial_state = ctx.saved_tensors + if h is None: + h, final_state = chunk_fwd_h_fn(k=k, v=v, g=g, gk=None, gv=None, BT=BT, h0=initial_state, output_final_state=False) + dh, dh0 = chunk_bwd_dh_fn(q=q, k=k, v=v, g=g, gk=None, gv=None, do=do, h0=initial_state, dht=dht, BT=BT, scale=scale) + dq, dk, dv, dg = chunk_bwd_dqkvg_fn(do, q, k, v, g, h, dh, scale) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg.to(g.dtype), None, dh0, None, None + + + +def chunk_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, # log decay + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + checkpoint_level: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T)` applied to keys. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `1` (recommended): + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the chunk-level hidden state `h` during backward pass. + """ + assert checkpoint_level in [0, 1], "checkpoint_level must be 0, 1" + assert q.dim() == k.dim() == v.dim() == 4, "q, k, v must have 4 dimensions (b, h, l, d)" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + if scale is None: + scale = k.shape[-1] ** -0.5 + g = g.float() + o, final_state = SimpleGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level) + return o, final_state \ No newline at end of file diff --git a/fla2/ops/simple_gla/naive.py b/fla2/ops/simple_gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..50f9b9211a9291b5f89ac5fd4424c0846a77abe6 --- /dev/null +++ b/fla2/ops/simple_gla/naive.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None): + if scale is None: + scale = (q.shape[-1] ** -0.5) + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size) + g = g.cumsum(-1) + kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) + S = torch.zeros_like(kv) + + for i in range(1, g.shape[-2]): + S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] + + inter = (q * g[..., None].exp()) @ S + attn = q @ k.transpose(-1, -2) + attn = attn * (g[..., None] - g[..., None, :]).exp() + attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + intra = attn @ v + o = inter + intra + return rearrange(o, 'b h n c d -> b h (n c) d') + + +def torch_simple_gla_recurrent(q, k, v, g, initial_state=None, scale=None): + B, H, T, DK = q.shape + if scale is None: + scale = DK ** -0.5 + q = q * scale + _, _, _, DV = v.shape + if initial_state is None: + S = torch.zeros(B, H, DK, DV).to(q) + else: + S = initial_state + o = torch.zeros(B, H, T, DV).to(q) + for i in range(T): + gate = g[:, :, i].exp() + key = k[:, :, i] + value = v[:, :, i] + kv = key.unsqueeze(-1) * value.unsqueeze(-2) + S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv + q_i = q[:, :, i, :] + o_i = (q_i.unsqueeze(-1) * S).sum(-2) + o[:, :, i] = o_i + return o, S + +if __name__ == '__main__': + torch.set_default_dtype(torch.bfloat16) + B = 4 + H = 4 + L = 100 + DK = 32 + DV = 32 + q = torch.randn(B, H, L, DK) + k = torch.randn(B, H, L, DK) + v = torch.randn(B, H, L, DV) + g = torch.nn.functional.logsigmoid(torch.randn(B, H, L)) + q, k, v, g = map(lambda x: x.cuda().requires_grad_(True), [q, k, v, g]) + from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla + + o, _ = fused_recurrent_simple_gla(q, k, v, g) + do = torch.randn_like(o) + o.backward(do) + q_grad, k_grad, v_grad, g_grad = q.grad, k.grad, v.grad, g.grad + q.grad, k.grad, v.grad, g.grad = None, None, None, None + o2, _ = chunk_simple_gla(q, k, v, g) + o2.backward(do) + q_grad2, k_grad2, v_grad2, g_grad2 = q.grad, k.grad, v.grad, g.grad + + print((o-o2).abs().max()) + print((q_grad-q_grad2).abs().max()) + print((k_grad-k_grad2).abs().max()) + print((v_grad-v_grad2).abs().max()) + print((g_grad-g_grad2).abs().max()) + + diff --git a/fla2/ops/simple_gla/recurrent_fuse.py b/fla2/ops/simple_gla/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..90f866441e270337e555ba29843c1515910c451e --- /dev/null +++ b/fla2/ops/simple_gla/recurrent_fuse.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple, Optional +import torch +from fla.ops.common.fused_recurrent import fused_recurrent + +def fused_recurrent_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + reverse: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = fused_recurrent(q, k, v, g, None, None, scale, initial_state, output_final_state, reverse) + return o, final_state diff --git a/fla3/__pycache__/__init__.cpython-310.pyc b/fla3/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cce19a81c218915f61b08c64991f7116991ca9d5 Binary files /dev/null and b/fla3/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/__pycache__/__init__.cpython-312.pyc b/fla3/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acd121810ffd870c584ffc3d99423e79c267971 Binary files /dev/null and b/fla3/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/__pycache__/utils.cpython-310.pyc b/fla3/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07f5e6957fd4ce0e37fae6395a143a7bd40303e1 Binary files /dev/null and b/fla3/__pycache__/utils.cpython-310.pyc differ diff --git a/fla3/__pycache__/utils.cpython-312.pyc b/fla3/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e717d4f4c7b3d1a7e09ba9f60d90c8b236503ce2 Binary files /dev/null and b/fla3/__pycache__/utils.cpython-312.pyc differ diff --git a/fla3/layers/__init__.py b/fla3/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e552dba48d97e496d7327d29ce763a852a284180 --- /dev/null +++ b/fla3/layers/__init__.py @@ -0,0 +1,51 @@ +# # -*- coding: utf-8 -*- +# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# from .abc import ABCAttention +# from .attn import Attention +# from .based import BasedLinearAttention +# from .bitattn import BitAttention +# from .delta_net import DeltaNet +# from .forgetting_attn import ForgettingAttention +# from .gated_deltanet import GatedDeltaNet +# from .gated_deltaproduct import GatedDeltaProduct +# from .gla import GatedLinearAttention +# from .gsa import GatedSlotAttention +# from .hgrn import HGRNAttention +# from .hgrn2 import HGRN2Attention +# from .lightnet import LightNetAttention +# from .linear_attn import LinearAttention +# from .mamba import Mamba +# from .mamba2 import Mamba2 +# from .multiscale_retention import MultiScaleRetention +# from .nsa import NativeSparseAttention +# from .path_attn import PaTHAttention +# from .rebased import ReBasedLinearAttention +# from .rwkv6 import RWKV6Attention +# from .rwkv7 import RWKV7Attention + +# __all__ = [ +# 'ABCAttention', +# 'Attention', +# 'BasedLinearAttention', +# 'BitAttention', +# 'DeltaNet', +# 'ForgettingAttention', +# 'GatedDeltaNet', +# 'GatedDeltaProduct', +# 'GatedLinearAttention', +# 'GatedSlotAttention', +# 'HGRNAttention', +# 'HGRN2Attention', +# 'LightNetAttention', +# 'LinearAttention', +# 'Mamba', +# 'Mamba2', +# 'MultiScaleRetention', +# 'NativeSparseAttention', +# 'ReBasedLinearAttention', +# 'RWKV6Attention', +# 'RWKV7Attention', +# 'PaTHAttention' +# ] +from .emdeltanet import emdeltanet \ No newline at end of file diff --git a/fla3/layers/__pycache__/__init__.cpython-310.pyc b/fla3/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0550d5bdf09c3cd6869bc5128888336f548bcb5 Binary files /dev/null and b/fla3/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/__init__.cpython-312.pyc b/fla3/layers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6395e8aa1c60a59c74a379d84081ee9cd3144df0 Binary files /dev/null and b/fla3/layers/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/layers/__pycache__/abc.cpython-310.pyc b/fla3/layers/__pycache__/abc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8abe311fda6c49e6d32b444e87ffeba88e390bfc Binary files /dev/null and b/fla3/layers/__pycache__/abc.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/attn.cpython-310.pyc b/fla3/layers/__pycache__/attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2667e8c667d222d4f3d2436ad36c1c82ce2bec61 Binary files /dev/null and b/fla3/layers/__pycache__/attn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/attn.cpython-312.pyc b/fla3/layers/__pycache__/attn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1802295e09debad201a8513e36acb2652fd6ce35 Binary files /dev/null and b/fla3/layers/__pycache__/attn.cpython-312.pyc differ diff --git a/fla3/layers/__pycache__/based.cpython-310.pyc b/fla3/layers/__pycache__/based.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea620fd036550e268702e5765338b107bf574360 Binary files /dev/null and b/fla3/layers/__pycache__/based.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/bitattn.cpython-310.pyc b/fla3/layers/__pycache__/bitattn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f66a8d90f44ca32cb64c1e563124262d51f2d5a7 Binary files /dev/null and b/fla3/layers/__pycache__/bitattn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/delta_net.cpython-310.pyc b/fla3/layers/__pycache__/delta_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67fec272db4f067ceeaaa4147e911f93e72000d3 Binary files /dev/null and b/fla3/layers/__pycache__/delta_net.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/delta_net.cpython-312.pyc b/fla3/layers/__pycache__/delta_net.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f195f7e3271c793cb5f6861834044da83e605017 Binary files /dev/null and b/fla3/layers/__pycache__/delta_net.cpython-312.pyc differ diff --git a/fla3/layers/__pycache__/emdeltanet.cpython-310.pyc b/fla3/layers/__pycache__/emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e85d2c340303fbc5dc46288b2f635eabed700146 Binary files /dev/null and b/fla3/layers/__pycache__/emdeltanet.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/emdeltanet.cpython-312.pyc b/fla3/layers/__pycache__/emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4362edcde52b20ccc9803d55fe70e3620e534d9 Binary files /dev/null and b/fla3/layers/__pycache__/emdeltanet.cpython-312.pyc differ diff --git a/fla3/layers/__pycache__/forgetting_attn.cpython-310.pyc b/fla3/layers/__pycache__/forgetting_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10e69c2e5fd094f6721f8390cc469b616d1b5d55 Binary files /dev/null and b/fla3/layers/__pycache__/forgetting_attn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/gated_deltanet.cpython-310.pyc b/fla3/layers/__pycache__/gated_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ffccd59e96a6dd34aaf4a937db64e11de1f4ad Binary files /dev/null and b/fla3/layers/__pycache__/gated_deltanet.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/gated_deltanet.cpython-312.pyc b/fla3/layers/__pycache__/gated_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1c4e59fcbb40b6e405c13b2ad24d746ada76d9d Binary files /dev/null and b/fla3/layers/__pycache__/gated_deltanet.cpython-312.pyc differ diff --git a/fla3/layers/__pycache__/gated_deltaproduct.cpython-310.pyc b/fla3/layers/__pycache__/gated_deltaproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0096590a9690023e9169cd7af1f778de5a56fc6f Binary files /dev/null and b/fla3/layers/__pycache__/gated_deltaproduct.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/gla.cpython-310.pyc b/fla3/layers/__pycache__/gla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..547b55878fb9f3031c3cb76f74d7a31113c7759c Binary files /dev/null and b/fla3/layers/__pycache__/gla.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/gsa.cpython-310.pyc b/fla3/layers/__pycache__/gsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e1425a81d0da1dbffda39d48f1f99508066f3eb Binary files /dev/null and b/fla3/layers/__pycache__/gsa.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/hgrn.cpython-310.pyc b/fla3/layers/__pycache__/hgrn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0af3081ce0b762b48079bb982cdc852b906dd79e Binary files /dev/null and b/fla3/layers/__pycache__/hgrn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/hgrn2.cpython-310.pyc b/fla3/layers/__pycache__/hgrn2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df048966a7c467303d46933f2326a39a1d767b9a Binary files /dev/null and b/fla3/layers/__pycache__/hgrn2.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/lightnet.cpython-310.pyc b/fla3/layers/__pycache__/lightnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aa5461b9a3dd224370ce32c3e2cb6dcd5a0aa64 Binary files /dev/null and b/fla3/layers/__pycache__/lightnet.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/linear_attn.cpython-310.pyc b/fla3/layers/__pycache__/linear_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f53b6e00a993eef5570a56073bd48779bdbb8e73 Binary files /dev/null and b/fla3/layers/__pycache__/linear_attn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/mamba.cpython-310.pyc b/fla3/layers/__pycache__/mamba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..478905016b9b66eefac0c6c18f0ce63db1f49bad Binary files /dev/null and b/fla3/layers/__pycache__/mamba.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/mamba2.cpython-310.pyc b/fla3/layers/__pycache__/mamba2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a547d793647bca91d3747410959fe3ea0b2ef01d Binary files /dev/null and b/fla3/layers/__pycache__/mamba2.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/multiscale_retention.cpython-310.pyc b/fla3/layers/__pycache__/multiscale_retention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d8991fa29579f01d3204eea372b37f50aa57e59 Binary files /dev/null and b/fla3/layers/__pycache__/multiscale_retention.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/nsa.cpython-310.pyc b/fla3/layers/__pycache__/nsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5458db82078b46b411cabe7e4173e3218d708958 Binary files /dev/null and b/fla3/layers/__pycache__/nsa.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/path_attn.cpython-310.pyc b/fla3/layers/__pycache__/path_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c69aba781b8ac71add812277a9e0289c65e7e177 Binary files /dev/null and b/fla3/layers/__pycache__/path_attn.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/rebased.cpython-310.pyc b/fla3/layers/__pycache__/rebased.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c640c888ee60b176a405672a6b8a267302b2da Binary files /dev/null and b/fla3/layers/__pycache__/rebased.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/rwkv6.cpython-310.pyc b/fla3/layers/__pycache__/rwkv6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3370e2a84ac62f200a3d2dcf3be99c9455fecc42 Binary files /dev/null and b/fla3/layers/__pycache__/rwkv6.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/rwkv7.cpython-310.pyc b/fla3/layers/__pycache__/rwkv7.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd325f11081b4e374953a42a2cd28ca08ff10b2e Binary files /dev/null and b/fla3/layers/__pycache__/rwkv7.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/utils.cpython-310.pyc b/fla3/layers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef522f42f4261825854cd258a7062b86d7c9e2f Binary files /dev/null and b/fla3/layers/__pycache__/utils.cpython-310.pyc differ diff --git a/fla3/layers/__pycache__/utils.cpython-312.pyc b/fla3/layers/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6cc4931d1b59f8e3e19f22792021b8c531c1432 Binary files /dev/null and b/fla3/layers/__pycache__/utils.cpython-312.pyc differ diff --git a/fla3/layers/abc.py b/fla3/layers/abc.py new file mode 100644 index 0000000000000000000000000000000000000000..57aee55800f4b34e164b09ea38183bbf887de096 --- /dev/null +++ b/fla3/layers/abc.py @@ -0,0 +1,217 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution +from fla.modules.activations import swiglu, swish +from fla.ops.abc.chunk import chunk_abc + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class ABCAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_low_rank_dim: int = 16, + gate_logit_normalizer: int = 16, + use_rope: bool = True, + use_input_gate: bool = False, + use_output_gate: bool = True, + use_norm: bool = True, + clamp_min: Optional[float] = -32, + clamp_max: Optional[float] = 32, + layer_idx: Optional[int] = None, + **kwargs + ) -> ABCAttention: + super().__init__() + + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.key_dim = int(self.hidden_size * self.expand_k) + self.value_dim = int(self.hidden_size * self.expand_v) + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.gate_low_rank_dim = gate_low_rank_dim + self.gate_logit_normalizer = gate_logit_normalizer + + self.use_rope = use_rope + self.use_input_gate = use_input_gate + self.use_output_gate = use_output_gate + self.use_norm = use_norm + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.norm_eps = norm_eps + + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + + if use_output_gate: + self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + + if self.use_norm: + if self.use_output_gate: + self.g_norm = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + else: + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + + if self.use_rope: + self.rotary = RotaryEmbedding(self.head_k_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if cu_seqlens is not None: + raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention") + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_input_gate: + q, k, v = map(lambda x: swish(x), (q, k, v)) + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + if self.use_rope: + seqlen_offset = 0 + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset) + + s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots) + s = s.clamp_(self.clamp_min, self.clamp_max) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + o, recurrent_state = chunk_abc( + q=q, + k=k, + v=v, + s=s, + initial_state=recurrent_state, + output_final_state=use_cache, + ) + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + if self.use_norm and not self.use_output_gate: + o = self.g_norm(o) + elif self.use_output_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.g_norm(o, g) if self.use_norm else swiglu(g, o) + o = rearrange(o, '... h d -> ... (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, seq_len: int = 2048): + return 2 * self.num_slots * self.hidden_size diff --git a/fla3/layers/attn.py b/fla3/layers/attn.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3c1d1b078cf91db33a9c47cb157b6af38973e3 --- /dev/null +++ b/fla3/layers/attn.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.layers.utils import pad_input, unpad_input +from fla.modules import RMSNorm, RotaryEmbedding +from fla.ops.utils.index import prepare_lens_from_mask + +if TYPE_CHECKING: + from fla.models.utils import Cache + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + +logger = logging.get_logger(__name__) + + +class Attention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + layer_idx: int = None + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + # equivalent to cu_seqlens in `flash_attn` + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, q_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=q_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if flash_attn_func is None: + raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + + # Contains at least one padding token in the sequence + if attention_mask is not None: + q, (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, v), attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seqlens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = pad_input(o, indices_q, batch_size, q_len) + elif cu_seqlens is not None: + o = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ).unsqueeze(0) + else: + o = flash_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values diff --git a/fla3/layers/based.py b/fla3/layers/based.py new file mode 100644 index 0000000000000000000000000000000000000000..d21613a7c3ed01f0b9cb19a3b488d3f4b4d240c1 --- /dev/null +++ b/fla3/layers/based.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import TaylorFeatureMap +from fla.ops.based import parallel_based +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn + + +class BasedLinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + + self.hidden_size = hidden_size + self.mode = mode + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + assert self.hidden_size % self.head_dim == 0 + self.causal = causal + + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + self.feature_map = TaylorFeatureMap(feature_dim) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v]) + if mode == "fused_chunk": + q, k = self.feature_map(q), self.feature_map(k) + o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'chunk': + q, k = self.feature_map(q), self.feature_map(k) + o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_based(q, k, v, scale=1, use_norm=True) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + o = self.dropout(o) + return o + + def forward_reference(self, hidden_states: torch.Tensor, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, t) + y (torch.Tensor): tensor of shape (b, d, t) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, t, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2) + v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h t d -> b t (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) diff --git a/fla3/layers/bitattn.py b/fla3/layers/bitattn.py new file mode 100644 index 0000000000000000000000000000000000000000..dea362acc66712bcba841aa0b65fafc34b78c4da --- /dev/null +++ b/fla3/layers/bitattn.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.layers.utils import pad_input, unpad_input +from fla.modules import RotaryEmbedding +from fla.modules.fused_bitlinear import FusedBitLinear +from fla.ops.utils.index import prepare_lens_from_mask + +if TYPE_CHECKING: + from fla.models.utils import Cache + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + +logger = logging.get_logger(__name__) + + +class BitAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + norm_eps: float = 1e-5, + layer_idx: int = None + ): + super().__init__() + + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False) + self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False) + self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + + # equivalent to cu_seqlens in `flash_attn` + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, q_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=q_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if flash_attn_func is None: + raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + + # Contains at least one padding token in the sequence + if attention_mask is not None: + q, (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, v), attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seqlens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = pad_input(o, indices_q, batch_size, q_len) + elif cu_seqlens is not None: + o = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ).unsqueeze(0) + else: + o = flash_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values diff --git a/fla3/layers/delta_net.py b/fla3/layers/delta_net.py new file mode 100644 index 0000000000000000000000000000000000000000..9f31bfef9a9224e63aa4c7caeaf2a39db0a2045a --- /dev/null +++ b/fla3/layers/delta_net.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from ..ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class DeltaNet(nn.Module): + r""" + The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa: + DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa + + Args: + mode (str, Optional): + Which DeltaNet kernel to use. + Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `False`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + allow_neg_eigval (bool, Optional): + Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2. + See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537) + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + qk_activation (str, Optional): + The activation function for the query and key. Default: `silu`. + qk_norm (str, Optional): + The normalization method for the query and key. Default: `l2`. + """ + + def __init__( + self, + mode: str = 'chunk', + d_model: int = None, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + use_beta: bool = True, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + allow_neg_eigval: bool = False, + layer_idx: int = None, + qk_activation: str = 'silu', + qk_norm: str = 'l2', + norm_eps: float = 1e-5, + **kwargs + ) -> DeltaNet: + super().__init__() + + self.mode = mode + self.qk_activation = qk_activation + self.qk_norm = qk_norm + + assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] + assert self.qk_norm in ['l2', 'sum'] + + if d_model is not None: + hidden_size = d_model + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.allow_neg_eigval = allow_neg_eigval + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.layer_idx = layer_idx + print('ooooo_deltanet') + if mode == 'fused_chunk': + raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.") + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.use_beta = use_beta + if self.use_beta: + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' if qk_activation == 'silu' else None + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' if qk_activation == 'silu' else None + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + # change to inference mode. + mode = 'fused_recurrent' if q_len <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + if self.qk_activation == 'silu': + q, k = F.silu(q), F.silu(k) + v = F.silu(self.v_proj(hidden_states)) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + if self.qk_activation != 'silu': + if self.qk_activation == 'relu': + q, k = q.relu(), k.relu() + elif self.qk_activation == 'elu': + q, k = elu_p1(q), elu_p1(k) + elif self.qk_activation != 'identity': + raise NotImplementedError + + if self.qk_norm == 'sum': + q = sum_norm(q).to(q) + k = sum_norm(k).to(k) + + if self.use_beta: + beta = self.b_proj(hidden_states).sigmoid() + else: + beta = torch.ones_like(q[..., 0]) + + if self.allow_neg_eigval: + beta = beta * 2. + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/emdeltanet copy.py b/fla3/layers/emdeltanet copy.py new file mode 100644 index 0000000000000000000000000000000000000000..eddd55c879c0e22507a81cc04b50ae7e5fdfcc3f --- /dev/null +++ b/fla3/layers/emdeltanet copy.py @@ -0,0 +1,1734 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from ..modules import FusedRMSNormSwishGate, RMSNorm +from fla.modules import ShortConvolution +import torch.nn.init as init +import math +from ..modules.l2norm import l2_norm_fn +from einops import rearrange +from fla.models.utils import Cache +from transformers.processing_utils import Unpack +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +# from ..modules import RotaryEmbedding +# from ..ops.gla import chunk_gla,fused_recurrent_gla +from ..ops.delta_rule import chunk_delta_rule,fused_chunk_delta_rule,fused_recurrent_delta_rule + +def simple_norm(x): + return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x) + + +# @torch.jit.script +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +# @torch.jit.script +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +# @torch.jit.script +def elu_norm(x): + dtype = x.dtype + x = F.elu(x, 1., False) + 1. + return (x / x.sum(-1, keepdim=True)).to(dtype) + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + +# # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) + +# # dealing with left-padding +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 + +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_idx) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b h l (r d)').to(q) + +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# # state = past_key_values[self.layer_idx][-1] if use_cache else None + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q, k, v_exp, beta, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b h l (r d)-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)+(scores-scores.detach())).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + + +# ####ver2 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) + +# # dealing with left-padding +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 + +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_scores) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b h l (r d)').to(q) + +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# # state = past_key_values[self.layer_idx][-1] if use_cache else None + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q, k, v_exp, beta, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b h l (r d)-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)).to(o_moe_v) +# # qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)+(scores-scores.detach())).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + + +###ver3 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 +# # aux_loss = 0 +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) +# # norm_state = torch.where(cum_idx==0,0,1/cum_idx) #可以把头维度拼起来#bhlr +# # norm_state = torch.nn.functional.normalize(norm_state,p=1,dim=-1)#p=1比2合理 一部分的显示告知当前token数 +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# q_exp = q.unsqueeze(-2).expand(-1,-1,-1,self.ratio,-1) +# k_exp = torch.einsum('b h l d,b h l r->b h l r d',k,masked_idx) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_scores) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# q_exp = rearrange(q_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# k_exp = rearrange(k_exp,'b h l r d-> b (h r) l d').to(q).contiguous() #如果可行,还是需要加速比较好 + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) +# beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training and aux_loss: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + + +####ver4 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet_v4') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 +# # aux_loss = 0 +# recurrent_state_kf,recurrent_state_kf2,recurrent_state_sf = None,None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_kf2,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) + +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# k_exp2 = torch.einsum('b h l d,b h l r->b h l r d',k,masked_scores).to(q) +# if recurrent_state_kf2 is not None: +# k_exp_sum = torch.cumsum(k_exp2,dim=-2)+recurrent_state_kf2 +# recurrent_state_kf2 = k_exp_sum[:,:,-1,:].unsqueeze(-2) +# else: +# k_exp_sum = torch.cumsum(k_exp2,dim=-2)#bhlrd +# recurrent_state_kf2 = k_exp_sum[:,:,-1,:].unsqueeze(-2) + +# qlamda = (torch.einsum('b h l d,b h l r d-> b h l r',q,k_exp_sum)*(self.head_qk_dim**(-0.5))+masked_mem) +# qlamda = torch.softmax(qlamda,dim=-1) +# q_exp = torch.einsum('b h l d,b h l r->b h l r d',q,qlamda) +# k_exp = torch.einsum('b h l d,b h l r->b h l r d',k,masked_idx) +# v_exp = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_exp = rearrange(v_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# q_exp = rearrange(q_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# k_exp = rearrange(k_exp,'b h l r d-> b (h r) l d').to(q).contiguous() #如果可行,还是需要加速比较好 +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) +# beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) +# o = torch.sum(o,dim=-2,keepdim=False) +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training and aux_loss: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + +# # # ##base-deltanet +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 1 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('branch') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# # if attention_mask is not None: +# # if attention_mask.shape[-1] != hidden_states.shape[-2]: +# # attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# # if attention_mask is not None: +# # v = v.mul_(attention_mask.unsqueeze(-1)) +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError +# q,k = map(lambda x: rearrange(x, 'b h l (k d) -> b (h k) l d', k=self.ratio), (q, k)) +# # print(v.shape) +# v = v.unsqueeze(2).expand(-1,-1,self.ratio,-1,-1).reshape(b,self.num_heads*self.ratio,q_len,self.head_v_dim) +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) +# beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() +# q_exp, k_exp, v_exp = map(lambda x: x.contiguous(), (q, k, v)) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) +# o = torch.sum(o,dim=-2,keepdim=False) +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + + +##base-deltanet+normfirst +from fla3.ops.generalized_delta_rule import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule +class emdeltanet(nn.Module): + def __init__( + self, + d_model: int = None, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + mode: str = 'chunk', + chunk_size: int = 64, + use_beta: bool = True, + use_gate: bool = False, + use_output_norm: bool = True, + use_elu: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + qk_activation: str = 'silu', + qk_norm: str = 'l2', + norm_first: bool = False, + norm_eps: float = 1e-6, + ratio :int = 2, + topk : int = 1 , + **kwargs + ) : + super().__init__() + + self.mode = mode + self.qk_activation = qk_activation + self.qk_norm = qk_norm + + assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] + assert self.qk_norm in ['l2', 'sum'] + + if d_model is not None: + hidden_size = d_model + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.chunk_size = chunk_size + self.use_gate = use_gate + self.use_output_norm = use_output_norm + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.norm_first = norm_first + self.layer_idx = layer_idx + self.ratio = 2 #ratio + self.top_k = topk + self.silu = nn.SiLU() + print(self.ratio) + print('branch+iplr') + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + if norm_first: + self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + import torch.nn.init as init + self.use_beta = use_beta + self.use_elu = use_elu + # if self.use_beta: + # self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, + conv_size, + activation='silu' if qk_activation == 'silu' else None) + self.k_conv1d = ShortConvolution(self.key_dim, + conv_size, + activation='silu' if qk_activation == 'silu' else None) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # change to inference mode. + mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode + b,q_len,d = hidden_states.shape + if self.norm_first: + hidden_states = self.norm(hidden_states) + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + offset = past_key_values.get_seq_length() + if attention_mask is not None: + if attention_mask.shape[-1] != hidden_states.shape[-2]: + attention_mask = attention_mask[:, -1:] + + if self.use_short_conv: + conv_state_q , conv_state_k , conv_state_v = None,None,None + if last_state is not None: + conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) + k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) + v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.num_heads), (q, k, v)) + beta = k.sigmoid() + q,k,beta = map(lambda x: rearrange(x, 'b l h (k d) -> b l (h k) d', k=self.ratio), (q, k,beta)) + if self.qk_norm is not None: + if self.qk_norm == 'l2': + q = l2_norm_fn(q) + k = l2_norm_fn(k) + elif self.qk_norm == 'sum': + q = sum_norm(q).to(v) + k = sum_norm(k).to(v) + v = v.unsqueeze(2).expand(-1,-1,self.ratio,-1,-1).reshape(b,self.num_heads*self.ratio,q_len,self.head_v_dim) + recurrent_state_kf,recurrent_state_sf = None,None + if last_state is not None: + recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] + + beta_exp = beta.contiguous() + q_exp, k_exp, v_exp = map(lambda x: x.contiguous(), (q, k, v)) + a = -k_exp*beta_exp.contiguous() + b = k_exp*beta_exp.contiguous() + if mode == 'fused_recurrent': + o, recurrent_state_sf = fused_recurrent_iplr_delta_rule(q_exp, k_exp, v_exp, a,b, initial_state=recurrent_state_sf, output_final_state=use_cache) + elif mode == 'chunk': + o, recurrent_state_sf = chunk_iplr_delta_rule(q_exp, k_exp, v_exp, a,b, initial_state=recurrent_state_sf, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) + o = torch.sum(o,dim=-2,keepdim=False) + o = rearrange(o,'b h l d-> b l h d') + if past_key_values is not None: + past_key_values.update( + recurrent_state=(recurrent_state_kf,recurrent_state_sf), + conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b l h d -> b l (h d)') + o = self.o_proj(o) + return o, None, past_key_values,None + + # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: + # param = next(self.parameters()) + # state = tuple() + # if self.use_short_conv: + # # for q/k/v each + # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), + # param.new_zeros(batch_size, self.key_dim, self.conv_size), + # param.new_zeros(batch_size, self.value_dim, self.conv_size)) + # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) + # return state + +#r-n gla +# from fla.ops.gla import chunk_gla,fused_recurrent_gla#BTHK +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 2 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('gla') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# # self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# # self.k_conv1d = ShortConvolution(self.key_dim, +# # conv_size, +# # activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# # if attention_mask is not None: +# # if attention_mask.shape[-1] != hidden_states.shape[-2]: +# # attention_mask = attention_mask[:, -1:] +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# # k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# # k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# # k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# # if attention_mask is not None: +# # v = v.mul_(attention_mask.unsqueeze(-1)) +# q, v = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.num_heads), (q, v)) +# # q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# # k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# # k = sum_norm(k).to(v) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b l h').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta = rearrange(beta, 'b l (h d) -> b l h d', h=self.num_heads) +# g = 1 - beta +# g = torch.clamp_min(g,1e-6) +# g = torch.log(g) +# q_exp, k_exp, v_exp,g = map(lambda x: x.contiguous(), (q, beta, v, g)) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_gla(q_exp, k_exp, v_exp, g, initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# o, recurrent_state_sf = chunk_gla(q_exp, k_exp, v_exp, g,initial_state=recurrent_state_sf, output_final_state=use_cache) +# # elif mode == 'chunk': +# # assert self.chunk_size in [16, 32, 64] +# # o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# # o = rearrange(o,'b l h d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + diff --git a/fla3/layers/emdeltanet.py b/fla3/layers/emdeltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..20cc5adc839b5a035eafef91dca4a20cdf4069c7 --- /dev/null +++ b/fla3/layers/emdeltanet.py @@ -0,0 +1,1866 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from ..modules import FusedRMSNormSwishGate, RMSNorm +from fla.modules import ShortConvolution +import torch.nn.init as init +import math +from ..modules.l2norm import l2_norm as l2_norm_fn +from einops import rearrange +from fla.models.utils import Cache +from transformers.processing_utils import Unpack +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +# from ..modules import RotaryEmbedding +# from ..ops.gla import chunk_gla,fused_recurrent_gla +from ..ops.delta_rule import chunk_delta_rule,fused_chunk_delta_rule,fused_recurrent_delta_rule + +def simple_norm(x): + return (F.normalize(x, dim=-1) * x.shape[-1] ** 0.5).to(x) + + +# @torch.jit.script +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +# @torch.jit.script +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +# @torch.jit.script +def elu_norm(x): + dtype = x.dtype + x = F.elu(x, 1., False) + 1. + return (x / x.sum(-1, keepdim=True)).to(dtype) + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + +# # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) + +# # dealing with left-padding +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 + +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_idx) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b h l (r d)').to(q) + +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# # state = past_key_values[self.layer_idx][-1] if use_cache else None + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q, k, v_exp, beta, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b h l (r d)-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)+(scores-scores.detach())).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + + + +# ####ver2 use for aaai +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) + +# # dealing with left-padding +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 + +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_scores) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b h l (r d)').to(q) + +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# # state = past_key_values[self.layer_idx][-1] if use_cache else None + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q, k, v_exp, beta, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q, k, v_exp, beta, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b h l (r d)-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)).to(o_moe_v) +# # qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)+(scores-scores.detach())).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + + + +###ver3 +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int =2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print('emdeltanet') +# print(self.ratio) +# print(self.top_k) +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + +# self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) +# import torch.nn.init as init +# init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# if attention_mask is not None: +# indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# if attention_mask is not None: +# v = v.mul_(attention_mask.unsqueeze(-1)) + +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# logits = torch.matmul(v,self.router_weight)#get b h l r logits +# scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k +# topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k +# masked_scores = torch.zeros_like(scores,device=q.device) +# masked_scores.scatter_(-1, topk_idx, topk_score) +# masked_idx = masked_scores.bool() +# if self.training : +# scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') +# aux_topk = self.top_k +# # always compute aux loss based on the naive greedy topk method +# topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) +# if True: +# scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) +# ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) +# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) +# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 +# # aux_loss = 0 +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['current_state'] + +# if recurrent_state_kf is not None: +# cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# else: +# cum_idx = torch.cumsum(masked_idx,dim=-2) +# recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) +# masked_mem = torch.where(cum_idx==0,-1e10,0) +# # norm_state = torch.where(cum_idx==0,0,1/cum_idx) #可以把头维度拼起来#bhlr +# # norm_state = torch.nn.functional.normalize(norm_state,p=1,dim=-1)#p=1比2合理 一部分的显示告知当前token数 +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError + +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# q_exp = q.unsqueeze(-2).expand(-1,-1,-1,self.ratio,-1) +# k_exp = torch.einsum('b h l d,b h l r->b h l r d',k,masked_idx) + +# v = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) +# v_e = torch.ones([b,self.num_heads,q_len,1]).to(v) +# v_e = torch.einsum('b h l d, b h l r -> b h l r d',v_e,masked_scores) +# v_exp = torch.concat([v,v_e],dim=-1) +# v_exp = rearrange(v_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# q_exp = rearrange(q_exp,'b h l r d-> b (h r) l d').to(q).contiguous() +# k_exp = rearrange(k_exp,'b h l r d-> b (h r) l d').to(q).contiguous() #如果可行,还是需要加速比较好 + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) +# beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() + +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") + +# o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) +# o_moe_v = o[...,:-1] +# o_moe_k = o[...,-1] +# qlogit = ((o_moe_k+masked_mem).softmax(dim=-1)).to(o_moe_v) +# o = torch.einsum('b h l r d,b h l r-> b h l d',o_moe_v,qlogit) + +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# if self.training and aux_loss: +# o = AddAuxiliaryLoss.apply(o,aux_loss) +# return o, None, past_key_values,None + + + + +###ver4 +class emdeltanet(nn.Module): + def __init__( + self, + d_model: int = None, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + mode: str = 'chunk', + chunk_size: int = 64, + use_beta: bool = True, + use_gate: bool = False, + use_output_norm: bool = True, + use_elu: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + qk_activation: str = 'silu', + qk_norm: str = 'l2', + norm_first: bool = False, + norm_eps: float = 1e-6, + ratio :int =2, + topk : int = 1 , + **kwargs + ) : + super().__init__() + + self.mode = mode + self.qk_activation = qk_activation + self.qk_norm = qk_norm + + assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] + assert self.qk_norm in ['l2', 'sum'] + + if d_model is not None: + hidden_size = d_model + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.chunk_size = chunk_size + self.use_gate = use_gate + self.use_output_norm = use_output_norm + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_qk_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.norm_first = norm_first + self.layer_idx = layer_idx + self.ratio = ratio + self.top_k = topk + self.silu = nn.SiLU() + print('emdeltanet_v4') + print(self.ratio) + print(self.top_k) + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + if norm_first: + self.norm = RMSNorm(self.hidden_size, eps=norm_eps) + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.head_v_dim,self.ratio))) + import torch.nn.init as init + init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) + self.use_beta = use_beta + self.use_elu = use_elu + if self.use_beta: + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, + conv_size, + activation='silu' if qk_activation == 'silu' else None) + self.k_conv1d = ShortConvolution(self.key_dim, + conv_size, + activation='silu' if qk_activation == 'silu' else None) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + # change to inference mode. + mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode + b,q_len,d = hidden_states.shape + if self.norm_first: + hidden_states = self.norm(hidden_states) + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + offset = past_key_values.get_seq_length() + if attention_mask is not None: + if attention_mask.shape[-1] != hidden_states.shape[-2]: + attention_mask = attention_mask[:, -1:] + + if self.use_short_conv: + conv_state_q , conv_state_k , conv_state_v = None,None,None + if last_state is not None: + conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) + k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) + v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + if attention_mask is not None: + v = v.mul_(attention_mask.unsqueeze(-1)) + + q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) + logits = torch.matmul(v,self.router_weight)#get b h l r logits + scores = (logits).softmax(dim=-1,dtype=torch.float32)#b h l r #划分k + topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k + masked_scores = torch.zeros_like(scores,device=q.device) + masked_scores.scatter_(-1, topk_idx, topk_score) + masked_idx = masked_scores.bool() + if self.training : + scores_for_aux = rearrange(scores,'b h l r ->(b h) l r') + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(b*self.num_heads, -1) + if True: + scores_for_seq_aux = scores_for_aux.view(b*self.num_heads, q_len, -1) + ce = torch.zeros(b*self.num_heads, self.ratio, device=hidden_states.device) + ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(b*self.num_heads, q_len * aux_topk, device=hidden_states.device)).div_(q_len * aux_topk / self.top_k) + aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * 0.001 + # aux_loss = 0 + recurrent_state_kf,recurrent_state_kf2,recurrent_state_sf = None,None,None + if last_state is not None: + recurrent_state_kf,recurrent_state_kf2,recurrent_state_sf = last_state['current_state'] + + if recurrent_state_kf is not None: + cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf + recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) + else: + cum_idx = torch.cumsum(masked_idx,dim=-2) + recurrent_state_kf = cum_idx[:,:,-1,:].unsqueeze(-2) + masked_mem = torch.where(cum_idx==0,-1e10,0) + + if self.qk_activation != 'silu': + if self.qk_activation == 'relu': + q, k = q.relu(), k.relu() + elif self.qk_activation == 'elu': + q, k = elu_p1(q), elu_p1(k) + elif self.qk_activation == 'identity': + pass + else: + raise NotImplementedError + if self.qk_norm is not None: + if self.qk_norm == 'l2': + q = l2_norm_fn(q) + k = l2_norm_fn(k) + elif self.qk_norm == 'sum': + q = sum_norm(q).to(v) + k = sum_norm(k).to(v) + k_exp2 = torch.einsum('b h l d,b h l r->b h l r d',k,masked_scores).to(q) + if recurrent_state_kf2 is not None: + k_exp_sum = torch.cumsum(k_exp2,dim=-2)+recurrent_state_kf2 + recurrent_state_kf2 = k_exp_sum[:,:,-1,:].unsqueeze(-2) + else: + k_exp_sum = torch.cumsum(k_exp2,dim=-2)#bhlrd + recurrent_state_kf2 = k_exp_sum[:,:,-1,:].unsqueeze(-2) + + qlamda = (torch.einsum('b h l d,b h l r d-> b h l r',q,k_exp_sum)*(self.head_qk_dim**(-0.5))+masked_mem) + qlamda = torch.softmax(qlamda,dim=-1) + q_exp = torch.einsum('b h l d,b h l r->b h l r d',q,qlamda) + k_exp = torch.einsum('b h l d,b h l r->b h l r d',k,masked_idx) + v_exp = torch.einsum('b h l d, b h l r -> b h l r d',v,masked_idx) + v_exp = rearrange(v_exp,'b h l r d-> b (h r) l d').to(q).contiguous() + q_exp = rearrange(q_exp,'b h l r d-> b (h r) l d').to(q).contiguous() + k_exp = rearrange(k_exp,'b h l r d-> b (h r) l d').to(q).contiguous() #如果可行,还是需要加速比较好 + if self.use_beta: + beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() + else: + beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) + beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) + beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() + if mode == 'fused_recurrent': + o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, recurrent_state_sf, output_final_state=use_cache) + elif mode == 'fused_chunk': + assert self.chunk_size in [16, 32, 64] + o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) + elif mode == 'chunk': + assert self.chunk_size in [16, 32, 64] + o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) + o = torch.sum(o,dim=-2,keepdim=False) + o = rearrange(o,'b h l d-> b l h d') + if past_key_values is not None: + past_key_values.update( + recurrent_state=(recurrent_state_kf,recurrent_state_sf), + conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b l h d -> b l (h d)') + o = self.o_proj(o) + if self.training and aux_loss: + o = AddAuxiliaryLoss.apply(o,aux_loss) + return o, None, past_key_values,None + + +# # # ##base-deltanet +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 1 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('branch') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# if self.use_beta: +# self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# # if attention_mask is not None: +# # if attention_mask.shape[-1] != hidden_states.shape[-2]: +# # attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# # if attention_mask is not None: +# # v = v.mul_(attention_mask.unsqueeze(-1)) +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, v)) +# if self.qk_activation != 'silu': +# if self.qk_activation == 'relu': +# q, k = q.relu(), k.relu() +# elif self.qk_activation == 'elu': +# q, k = elu_p1(q), elu_p1(k) +# elif self.qk_activation == 'identity': +# pass +# else: +# raise NotImplementedError +# q,k = map(lambda x: rearrange(x, 'b h l (k d) -> b (h k) l d', k=self.ratio), (q, k)) +# # print(v.shape) +# v = v.unsqueeze(2).expand(-1,-1,self.ratio,-1,-1).reshape(b,self.num_heads*self.ratio,q_len,self.head_v_dim) +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] + +# if self.use_beta: +# beta = rearrange(self.b_proj(hidden_states), 'b l h -> b h l').sigmoid() +# else: +# beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) +# beta_exp = beta.unsqueeze(-1).expand(-1,-1,-1,self.ratio) +# beta_exp = rearrange(beta_exp,'b h l r->b (h r) l').contiguous() +# q_exp, k_exp, v_exp = map(lambda x: x.contiguous(), (q, k, v)) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_delta_rule(q_exp, k_exp, v_exp, beta_exp, initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'fused_chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = fused_chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# assert self.chunk_size in [16, 32, 64] +# o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# o = rearrange(o,'b (h r) l d-> b h l r d',r=self.ratio) +# o = torch.sum(o,dim=-2,keepdim=False) +# o = rearrange(o,'b h l d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + + + + +# ##base-deltanet+normfirst +# from ..ops.generalized_delta_rule import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 8 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('branch+iplr+norm_all') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# # if self.use_beta: +# # self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# # if attention_mask is not None: +# # v = v.mul_(attention_mask.unsqueeze(-1)) +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.num_heads), (q, k, v)) +# beta = k.sigmoid() +# if self.qk_norm is not None:#norm-all +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# #这是一个问题,是否是分开norm还是集体norm,将影响到模型本身的1-\lamda*kkt的特征值 +# q,k,beta = map(lambda x: rearrange(x, 'b l h (k d) -> b l (h k) d', k=self.ratio), (q, k,beta)) +# v = v.unsqueeze(-2).expand(-1,-1,-1,self.ratio,-1).reshape(b,q_len,self.num_heads*self.ratio,self.head_v_dim) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] +# beta_exp = beta.contiguous() +# q_exp, k_exp, v_exp = map(lambda x: x.contiguous(), (q, k, v)) +# a = -k_exp*beta_exp.contiguous() +# b = k_exp*beta_exp.contiguous() +# gk = torch.zeros_like(q).to(q) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_dplr_delta_rule(q_exp, k_exp, v_exp, a,b,gk,scale=None,initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# o, recurrent_state_sf = chunk_dplr_delta_rule(q_exp, k_exp, v_exp, a,b,gk, initial_state=recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# o = rearrange(o,'b l (h r) d-> b l h r d',r=self.ratio) +# o = torch.sum(o,dim=-2,keepdim=False) +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state +# # r-n gla +# from fla.ops.gla import chunk_gla,fused_recurrent_gla#BTHK +# #####一个发现,gla-base 如果采用类hgrn k-decay bound ze +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() + +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm + +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] + +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 2 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('gla') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# # if self.use_beta: +# # self.b_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# # if attention_mask is not None: +# # if attention_mask.shape[-1] != hidden_states.shape[-2]: +# # attention_mask = attention_mask[:, -1:] +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# # k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# # if attention_mask is not None: +# # v = v.mul_(attention_mask.unsqueeze(-1)) +# q,k,v = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.num_heads), (q, k, v)) +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# # k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# # k = sum_norm(k).to(v) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] +# beta = k.sigmoid() +# g = 1 - beta +# g = torch.clamp(g,1e-6) +# g = torch.log(g) +# q_exp, k_exp, v_exp,g = map(lambda x: x.contiguous(), (q, beta, v, g)) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_gla(q_exp, k_exp, v_exp, g, initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# o, recurrent_state_sf = chunk_gla(q_exp, k_exp, v_exp, g,initial_state=recurrent_state_sf, output_final_state=use_cache) +# # elif mode == 'chunk': +# # assert self.chunk_size in [16, 32, 64] +# # o, recurrent_state_sf = chunk_delta_rule(q_exp, k_exp, v_exp, beta_exp, self.chunk_size, recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# # o = rearrange(o,'b l h d-> b l h d') +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + +# # def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: +# # param = next(self.parameters()) +# # state = tuple() +# # if self.use_short_conv: +# # # for q/k/v each +# # state += (param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.key_dim, self.conv_size), +# # param.new_zeros(batch_size, self.value_dim, self.conv_size)) +# # state += (param.new_zeros(batch_size, self.num_heads, self.head_qk_dim, (self.ratio*(self.head_v_dim+1))),) +# # return state + + +# ####稀疏block 卷积 +# from ..ops.generalized_delta_rule import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +# class emdeltanet(nn.Module): +# def __init__( +# self, +# d_model: int = None, +# hidden_size: int = 1024, +# expand_k: float = 1.0, +# expand_v: float = 1.0, +# num_heads: int = 4, +# mode: str = 'chunk', +# chunk_size: int = 64, +# use_beta: bool = True, +# use_gate: bool = False, +# use_output_norm: bool = True, +# use_elu: bool = False, +# use_short_conv: bool = True, +# conv_size: int = 4, +# conv_bias: bool = False, +# layer_idx: int = None, +# qk_activation: str = 'silu', +# qk_norm: str = 'l2', +# norm_first: bool = False, +# norm_eps: float = 1e-6, +# ratio :int = 2, +# topk : int = 1 , +# **kwargs +# ) : +# super().__init__() +# self.mode = mode +# self.qk_activation = qk_activation +# self.qk_norm = qk_norm +# assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] +# assert self.qk_norm in ['l2', 'sum'] +# if d_model is not None: +# hidden_size = d_model +# self.hidden_size = hidden_size +# self.expand_k = expand_k +# self.expand_v = expand_v +# self.num_heads = num_heads +# self.chunk_size = chunk_size +# self.use_gate = use_gate +# self.use_output_norm = use_output_norm +# self.use_short_conv = use_short_conv +# self.conv_size = conv_size +# self.conv_bias = conv_bias + +# self.key_dim = int(hidden_size * expand_k) +# self.value_dim = int(hidden_size * expand_v) +# self.head_qk_dim = self.key_dim // num_heads +# self.head_v_dim = self.value_dim // num_heads +# self.norm_first = norm_first +# self.layer_idx = layer_idx +# self.ratio = 8 #ratio +# self.top_k = topk +# self.silu = nn.SiLU() +# print(self.ratio) +# print('branch+iplr') +# assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." +# assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" +# assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + +# if norm_first: +# self.norm = RMSNorm(self.hidden_size, eps=norm_eps) +# self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) +# self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# import torch.nn.init as init +# self.use_beta = use_beta +# self.use_elu = use_elu +# if use_short_conv: +# self.conv_size = conv_size +# self.q_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.k_conv1d = ShortConvolution(self.key_dim, +# conv_size, +# activation='silu' if qk_activation == 'silu' else None) +# self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') +# if use_gate: +# self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) +# self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) +# else: +# self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) +# self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) +# self.apply(self._initialize_weights) + +# def _initialize_weights(self, module: nn.Module): +# if getattr(module, "_is_hf_initialized", False): +# return +# if isinstance(module, nn.Linear): +# nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) +# if module.bias is not None: +# nn.init.zeros_(module.bias) +# module._is_hf_initialized = True + +# def forward( +# self, +# hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# past_key_values: Optional[Cache] = None, +# use_cache: Optional[bool] = False, +# output_attentions: Optional[bool] = False, +# **kwargs +# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: +# # change to inference mode. +# mode = 'fused_recurrent' if hidden_states.shape[1] < 64 else self.mode +# b,q_len,d = hidden_states.shape +# if self.norm_first: +# hidden_states = self.norm(hidden_states) + +# cu_seqlens = kwargs.get('cu_seqlens', None) +# # if attention_mask is not None: +# # indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) +# # hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) +# last_state = None +# if past_key_values is not None and len(past_key_values) > self.layer_idx: +# last_state = past_key_values[self.layer_idx] +# offset = past_key_values.get_seq_length() +# if attention_mask is not None: +# if attention_mask.shape[-1] != hidden_states.shape[-2]: +# attention_mask = attention_mask[:, -1:] + +# if self.use_short_conv: +# conv_state_q , conv_state_k , conv_state_v = None,None,None +# if last_state is not None: +# conv_state_q , conv_state_k , conv_state_v = last_state['conv_state'] +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q,conv_state_q = self.q_conv1d(q, cache= conv_state_q,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# k,conv_state_k = self.k_conv1d(k, cache= conv_state_k,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# v,conv_state_v = self.v_conv1d(v, cache= conv_state_v,output_final_state = use_cache,cu_seqlens =cu_seqlens) +# else: +# q = self.q_proj(hidden_states) +# k = self.k_proj(hidden_states) +# v = self.v_proj(hidden_states) +# q, k, v = map(lambda x: rearrange(x, 'b l (h d) -> b l h d', h=self.num_heads), (q, k, v)) +# beta = k.sigmoid() +# q,k,beta = map(lambda x: rearrange(x, 'b l h (k d) -> b l (h k) d', k=self.ratio), (q, k,beta)) +# if self.qk_norm is not None: +# if self.qk_norm == 'l2': +# q = l2_norm_fn(q) +# k = l2_norm_fn(k) +# elif self.qk_norm == 'sum': +# q = sum_norm(q).to(v) +# k = sum_norm(k).to(v) +# v = v.unsqueeze(-2).expand(-1,-1,-1,self.ratio,-1).reshape(b,q_len,self.num_heads*self.ratio,self.head_v_dim) +# recurrent_state_kf,recurrent_state_sf = None,None +# if last_state is not None: +# recurrent_state_kf,recurrent_state_sf = last_state['recurrent_state'] + +# beta_exp = beta.contiguous() +# q_exp, k_exp, v_exp = map(lambda x: x.contiguous(), (q, k, v)) +# a = -k_exp*beta_exp.contiguous() +# b = k_exp*beta_exp.contiguous() +# gk = torch.zeros_like(q).to(q) +# if mode == 'fused_recurrent': +# o, recurrent_state_sf = fused_recurrent_dplr_delta_rule(q_exp, k_exp, v_exp, a,b,gk,scale=None,initial_state=recurrent_state_sf, output_final_state=use_cache) +# elif mode == 'chunk': +# o, recurrent_state_sf = chunk_dplr_delta_rule(q_exp, k_exp, v_exp, a,b,gk, initial_state=recurrent_state_sf, output_final_state=use_cache) +# else: +# raise NotImplementedError(f"Not supported mode `{mode}`.") +# o = rearrange(o,'b l (h r) d-> b l h r d',r=self.ratio) +# o = torch.sum(o,dim=-2,keepdim=False) +# if past_key_values is not None: +# past_key_values.update( +# recurrent_state=(recurrent_state_kf,recurrent_state_sf), +# conv_state=(conv_state_q,conv_state_k,conv_state_v) if self.use_short_conv else None, +# layer_idx=self.layer_idx, +# offset=q_len +# ) +# if self.use_gate: +# g = rearrange(self.g_proj(hidden_states), 'b l (h d) -> b l h d', h=self.num_heads) +# o = self.o_norm(o, g) +# else: +# o = self.o_norm(o) +# o = rearrange(o, 'b l h d -> b l (h d)') +# o = self.o_proj(o) +# return o, None, past_key_values,None + + + + + diff --git a/fla3/layers/emla.py b/fla3/layers/emla.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb113d48bb4d516671df66bcb3a2629a30425f2 --- /dev/null +++ b/fla3/layers/emla.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.modules import RMSNorm +# from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap +# from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn +import torch.nn.init as init +import math +from fla.modules.l2norm import l2norm +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from einops import rearrange + +class emla(nn.Module): + def __init__( + self, + mode: str = 'chunk', + hidden_size: str = 1024, + expand_k: int = 1.0, + expand_v: int = 1.0, + num_heads: int = 8, + + output_norm: str = 'rmsnorm', + elementwise_affine: bool = True, + norm_eps: float = 1e-5, + use_gate :bool = False, + ratio : int =2, + **kwargs + ): + super().__init__() + + self.hidden_size = hidden_size + self.mode = mode + self.num_heads = num_heads + self.num_kv_heads = num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + self.use_gate = use_gate + if use_gate : + self.g_proj = nn.Linear(self.hidden_size,self.value_dim_per_group,False) + if output_norm == 'rmsnorm': + self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) + elif output_norm == 'identity': + self.norm = nn.Identity() + else: + raise NotImplementedError(f"Not supported output norm `{output_norm}`.") + self.ratio = ratio + self.gate_fn = nn.functional.silu + self.router_weight = nn.Parameter(torch.empty((self.num_heads,self.ratio,self.head_v_dim))) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.d_conv = 4 + self.conv1d = nn.Conv1d( + in_channels=self.hidden_size, + out_channels=self.hidden_size, + bias=False, + kernel_size=self.d_conv, + groups=self.hidden_size, + padding=self.d_conv - 1, + # **factory_kwargs, + ) + self.reset_parameters() + def reset_parameters(self) -> None: + import torch.nn.init as init + init.kaiming_uniform_(self.router_weight, a=math.sqrt(5)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5) + if self.use_gate: + nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5) + nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -2.5) + + + def forward(self, hidden_state,seqlen_offset = None): + x = hidden_state + # x = x.transpose(0, 1).contiguous() + b,l,d = x.shape + x = rearrange(x, 'b l d -> b d l').contiguous() + if self.training: + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias.to(self.precision) + if self.conv1d.bias is not None + else self.conv1d.bias, + activation="silu", + ) + elif conv_states is None: + conv_states = nn.functional.pad( + x, (self.d_conv - x.shape[-1], 0) + ) + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias.to(self.precision) + if self.conv1d.bias is not None + else self.conv1d.bias, + activation="silu", + ) + else: + x = causal_conv1d_update( + x, + conv_states, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias.to(self.precision) + if self.conv1d.bias is not None + else self.conv1d.bias, + activation="silu", + ) + x = x + x = rearrange(x, 'b d l -> b l d').contiguous() + q,_ = (self.q_proj(x)) #query_q(b l dk) + q = self.gate_fn(q) + k,_ = self.k_proj(x) #get k(b l dk) + v,_ = self.v_proj(x) #b l 2*self.head_dv + g,_ = self.g_proj(x) #all get b l d + q = rearrange(q, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() + k = rearrange(k, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() + v = rearrange(v, 'b l (h d) -> b h l d', h = self.num_heads).contiguous() + output,k_f,s_f = self.gated_linear_attention(q, k, v,k_f,s_f) + output = rearrange(output,'b h l d -> b l h d') + output = self.norm(output) + + output = self.gate_fn(g) * (output.view(b,l,d)) + output,_ = self.o_proj(output) + # output = output.transpose(0, 1) + return output,k_f,s_f,conv_states + + def gated_linear_attention(self,q, k, v, past_sum=None,past_state = None): + '''torch qk version''' + b,h,l,d = v.shape #b h l d + dk = q.shape[-1] # h d r + logits = torch.matmul(v,self.router_weight)#get b h l r' + scores = logits.softmax(dim=-1) + topk_score , topk_idx = torch.topk(scores,k = self.top_k,dim=-1,sorted=False)#get b,h,l,top_k + if self.top_k>1:#norm + sum_score = topk_score.sum(dim=-1,keepdim=True)+1e-20 + topk_score = topk_score/sum_score + #到这都类似 + masked_scores = torch.zeros_like(scores,device=q.device) + masked_scores.scatter_(-1, topk_idx, topk_score) + masked_idx = masked_scores.bool() + if self.training: + k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx) + router_weight_qk = torch.cumsum(k_exp0,dim=-3) + k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores) + norm_k = (l2norm(router_weight_qk)) + qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) #bhlr #bhlr + q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit) + q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') + k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') + qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) + qk = qk.tril(diagonal=0) + o_moe = qk@v + return o_moe,None,None + else: + if past_sum == None: + k_final = torch.zeros([b,h,self.ratio,dk]).to(q) + else: + k_final = past_sum #bhrd + k_exp0 = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_idx) + router_weight_qk = torch.cumsum(k_exp0,dim=-3)+k_final.unsqueeze(-3) #bhlrd + norm_k = (l2norm(router_weight_qk)) + k_exp = torch.einsum('b h l d, b h l r-> b h l r d',k,masked_scores)#bhlrd + if past_state==None: + s_final = torch.zeros([b,h,self.ratio,dk,d]).to(q)#bhr dk d + else: + s_final = past_state + qlogit = torch.einsum('b h l d, b h l r d-> b h l r',q,norm_k).softmax(dim=-1) #bhlr #bhlr + q_exp = torch.einsum('b h l d, b h l r-> b h l r d',q,qlogit) + k_transexp = rearrange(k_exp,'b h l r d-> b h r d l') + final_state = s_final + k_transexp@(v.unsqueeze(-3))#b h r dk d + if past_state == None: + q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') + k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') + qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) + qk = qk.tril(diagonal=0) + o_moe = qk@v + else: + o_moe = torch.einsum('b h l r k,b h r k d-> b h l d',q_exp,s_final)* (dk**-0.5) + q_exp = rearrange(q_exp,'b h l r d -> b h l (r d)') + k_exp = rearrange(k_exp,'b h l r d -> b h l (r d)') + qk = q_exp @ k_exp.transpose(-1,-2) * (dk**-0.5) + qk = qk.tril(diagonal=0) + o_moe += qk@v + return o_moe,router_weight_qk[:,:,-1,:,:],final_state + diff --git a/fla3/layers/forgetting_attn.py b/fla3/layers/forgetting_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..add32d96ffb3b0a5d6394ca156d2f67b2f8ab994 --- /dev/null +++ b/fla3/layers/forgetting_attn.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.layers.utils import pad_input, unpad_input +from fla.modules import GroupNorm +from fla.ops.attn.decoding import attn_decoding_one_step +from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn + +if TYPE_CHECKING: + from fla.models.utils import Cache + +logger = logging.get_logger(__name__) + + +class ForgettingAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + use_output_gate: bool = False, + layer_idx: int = None + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.use_output_gate = use_output_gate + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) + + if use_output_gate: + self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = GroupNorm( + num_groups=self.num_heads, + hidden_size=self.hidden_size, + is_rms_norm=True, + ) + self.k_norm = GroupNorm( + num_groups=self.num_kv_heads, + hidden_size=self.kv_dim, + is_rms_norm=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.size() + + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + f = F.logsigmoid(self.f_proj(hidden_states).float()) + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + cu_seqlens = kwargs.get('cu_seqlens', None) + if past_key_values is not None: + assert cu_seqlens is None, "cu_seqlens should not be provided when past_key_values is not None" + state = past_key_values.update( + attn_state=(k, v, f), + layer_idx=self.layer_idx, + offset=q_len, + cache_kwargs=dict(window_size=self.window_size) + ) + k, v, f = state['attn_state'] + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if attention_mask is not None: + q, (k, v, f), indices_q, cu_seqlens, max_seq_lens = unpad_input(q, (k, v, f), attention_mask, q_len, keepdim=True) + _, cu_seqlens_k = cu_seqlens + cu_seqlens = cu_seqlens_k + max_seqlen_q, max_seqlen_k = max_seq_lens + if max_seqlen_q != max_seqlen_k: + assert max_seqlen_q == 1, "only support q_len == 1 for decoding" + o = attn_decoding_one_step(q, k, v, f, cu_seqlens=cu_seqlens) + else: + o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens) + else: + o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices_q, batch_size, q_len) + o = rearrange(o, '... h d -> ... (h d)') + if self.use_output_gate: + o = self.g_proj(hidden_states).sigmoid() * o + o = self.o_proj(o) + return o, None, past_key_values diff --git a/fla3/layers/gated_deltanet.py b/fla3/layers/gated_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9003fbaf157295fcf53fee9a8d32eccda973976 --- /dev/null +++ b/fla3/layers/gated_deltanet.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from ..layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from ..ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from ..models.utils import Cache + + +@torch.compile +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +@torch.compile +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class GatedDeltaNet(nn.Module): + """ + The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa + + Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. + + Parameter alloation when use_gate=True: + - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each + - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each + - Others are ignorably small. + - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size + NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. + + Parameter allocation when use_gate=False: + - 1 * hidden_size * hidden_size for the q_proj and k_proj each + - 2 * hidden_size * hidden_size for the v_proj and o_proj each + - Others are ignorably small. + - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size + + Args: + hidden_size (int, Optional): + The hidden size of the input. Default: 2048. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + head_dim (int, Optional): + The dimension of each head. Default: 256. + num_heads (int, Optional): + The number of heads. Default: 4. + mode (str, Optional): + Which Gated DeltaNet kernel to use. + Currently available: `chunk` and `fused_recurrent`. + Default: `chunk`. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `True`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the normalization layer. Default: 1e-5. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + mode: str = 'chunk', + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + **kwargs + ) -> GatedDeltaNet: + super().__init__() + + self.mode = mode + + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + + self.key_dim = int(self.num_heads * self.head_dim) + self.value_dim = int(self.key_dim * self.expand_v) + self.head_k_dim = head_dim + self.head_v_dim = int(head_dim * self.expand_v) + self.layer_idx = layer_idx + + # Consistency check: Ensure expand_v produces integer values + if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. " + f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear." + ) + if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. " + f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated." + ) + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + print('gggggated_deltanet') + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + # change to inference mode. + mode = 'fused_recurrent' if q_len <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = F.silu(self.q_proj(hidden_states)) + k = F.silu(self.k_proj(hidden_states)) + v = F.silu(self.v_proj(hidden_states)) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + beta = self.b_proj(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/gated_deltaproduct.py b/fla3/layers/gated_deltaproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..a07af943018ea473b4e1482f6b073c15af9460d4 --- /dev/null +++ b/fla3/layers/gated_deltaproduct.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.ops.delta_rule import chunk_delta_rule +from fla.ops.gated_delta_rule import chunk_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +def elu_p1(x): + return (F.elu(x, 1.0, False) + 1.0).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +def interleave_multiple_sequences(*sequences): + """ + Interleave multiple sequences together. + For example, with sequences [A1, A2], [B1, B2], [C1, C2], + returns [A1, B1, C1, A2, B2, C2] + """ + if isinstance(sequences[0], (list, tuple)): + sequences = sequences[0] + + if len(sequences) == 1: + return sequences[0] + + # All sequences should have the same shape + assert all(s.shape == sequences[0].shape for s in sequences) + + # Get the original shape + batch_size, seq_len, *rest = sequences[0].shape + + # Stack sequences along a new dimension + stacked = torch.stack(sequences, dim=2) + + # Reshape to interleave + reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest) + + return reshaped + + +class GatedDeltaProduct(nn.Module): + """ + Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + num_householder: int = 2, # New parameter for number of householder transformations + mode: str = "chunk", + use_gate: bool = True, + use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int | None = None, + norm_eps: float = 1e-5, + allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1] + **kwargs, + ) -> None: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.head_dim = head_dim + self.num_heads = num_heads + self.num_householder = num_householder + self.allow_neg_eigval = allow_neg_eigval + self.use_forget_gate = use_forget_gate + self.key_dim = self.num_heads * self.head_dim + self.value_dim = int(self.key_dim * self.expand_v) + self.head_qk_dim = head_dim + self.head_v_dim = int(head_dim * self.expand_v) + self.layer_idx = layer_idx + self.silu = nn.SiLU() + assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`." + # Create multiple projection layers for each householder transformation + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + + self.k_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.key_dim, bias=False) + for _ in range(num_householder) + ] + ) + self.v_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.value_dim, bias=False) + for _ in range(num_householder) + ] + ) + self.b_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.num_heads, bias=False) + for _ in range(num_householder) + ] + ) + if use_short_conv: + self.q_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + self.k_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + self.v_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + + if self.use_forget_gate: + self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # Initialize dt parameters + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_weight_decay = True + + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.k_id = torch.nn.Identity() + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding)." + ) + + mode = ( + "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + ) + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + # Process each householder transformation + ks, vs, betas = [], [], [] + conv_states = [] + + for i in range(self.num_householder): + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][ + i + ] + conv_mask = ( + attention_mask[:, -hidden_states.shape[1]:] + if attention_mask is not None + else None + ) + + k, conv_state_k = self.k_conv1ds[i]( + x=self.k_projs[i](hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + ) + v, conv_state_v = self.v_conv1ds[i]( + x=self.v_projs[i](hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + ) + conv_states.append((conv_state_q, conv_state_k, conv_state_v)) + else: + k = self.silu(self.k_projs[i](hidden_states)) + v = self.silu(self.v_projs[i](hidden_states)) + + ks.append(k) + vs.append(v) + + beta = self.b_projs[i]( + hidden_states + ).sigmoid() # bs, sequence_length, num_heads + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None]) + if self.allow_neg_eigval: + beta = beta * 2 + betas.append(beta) + + if self.use_short_conv: + q, conv_state_q = self.q_conv1ds[0]( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + ) + else: + q = self.silu(self.q_proj(hidden_states)) + q = interleave_multiple_sequences( + [torch.zeros_like(q)] * (self.num_householder - 1) + [q] + ) + # Interleave all sequences + k = interleave_multiple_sequences(ks) + v = interleave_multiple_sequences(vs) + beta = interleave_multiple_sequences(betas) + + q, k, v = ( + rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v) + ) + + recurrent_state = ( + last_state["recurrent_state"] if last_state is not None else None + ) + offsets = kwargs.get("offsets") + + if mode == "chunk": + if self.use_forget_gate: + g = -self.A_log.float().exp() * F.softplus( + self.a_proj(hidden_states).float() + self.dt_bias + ) + if attention_mask is not None: + g = g.mul(attention_mask[:, -g.shape[-2]:, None]) + + # Interleave g with zeros for non-first transformations + g = interleave_multiple_sequences( + [g] + [torch.zeros_like(g)] * (self.num_householder - 1) + ) + + o, recurrent_state = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=offsets, + use_qk_l2norm_in_kernel=True + ) + else: + o, recurrent_state = chunk_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=offsets, + use_qk_l2norm_in_kernel=True + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + # Take every nth element for n householder transformations + o = o[:, self.num_householder - 1:: self.num_householder, :] + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=conv_states if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[2], + ) + + if self.use_gate: + g = rearrange( + self.g_proj(hidden_states), + "... (h d) -> ... h d", + h=self.num_heads, + ) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla3/layers/gla.py b/fla3/layers/gla.py new file mode 100644 index 0000000000000000000000000000000000000000..31dfc9b5ab51de5a6f5532f697db59dfb39606b0 --- /dev/null +++ b/fla3/layers/gla.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules.activations import ACT2FN +from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class GatedLinearAttention(nn.Module): + r""" + The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa + + Args: + mode (str, Optional): + Which GLA kernel to use. + Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 0.5. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + use_output_gate (bool, Optional): + Whether to use output gate. Default: `True`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + gate_logit_normalizer (int, Optional): + The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. + gate_low_rank_dim (int, Optional): + The low rank dim for the gate projection. Default: 16. + clamp_min (float, Optional): + The minimum value for the gate logits. Default: None. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + use_output_gate: bool = True, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 16, + gate_low_rank_dim: int = 16, + clamp_min: Optional[float] = None, + fuse_norm: bool = True, + layer_idx: int = None, + ) -> GatedLinearAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.use_output_gate = use_output_gate + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.clamp_min = clamp_min + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True)) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm and use_output_gate: + self.g_norm_swish_gate = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.gate_fn = ACT2FN[gate_fn] + + self.gate_logit_normalizer = gate_logit_normalizer + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + gk = self.gk_proj(hidden_states) + + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + if self.num_kv_groups > 1: + k, gk = (repeat(x, '... (h d) -> ... (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk)) + v = repeat(v, '... (h d) -> ... (h g) d', g=self.num_kv_groups, d=self.head_v_dim) + else: + k, gk = (rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim) for x in (k, gk)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + gk = F.logsigmoid(gk) / self.gate_logit_normalizer + + if self.clamp_min is not None: + gk = torch.clamp_min(gk, self.clamp_min) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla( + q=q, + k=k, + v=v, + gk=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_gla( + q=q, + k=k, + v=v, + g=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gla( + q=q, + k=k, + v=v, + g=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + if self.use_output_gate: + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, '... (h d) -> ... h d', d=self.head_v_dim) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, '... h d -> ... (h d)') + else: + o = rearrange(self.g_norm(o), '... h d -> ... (h d)') + o = o * self.gate_fn(g) + else: + o = rearrange(self.g_norm(o), '... h d -> ... (h d)') + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/gsa.py b/fla3/layers/gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..ae71835711192b22fab7b414643ea65c5b243c5f --- /dev/null +++ b/fla3/layers/gsa.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap +from fla.modules.layernorm import rms_norm_linear +from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class GatedSlotAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1., + expand_v: float = 1., + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 8, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + layer_idx: Optional[int] = None, + scale: Optional[float] = 1., + **kwargs + ) -> GatedSlotAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.gate_logit_normalizer = gate_logit_normalizer + + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.scale = scale + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.register_module('feature_map', None) + if feature_map == 'swish': + self.feature_map = SwishFeatureMap() + elif feature_map == 'relu': + self.feature_map = ReLUFeatureMap() + elif feature_map == 't2r': + self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim) + else: + raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.") + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False) + self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + f = self.f_proj(hidden_states) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + f = rearrange(f, '... (h m) -> ... h m', m=self.num_slots) + + if self.feature_map is not None: + q, k = map(lambda x: self.feature_map(x), (q, k)) + v = F.silu(v) + + f = F.logsigmoid(f) / self.gate_logit_normalizer + s = (1 - f.exp()).to(f.dtype) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gsa( + q=q, + k=k, + v=v, + s=s, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + scale=self.scale, + cu_seqlens=cu_seqlens, + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gsa( + q=q, + k=k, + v=v, + s=s, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + scale=self.scale, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + o = rearrange(o, '... h d -> ... (h d)') + o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/hgrn.py b/fla3/layers/hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfa0f946f9b47ec2315605554e2362acef6cd56 --- /dev/null +++ b/fla3/layers/hgrn.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823] + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fla.modules import FusedRMSNormGated, ShortConvolution +from fla.modules.activations import swiglu +from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class HGRNAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRNAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_ratio = expand_ratio + self.input_dim = int(hidden_size * expand_ratio) + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + + self.g_norm = FusedRMSNormGated( + hidden_size=self.input_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_i, conv_state_f = None, None + if last_state is not None: + conv_state_i, conv_state_f = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + i, conv_state_i = self.i_conv1d( + x=self.i_proj(hidden_states), + mask=conv_mask, + cache=conv_state_i, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + f, conv_state_f = self.f_conv1d( + x=self.f_proj(hidden_states), + mask=conv_mask, + cache=conv_state_f, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + i = self.i_proj(hidden_states) + f = self.f_proj(hidden_states) + + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + i, f = swiglu(i, 1 - g), g.log() + + # dealing with left-padding + if attention_mask is not None: + i = i.mul_(attention_mask[:, -i.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + if cu_seqlens is not None: + raise NotImplementedError("Chunk mode does not support variable-length sequences.") + o, recurrent_state = chunk_hgrn( + x=i, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_hgrn( + x=i, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=i.shape[2] + ) + + o = self.g_norm(o, self.g_proj(hidden_states)) + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.hidden_size + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla3/layers/hgrn2.py b/fla3/layers/hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..45be0706b2f14e7645960f31adf3759afc3e48df --- /dev/null +++ b/fla3/layers/hgrn2.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.activations import swish +from fla.modules.layernorm import rms_norm_linear +from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class HGRN2Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRN2Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + + if expand_ratio is not None: + num_heads = hidden_size // expand_ratio + elif expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is None and num_heads is None: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.forget_dim = int(self.num_heads * self.expand_ratio) + self.input_dim = hidden_size + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}" + assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}" + + self.head_f_dim = self.expand_ratio + self.head_i_dim = self.hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None) + self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None) + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + + self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_f, conv_state_i = None, None, None + if last_state is not None: + conv_state_q, conv_state_f, conv_state_i = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + f, conv_state_f = self.f_conv1d( + x=self.f_proj(hidden_states), + cache=conv_state_f, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + i, conv_state_i = self.i_conv1d( + x=self.i_proj(hidden_states), + cache=conv_state_i, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + f = self.f_proj(hidden_states) + i = self.i_proj(hidden_states) + + q = swish(q) + + # improve precision + f = f.float() + + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + k, g = 1 - f.sigmoid(), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + k, g = 1 - g, g.log() + + q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g)) + i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla( + q=q, + k=k, + v=i, + gk=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_gla( + q=q, + k=k, + v=i, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gla( + q=q, + k=k, + v=i, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + o = rearrange(o, '... h d -> ... (h d)') + o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/lightnet.py b/fla3/layers/lightnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2e89fa5249b62f83fe2076cdba13f8018a0016e5 --- /dev/null +++ b/fla3/layers/lightnet.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022) + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import FusedRMSNormGated, ShortConvolution +from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear +from fla.ops.gla import chunk_gla, fused_recurrent_gla + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class LightNetAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + gate_low_rank_dim: int = 128, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> LightNetAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + + if expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is not None and num_heads is None: + num_heads = hidden_size // expand_ratio + elif expand_ratio is None and num_heads is None: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(self.num_heads * self.expand_ratio) + self.value_dim = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_f_dim = self.expand_ratio + self.head_i_dim = self.hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None) + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None) + + self.g_proj = nn.Sequential( + nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, hidden_size, bias=False) + ) + self.g_norm = FusedRMSNormGated( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + + q = F.silu(q) + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim) + # TODO: this 2 steps took huge amount of time, which should be optimized + z = k.float().logcumsumexp(1) + + if cu_seqlens is not None: + raise NotImplementedError("LightNet does not support variable-length sequences for now.") + k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla( + q=q, + k=k, + v=v, + gk=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gla( + q=q, + k=k, + v=v, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + o = rms_norm_swish_gate_linear( + rearrange(o, 'b t h d -> b t (h d)'), + self.g_proj(hidden_states), + self.g_norm.weight, + self.g_norm.bias, + self.o_proj.weight, + self.o_proj.bias + ) + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_i_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla3/layers/linear_attn.py b/fla3/layers/linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a6479ce7bbd57ee6d2df5aa98328654e1ec958 --- /dev/null +++ b/fla3/layers/linear_attn.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.modules import RMSNorm +from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn + + +class LinearAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: str = 1024, + expand_k: int = 1.0, + expand_v: int = 1.0, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: str = 'elementwise_product', + tie_feature_map_qk: bool = False, + output_norm: str = 'rmsnorm', + norm_q: bool = False, + norm_k: bool = False, + do_feature_map_norm: bool = False, + elementwise_affine: bool = True, + norm_eps: float = 1e-5, + **kwargs + ): + super().__init__() + + self.hidden_size = hidden_size + self.mode = mode + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.do_feature_map_norm = do_feature_map_norm + + if feature_map == 'hedgehog': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 't2r': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'elementwise_product': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'dpfp': + self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'elu': + def elu(x): + return F.elu(x) + 1 + self.feature_map_q = elu + self.feature_map_k = elu + + elif feature_map == 'relu': + self.feature_map_q = nn.ReLU() + self.feature_map_k = nn.ReLU() + + elif feature_map == 'identity': + self.feature_map_q = nn.Identity() + self.feature_map_k = nn.Identity() + else: + raise NotImplementedError(f"Not supported feature map `{feature_map}`.") + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + + if output_norm == 'rmsnorm': + self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) + elif output_norm == 'identity': + self.norm = nn.Identity() + else: + raise NotImplementedError(f"Not supported output norm `{output_norm}`.") + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.norm_q = norm_q + self.norm_k = norm_k + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs + ) -> torch.Tensor: + mode = self.mode + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + if self.num_kv_groups > 1: + k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups) + v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups) + else: + k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + + q = self.feature_map_q(q) + k = self.feature_map_k(k) + + if self.norm_q: + q = q / (q.sum(-1, True) + 1e-4) + if self.norm_k: + k = k / (k.sum(-1, True) + 1e-4) + + if mode == 'chunk': + o, final_state = chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + ) + elif mode == 'fused_chunk': + o, final_state = fused_chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + ) + elif mode == 'fused_recurrent': + o, final_state = fused_recurrent_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + ) + else: + raise NotImplementedError + o = self.norm(o) + o = rearrange(o, '... h d -> ... (h d)') + o = self.o_proj(o) + return o diff --git a/fla3/layers/mamba.py b/fla3/layers/mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..55879bbcb0478535cbdb7f8720b13336d5c3b09c --- /dev/null +++ b/fla3/layers/mamba.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Dict, Optional + +import torch +import torch.nn as nn +from transformers.utils import logging + +from fla.modules.activations import ACT2FN + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + try: + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + except ImportError: + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + + try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + is_fast_path_available = all(( + selective_state_update, + selective_scan_fn, + causal_conv1d_fn, + causal_conv1d_update, + mamba_inner_fn + )) +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.mamba.modeling_mamba import MambaCache + +logger = logging.get_logger(__name__) + + +class Mamba(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__( + self, + hidden_size: int = 2048, + state_size: int = 16, + conv_kernel: int = 4, + use_conv_bias: bool = True, + intermediate_size: int = 2048, + time_step_rank: int = 256, + use_bias: bool = True, + hidden_act: str = "silu", + layer_idx: int = None, + ): + super().__init__() + + self.hidden_size = hidden_size + self.ssm_state_size = state_size + self.conv_kernel_size = conv_kernel + self.use_conv_bias = use_conv_bias + self.intermediate_size = intermediate_size + self.time_step_rank = time_step_rank + self.use_bias = use_bias + + self.conv1d = nn.Conv1d( + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, + bias=use_conv_bias, + kernel_size=conv_kernel, + groups=self.intermediate_size, + padding=conv_kernel - 1, + ) + + self.activation = hidden_act + self.act = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + # projection of the input hidden states + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=use_bias) + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) + # time step projection (discretization) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + + self.A_log = nn.Parameter(torch.log(A)) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of " + "`(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs: Unpack[Dict] + ): + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1, 2) + + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias if self.use_conv_bias else None, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias.float() if self.use_bias else None, + -torch.exp(self.A_log.float()), + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if cache_params is not None and cache_position[0] > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), + cache_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.update_conv_state(self.layer_idx, conv_states, cache_position) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if cache_params is not None and cache_position[0] > 0: + scan_outputs = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(self.layer_idx, ssm_state) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states + + def slow_forward( + self, + input_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs: Unpack[Dict] + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # 1. Gated MLP's linear projection + # [batch, 2 * intermediate_size, seq_len] + projected_states = self.in_proj(input_states).transpose(1, 2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 2. Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + # use `cache_position.shape[0]` to check whether we are in prefill + # stage, it's equivalent to check `cache_position[0] == 0`, which + # breaks dynamo fullgraph constraints + if cache_position.shape[0] == self.conv_kernel_size: + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + + cache_params.update_conv_state(self.layer_idx, conv_state, cache_position) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + else: + conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + # [batch, intermediate_size, 1] : decoding + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) + else: + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + # [batch, intermediate_size, seq_len] + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(1) + + # 3. State Space Model sequence transformation + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + # [batch, seq_len, intermediate_size] + discrete_time_step = self.dt_proj(time_step) + # [batch, intermediate_size, seq_len] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + # [intermediate_size, ssm_state_size] + A = -torch.exp(self.A_log.float()) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) + # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + scan_outputs = [] + for i in range(seq_len): + # [batch, intermediade_size, ssm_state] + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] + # [batch, intermediade_size, 1] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) + scan_outputs.append(scan_output[:, :, 0]) + # [batch, seq_len, intermediade_size] + scan_output = torch.stack(scan_outputs, dim=-1) + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + # 4. Final linear projection + # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **kwargs: Unpack[Dict] + ): + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, **kwargs) + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask, **kwargs) diff --git a/fla3/layers/mamba2.py b/fla3/layers/mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..9b46ecf92e89b818c751e31e2f3ad91b3423d168 --- /dev/null +++ b/fla3/layers/mamba2.py @@ -0,0 +1,586 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from transformers.utils import logging + +from fla.modules.activations import ACT2FN +from fla.modules.layernorm_gated import RMSNormGated + +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + try: + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined + except ImportError: + selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None + try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + except ImportError: + causal_conv1d_update, causal_conv1d_fn = None, None + is_fast_path_available = all(( + selective_state_update, + causal_conv1d_fn, + causal_conv1d_update + )) + +if TYPE_CHECKING: + from fla.models.mamba2.modeling_mamba2 import Mamba2Cache + +logger = logging.get_logger(__name__) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> + # [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Mamba2(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__( + self, + num_heads: int, + head_dim: int = 64, + hidden_size: int = 2048, + state_size: int = 128, + expand: int = 2, + n_groups: int = 1, + conv_kernel: int = 4, + use_conv_bias: bool = False, + hidden_act: str = "silu", + rms_norm: bool = True, + chunk_size: int = 256, + time_step_rank: float = 256, + time_step_limit: Tuple[float, float] = (0.0, float("inf")), + time_step_min: float = 0.001, + time_step_max: float = 0.1, + use_bias: bool = True, + norm_eps: float = 1e-5, + layer_idx: int = None, + ) -> Mamba2: + super().__init__() + + self.num_heads = num_heads + self.head_dim = head_dim + self.hidden_size = hidden_size + self.ssm_state_size = state_size + self.expand = expand + self.intermediate_size = int(expand * hidden_size) + self.n_groups = n_groups + + self.conv_kernel_size = conv_kernel + self.use_conv_bias = use_conv_bias + self.activation = hidden_act + self.act = ACT2FN[hidden_act] + + self.rms_norm = rms_norm + self.norm_eps = norm_eps + + self.chunk_size = chunk_size + + self.time_step_rank = int(time_step_rank) + self.time_step_limit = time_step_limit + self.time_step_min = time_step_min + self.time_step_max = time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=use_conv_bias, + kernel_size=conv_kernel, + groups=self.conv_dim, + padding=conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = RMSNormGated( + self.intermediate_size, eps=self.norm_eps, norm_before_gate=False + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=use_bias) + self.use_bias = use_bias + + self.layer_idx = layer_idx + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because one of " + "`(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. " + "Falling back to the naive implementation. " + "To install follow https://github.com/state-spaces/mamba/#installation and" + "https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [ + self.intermediate_size, + groups_time_state_size, + groups_time_state_size, + ], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.eps, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward( + self, + input_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - + 2 * self.n_groups * self.ssm_state_size - self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + # Shape: [b*h, d, n] + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + # shape: (b, c, l, s, h, n) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) diff --git a/fla3/layers/multiscale_retention.py b/fla3/layers/multiscale_retention.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3973a3b8e6723af25faa79a666d0b21ab25fb3 --- /dev/null +++ b/fla3/layers/multiscale_retention.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from transformers.activations import ACT2FN + +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules.rotary import RotaryEmbedding +from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention +from fla.ops.utils.index import prepare_lens_from_mask + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class MultiScaleRetention(nn.Module): + r""" + The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa + + Args: + mode (str, Optional): + Which Retention kernel to use. + Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + num_heads (int, Optional): + The number of heads. Default: 8. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + use_output_gate (bool, Optional): + Whether to use output gate. Default: `True`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 2.0, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + use_output_gate: bool = True, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + fuse_norm: bool = True, + layer_idx: int = None, + **kwargs + ) -> MultiScaleRetention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.use_output_gate = use_output_gate + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm and use_output_gate: + self.g_norm_swish_gate = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.gate_fn = ACT2FN[gate_fn] + + # TODO: fix this issue + # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180 + # Ideally, we would want to support arbitrary d_head_qk + assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256" + self.rotary = RotaryEmbedding(dim=self.head_k_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.shape + mode = 'fused_recurrent' if q_len <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if attention_mask is not None: + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + hidden_states = index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + + seqlen_offset, max_seqlen = 0, q_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q_len + seqlen_offset + + if attention_mask is not None and seqlen_offset > 0: + # to deliminate the offsets of padding tokens + seqlen_offset = prepare_lens_from_mask(attention_mask) - q_len + max_seqlen = q_len + seqlen_offset.max().item() + + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if self.num_kv_groups > 1: + k = repeat(k, '... h d -> ... (h g) d', g=self.num_kv_groups) + v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups) + else: + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'parallel': + o, recurrent_state = parallel_retention( + q=q, + k=k, + v=v, + cu_seqlens=cu_seqlens, + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q_len + ) + + if self.use_output_gate: + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, '... (h d) -> ... h d', d=self.head_v_dim) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, '... h d -> ... (h d)') + else: + o = rearrange(self.g_norm(o), '... h d -> ... (h d)') + o = o * self.gate_fn(g) + else: + o = rearrange(self.g_norm(o), '... h d -> ... (h d)') + o = self.o_proj(o) + if attention_mask is not None: + o = pad_input(o.squeeze(0), indices, batch_size, q_len) + + return o, None, past_key_values diff --git a/fla3/layers/nsa.py b/fla3/layers/nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..ae38498465a1a999fdcf5973066010bd804ebee2 --- /dev/null +++ b/fla3/layers/nsa.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RotaryEmbedding +from fla.ops.nsa.parallel import parallel_nsa +from fla.ops.utils.index import prepare_lens_from_mask + +if TYPE_CHECKING: + from fla.models.utils import Cache + +logger = logging.get_logger(__name__) + + +class NativeSparseAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 64, + num_kv_heads: Optional[int] = 4, + head_dim: int = 64, + qkv_bias: bool = False, + block_size: Optional[int] = 64, + block_counts: Optional[Union[torch.LongTensor, int]] = 16, + window_size: Optional[int] = 512, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + layer_idx: int = None + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + + self.block_size = block_size + self.block_counts = block_counts + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, seq_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3) + g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1) + + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, seq_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + prepare_lens_from_mask(attention_mask) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=seq_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + o = parallel_nsa( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + g_slc=g_slc, + g_swa=g_swa, + block_size=self.block_size, + block_counts=self.block_counts, + window_size=self.window_size, + cu_seqlens=cu_seqlens, + ) + o = o.reshape(batch_size, seq_len, -1) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values diff --git a/fla3/layers/path_attn.py b/fla3/layers/path_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..5a31a0843f525e2ec1279de6b3d44416434433ef --- /dev/null +++ b/fla3/layers/path_attn.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.layers.utils import pad_input, unpad_input +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.l2norm import l2_norm +from fla.ops.attn.decoding import attn_decoding_one_step +from fla.ops.path_attn.parallel import parallel_path_attention + +if TYPE_CHECKING: + from fla.models.utils import Cache + +logger = logging.get_logger(__name__) + + +class PaTHAttention(nn.Module): + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + use_forget_gate: bool = False, + use_qk_norm: bool = False, + use_w_shortconv: bool = True, + layer_idx: int = None, + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + # We use low-rank parameterization for the w_proj to reduce parameters in MHA settings. + if self.num_heads == self.num_kv_heads: + self.w_proj = nn.Sequential( + nn.Linear(self.hidden_size, 32, bias=False), + nn.Linear(32, self.kv_dim, bias=False) + ) + # In MQA/GQA settings, key/value heads are shared, so we use a standard linear projection + # which doesn't introduce too many parameters + else: + self.w_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) + + if use_qk_norm: + self.maybe_q_norm = RMSNorm(self.hidden_size) + self.maybe_k_norm = RMSNorm(self.kv_dim) + else: + self.maybe_q_norm = nn.Identity() + self.maybe_k_norm = nn.Identity() + + if use_w_shortconv: + self.w_conv1d = ShortConvolution(self.kv_dim, 3) + self.use_w_shortconv = use_w_shortconv + self.bt_proj = nn.Linear(self.hidden_size, self.num_kv_heads, bias=True) + self.use_forget_gate = use_forget_gate + if use_forget_gate: + self.g_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if use_cache: + assert past_key_values is not None, "past_key_values must be provided when use_cache is True" + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + batch_size, q_len, _ = hidden_states.size() + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + w = self.w_proj(hidden_states) + beta = self.bt_proj(hidden_states).sigmoid() * 2 # allowing negative eigenvalues + g = F.logsigmoid(self.g_proj(hidden_states).float()) if self.use_forget_gate else None + q, k = self.maybe_q_norm(q), self.maybe_k_norm(k) + cu_seqlens = kwargs.get('cu_seqlens', None) + assert not (cu_seqlens is not None and attention_mask is not None), ( + "cu_seqlens should not be provided when attention_mask is not None" + ) + # Training + if attention_mask is None: + assert use_cache is False, "use_cache should be False in training" + if self.use_w_shortconv: + w, _ = self.w_conv1d(w, cache=None, output_final_state=False, cu_seqlens=cu_seqlens) + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + w = rearrange(w, '... (h d) -> ... h d', d=self.head_dim) + w = l2_norm(w) + o, _ = parallel_path_attention(q=q, k=k, v=v, w=w, beta=beta, g=g, cu_seqlens=cu_seqlens) + + # Prefilling or decoding + else: + assert self.training is False, "attention mask is not supported in training. Please use variable length input." + try: + last_state = past_key_values[self.layer_idx] + except KeyError: + last_state = None + # Decoding + if last_state is not None: + if g is not None: + past_k, past_v, past_g = last_state['attn_state'] + else: + past_k, past_v = last_state['attn_state'] + w_conv_state = last_state['conv_state'] + past_k = rearrange(past_k, '... (h d) -> ... h d', d=self.head_dim) + if self.use_w_shortconv: + w, w_conv_state = self.w_conv1d(w, cache=w_conv_state, output_final_state=use_cache, cu_seqlens=cu_seqlens) + w = rearrange(w, '... (h d) -> ... h d', d=self.head_dim) + w = l2_norm(w) + + def rank_one_update(k, w, beta): + original_dtype = k.dtype + k = k.float() + w = w.float() + beta = beta.float() + k = k - beta[..., None].float() * (k * w).sum(-1, keepdim=True) * w + return k.to(original_dtype) + + past_k = rank_one_update(past_k, w, beta) + past_k = rearrange(past_k, '... h d -> ... (h d)') + k = torch.cat([past_k, k], dim=1) + v = torch.cat([past_v, v], dim=1) + g = torch.cat([past_g, g], dim=1) if g is not None else None + past_key_values[self.layer_idx]['attn_state'] = (k, v, g) if g is not None else (k, v) + past_key_values.update( + conv_state=w_conv_state, + layer_idx=self.layer_idx, + offset=q_len + ) + if g is not None: + q, (k, v, g), indices_q, cu_seqlens, max_seq_lens = unpad_input( + q, (k, v, g), attention_mask, q_len, keepdim=True) + max_seqlen_q, max_seqlen_k = max_seq_lens + else: + q, (k, v), indices_q, cu_seqlens, max_seq_lens = unpad_input( + q, (k, v), attention_mask, q_len, keepdim=True) + max_seqlen_q, max_seqlen_k = max_seq_lens + _, cu_seqlens = cu_seqlens + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + assert max_seqlen_q == 1, "only support q_len == 1 for decoding" + o = attn_decoding_one_step(q, k, v, g, cu_seqlens=cu_seqlens, do_gate_scale=True) # reduced to fox's decoding + # Prefilling + else: + v_cache = v.clone() + g_cache = g.clone() if g is not None else None + if g is None: + q, (k, v, w, beta), indices_q, cu_seqlens, max_seq_lens = unpad_input( + q, (k, v, w, beta), attention_mask, q_len, keepdim=True) + else: + q, (k, v, w, beta, g), indices_q, cu_seqlens, max_seq_lens = unpad_input( + q, (k, v, w, beta, g), attention_mask, q_len, keepdim=True) + max_seqlen_q, max_seqlen_k = max_seq_lens + assert max_seqlen_q == max_seqlen_k, "max_seqlen_q should be equal to max_seqlen_k in prefilling" + _, cu_seqlens = cu_seqlens + if self.use_w_shortconv: + w, w_conv_state = self.w_conv1d(w, cache=None, output_final_state=use_cache, cu_seqlens=cu_seqlens) + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + w = rearrange(w, '... (h d) -> ... h d', d=self.head_dim) + w = l2_norm(w) + o, k_cache = parallel_path_attention(q=q, k=k, v=v, w=w, beta=beta, g=g, + cu_seqlens=cu_seqlens, use_cache=use_cache) + if use_cache: + k_cache = pad_input(k_cache.squeeze(0), indices_q, batch_size, q_len) + k_cache = rearrange(k_cache, '... h d -> ... (h d)') + past_key_values.update( + attn_state=(k_cache, v_cache, g_cache) if g_cache is not None else (k_cache, v_cache), + conv_state=w_conv_state, + layer_idx=self.layer_idx, + offset=q_len + ) + o = pad_input(o.squeeze(0), indices_q, batch_size, q_len) + o = rearrange(o, '... h d -> ... (h d)') + o = self.o_proj(o) + return o, None, past_key_values diff --git a/fla3/layers/rebased.py b/fla3/layers/rebased.py new file mode 100644 index 0000000000000000000000000000000000000000..ebf1c65790051dd8a234ca6ea10df330352f7733 --- /dev/null +++ b/fla3/layers/rebased.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import RebasedFeatureMap +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn +from fla.ops.rebased import parallel_rebased + + +class ReBasedLinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 16, + num_heads: int = 16, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True, + causal: bool = True, + eps: float = 1e-5, + mode: str = "parallel", + layer_idx: Optional[int] = None, + **kwargs + ) -> ReBasedLinearAttention: + super().__init__() + self.hidden_size = hidden_size + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel", 'chunk'] + + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + self.causal = causal + self.eps = eps + self.mode = mode + self.layer_idx = layer_idx + + self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize) + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v]) + q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel')) + if mode == "fused_chunk": + o = fused_chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=True, + scale=1, + ) + elif mode == 'chunk': + o = chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=True, + scale=1, + ) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_rebased( + q=q, + k=k, + v=v, + eps=self.eps, + use_scale=True, + use_normalize=True, + ) + o = self.o_proj(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + def forward_reference( + self, + hidden_states: torch.Tensor, + filters: torch.Tensor = None, + *args, + **kwargs + ): + """ + x (torch.Tensor): tensor of shape (b, d, t) + y (torch.Tensor): tensor of shape (b, d, t) + """ + b, t, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, t, -1, self.feature_dim).transpose(1, 2) + k = k.view(b, t, -1, self.feature_dim).transpose(1, 2) + v = v.view(b, t, -1, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h t d -> b t (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) diff --git a/fla3/layers/rwkv6.py b/fla3/layers/rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3b7aa7f41095fe7307fc6e730547fd009ee61c --- /dev/null +++ b/fla3/layers/rwkv6.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892] + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules import GroupNorm +from fla.modules.activations import ACT2FN +from fla.modules.token_shift import token_shift +from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class RWKV6Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + gate_fn: str = 'swish', + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + fuse_norm: bool = True, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None, + **kwargs + ) -> RWKV6Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.x_proj = nn.Sequential( + LerpLinear(hidden_size, proj_low_rank_dim * 5), + nn.Tanh(), + nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False) + ) + self.x_bias = nn.Parameter(torch.zeros(5, hidden_size)) + + self.r_proj = DDLerpLinear(hidden_size, self.key_dim) + self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim) + self.k_proj = DDLerpLinear(hidden_size, self.key_dim) + self.v_proj = DDLerpLinear(hidden_size, self.value_dim) + self.g_proj = DDLerpLinear(hidden_size, self.value_dim) + self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim)) + + # TODO: fuse GroupNorm and output gate + self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.gate_fn = ACT2FN[gate_fn] + + try: + from transformers.modeling_utils import _init_weights + except ImportError: + _init_weights = True + if _init_weights: + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, nn.Parameter): + nn.init.xavier_uniform_(module, gain=2 ** -2.5) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, seq_len, hidden_size = hidden_states.shape + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + if attention_mask is not None: + hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None]) + cu_seqlens = kwargs.get('cu_seqlens', None) + if hidden_states.shape[1] == 1 and last_state is not None: + shifted = last_state['conv_state'].unsqueeze(1) + delta = shifted - hidden_states + elif last_state is None: + delta = token_shift(hidden_states, cu_seqlens) + else: + shifted = self.time_shift(hidden_states) + shifted[:, 0] = last_state['conv_state'] + delta = shifted - hidden_states + + x = self.x_proj[0](hidden_states, delta, cu_seqlens).view(batch_size, seq_len, -1, self.proj_low_rank_dim) + x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1)) + + r, w, k, v, g = x.add_(self.x_bias).unbind(-2) + r = self.r_proj(hidden_states, r, delta, cu_seqlens) + w = self.w_proj(hidden_states, w, delta, cu_seqlens) + k = self.k_proj(hidden_states, k, delta, cu_seqlens) + v = self.v_proj(hidden_states, v, delta, cu_seqlens) + g = self.g_proj(hidden_states, g, delta, cu_seqlens) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k)) + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + w = -torch.exp(w) + u = self.bonus + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_rwkv6( + r=r, + k=k, + v=v, + w=w, + u=u, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'chunk': + o, recurrent_state = chunk_rwkv6( + r=r, + k=k, + v=v, + w=w, + u=u, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=hidden_states[:, -1], + layer_idx=self.layer_idx, + offset=r.shape[2] + ) + + o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g) + o = self.o_proj(o) + + return o, None, past_key_values + + +class LoRA(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: int, + bias: Optional[bool] = True, + activation: Optional[str] = 'tanh' + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + self.bias = bias + + if activation is None: + self.activation = nn.Identity() + elif activation == 'sigmoid': + self.activation = nn.Sigmoid() + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'relu': + self.activation = nn.ReLU() + else: + raise ValueError(f"Not supported activation `{activation}`.") + + self.lora = nn.Sequential( + nn.Linear(input_dim, low_rank_dim, bias=False), + self.activation, + nn.Linear(low_rank_dim, output_dim, bias=bias) + ) + try: + from transformers.modeling_utils import _init_weights + except ImportError: + _init_weights = True + if _init_weights: + self.apply(self._initialize_weights) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(" + s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}" + if not self.bias: + s += f", bias={self.bias}" + s += ")" + return s + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + + # Initialize weights to zero as in original code + nn.init.zeros_(self.lora[0].weight) + original_dtype = self.lora[2].weight.dtype + shape = self.lora[2].weight.shape + # Convert to float32 for numerical stability in orthogonal init + weight_fp32 = self.lora[2].weight.float() + + # Calculate gain based on dimensions + gain = math.sqrt(shape[1] / shape[0]) if shape[1] > shape[0] else 1 + + # Apply orthogonal initialization with scaling factor 0.1 + nn.init.orthogonal_(weight_fp32, gain=gain * 0.1) + + # Convert back to original dtype + self.lora[2].weight.data.copy_(weight_fp32.to(original_dtype)) + # Set Lora[2] bias to zero + if self.lora[2].bias is not None: + nn.init.zeros_(self.lora[2].bias) + + module._is_hf_initialized = True + + def set_bias_value(self, value): + """Set bias to a specific value (for v0, w0 etc.)""" + if self.bias and self.lora[2].bias is not None: + if isinstance(value, torch.Tensor): + # Handle tensor values + self.lora[2].bias.data.copy_(value.to(self.lora[2].bias.dtype)) + else: + # Handle scalar values + nn.init.constant_(self.lora[2].bias, value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lora(x) + + +class LerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + self.mu = nn.Parameter(torch.zeros(input_dim)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + delta = token_shift(x, cu_seqlens) + return self.linear(x + delta * self.mu) + + +class DDLerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, mu: torch.Tensor, + delta: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + delta = token_shift(x, cu_seqlens) + return self.linear(x + delta * mu) diff --git a/fla3/layers/rwkv7.py b/fla3/layers/rwkv7.py new file mode 100644 index 0000000000000000000000000000000000000000..4bebd2c1fa40b4134b6feafcf49585b1fab46dd6 --- /dev/null +++ b/fla3/layers/rwkv7.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.layers.rwkv6 import LoRA +from fla.modules import GroupNorm +from fla.modules.l2norm import l2_norm +from fla.modules.token_shift import token_shift +from fla.ops.rwkv7 import chunk_rwkv7, fused_mul_recurrent_rwkv7 +from fla.ops.rwkv7.fused_addcmul import fused_addcmul_rwkv7 +from fla.ops.rwkv7.fused_k_update import fused_k_rwkv7 + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class RWKV7Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, + decay_low_rank_dim: int = 64, + gate_low_rank_dim: int = 128, + a_low_rank_dim: int = 64, + v_low_rank_dim: int = 16, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None, + fuse_norm: bool = False, + value_dim: int = None, + num_hidden_layers: int = None, + **kwargs + ) -> RWKV7Attention: + super().__init__() + + self.mode = mode + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`." + self.hidden_size = hidden_size + + self.key_dim = hidden_size + self.value_dim = value_dim if value_dim is not None else hidden_size + if head_dim is None and num_heads is None: + raise ValueError("Either `head_dim` or `num_heads` must be specified.") + elif head_dim is not None: + self.head_dim = head_dim + self.num_heads = int(hidden_size // head_dim) + elif num_heads is not None: + self.head_dim = int(hidden_size // num_heads) + self.num_heads = num_heads + self.head_v_dim = int(self.value_dim // self.num_heads) + + self.decay_low_rank_dim = decay_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.a_low_rank_dim = a_low_rank_dim + self.v_low_rank_dim = v_low_rank_dim + self.layer_idx = layer_idx + self.num_hidden_layers = num_hidden_layers + self.fuse_norm = fuse_norm + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.x_r = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_w = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_k = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_v = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_a = nn.Parameter(torch.zeros(1, 1, hidden_size)) + self.x_g = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + self.k_k = nn.Parameter(torch.zeros(self.key_dim)) + self.k_a = nn.Parameter(torch.zeros(self.key_dim)) + self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim)) + + self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh') + if self.layer_idx != 0: + self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None) + self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None) + self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False) + + if self.fuse_norm: + self.g_norm = GroupNorm( + num_groups=self.num_heads, + hidden_size=self.value_dim, + elementwise_affine=elementwise_affine, + eps=self.head_dim*norm_eps, + bias=True, + ) + else: + self.g_norm = nn.GroupNorm( + num_groups=self.num_heads, + num_channels=self.value_dim, + eps=self.head_dim*norm_eps, + affine=elementwise_affine + ) + + try: + from transformers.modeling_utils import _init_weights + except ImportError: + _init_weights = True + if _init_weights: + self.apply(self._initialize_weights) + for name, module in self.named_modules(): + module._in_rwkv_module = True + + @torch.compiler.disable + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + + # Initialize only when we're processing the RWKV7Attention module itself + if isinstance(module, RWKV7Attention) and self.layer_idx is not None: + ratio_0_to_1 = self.layer_idx / (self.num_hidden_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (self.layer_idx / self.num_hidden_layers) # 1 to ~0 + + # Create position-based initialization tensor + with torch.no_grad(): + ddd = torch.ones(1, 1, self.hidden_size) + for i in range(self.hidden_size): + ddd[0, 0, i] = i / self.hidden_size + + # Initialize x_* parameters directly + self.x_r.data = (1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)).to(self.x_r.dtype) + self.x_w.data = (1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)).to(self.x_w.dtype) + self.x_k.data = (1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1)).to(self.x_k.dtype) + self.x_v.data = (1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1)).to(self.x_v.dtype) + self.x_a.data = (1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)).to(self.x_a.dtype) + self.x_g.data = (1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)).to(self.x_g.dtype) + # Set specific bias values for LoRA modules + # w0 initialization - complex decay speed + decay_speed = torch.ones(self.hidden_size) + for n in range(self.hidden_size): + decay_speed[n] = -7 + 5 * (n / (self.hidden_size - 1)) ** ( + 0.85 + 1.0 * ratio_0_to_1**0.5 + ) + # Initialize k_k, k_a, r_k + nn.init.constant_(self.k_k, 0.85) + nn.init.constant_(self.k_a, 1.0) + nn.init.zeros_(self.r_k) + + self.w_lora.set_bias_value(decay_speed + 0.5) + + # v0 initialization - ones (for non-first layers) + if self.layer_idx != 0: + self.v_lora._initialize_weights(self.v_lora) + self.v_lora.set_bias_value(1.0) + + self.r_proj.weight.data.uniform_(-0.5/(self.hidden_size**0.5), 0.5/(self.hidden_size**0.5)) + self.k_proj.weight.data.uniform_(-0.05/(self.hidden_size**0.5), 0.05/(self.hidden_size**0.5)) + self.v_proj.weight.data.uniform_(-0.5/(self.hidden_size**0.5), 0.5/(self.hidden_size**0.5)) + self.o_proj.weight.data.zero_() + + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + v_first: torch.Tensor = None, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, seq_len, _ = hidden_states.shape + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + if attention_mask is not None: + hidden_states = hidden_states.mul(attention_mask[:, -seq_len:, None]) + cu_seqlens = kwargs.get('cu_seqlens', None) + # delta [batch_size, seq_len, hidden_size] + if last_state is None: + delta = token_shift(hidden_states, cu_seqlens) + recurrent_state = None + elif hidden_states.shape[1] == 1: + shifted = last_state['conv_state'].unsqueeze(1) + delta = shifted - hidden_states + recurrent_state = last_state['recurrent_state'] + else: + shifted = self.time_shift(hidden_states) + shifted[:, 0] = last_state['conv_state'] + delta = shifted - hidden_states + recurrent_state = last_state['recurrent_state'] + + xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(hidden_states, delta, self.x_r, self.x_w, + self.x_k, self.x_v, self.x_a, self.x_g) + + r = self.r_proj(xr) + # Using bf16 for LoRA computation is numerically safe here because: + # 1. After sigmoid activation: + # - Max absolute error (vs float32): 0.003 + # - Mean absolute error: 0.0004 + # 2. Subsequent scaling by -0.6065 will further reduce relative error + # (error scales linearly with constant multiplication) + # 3. Final compounded error remains within acceptable bounds for bf16 precision + # Empirical observation confirms bf16 introduces no practical degradation + w = -0.6065306597126334 * self.w_lora(xw).sigmoid() + + k = self.k_proj(xk) + v = self.v_proj(xv) + + if self.layer_idx == 0: + v_first = v + else: + v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid()) + a = self.a_lora(xa).sigmoid() + g = self.g_lora(xg) + + if self.fuse_norm: + kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim)) + else: + kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0) + + # Prefer addcmul over expanded form for numerical stability in bf16: + # 1. Fused Multiply-Add (FMA) in addcmul reduces intermediate rounding: + # - Single op vs original 3 ops (mul, sub, mul) + # - 1 less intermediate value storage (bf16 write->read overhead) + # 2. Mathematically equivalent to k*(1 + (a-1)*self.k_a) + # but with better precision preservation + # 3. Particularly crucial for bf16 where intermediate values easily lose precision + # 4. Pytorch method: k = k.addcmul(k * (a - 1), self.k_a) + k = fused_k_rwkv7(k, a, self.k_a) + + # dealing with left-padding + if attention_mask is not None: + v = v * attention_mask[:, -seq_len:, None] + + r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a)) + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + + if self.training or seq_len >= 64: + # if training, use chunk mode no matter how short the sequence is + # launching the triton kernel for just one token will actually be slower + o, recurrent_state = chunk_rwkv7( + r=r, + w=w, + k=k, + v=v, + a=-kk, + b=kk * a, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + o, recurrent_state = fused_mul_recurrent_rwkv7( + r=r, + w=w, + k=k, + v=v, + kk=kk, + a=a, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=hidden_states[:, -1], + layer_idx=self.layer_idx, + offset=r.shape[1] + ) + + if self.fuse_norm: + o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) + else: + o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1) + + o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1) + o = self.o_proj(o * g) + + return o, None, past_key_values, v_first diff --git a/fla3/layers/simple_gla.py b/fla3/layers/simple_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..02742efc6bffde7852770bfb8997be5cedd98ea6 --- /dev/null +++ b/fla3/layers/simple_gla.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules.activations import ACT2FN +from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class SimpleGatedLinearAttention(nn.Module): + r""" + The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa + This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise. + + Args: + mode (str, Optional): + Which GLA kernel to use. + Currently available: `chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + gate_logit_normalizer (int, Optional): + The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1., + expand_v: float = 1., + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 16, + fuse_norm: bool = True, + layer_idx: int = None, + ) -> SimpleGatedLinearAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.layer_idx = layer_idx + + assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.gk_proj = nn.Linear(hidden_size, self.num_heads) + + if gate_fn == 'swish' and fuse_norm: + self.g_norm_swish_gate = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.gate_fn = ACT2FN[gate_fn] + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.gate_logit_normalizer = gate_logit_normalizer + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + gk = self.gk_proj(hidden_states) + + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + if self.num_kv_groups > 1: + k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v)) + else: + k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v)) + gk = F.logsigmoid(gk) / self.gate_logit_normalizer + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_simple_gla( + q=q, + k=k, + v=v, + gk=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + gk=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b t h d -> b t (h d)') + else: + o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') + o = o * self.gate_fn(g) + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla3/layers/utils.py b/fla3/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a51a698eab84e551efe71768a591aef3e95cc10 --- /dev/null +++ b/fla3/layers/utils.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Code is adapted from flash-attn.bert_padding.py + +from typing import Tuple + +import torch +from einops import rearrange, repeat + +from ..ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask +from ..utils import tensor_cache + + +class IndexFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, indices): + ctx.save_for_backward(indices) + assert x.ndim >= 2 + ctx.first_axis_dim, other_shape = x.shape[0], x.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return x[indices] + return torch.gather( + rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, do): + (indices,) = ctx.saved_tensors + assert do.ndim >= 2 + other_shape = do.shape[1:] + do = rearrange(do, "b ... -> b (...)") + dx = torch.zeros( + [ctx.first_axis_dim, do.shape[1]], + device=do.device, + dtype=do.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # dx[indices] = do + dx.scatter_(0, repeat(indices, "z -> z d", d=do.shape[1]), do) + return dx.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert x.ndim >= 2 + y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype) + # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + y[indices] = x + # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x) + return y + + @staticmethod + def backward(ctx, do): + (indices,) = ctx.saved_tensors + # TODO [2022-03-04] For some reason torch.gather is a bit faster than indexing. + dx = do[indices] + # dx = torch.gather(do, 0, repeat(indices, 'z -> z d', d=do.shape[1])) + return dx, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +@tensor_cache +def get_unpad_data( + attention_mask: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Args: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is [batch_size + 1]. + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + lens = prepare_lens_from_mask(attention_mask) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = lens.max().item() + cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask) + return indices, cu_seqlens, max_seqlen_in_batch + + +def unpad_input( + q: torch.Tensor, + states: Tuple[torch.Tensor], + attention_mask: torch.Tensor, + q_len: int, + keepdim: bool = False, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens + even though they belong to different batches. + + + Arguments: + q (`torch.Tensor`): + Query state with padding. Shape: [batch_size, q_len, ...]. + states (`Tuple[torch.Tensor]`): + Attention state with padding. Shape: [batch_size, seq_len, ...]. + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape [batch_size, sequence_length], 1 means valid and 0 means not valid. + q_len (`int`): + Target length. + keepdim (`bool`): + Whether to keep the batch dimension. Default: `False`. + + Return: + q (`torch.Tensor`): + Query state without padding. + Shape: [1, total_target_length, ...] if `keepdim=True` else [total_target_length, ...]. + states (`Tuple[torch.Tensor]`): + Attention state without padding. + Shape: [1, total_source_length, ...] if `keepdim=True` else [total_source_length, ...]. + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), + used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is [batch_size + 1]. + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence + i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask) + batch_size, seq_len, *_ = states[0].shape + + state = tuple( + index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k) + for s in states + ) + + if q_len == seq_len: + q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)") + + if keepdim: + q = q.unsqueeze(0) + state = tuple(s.unsqueeze(0) for s in state) + + return ( + q, + state, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def pad_input( + hidden_states: torch.Tensor, + indices: torch.LongTensor, + batch_size: int, + seq_len: int, +) -> torch.Tensor: + """ + Args: + hidden_states ([total_tokens, ...]): + where total_tokens denotes the number of tokens in selected in attention_mask. + indices ([total_tokens]): + the indices that represent the non-masked tokens of the original padded input sequence. + batch_size (int): + batch_size size for the padded sequence. + seq_len (int): + maximum sequence length for the padded sequence. + + Return: + hidden_states of shape [batch_size, seq_len, ...] + """ + output = index_put_first_axis(hidden_states, indices, batch_size * seq_len) + return rearrange(output, "(b s) ... -> b s ...", b=batch_size) diff --git a/fla3/models/__init__.py b/fla3/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa8db6d4b32effb6b6aafb4b97457f7a3f9c6a3 --- /dev/null +++ b/fla3/models/__init__.py @@ -0,0 +1,56 @@ +# # -*- coding: utf-8 -*- + +# from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel +# from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel +from .delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel +# from fla.models.forgetting_transformer import ( +# ForgettingTransformerConfig, +# ForgettingTransformerForCausalLM, +# ForgettingTransformerModel +# ) +from .gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel +# from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel +# from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel +# from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel +# from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel +# from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model +# from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel +# from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel +# from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel +# from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model +# from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel +# from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel +# from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel +# from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model +# from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model +# from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel +# from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel + +# __all__ = [ +# 'ABCConfig', 'ABCForCausalLM', 'ABCModel', +# 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel', +# 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', +# 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel', +# 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel', +# 'GLAConfig', 'GLAForCausalLM', 'GLAModel', +# 'GSAConfig', 'GSAForCausalLM', 'GSAModel', +# 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', +# 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', +# 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel', +# 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', +# 'MambaConfig', 'MambaForCausalLM', 'MambaModel', +# 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model', +# 'NSAConfig', 'NSAForCausalLM', 'NSAModel', +# 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', +# 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', +# 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model', +# 'SambaConfig', 'SambaForCausalLM', 'SambaModel', +# 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel', +# 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel', +# 'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel', +# ] +# from .emla import emlaConfig,emlaForCausalLM,emlaModel +# from .emgla import emglaConfig,emglaForCausalLM,emglaModel +from .emdeltanet import emdeltanetConfig,emdeltanetForCausalLM,emdeltanetModel +# from .transformer import TransformerConfig,TransformerForCausalLM,TransformerModel + diff --git a/fla3/models/__pycache__/__init__.cpython-310.pyc b/fla3/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dabd5109c441fa922f6daa712138f63099d27e1c Binary files /dev/null and b/fla3/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/__pycache__/__init__.cpython-312.pyc b/fla3/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b6f70f4fe47bcce14481567397320c0bc908c21 Binary files /dev/null and b/fla3/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/models/__pycache__/utils.cpython-310.pyc b/fla3/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca56b4a5f5a1e271da41ec9bec8769b45c5edf9 Binary files /dev/null and b/fla3/models/__pycache__/utils.cpython-310.pyc differ diff --git a/fla3/models/__pycache__/utils.cpython-312.pyc b/fla3/models/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..453a6986d26fd8492e80274490f20a824e81e673 Binary files /dev/null and b/fla3/models/__pycache__/utils.cpython-312.pyc differ diff --git a/fla3/models/abc/__init__.py b/fla3/models/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7021f22ff0f9781432bd3969473520851f4b553 --- /dev/null +++ b/fla3/models/abc/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.abc.configuration_abc import ABCConfig +from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel + +AutoConfig.register(ABCConfig.model_type, ABCConfig) +AutoModel.register(ABCConfig, ABCModel) +AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) + + +__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] diff --git a/fla3/models/abc/__pycache__/__init__.cpython-310.pyc b/fla3/models/abc/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835ef94a80de62ccece230f36b7a469ebf62f135 Binary files /dev/null and b/fla3/models/abc/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/abc/__pycache__/configuration_abc.cpython-310.pyc b/fla3/models/abc/__pycache__/configuration_abc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..269a23a0c30053ec1d67326d57d58904a25ec50d Binary files /dev/null and b/fla3/models/abc/__pycache__/configuration_abc.cpython-310.pyc differ diff --git a/fla3/models/abc/__pycache__/modeling_abc.cpython-310.pyc b/fla3/models/abc/__pycache__/modeling_abc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..126195e2711a378e44ec24c7b1a07562289f504e Binary files /dev/null and b/fla3/models/abc/__pycache__/modeling_abc.cpython-310.pyc differ diff --git a/fla3/models/abc/configuration_abc.py b/fla3/models/abc/configuration_abc.py new file mode 100644 index 0000000000000000000000000000000000000000..48b496b9493b007172f32d6edaed03ea0bde25e4 --- /dev/null +++ b/fla3/models/abc/configuration_abc.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ABCConfig(PretrainedConfig): + + model_type = 'abc' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + gate_low_rank_dim: int = 16, + clamp_min: float = -32, + clamp_max: float = 32, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 0.5, + exapnd_v: float = 1, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_rope: bool = True, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_rope = use_rope + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/abc/modeling_abc.py b/fla3/models/abc/modeling_abc.py new file mode 100644 index 0000000000000000000000000000000000000000..455d5b7f10358d1af37698e44544ca53784896a7 --- /dev/null +++ b/fla3/models/abc/modeling_abc.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.abc import ABCAttention +from fla.layers.attn import Attention +from fla.models.abc.configuration_abc import ABCConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as ABCMLP +from fla.modules import RMSNorm + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class ABCBlock(nn.Module): + def __init__(self, config: ABCConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = ABCAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + use_rope=config.use_rope, + clamp_min=config.clamp_min, + clamp_max=config.clamp_max, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = ABCMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class ABCPreTrainedModel(PreTrainedModel): + + config_class = ABCConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['ABCBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class ABCModel(ABCPreTrainedModel): + + def __init__(self, config: ABCConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ABCModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/bitnet/__init__.py b/fla3/models/bitnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bede22c64707be1ff17f402c0af6ed9da1ff1aee --- /dev/null +++ b/fla3/models/bitnet/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.bitnet.configuration_bitnet import BitNetConfig +from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel + +AutoConfig.register(BitNetConfig.model_type, BitNetConfig) +AutoModel.register(BitNetConfig, BitNetModel) +AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM) + + +__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel'] diff --git a/fla3/models/bitnet/__pycache__/__init__.cpython-310.pyc b/fla3/models/bitnet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b325ef7552b89a19920904103a2e8812a6b5bbe8 Binary files /dev/null and b/fla3/models/bitnet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/bitnet/__pycache__/configuration_bitnet.cpython-310.pyc b/fla3/models/bitnet/__pycache__/configuration_bitnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04eb145ccac2739550b2c7b56a16b74e4cb869c6 Binary files /dev/null and b/fla3/models/bitnet/__pycache__/configuration_bitnet.cpython-310.pyc differ diff --git a/fla3/models/bitnet/__pycache__/modeling_bitnet.cpython-310.pyc b/fla3/models/bitnet/__pycache__/modeling_bitnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf1dfc7524dea5f808ab2753af20b941e6b0a7bc Binary files /dev/null and b/fla3/models/bitnet/__pycache__/modeling_bitnet.cpython-310.pyc differ diff --git a/fla3/models/bitnet/configuration_bitnet.py b/fla3/models/bitnet/configuration_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..740873128a19730e2ad8b8d8351f24cf1ae56604 --- /dev/null +++ b/fla3/models/bitnet/configuration_bitnet.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class BitNetConfig(PretrainedConfig): + + model_type = 'bitnet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/bitnet/modeling_bitnet.py b/fla3/models/bitnet/modeling_bitnet.py new file mode 100644 index 0000000000000000000000000000000000000000..89f6ef32cbb7d17981a7bb0580662402e14d19d6 --- /dev/null +++ b/fla3/models/bitnet/modeling_bitnet.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.bitattn import BitAttention +from fla.models.bitnet.configuration_bitnet import BitNetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm +from fla.modules.activations import swiglu +from fla.modules.fused_bitlinear import FusedBitLinear + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class BitNetMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish', + fuse_swiglu: bool = True + ) -> BitNetMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.fuse_swiglu = fuse_swiglu + + if hidden_act != 'swish': + raise ValueError(f'Unsupported hidden_act: {hidden_act}') + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + **kwargs: Unpack[Any] + ) -> torch.Tensor: + gate, y = self.gate_proj(x), self.up_proj(x) + return self.down_proj(swiglu(gate, y)) + + +class BitNetBlock(nn.Module): + + def __init__(self, config: BitNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = BitAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = BitNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[Any] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class BitNetPreTrainedModel(PreTrainedModel): + + config_class = BitNetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['BitNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, FusedBitLinear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class BitNetModel(BitNetPreTrainedModel): + + def __init__( + self, + config: BitNetConfig + ) -> BitNetModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([BitNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`BitNetModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache, + **kwargs + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = BitNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/delta_net/__init__.py b/fla3/models/delta_net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19b7d5b57c8bcc5cc176d1b5273fc5f2800fefb8 --- /dev/null +++ b/fla3/models/delta_net/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_delta_net import DeltaNetConfig +from .modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel + +AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) +AutoModel.register(DeltaNetConfig, DeltaNetModel) +AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) + +__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] diff --git a/fla3/models/delta_net/__pycache__/__init__.cpython-310.pyc b/fla3/models/delta_net/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2130d9f2e0fb8c9a0a55e3e997ad31c52edce297 Binary files /dev/null and b/fla3/models/delta_net/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/delta_net/__pycache__/__init__.cpython-312.pyc b/fla3/models/delta_net/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..075e507e7a4c8c89505c8ac921583e71319c0920 Binary files /dev/null and b/fla3/models/delta_net/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-310.pyc b/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7617e816773b6c18f9620734e054c0820db3bcd3 Binary files /dev/null and b/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-310.pyc differ diff --git a/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc b/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689b06edc3671eb21878a2927030698f0baa875b Binary files /dev/null and b/fla3/models/delta_net/__pycache__/configuration_delta_net.cpython-312.pyc differ diff --git a/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-310.pyc b/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a6ae9a38f40e89b9c3f93e5bca937f577f66200 Binary files /dev/null and b/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-310.pyc differ diff --git a/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc b/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fffe83ebeba5d68082e8be934a818ae0043ad4a9 Binary files /dev/null and b/fla3/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc differ diff --git a/fla3/models/delta_net/configuration_delta_net.py b/fla3/models/delta_net/configuration_delta_net.py new file mode 100644 index 0000000000000000000000000000000000000000..095b0e88cc1544295752acfe391ba84a44abad98 --- /dev/null +++ b/fla3/models/delta_net/configuration_delta_net.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class DeltaNetConfig(PretrainedConfig): + + model_type = 'delta_net' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/delta_net/modeling_delta_net.py b/fla3/models/delta_net/modeling_delta_net.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9d74e24e45836cc523a86faa148d0b78f49fd5 --- /dev/null +++ b/fla3/models/delta_net/modeling_delta_net.py @@ -0,0 +1,415 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.delta_net import DeltaNet +from ...models.delta_net.configuration_delta_net import DeltaNetConfig +from ...models.utils import Cache +from ...modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from ...modules import GatedMLP as DeltaNetMLP +from ...modules import RMSNorm + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class DeltaNetBlock(nn.Module): + def __init__(self, config: DeltaNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = DeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_beta=config.use_beta, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + qk_norm=config.qk_norm, + qk_activation=config.qk_activation, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = DeltaNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class DeltaNetPreTrainedModel(PreTrainedModel): + + config_class = DeltaNetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['DeltaNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class DeltaNetModel(DeltaNetPreTrainedModel): + + def __init__(self, config: DeltaNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeltaNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/emdeltanet/__init__.py b/fla3/models/emdeltanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9230484e43ba29d230276e0407b9d97cc9aaa93a --- /dev/null +++ b/fla3/models/emdeltanet/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_emdeltanet import emdeltanetConfig +from .modeling_emdeltanet import emdeltanetForCausalLM, emdeltanetModel + +AutoConfig.register(emdeltanetConfig.model_type, emdeltanetConfig) +AutoModel.register(emdeltanetConfig, emdeltanetModel) +AutoModelForCausalLM.register(emdeltanetConfig, emdeltanetForCausalLM) + +__all__ = ['emdeltanetConfig', 'emdeltanetForCausalLM', 'emdeltanetModel'] diff --git a/fla3/models/emdeltanet/__pycache__/__init__.cpython-310.pyc b/fla3/models/emdeltanet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d183a13009e00c3329016ca89f8b44baddb4d6d Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/__init__.cpython-312.pyc b/fla3/models/emdeltanet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f984688471b65a33a49b0c84a79a64b9f79587c Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc b/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dc0afc2a5bd27135c9efc38030480bd9af0d5ee Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-310.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc b/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb8faeab4e7b7347e3e03b0fa7cd529ac6188c7 Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/configuration_emdeltanet.cpython-312.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc b/fla3/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9c664fd9d4914d4f96f85972368580a45bea817 Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/configuration_emgla.cpython-310.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc b/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32ceedae4f3bd1f96d15f1099ec4deda4e7af8c0 Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-310.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc b/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc8f655d8d9369e7da7417b61139cabf483a8e9f Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/modeling_emdeltanet.cpython-312.pyc differ diff --git a/fla3/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc b/fla3/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15c2e90086bd98e7bb2c3a4dbc409f774ac8649 Binary files /dev/null and b/fla3/models/emdeltanet/__pycache__/modeling_emgla.cpython-310.pyc differ diff --git a/fla3/models/emdeltanet/configuration_emdeltanet.py b/fla3/models/emdeltanet/configuration_emdeltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..e6554707dcfb9c0e53520f50f9e24c0ef023fef6 --- /dev/null +++ b/fla3/models/emdeltanet/configuration_emdeltanet.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + + +class emdeltanetConfig(PretrainedConfig): + + model_type = 'emdeltanet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + aux_loss_scale:float = 0.01, + ratio : int = 2, + topk : int = 1, + **kwargs + ): + self.attn_mode = attn_mode + self.aux_loss_scale = aux_loss_scale + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + self.topk = topk + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/emdeltanet/modeling_emdeltanet.py b/fla3/models/emdeltanet/modeling_emdeltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..446866ed2324a47916d206ffe6e3809c14f8d0a2 --- /dev/null +++ b/fla3/models/emdeltanet/modeling_emdeltanet.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union,Iterable +from einops import rearrange +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.emdeltanet import emdeltanet +from ...models.emdeltanet.configuration_emdeltanet import emdeltanetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emdeltanetMLP +from ...modules import RMSNorm +from ...modules import RotaryEmbedding +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emdeltanetBlock(nn.Module): + def __init__(self, config: emdeltanetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + print(config) + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emdeltanet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + ratio = config.ratio, + topk = config.topk, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emdeltanetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values,router_logits = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values,router_logits) + + return outputs + + +class emdeltanetPreTrainedModel(PreTrainedModel): + + config_class = emdeltanetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emdeltanetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emdeltanetModel(emdeltanetPreTrainedModel): + + def __init__(self, config: emdeltanetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emdeltanetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emdeltanetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + all_router_logits = () + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, router_logits = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, router_logits = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + all_router_logits += (router_logits,) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=past_key_values, + # hidden_states=all_hidden_states, + # attentions=all_attns + # ) + return emdeltanetOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns, + router_logits=all_router_logits + ) + + +from dataclasses import dataclass +@dataclass +class emdeltanetOutputWithPast(BaseModelOutputWithPast): + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class emdeltanetCausalLMOutputWithPast(CausalLMOutputWithPast): + aux_loss: Optional[torch.FloatTensor] = None + router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple], + num_memories: torch.Tensor = None, + top_k=2, + use_layer_wise_balance=False, +) -> torch.FloatTensor: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. + num_memories (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or ( + isinstance(gate_logits, Iterable) and len(gate_logits) == 0 + ): + return 0 + + # ✨ Here is the fix for balance loss in Mixtral. + # We should calculate the balance loss in a layer-wise manner otherwise it may lead to degenerated solutions. + if use_layer_wise_balance: + if not isinstance(gate_logits, Iterable): + gate_logits = (gate_logits,) + else: + if isinstance(gate_logits, Iterable): + gate_logits = (torch.cat(gate_logits, dim=0),) + else: + gate_logits = (gate_logits,) + + all_balance_losses = [] + + for logits in gate_logits: + if logits.dim() == 4: + logits = rearrange(logits,'b h l r-> (b h) l r') + routing_weights, selected_experts = torch.topk(logits, top_k, dim=-1) + routing_weights = routing_weights.softmax(dim=-1).to(logits.dtype) + routing_weights_full = torch.zeros_like(logits).scatter(-1, selected_experts, routing_weights) + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if selected_experts.dtype != torch.int64: + selected_experts = selected_experts.to(torch.int64) + + if len(selected_experts.shape) == 2: + selected_experts = selected_experts.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(routing_weights_full, axis=-2) + + # ✨ balance loss for this layer + balance_loss = torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert + ) * (num_memories**2) + all_balance_losses.append(balance_loss.reshape(1)) + + all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + + return all_balance_losses + + +class emdeltanetForCausalLM(emdeltanetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emdeltanetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + valid_router_logits = tuple( + logits + for logits in (outputs.router_logits if return_dict else outputs[-1]) + if logits is not None + ) + aux_loss = load_balancing_loss_func( + valid_router_logits, + self.config.ratio, + self.config.topk, + use_layer_wise_balance=True, + ) + aux_loss *= self.config.aux_loss_scale + # print('aux_loss:',aux_loss) + loss += aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + return emdeltanetCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + aux_loss=aux_loss + ) \ No newline at end of file diff --git a/fla3/models/emla/__init__.py b/fla3/models/emla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7ab729f889d66b8ca8b3c7f3a62ec8fa632c01 --- /dev/null +++ b/fla3/models/emla/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.emla.configuration_emla import emlaConfig +from fla.models.emla.modeling_emla import emlaForCausalLM, emlaModel + +AutoConfig.register(emlaConfig.model_type, emlaConfig) +AutoModel.register(emlaConfig, emlaModel) +AutoModelForCausalLM.register(emlaConfig, emlaForCausalLM) + +__all__ = ['emlaConfig', 'emlaForCausalLM', 'emlaModel'] diff --git a/fla3/models/emla/configuration_emla.py b/fla3/models/emla/configuration_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..b98889229c20c89aa38243d994428285aa631706 --- /dev/null +++ b/fla3/models/emla/configuration_emla.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class emlaConfig(PretrainedConfig): + + model_type = 'emla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + use_beta: bool = True, + use_output_norm: bool = True, + num_heads: int = 16, + qk_norm: str = 'l2', + qk_activation: str = 'silu', + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 24, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + ratio : int = 2, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_beta = use_beta + self.use_output_norm = use_output_norm + self.num_heads = num_heads + self.qk_norm = qk_norm + self.qk_activation = qk_activation + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + self.ratio = ratio + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/emla/modeling_emla.py b/fla3/models/emla/modeling_emla.py new file mode 100644 index 0000000000000000000000000000000000000000..fc3813f59e8246e6dc10cc9dda7556bd4e724213 --- /dev/null +++ b/fla3/models/emla/modeling_emla.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.emla import emla +from fla.models.emla.configuration_emla import emlaConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as emlaMLP +from fla.modules import RMSNorm + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class emlaBlock(nn.Module): + def __init__(self, config: emlaConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = emla( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_beta=config.use_beta, + use_short_conv=config.use_short_conv, + use_output_norm=config.use_output_norm, + conv_size=config.conv_size, + qk_norm=config.qk_norm, + qk_activation=config.qk_activation, + norm_eps=config.norm_eps, + ratio = config.ratio, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = emlaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class emlaPreTrainedModel(PreTrainedModel): + + config_class = emlaConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['emlaBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class emlaModel(emlaPreTrainedModel): + + def __init__(self, config: emlaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([emlaBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`emlaModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class emlaForCausalLM(emlaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = emlaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/forgetting_transformer/__init__.py b/fla3/models/forgetting_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de687aba3fdff3661ca1800a932211fb850e1407 --- /dev/null +++ b/fla3/models/forgetting_transformer/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig +from fla.models.forgetting_transformer.modeling_forgetting_transformer import ( + ForgettingTransformerForCausalLM, + ForgettingTransformerModel +) + +AutoConfig.register(ForgettingTransformerConfig.model_type, ForgettingTransformerConfig) +AutoModel.register(ForgettingTransformerConfig, ForgettingTransformerModel) +AutoModelForCausalLM.register(ForgettingTransformerConfig, ForgettingTransformerForCausalLM) + + +__all__ = ['ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel'] diff --git a/fla3/models/forgetting_transformer/__pycache__/__init__.cpython-310.pyc b/fla3/models/forgetting_transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd132e212115cdcd94b0537d4cde44965ee87932 Binary files /dev/null and b/fla3/models/forgetting_transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-310.pyc b/fla3/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59073616f8bdd34a20f2d32f0c9ea5c52179812 Binary files /dev/null and b/fla3/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-310.pyc differ diff --git a/fla3/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-310.pyc b/fla3/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb14b6ac63ae1e9cb909fcb32aeab7451fd6c34e Binary files /dev/null and b/fla3/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-310.pyc differ diff --git a/fla3/models/forgetting_transformer/configuration_forgetting_transformer.py b/fla3/models/forgetting_transformer/configuration_forgetting_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c06fd47c829eccf2828a4b62ff7b14cea1b857 --- /dev/null +++ b/fla3/models/forgetting_transformer/configuration_forgetting_transformer.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ForgettingTransformerConfig(PretrainedConfig): + + model_type = 'forgetting_transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + use_output_gate: bool = False, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.window_size = window_size + self.use_output_gate = use_output_gate + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/forgetting_transformer/modeling_forgetting_transformer.py b/fla3/models/forgetting_transformer/modeling_forgetting_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd2ca0ae7070b56b199c1c6f8b81d91778ee961 --- /dev/null +++ b/fla3/models/forgetting_transformer/modeling_forgetting_transformer.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.forgetting_attn import ForgettingAttention +from fla.models.forgetting_transformer.configuration_forgetting_transformer import ForgettingTransformerConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as ForgettingTransformerMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class ForgettingTransformerBlock(nn.Module): + + def __init__(self, config: ForgettingTransformerConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = ForgettingAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + qkv_bias=config.qkv_bias, + qk_norm=config.qk_norm, + window_size=config.window_size, + use_output_gate=config.use_output_gate, + layer_idx=layer_idx + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = ForgettingTransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[Any] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class ForgettingTransformerPreTrainedModel(PreTrainedModel): + + config_class = ForgettingTransformerConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['ForgettingTransformerBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per ForgettingTransformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class ForgettingTransformerModel(ForgettingTransformerPreTrainedModel): + + def __init__( + self, + config: ForgettingTransformerConfig + ) -> ForgettingTransformerModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ + ForgettingTransformerBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`ForgettingTransformerModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache, + **kwargs + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class ForgettingTransformerForCausalLM(ForgettingTransformerPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ForgettingTransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/gated_deltanet/__init__.py b/fla3/models/gated_deltanet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b48bf4fe543da44d6bb357cc6145a89074fae98 --- /dev/null +++ b/fla3/models/gated_deltanet/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_gated_deltanet import GatedDeltaNetConfig +from .modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel + +AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig) +AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel) +AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM) + +__all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel'] diff --git a/fla3/models/gated_deltanet/__pycache__/__init__.cpython-310.pyc b/fla3/models/gated_deltanet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d821c40aa576f4e91d9cc3929070534b3d4dd05 Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc b/fla3/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..042f84f108eca1c6f3a8656e4f6ceff66530682b Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-310.pyc b/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75613b86f114a7317d07b08fe70884b7aac359c8 Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-310.pyc differ diff --git a/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc b/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410f4dcc95cff302d2442a46e5c5eb2ab061ccb8 Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc differ diff --git a/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-310.pyc b/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d990a20f4b4ab731cc2d10f927fea85538ab402d Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-310.pyc differ diff --git a/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc b/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71c3f9fc9e7003db7ca79ca8054a7136444ff6a4 Binary files /dev/null and b/fla3/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc differ diff --git a/fla3/models/gated_deltanet/configuration_gated_deltanet.py b/fla3/models/gated_deltanet/configuration_gated_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3279dc4b706a4c1377cc81c7ee94819deca767 --- /dev/null +++ b/fla3/models/gated_deltanet/configuration_gated_deltanet.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GatedDeltaNetConfig(PretrainedConfig): + model_type = 'gated_deltanet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_v: int = 2, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + head_dim: int = 256, + num_heads: int = 6, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 21, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.head_dim = head_dim + self.num_heads = num_heads + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/gated_deltanet/modeling_gated_deltanet.py b/fla3/models/gated_deltanet/modeling_gated_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..9eec6501704b9cad1006ea3a6ffd206ca47d1e8c --- /dev/null +++ b/fla3/models/gated_deltanet/modeling_gated_deltanet.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from ...layers.attn import Attention +from ...layers.gated_deltanet import GatedDeltaNet +from ...models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as GatedDeltaNetMLP +from ...modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class GatedDeltaNetBlock(nn.Module): + def __init__(self, config: GatedDeltaNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedDeltaNet( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = GatedDeltaNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GatedDeltaNetPreTrainedModel(PreTrainedModel): + + config_class = GatedDeltaNetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['GatedDeltaNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, GatedDeltaNet): + + # --- A_log --- + A = torch.empty(module.num_heads, dtype=torch.float32).uniform_(0, 16) + with torch.no_grad(): + if not isinstance(module.A_log, torch.distributed.tensor.DTensor): + module.A_log.copy_(torch.log(A)) + else: + logger.warning_once("`A_log` is a DTensor, skipping initialization") + module.A_log._no_weight_decay = True + + # --- dt_bias --- + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(module.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor): + module.dt_bias.copy_(inv_dt) + else: + logger.warning_once("`dt_bias` is a DTensor, skipping initialization") + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + module.dt_bias._no_weight_decay = True + module.dt_bias._no_reinit = True + + elif isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel): + + def __init__(self, config: GatedDeltaNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GatedDeltaNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/gated_deltaproduct/__init__.py b/fla3/models/gated_deltaproduct/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7067f53a23309a9817b6de3bc3eb22480cf753 --- /dev/null +++ b/fla3/models/gated_deltaproduct/__init__.py @@ -0,0 +1,14 @@ +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig +from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductForCausalLM, GatedDeltaProductModel + +AutoConfig.register(GatedDeltaProductConfig.model_type, GatedDeltaProductConfig) +AutoModel.register(GatedDeltaProductConfig, GatedDeltaProductModel) +AutoModelForCausalLM.register(GatedDeltaProductConfig, GatedDeltaProductForCausalLM) + +__all__ = [ + "GatedDeltaProductConfig", + "GatedDeltaProductForCausalLM", + "GatedDeltaProductModel", +] diff --git a/fla3/models/gated_deltaproduct/__pycache__/__init__.cpython-310.pyc b/fla3/models/gated_deltaproduct/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43efda323f30dc0ee27e37554d79b504e7ce84ec Binary files /dev/null and b/fla3/models/gated_deltaproduct/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-310.pyc b/fla3/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c0b2e19bde8edab2851cb1b93923378bd6082af Binary files /dev/null and b/fla3/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-310.pyc differ diff --git a/fla3/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-310.pyc b/fla3/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f07f45493e691280d6e32338db34feae98e5147 Binary files /dev/null and b/fla3/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-310.pyc differ diff --git a/fla3/models/gated_deltaproduct/configuration_gated_deltaproduct.py b/fla3/models/gated_deltaproduct/configuration_gated_deltaproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..de2482cb2b1337787ed1dfee011b8cc338043485 --- /dev/null +++ b/fla3/models/gated_deltaproduct/configuration_gated_deltaproduct.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GatedDeltaProductConfig(PretrainedConfig): + model_type = 'gated_deltaproduct' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_v: int = 2, + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + head_dim: int = 256, + num_heads: int = 6, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + num_hidden_layers: int = 21, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + use_forget_gate: bool = False, + allow_neg_eigval: bool = False, + num_householder: int = 1, + **kwargs, + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.head_dim = head_dim + self.num_heads = num_heads + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + # DeltaProduct specific + self.allow_neg_eigval = allow_neg_eigval + self.num_householder = num_householder + self.use_forget_gate = use_forget_gate + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/gated_deltaproduct/modeling_gated_deltaproduct.py b/fla3/models/gated_deltaproduct/modeling_gated_deltaproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..03c514f622b7b72c8d3f33308d86a7a4a8e3320c --- /dev/null +++ b/fla3/models/gated_deltaproduct/modeling_gated_deltaproduct.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.gated_deltaproduct import GatedDeltaProduct +from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as GatedDeltaProductMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class GatedDeltaProductBlock(nn.Module): + def __init__(self, config: GatedDeltaProductConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedDeltaProduct( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_v=config.expand_v, + head_dim=config.head_dim, + num_heads=config.num_heads, + use_gate=config.use_gate, + use_forget_gate=config.use_forget_gate, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + norm_eps=config.norm_eps, + allow_neg_eigval=config.allow_neg_eigval, + num_householder=config.num_householder, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = GatedDeltaProductMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GatedDeltaProductPreTrainedModel(PreTrainedModel): + + config_class = GatedDeltaProductConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['GatedDeltaProductBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel): + + def __init__(self, config: GatedDeltaProductConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ + GatedDeltaProductBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GatedDeltaProductModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GatedDeltaProductModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/gla/__init__.py b/fla3/models/gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edccb515af8f04144308bfcbb72be8e91e714cd7 --- /dev/null +++ b/fla3/models/gla/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel + +AutoConfig.register(GLAConfig.model_type, GLAConfig) +AutoModel.register(GLAConfig, GLAModel) +AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) + + +__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] diff --git a/fla3/models/gla/__pycache__/__init__.cpython-310.pyc b/fla3/models/gla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e79ed1941cfad3de6d710c6b8fc5c8adb43cf41 Binary files /dev/null and b/fla3/models/gla/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/gla/__pycache__/configuration_gla.cpython-310.pyc b/fla3/models/gla/__pycache__/configuration_gla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7499679ca016cac2f58b4c682fe15c6e36ffd784 Binary files /dev/null and b/fla3/models/gla/__pycache__/configuration_gla.cpython-310.pyc differ diff --git a/fla3/models/gla/__pycache__/modeling_gla.cpython-310.pyc b/fla3/models/gla/__pycache__/modeling_gla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34de85a879a7352a31491ab52dbbcf44f36a912a Binary files /dev/null and b/fla3/models/gla/__pycache__/modeling_gla.cpython-310.pyc differ diff --git a/fla3/models/gla/configuration_gla.py b/fla3/models/gla/configuration_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..5ddfe436a090c8c6cd61c3efdc42a73013e71360 --- /dev/null +++ b/fla3/models/gla/configuration_gla.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GLAConfig(PretrainedConfig): + + model_type = 'gla' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + attn_mode: str = "chunk", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + clamp_min: Optional[float] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_gk: bool = True, + use_gv: bool = False, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.attn_mode = attn_mode + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.clamp_min = clamp_min + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_gk = use_gk + self.use_gv = use_gv + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/gla/modeling_gla.py b/fla3/models/gla/modeling_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..662784df3cf66381b203b0dc0ef005483a04d694 --- /dev/null +++ b/fla3/models/gla/modeling_gla.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.gla import GatedLinearAttention +from fla.models.gla.configuration_gla import GLAConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as GLAMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class GLABlock(nn.Module): + def __init__(self, config: GLAConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedLinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + clamp_min=config.clamp_min, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = GLAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GLAPreTrainedModel(PreTrainedModel): + + config_class = GLAConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['GLABlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class GLAModel(GLAPreTrainedModel): + + def __init__(self, config: GLAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GLAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/gsa/__init__.py b/fla3/models/gsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a134f758e0bea0eb844a2db73957936078f889b6 --- /dev/null +++ b/fla3/models/gsa/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel + +AutoConfig.register(GSAConfig.model_type, GSAConfig) +AutoModel.register(GSAConfig, GSAModel) +AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM) + + +__all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel'] diff --git a/fla3/models/gsa/__pycache__/__init__.cpython-310.pyc b/fla3/models/gsa/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..564fc432f6adf07fdd5b962e8f7766228c6db421 Binary files /dev/null and b/fla3/models/gsa/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/gsa/__pycache__/configuration_gsa.cpython-310.pyc b/fla3/models/gsa/__pycache__/configuration_gsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59180df32319fe1ea25994811531b3fab8e6ac5 Binary files /dev/null and b/fla3/models/gsa/__pycache__/configuration_gsa.cpython-310.pyc differ diff --git a/fla3/models/gsa/__pycache__/modeling_gsa.cpython-310.pyc b/fla3/models/gsa/__pycache__/modeling_gsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db428bdba2a7219c95ec528ade131b4972354683 Binary files /dev/null and b/fla3/models/gsa/__pycache__/modeling_gsa.cpython-310.pyc differ diff --git a/fla3/models/gsa/configuration_gsa.py b/fla3/models/gsa/configuration_gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..e0379d5c8c6a341b4fe3154d5238fc537ab85776 --- /dev/null +++ b/fla3/models/gsa/configuration_gsa.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class GSAConfig(PretrainedConfig): + + model_type = 'gsa' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + gate_logit_normalizer: Optional[int] = 8, + clamp_min: Optional[float] = None, + clamp_max: Optional[float] = None, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + num_slots: Optional[int] = 64, + use_short_conv: bool = False, + conv_size: int = 4, + exapnd_k: float = 1, + exapnd_v: float = 1, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + max_position_embeddings: int = 2048, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + initializer_range: float = 0.02, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.gate_logit_normalizer = gate_logit_normalizer + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_slots = num_slots + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.expand_k = exapnd_k + self.expand_v = exapnd_v + self.feature_map = feature_map + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.max_position_embeddings = max_position_embeddings + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/gsa/modeling_gsa.py b/fla3/models/gsa/modeling_gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..700c656b34458611aa1859dd12efad4807b6956c --- /dev/null +++ b/fla3/models/gsa/modeling_gsa.py @@ -0,0 +1,420 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.gsa import GatedSlotAttention +from fla.models.gsa.configuration_gsa import GSAConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as GSAMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class GSABlock(nn.Module): + def __init__(self, config: GSAConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = GatedSlotAttention( + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + num_slots=config.num_slots, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + use_norm=config.use_norm, + gate_fn=config.hidden_act, + gate_logit_normalizer=config.gate_logit_normalizer, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = GSAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class GSAPreTrainedModel(PreTrainedModel): + + config_class = GSAConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['GSABlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class GSAModel(GSAPreTrainedModel): + + def __init__(self, config: GSAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + + super().__init__(config) + self.model = GSAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/hgrn/__init__.py b/fla3/models/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b29a3dd82da6d64bac6cc887e24295a03de5b23 --- /dev/null +++ b/fla3/models/hgrn/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel + +AutoConfig.register(HGRNConfig.model_type, HGRNConfig) +AutoModel.register(HGRNConfig, HGRNModel) +AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) + + +__all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] diff --git a/fla3/models/hgrn/__pycache__/__init__.cpython-310.pyc b/fla3/models/hgrn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..218472241c4f33cb68769d19d6f7c05bb26da4cc Binary files /dev/null and b/fla3/models/hgrn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/hgrn/__pycache__/configuration_hgrn.cpython-310.pyc b/fla3/models/hgrn/__pycache__/configuration_hgrn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e676b27da6f5b25a4c860c3a8a43b7c75ea7a571 Binary files /dev/null and b/fla3/models/hgrn/__pycache__/configuration_hgrn.cpython-310.pyc differ diff --git a/fla3/models/hgrn/__pycache__/modeling_hgrn.cpython-310.pyc b/fla3/models/hgrn/__pycache__/modeling_hgrn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6a5438cbf809e58a9cc5c692adddef28f0accce Binary files /dev/null and b/fla3/models/hgrn/__pycache__/modeling_hgrn.cpython-310.pyc differ diff --git a/fla3/models/hgrn/configuration_hgrn.py b/fla3/models/hgrn/configuration_hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..c3350206295d1d8581c1af89c913da4e605cb21f --- /dev/null +++ b/fla3/models/hgrn/configuration_hgrn.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRNConfig(PretrainedConfig): + + model_type = 'hgrn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "fused_recurrent", + hidden_size: int = 2048, + num_hidden_layers: int = 24, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.elementwise_affine = elementwise_affine + self.attn = attn + self.norm_eps = norm_eps + self.hidden_act = hidden_act + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/hgrn/modeling_hgrn.py b/fla3/models/hgrn/modeling_hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..9037bdab35e89fefe9f2c8744bd5f0490b612977 --- /dev/null +++ b/fla3/models/hgrn/modeling_hgrn.py @@ -0,0 +1,420 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.hgrn import HGRNAttention +from fla.models.hgrn.configuration_hgrn import HGRNConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as HGRNMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class HGRNBlock(nn.Module): + def __init__(self, config: HGRNConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRNAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = HGRNMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRNPreTrainedModel(PreTrainedModel): + + config_class = HGRNConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['HGRNBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class HGRNModel(HGRNPreTrainedModel): + + def __init__(self, config: HGRNConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRNForCausalLM(HGRNPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRNModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/hgrn2/__init__.py b/fla3/models/hgrn2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..306b8082220a57091f2e99cd689c011690db0439 --- /dev/null +++ b/fla3/models/hgrn2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model + +AutoConfig.register(HGRN2Config.model_type, HGRN2Config) +AutoModel.register(HGRN2Config, HGRN2Model) +AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) + + +__all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] diff --git a/fla3/models/hgrn2/__pycache__/__init__.cpython-310.pyc b/fla3/models/hgrn2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70de648687e3a28e0d944435186bb44a322ec5ad Binary files /dev/null and b/fla3/models/hgrn2/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/hgrn2/__pycache__/configuration_hgrn2.cpython-310.pyc b/fla3/models/hgrn2/__pycache__/configuration_hgrn2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be05887a1cc5a4d89739bd19f15df39ff2fe38cf Binary files /dev/null and b/fla3/models/hgrn2/__pycache__/configuration_hgrn2.cpython-310.pyc differ diff --git a/fla3/models/hgrn2/__pycache__/modeling_hgrn2.cpython-310.pyc b/fla3/models/hgrn2/__pycache__/modeling_hgrn2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd312257a8dd315c9b2fd014186539eaf9abb615 Binary files /dev/null and b/fla3/models/hgrn2/__pycache__/modeling_hgrn2.cpython-310.pyc differ diff --git a/fla3/models/hgrn2/configuration_hgrn2.py b/fla3/models/hgrn2/configuration_hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..75855eb03e10e31c4c60730b041ad51b2b0169b1 --- /dev/null +++ b/fla3/models/hgrn2/configuration_hgrn2.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class HGRN2Config(PretrainedConfig): + + model_type = 'hgrn2' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + use_lower_bound: bool = True, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + + if expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is not None and num_heads is None: + num_heads = hidden_size // expand_ratio + elif expand_ratio is None and num_heads is None: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_lower_bound = use_lower_bound + self.max_position_embeddings = max_position_embeddings + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/hgrn2/modeling_hgrn2.py b/fla3/models/hgrn2/modeling_hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..6b48879147db1b2d7a953fb9855018fbdf040bd9 --- /dev/null +++ b/fla3/models/hgrn2/modeling_hgrn2.py @@ -0,0 +1,421 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.hgrn2 import HGRN2Attention +from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as HGRN2MLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class HGRN2Block(nn.Module): + def __init__(self, config: HGRN2Config, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = HGRN2Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = HGRN2MLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class HGRN2PreTrainedModel(PreTrainedModel): + + config_class = HGRN2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['HGRN2Block'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class HGRN2Model(HGRN2PreTrainedModel): + + def __init__(self, config: HGRN2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_lower_bound: + self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size)) + self.layers = nn.ModuleList([HGRN2Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`HGRN2Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + if self.config.use_lower_bound: + lower_bounds = self.lower_bounds.softmax(0) + lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0] + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + lower_bound = lower_bounds[i] if self.config.use_lower_bound else None + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + lower_bound, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + lower_bound=lower_bound, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class HGRN2ForCausalLM(HGRN2PreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = HGRN2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/lightnet/__init__.py b/fla3/models/lightnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..843285d4e95aa7cd8a34f7f48b73cee306a5ec20 --- /dev/null +++ b/fla3/models/lightnet/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.lightnet.configuration_lightnet import LightNetConfig +from fla.models.lightnet.modeling_lightnet import LightNetForCausalLM, LightNetModel + +AutoConfig.register(LightNetConfig.model_type, LightNetConfig) +AutoModel.register(LightNetConfig, LightNetModel) +AutoModelForCausalLM.register(LightNetConfig, LightNetForCausalLM) + + +__all__ = ['LightNetConfig', 'LightNetForCausalLM', 'LightNetModel'] diff --git a/fla3/models/lightnet/__pycache__/__init__.cpython-310.pyc b/fla3/models/lightnet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..025292ad928cbe6e0afed5b12c0c9c78a506cc77 Binary files /dev/null and b/fla3/models/lightnet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/lightnet/__pycache__/configuration_lightnet.cpython-310.pyc b/fla3/models/lightnet/__pycache__/configuration_lightnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36cb9cab8d087163efd5ec7c925ed4c7e59f66c3 Binary files /dev/null and b/fla3/models/lightnet/__pycache__/configuration_lightnet.cpython-310.pyc differ diff --git a/fla3/models/lightnet/__pycache__/modeling_lightnet.cpython-310.pyc b/fla3/models/lightnet/__pycache__/modeling_lightnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c253823a4a746bcd1f86006912fa7757e0418ce5 Binary files /dev/null and b/fla3/models/lightnet/__pycache__/modeling_lightnet.cpython-310.pyc differ diff --git a/fla3/models/lightnet/configuration_lightnet.py b/fla3/models/lightnet/configuration_lightnet.py new file mode 100644 index 0000000000000000000000000000000000000000..620499892d109f72ebd46fdb3b898d3f7542c1ae --- /dev/null +++ b/fla3/models/lightnet/configuration_lightnet.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LightNetConfig(PretrainedConfig): + + model_type = 'lightnet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + attn_mode: str = "chunk", + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + gate_low_rank_dim: int = 128, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attn_mode = attn_mode + self.num_heads = num_heads + self.expand_ratio = expand_ratio + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.max_position_embeddings = max_position_embeddings + self.gate_low_rank_dim = gate_low_rank_dim + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/lightnet/modeling_lightnet.py b/fla3/models/lightnet/modeling_lightnet.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed0ff789f8821034ffc16bc6878749b867b235 --- /dev/null +++ b/fla3/models/lightnet/modeling_lightnet.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.lightnet import LightNetAttention +from fla.models.lightnet.configuration_lightnet import LightNetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as LightNetMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class LightNetBlock(nn.Module): + def __init__(self, config: LightNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = LightNetAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + num_heads=config.num_heads, + expand_ratio=config.expand_ratio, + use_short_conv=config.use_short_conv, + conv_size=config.conv_size, + gate_low_rank_dim=config.gate_low_rank_dim, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = LightNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class LightNetPreTrainedModel(PreTrainedModel): + + config_class = LightNetConfig + supports_gradient_checkpointing = True + _no_split_modules = ['LightNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class LightNetModel(LightNetPreTrainedModel): + + def __init__(self, config: LightNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LightNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/linear_attn/__init__.py b/fla3/models/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4446d4725923bdcf649dd38e400e7a44ee2cae0 --- /dev/null +++ b/fla3/models/linear_attn/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig +from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel + +AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) +AutoModel.register(LinearAttentionConfig, LinearAttentionModel) +AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) + +__all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] diff --git a/fla3/models/linear_attn/__pycache__/__init__.cpython-310.pyc b/fla3/models/linear_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d16498644c554562178f811353650966b7857c Binary files /dev/null and b/fla3/models/linear_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/linear_attn/__pycache__/configuration_linear_attn.cpython-310.pyc b/fla3/models/linear_attn/__pycache__/configuration_linear_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6b4201e730e48c34e8ffddf29235953c59929af Binary files /dev/null and b/fla3/models/linear_attn/__pycache__/configuration_linear_attn.cpython-310.pyc differ diff --git a/fla3/models/linear_attn/__pycache__/modeling_linear_attn.cpython-310.pyc b/fla3/models/linear_attn/__pycache__/modeling_linear_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..480ce6bc5a25c7b76fa87df93ae515a7cce9af2e Binary files /dev/null and b/fla3/models/linear_attn/__pycache__/modeling_linear_attn.cpython-310.pyc differ diff --git a/fla3/models/linear_attn/configuration_linear_attn.py b/fla3/models/linear_attn/configuration_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..dbeed3aeb4ea3f962c17a9e6537c0052cc48a4e3 --- /dev/null +++ b/fla3/models/linear_attn/configuration_linear_attn.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class LinearAttentionConfig(PretrainedConfig): + + model_type = 'linear_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "fused_chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 1, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: str = "elementwise_product", + tie_feature_map_qk: bool = False, + norm_q: bool = False, + norm_k: bool = False, + norm_feature_map: bool = False, + hidden_act: str = "swish", + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.tie_feature_map_qk = tie_feature_map_qk + self.norm_q = norm_q + self.norm_k = norm_k + self.norm_feature_map = norm_feature_map + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/linear_attn/modeling_linear_attn.py b/fla3/models/linear_attn/modeling_linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..73f8454cd285f75fcf56797954df318dedb3e5c8 --- /dev/null +++ b/fla3/models/linear_attn/modeling_linear_attn.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.linear_attn import LinearAttention +from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as LinearAttentionMLP +from fla.modules import RMSNorm + +logger = logging.get_logger(__name__) + + +class LinearAttentionBlock(nn.Module): + def __init__(self, config: LinearAttentionConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = LinearAttention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + tie_feature_map_qk=config.tie_feature_map_qk, + norm_q=config.norm_q, + norm_k=config.norm_k, + do_feature_map_norm=config.norm_feature_map, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = LinearAttentionMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + # currently not supported + attentions, past_key_values = None, None + hidden_states = self.attn_norm(hidden_states) + hidden_states = self.attn(hidden_states=hidden_states, **kwargs) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class LinearAttentionPreTrainedModel(PreTrainedModel): + + config_class = LinearAttentionConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['LinearAttentionBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class LinearAttentionModel(LinearAttentionPreTrainedModel): + + def __init__(self, config: LinearAttentionConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LinearAttentionBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`LinearAttentionModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class LinearAttentionForCausalLM(LinearAttentionPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LinearAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0 + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/mamba/__init__.py b/fla3/models/mamba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b67cf0a75012a2f71a0f12c53584071bdc456a6b --- /dev/null +++ b/fla3/models/mamba/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.models.mamba.modeling_mamba import MambaBlock, MambaForCausalLM, MambaModel + +AutoConfig.register(MambaConfig.model_type, MambaConfig, True) +AutoModel.register(MambaConfig, MambaModel, True) +AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) + + +__all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] diff --git a/fla3/models/mamba/__pycache__/__init__.cpython-310.pyc b/fla3/models/mamba/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5887da71337d16d48d36b7ce588275492095fc6 Binary files /dev/null and b/fla3/models/mamba/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/mamba/__pycache__/configuration_mamba.cpython-310.pyc b/fla3/models/mamba/__pycache__/configuration_mamba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c42c5885742fc73b27c2e8e05fc27ed6fcd3944e Binary files /dev/null and b/fla3/models/mamba/__pycache__/configuration_mamba.cpython-310.pyc differ diff --git a/fla3/models/mamba/__pycache__/modeling_mamba.cpython-310.pyc b/fla3/models/mamba/__pycache__/modeling_mamba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93fdcca191d4298e76965db0dc6b7bf09a338bea Binary files /dev/null and b/fla3/models/mamba/__pycache__/modeling_mamba.cpython-310.pyc differ diff --git a/fla3/models/mamba/configuration_mamba.py b/fla3/models/mamba/configuration_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..fb861e24cb6d3e02174a4cc8901223cb54fb9b1b --- /dev/null +++ b/fla3/models/mamba/configuration_mamba.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA configuration""" + +import math + +from transformers.configuration_utils import PretrainedConfig + + +class MambaConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA + [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the Mamba model. + hidden_size (`int`, *optional*): + Dimensionality of the embeddings and hidden states. Default: 2048. + state_size (`int`, *optional*): + Shape of the state space latents. Default: 16. + num_hidden_layers (`int`, *optional*): + Number of hidden layers in the model. Default: 48. + norm_eps (`float`, *optional*): + The epsilon to use in the layer normalization layers. Default: 1e-5. + pad_token_id (`int`, *optional*): + Padding token id. Default: 0. + bos_token_id (`int`, *optional*): + The id of the beginning of sentence token in the vocabulary. Default: 0. + eos_token_id (`int`, *optional*): + The id of the end of sentence token in the vocabulary. Default: 0. + expand (`int`, *optional*): + Expanding factor used to determine the intermediate size. Default: 2. + conv_kernel (`int`, *optional*): + Size of the convolution kernel. Default: 4. + use_bias (`bool`, *optional*): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block. Default: `False`. + use_conv_bias (`bool`, *optional*): + Whether or not to use bias in the convolution layer of the mixer block. Default: `True`. + hidden_act (`str`, *optional*): + The non-linear activation function (function or string) in the decoder. Default: `"silu"`. + initializer_range (`float`, *optional*): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Default: 0.02. + residual_in_fp32 (`bool`, *optional*): + Whether or not residuals should be in `float32`. + If set to `False` residuals will keep the same `dtype` as the rest of the model. Default: `True`. + time_step_rank (`Union[int,str]`, *optional*): + Rank of the the discretization projection matrix. + `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`. Default: `"auto"`. + time_step_scale (`float`, *optional*): + Scale used used to scale `dt_proj.bias`. Default: 1.0. + time_step_min (`float`, *optional*): + Minimum `time_step` used to bound `dt_proj.bias`. Default: 0.001. + time_step_max (`float`, *optional*): + Maximum `time_step` used to bound `dt_proj.bias`. Default: 0.1. + time_step_init_scheme (`float`, *optional*): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`. Default: `"random"`. + time_step_floor (`float`, *optional*): + Minimum clamping value of the `dt_proj.bias` layer initialization. Default: 0.0001. + window_size (`int`, *optional*): + The window size used for sliding window attention. Default: 2048. + rescale_prenorm_residual (`bool`, *optional*): + Whether or not to rescale `out_proj` weights when initializing. Default: `False`. + use_cache (`bool`, *optional*): + Whether or not the cache should be used. Default: `True`. + + + Example: + + ```python + >>> from transformers import MambaConfig, MambaModel + + >>> # Initializing a Mamba configuration + >>> configuration = MambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba" + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 2048, + state_size: int = 16, + num_hidden_layers: int = 48, + norm_eps=1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "silu", + initializer_range: str = 0.02, + residual_in_fp32: bool = False, + time_step_rank: str = "auto", + time_step_scale: float = 1.0, + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_init_scheme: str = "random", + time_step_floor: float = 1e-4, + rescale_prenorm_residual: bool = False, + use_cache: bool = True, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/fla3/models/mamba/modeling_mamba.py b/fla3/models/mamba/modeling_mamba.py new file mode 100644 index 0000000000000000000000000000000000000000..7b255d96a20010a305e23ddfd4438334bf96f63b --- /dev/null +++ b/fla3/models/mamba/modeling_mamba.py @@ -0,0 +1,565 @@ +# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.mamba import Mamba +from fla.models.mamba.configuration_mamba import MambaConfig +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm + +logger = logging.get_logger(__name__) + + +class MambaCache: + """ + Cache for mamba model which does not have attention mechanism and key value states. + + Arguments: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + batch_size (`int`): + The batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): + The default `dtype` to use when initializing the layer. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. Should be the same as the layer. + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + intermediate_size: (`int`): + Model's intermediate_size taken from config. + ssm_state_size: (`int`): + Model's state_size taken from config. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config + conv_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states + + Example: + + ```python + >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache + + >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") + + >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values + MambaCache() + ``` + """ + + # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. + def __init__( + self, + config: PretrainedConfig, + batch_size: int = None, + dtype: torch.dtype = torch.float16, + device: Optional[Union[torch.device, str]] = None, + max_batch_size: Optional[int] = None, + ): + if max_batch_size is not None: + logger.warning_once( + f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " + "v4.46. Use the more precisely named 'batch_size' argument instead." + ) + self.dtype = dtype + self.batch_size = batch_size or max_batch_size + self.intermediate_size = config.intermediate_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + + self.conv_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.batch_size, + self.intermediate_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states: torch.Tensor = torch.zeros( + config.num_hidden_layers, + self.batch_size, + self.intermediate_size, + self.ssm_state_size, + device=device, + dtype=dtype, + ) + + torch._dynamo.mark_static_address(self.conv_states) + torch._dynamo.mark_static_address(self.ssm_states) + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.mixer = Mamba( + hidden_size=config.hidden_size, + state_size=config.state_size, + conv_kernel=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_bias=config.use_bias, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + if self.residual_in_fp32: + hidden_states = hidden_states.to(dtype=self.norm.weight.dtype) + return hidden_states + + +class MambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MambaConfig + base_model_prefix = 'backbone' + _no_split_modules = ['Mamba', 'MambaBlock'] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, Mamba): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device)) + module.dt_proj.bias._no_reinit = True + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class MambaModel(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = MambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, + model_kwargs: Dict[str, Any], + num_new_tokens: int = 1, + **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + if ( + model_kwargs.get("use_cache", True) + and "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + return model_kwargs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[MambaCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + logits_to_keep: Optional[int] = None, + **kwargs, + ): + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if attention_mask is not None: + attention_mask = None + + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'cache_params': cache_params, + 'use_cache': use_cache, + 'cache_position': cache_position, + 'attention_mask': attention_mask, + 'logits_to_keep': logits_to_keep, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + logits_to_keep: Optional[int] = 0, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, MambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba_outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return (loss,) + output if loss is not None else output + + return MambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/fla3/models/mamba2/__init__.py b/fla3/models/mamba2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8ac62a700590e06d1e524979b2f21353aa5188 --- /dev/null +++ b/fla3/models/mamba2/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.mamba2.configuration_mamba2 import Mamba2Config +from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model + +AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True) +AutoModel.register(Mamba2Config, Mamba2Model, True) +AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True) + + +__all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model'] diff --git a/fla3/models/mamba2/__pycache__/__init__.cpython-310.pyc b/fla3/models/mamba2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8774907b299bbfd3419f35529112d7cef648acc Binary files /dev/null and b/fla3/models/mamba2/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/mamba2/__pycache__/configuration_mamba2.cpython-310.pyc b/fla3/models/mamba2/__pycache__/configuration_mamba2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15743d57ed3bb5d762b473d02f25d8638d2f16cc Binary files /dev/null and b/fla3/models/mamba2/__pycache__/configuration_mamba2.cpython-310.pyc differ diff --git a/fla3/models/mamba2/__pycache__/modeling_mamba2.cpython-310.pyc b/fla3/models/mamba2/__pycache__/modeling_mamba2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a3af0ede15714cf1ac964179f78cb6dc4540470 Binary files /dev/null and b/fla3/models/mamba2/__pycache__/modeling_mamba2.cpython-310.pyc differ diff --git a/fla3/models/mamba2/configuration_mamba2.py b/fla3/models/mamba2/configuration_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..27c9c5b73c21e9a694b129bd8f6f6e70e5feb88e --- /dev/null +++ b/fla3/models/mamba2/configuration_mamba2.py @@ -0,0 +1,167 @@ +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA2 configuration""" + +import math + +from transformers.configuration_utils import PretrainedConfig + + +class Mamba2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA2 + [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + head_dim (`int`, *optional*, defaults to 64): + Dimension of each head. + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mamba2Model`]. + hidden_size (`int`, *optional*, defaults to 2048): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 128): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the model. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + n_groups (`int`, *optional*, defaults to 1): + Number of groups for the evolution matrices of mamba 2. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. + If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. + `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm or not. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie word embeddings or not. + """ + + model_type = "mamba2" + + def __init__( + self, + head_dim: int = 64, + vocab_size: int = 32000, + hidden_size: int = 2048, + state_size: int = 128, + num_hidden_layers: int = 48, + norm_eps: float = 1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + n_groups: int = 1, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "silu", + initializer_range: float = 0.02, + residual_in_fp32: bool = True, + time_step_rank: str = "auto", + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_floor: float = 1e-4, + time_step_limit=(0.0, float("inf")), + rescale_prenorm_residual: bool = True, + use_cache: bool = True, + rms_norm: bool = True, + chunk_size: int = 256, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.conv_kernel = conv_kernel + self.expand = expand + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = ( + math.ceil(self.hidden_size / 16) + if time_step_rank == "auto" + else time_step_rank + ) + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.n_groups = n_groups + self.head_dim = head_dim + self.num_heads = int(self.expand * self.hidden_size / self.head_dim) + self.rms_norm = rms_norm + self.state_size = state_size + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.tie_word_embeddings = tie_word_embeddings + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/mamba2/modeling_mamba2.py b/fla3/models/mamba2/modeling_mamba2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f34c7f0f0d5427862ca2bde46e7806bfbe94cf5 --- /dev/null +++ b/fla3/models/mamba2/modeling_mamba2.py @@ -0,0 +1,562 @@ +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.mamba2 import Mamba2 +from fla.models.mamba2.configuration_mamba2 import Mamba2Config +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm + +logger = logging.get_logger(__name__) + + +class Mamba2Cache: + """ + Arguments: + config: Mamba2Config + batch_size: int + dtype: torch.dtype + device: torch.device + + Attributes: + dtype: (`torch.dtype`): + The default `dtype` used to initializing the cache. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config. + n_groups: (`int`): + Model's number of groups taken from the config - similar to tensor parallel in Transformer. + state_size: (`int`): + Model's SSM state size taken from config. + num_heads: (`int`): + The number of heads used in the linear attention / SSM. + head_dim: (`int`): + The respective dimension of the heads used in the linear attention / SSM. + intermediate_size: (`int`): + Model's intermediate_size based on (expand * hidden_dim) from config. + conv_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` + that holds convolutional states. + ssm_states: (`torch.Tensor`): + A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. + """ + + def __init__( + self, + config: Mamba2Config, + batch_size: int, + dtype: torch.dtype = torch.float16, + device: Optional[str] = None, + ): + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.n_groups = config.n_groups + self.state_size = config.state_size + self.num_heads = config.num_heads + self.head_dim = config.head_dim + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.intermediate_size + 2 * self.n_groups * self.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + self.ssm_states = torch.zeros( + config.num_hidden_layers, + batch_size, + self.num_heads, + self.head_dim, + self.state_size, + device=device, + dtype=dtype, + ) + + def update_conv_state( + self, + layer_idx: int, + new_conv_state: torch.Tensor, + cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class Mamba2Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) + self.mixer = Mamba2( + num_heads=config.num_heads, + head_dim=config.head_dim, + hidden_size=config.hidden_size, + state_size=config.state_size, + expand=config.expand, + n_groups=config.n_groups, + conv_kernel=config.conv_kernel, + use_conv_bias=config.use_conv_bias, + hidden_act=config.hidden_act, + rms_norm=config.rms_norm, + chunk_size=config.chunk_size, + time_step_rank=config.time_step_rank, + time_step_limit=config.time_step_limit, + time_step_min=config.time_step_min, + time_step_max=config.time_step_max, + use_bias=config.use_bias, + norm_eps=config.norm_eps, + layer_idx=layer_idx, + ) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = residual + hidden_states + if self.residual_in_fp32: + hidden_states = hidden_states.to(dtype=self.norm.weight.dtype) + return hidden_states + + +class Mamba2PreTrainedModel(PreTrainedModel, GenerationMixin): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights( + self, + module: nn.Module, + num_residuals_per_layer: int = 1, + ): + """Initialize the weights.""" + if isinstance(module, Mamba2): + + # --- A_log --- + A = torch.arange(1, module.num_heads + 1) + with torch.no_grad(): + if not isinstance(module.A_log, torch.distributed.tensor.DTensor): + module.A_log.copy_(torch.log(A)) + else: + logger.warning_once("`A_log` is a DTensor, skipping initialization") + module.A_log._no_weight_decay = True + + # --- D --- + nn.init.ones_(module.D) + module.D._no_weight_decay = True + + # --- dt_bias --- + dt = torch.exp( + torch.rand(self.config.num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + if not isinstance(module.dt_bias, torch.distributed.tensor.DTensor): + module.dt_bias.copy_(inv_dt) + else: + logger.warning_once("`dt_bias` is a DTensor, skipping initialization") + module.dt_bias._no_reinit = True + + elif isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + # guard against deprecated behavior + if hasattr(module.bias, "_no_reinit"): + raise ValueError("This is not supposed to happen") + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + # p = module.o_proj.weight + # guard against deprecated behavior + raise ValueError("This is not supposed to happen") + elif hasattr(module, 'out_proj'): + p = module.out_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 +class Mamba2CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class Mamba2Model(Mamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, + hidden_states, + cache_params, + cache_position, + attention_mask, + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class Mamba2ForCausalLM(Mamba2PreTrainedModel): + _tied_weights_keys = [] + + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + logits_to_keep: Optional[int] = None, + **kwargs, + ): + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + + if attention_mask is not None: + attention_mask = None + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'attention_mask': attention_mask, + 'cache_params': cache_params, + 'use_cache': use_cache, + 'cache_position': cache_position, + 'logits_to_keep': logits_to_keep + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + logits_to_keep: Optional[int] = 0, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=outputs.cache_params, + hidden_states=outputs.hidden_states, + ) diff --git a/fla3/models/nsa/__init__.py b/fla3/models/nsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65b8d8982cfb751a9dc0b15b4c8546ac08bf1b06 --- /dev/null +++ b/fla3/models/nsa/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.nsa.configuration_nsa import NSAConfig +from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel + +AutoConfig.register(NSAConfig.model_type, NSAConfig) +AutoModel.register(NSAConfig, NSAModel) +AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM) + + +__all__ = [ + 'NSAConfig', 'NSAModel', 'NSAForCausalLM', +] diff --git a/fla3/models/nsa/__pycache__/__init__.cpython-310.pyc b/fla3/models/nsa/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d80a7025635b0c6fed46bd64f0d50074735f067 Binary files /dev/null and b/fla3/models/nsa/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/nsa/__pycache__/configuration_nsa.cpython-310.pyc b/fla3/models/nsa/__pycache__/configuration_nsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb861ab95ee5c6d2ab48b9b0ff94940996da566d Binary files /dev/null and b/fla3/models/nsa/__pycache__/configuration_nsa.cpython-310.pyc differ diff --git a/fla3/models/nsa/__pycache__/modeling_nsa.cpython-310.pyc b/fla3/models/nsa/__pycache__/modeling_nsa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4536d9092f73fc029cf51a50a7565125088c432 Binary files /dev/null and b/fla3/models/nsa/__pycache__/modeling_nsa.cpython-310.pyc differ diff --git a/fla3/models/nsa/configuration_nsa.py b/fla3/models/nsa/configuration_nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..90c2fe25f62155c05090c84312a65d034eeec015 --- /dev/null +++ b/fla3/models/nsa/configuration_nsa.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class NSAConfig(PretrainedConfig): + + model_type = 'nsa' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 64, + num_kv_heads: int = 4, + head_dim: int = 32, + qkv_bias: bool = False, + block_size: int = 64, + block_counts: Optional[int] = 16, + window_size: Optional[int] = 512, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.qkv_bias = qkv_bias + self.block_size = block_size + self.block_counts = block_counts + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/nsa/modeling_nsa.py b/fla3/models/nsa/modeling_nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..96131feb8cdb0abfa2dfe675d5132727b7afc4a9 --- /dev/null +++ b/fla3/models/nsa/modeling_nsa.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.nsa import NativeSparseAttention +from fla.models.nsa.configuration_nsa import NSAConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as NSAMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class NSABlock(nn.Module): + def __init__(self, config: NSAConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = NativeSparseAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + qkv_bias=config.qkv_bias, + block_size=config.block_size, + block_counts=config.block_counts, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = NSAMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class NSAPreTrainedModel(PreTrainedModel): + + config_class = NSAConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['NSABlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class NSAModel(NSAPreTrainedModel): + + def __init__(self, config: NSAConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = NSAModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/path_attn/__init__.py b/fla3/models/path_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..238d80e1a69b5088e90b81a8ead33206ebcd2107 --- /dev/null +++ b/fla3/models/path_attn/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.path_attn.configuration_path_attention import PaTHAttentionConfig +from fla.models.path_attn.modeling_path_attention import PaTHAttentionForCausalLM, PaTHAttentionModel + +AutoConfig.register(PaTHAttentionConfig.model_type, PaTHAttentionConfig) +AutoModel.register(PaTHAttentionConfig, PaTHAttentionModel) +AutoModelForCausalLM.register(PaTHAttentionConfig, PaTHAttentionForCausalLM) + + +__all__ = ['PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel'] diff --git a/fla3/models/path_attn/__pycache__/__init__.cpython-310.pyc b/fla3/models/path_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7bfbf5f8e38a3599edaced69b816f5c90d2c088 Binary files /dev/null and b/fla3/models/path_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/path_attn/__pycache__/configuration_path_attention.cpython-310.pyc b/fla3/models/path_attn/__pycache__/configuration_path_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1830098c503b8fb8dfefff474a009091b1b6f23 Binary files /dev/null and b/fla3/models/path_attn/__pycache__/configuration_path_attention.cpython-310.pyc differ diff --git a/fla3/models/path_attn/__pycache__/modeling_path_attention.cpython-310.pyc b/fla3/models/path_attn/__pycache__/modeling_path_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..370587453ff6dbd9b12dc4f451fd2b6707ccc593 Binary files /dev/null and b/fla3/models/path_attn/__pycache__/modeling_path_attention.cpython-310.pyc differ diff --git a/fla3/models/path_attn/configuration_path_attention.py b/fla3/models/path_attn/configuration_path_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..63f24d3a54927656bb3d31eb92a36a37bcb7536a --- /dev/null +++ b/fla3/models/path_attn/configuration_path_attention.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class PaTHAttentionConfig(PretrainedConfig): + + model_type = 'path_attn' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: Optional[int] = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + use_forget_gate: bool = False, + use_w_shortconv: bool = True, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + self.use_forget_gate = use_forget_gate + self.use_w_shortconv = use_w_shortconv + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/path_attn/modeling_path_attention.py b/fla3/models/path_attn/modeling_path_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..00f770a8a01384eb68625a0773a1e725e5e0ab40 --- /dev/null +++ b/fla3/models/path_attn/modeling_path_attention.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.path_attn import PaTHAttention +from fla.models.path_attn.configuration_path_attention import PaTHAttentionConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as PaTHAttentionMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class PaTHAttentionBlock(nn.Module): + + def __init__(self, config: PaTHAttentionConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = PaTHAttention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + use_forget_gate=config.use_forget_gate, + use_w_shortconv=config.use_w_shortconv, + layer_idx=layer_idx + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = PaTHAttentionMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[Any] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class PaTHAttentionPreTrainedModel(PreTrainedModel): + + config_class = PaTHAttentionConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['PaTHAttentionBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per PaTHAttention Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class PaTHAttentionModel(PaTHAttentionPreTrainedModel): + + def __init__( + self, + config: PaTHAttentionConfig + ) -> PaTHAttentionModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([ + PaTHAttentionBlock(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`PaTHAttentionModel` does not support output attention weights now, " + "so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache, + **kwargs + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class PaTHAttentionForCausalLM(PaTHAttentionPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = PaTHAttentionModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/retnet/__init__.py b/fla3/models/retnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7d9e9da930819a2a6728e3e189090651b82a2e --- /dev/null +++ b/fla3/models/retnet/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel + +AutoConfig.register(RetNetConfig.model_type, RetNetConfig) +AutoModel.register(RetNetConfig, RetNetModel) +AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) + + +__all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] diff --git a/fla3/models/retnet/__pycache__/__init__.cpython-310.pyc b/fla3/models/retnet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2acabbfb05b96ecf9d563f55c610f0feb7289e50 Binary files /dev/null and b/fla3/models/retnet/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/retnet/__pycache__/configuration_retnet.cpython-310.pyc b/fla3/models/retnet/__pycache__/configuration_retnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121da08777fe92a2fbdef811abd4d24719182063 Binary files /dev/null and b/fla3/models/retnet/__pycache__/configuration_retnet.cpython-310.pyc differ diff --git a/fla3/models/retnet/__pycache__/modeling_retnet.cpython-310.pyc b/fla3/models/retnet/__pycache__/modeling_retnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99448fed27408e5bd3135df3fa217bed1dcaa308 Binary files /dev/null and b/fla3/models/retnet/__pycache__/modeling_retnet.cpython-310.pyc differ diff --git a/fla3/models/retnet/configuration_retnet.py b/fla3/models/retnet/configuration_retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0f09217228d97e3dd794cecfffff612c8721b6f0 --- /dev/null +++ b/fla3/models/retnet/configuration_retnet.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RetNetConfig(PretrainedConfig): + + model_type = 'retnet' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 1, + expand_v: int = 2, + hidden_ratio: Optional[int] = 2, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + hidden_act: str = "swish", + use_short_conv: bool = False, + conv_size: int = 4, + use_output_gate: bool = True, + max_position_embeddings: int = 2048, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ) -> RetNetConfig: + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.feature_map = feature_map + self.hidden_act = hidden_act + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.use_output_gate = use_output_gate + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/retnet/modeling_retnet.py b/fla3/models/retnet/modeling_retnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f4500020dc152d5ce418983722bdad9e3de0287e --- /dev/null +++ b/fla3/models/retnet/modeling_retnet.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.multiscale_retention import MultiScaleRetention +from fla.models.retnet.configuration_retnet import RetNetConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as RetNetMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class RetNetBlock(nn.Module): + def __init__(self, config: RetNetConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = MultiScaleRetention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + feature_map=config.feature_map, + use_output_gate=config.use_output_gate, + gate_fn=config.hidden_act, + elementwise_affine=config.elementwise_affine, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = RetNetMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class RetNetPreTrainedModel(PreTrainedModel): + + config_class = RetNetConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['RetNetBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + prenorm_residual_strategy: Optional[str] = 'rescale', + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if prenorm_residual_strategy is not None: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + if prenorm_residual_strategy == 'rescale': + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + elif prenorm_residual_strategy == 'zero': + nn.init.zeros_(p) + else: + raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}") + + +class RetNetModel(RetNetPreTrainedModel): + + def __init__(self, config: RetNetConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn( + "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RetNetForCausalLM(RetNetPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RetNetModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'" + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for `inputs_ids` if the `past_key_values` is passed along. + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/rwkv6/__init__.py b/fla3/models/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5902546e88200af55af16acf7c6f85512d72cf --- /dev/null +++ b/fla3/models/rwkv6/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model + +AutoConfig.register(RWKV6Config.model_type, RWKV6Config, True) +AutoModel.register(RWKV6Config, RWKV6Model, True) +AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM, True) + + +__all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] diff --git a/fla3/models/rwkv6/__pycache__/__init__.cpython-310.pyc b/fla3/models/rwkv6/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a32629cba80bfee3ebb27cace1036b69bfe7292b Binary files /dev/null and b/fla3/models/rwkv6/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/rwkv6/__pycache__/configuration_rwkv6.cpython-310.pyc b/fla3/models/rwkv6/__pycache__/configuration_rwkv6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..683e10a6090554ebb6367279e5fc9fa406f30aea Binary files /dev/null and b/fla3/models/rwkv6/__pycache__/configuration_rwkv6.cpython-310.pyc differ diff --git a/fla3/models/rwkv6/__pycache__/modeling_rwkv6.cpython-310.pyc b/fla3/models/rwkv6/__pycache__/modeling_rwkv6.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc4a9e9f1838e1f19af0f1f4ca113d108af535e6 Binary files /dev/null and b/fla3/models/rwkv6/__pycache__/modeling_rwkv6.cpython-310.pyc differ diff --git a/fla3/models/rwkv6/configuration_rwkv6.py b/fla3/models/rwkv6/configuration_rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..c74da70f218259cbb39739431aef7cb0fa6cc3b4 --- /dev/null +++ b/fla3/models/rwkv6/configuration_rwkv6.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV6Config(PretrainedConfig): + + model_type = 'rwkv6' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + expand_k: int = 0.5, + expand_v: int = 1, + hidden_ratio: Optional[int] = 3.5, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + num_heads: int = 4, + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/rwkv6/modeling_rwkv6.py b/fla3/models/rwkv6/modeling_rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..94e6b25820122de5d7c8f13f5e994d224c825951 --- /dev/null +++ b/fla3/models/rwkv6/modeling_rwkv6.py @@ -0,0 +1,486 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.rwkv6 import LerpLinear, RWKV6Attention +from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN +from fla.modules.token_shift import token_shift + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class RWKV6FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None + ) -> RWKV6FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 3.5 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = LerpLinear(hidden_size, intermediate_size) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + self.receptance = LerpLinear(hidden_size, hidden_size) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + state: Optional[Cache] = None, + **kwargs + ) -> torch.Tensor: + if attention_mask is not None: + x = x.mul_(attention_mask[:, -x.shape[-2]:, None]) + if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1) + delta = shifted - x + elif state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = self.time_shift(x) + if state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted[:, 0] = state[self.layer_idx]['ffn_state'] + delta = shifted - x + else: + cu_seqlens = kwargs.get('cu_seqlens', None) + delta = token_shift(x, cu_seqlens) + key = self.act_fn(self.key(x, delta)) + value = self.value(key) + receptance = self.receptance(x, delta) + + if state is not None: + # no need to update the offset twice + state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0) + return receptance.sigmoid() * value, state + + +class RWKV6Block(nn.Module): + def __init__(self, config: RWKV6Config, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + if config.norm_first and layer_idx == 0: + self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV6Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + expand_k=config.expand_k, + expand_v=config.expand_v, + num_heads=config.num_heads, + proj_low_rank_dim=config.proj_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx + ) + self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.ffn = RWKV6FeedForward( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states + hidden_states = self.attn_norm(residual) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values) + + return outputs + + +class RWKV6PreTrainedModel(PreTrainedModel): + + config_class = RWKV6Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV6Block'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV6Model(RWKV6PreTrainedModel): + + def __init__(self, config: RWKV6Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV6Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV6Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + **kwargs + ) + else: + hidden_states, attentions, past_key_values = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV6ForCausalLM(RWKV6PreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV6Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/rwkv7/__init__.py b/fla3/models/rwkv7/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f132f3fc8de7108242e1accc51e55f4a4e6ed5 --- /dev/null +++ b/fla3/models/rwkv7/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model + +AutoConfig.register(RWKV7Config.model_type, RWKV7Config, True) +AutoModel.register(RWKV7Config, RWKV7Model, True) +AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM, True) + + +__all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model'] diff --git a/fla3/models/rwkv7/__pycache__/__init__.cpython-310.pyc b/fla3/models/rwkv7/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89add7c863b0baf9c698e48e74f6ab0420d4ee65 Binary files /dev/null and b/fla3/models/rwkv7/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/rwkv7/__pycache__/configuration_rwkv7.cpython-310.pyc b/fla3/models/rwkv7/__pycache__/configuration_rwkv7.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaad9ca9fed9909fc1edbb39fc219ce80c9bd0c1 Binary files /dev/null and b/fla3/models/rwkv7/__pycache__/configuration_rwkv7.cpython-310.pyc differ diff --git a/fla3/models/rwkv7/__pycache__/modeling_rwkv7.cpython-310.pyc b/fla3/models/rwkv7/__pycache__/modeling_rwkv7.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c3cd0780bc49d37e1834a77cb3bb1a8a66da266 Binary files /dev/null and b/fla3/models/rwkv7/__pycache__/modeling_rwkv7.cpython-310.pyc differ diff --git a/fla3/models/rwkv7/configuration_rwkv7.py b/fla3/models/rwkv7/configuration_rwkv7.py new file mode 100644 index 0000000000000000000000000000000000000000..606beb12cd55166a49770971e38a501f6bcb8c65 --- /dev/null +++ b/fla3/models/rwkv7/configuration_rwkv7.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- + +from typing import Dict, List, Optional, Union + +from transformers.configuration_utils import PretrainedConfig + + +class RWKV7Config(PretrainedConfig): + + model_type = 'rwkv7' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + attn_mode: str = "chunk", + hidden_size: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + num_hidden_layers: int = 24, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, + decay_low_rank_dim: int = 64, + gate_low_rank_dim: int = 128, + a_low_rank_dim: int = 64, + v_low_rank_dim: int = 16, + hidden_act: str = "sqrelu", + max_position_embeddings: int = 2048, + norm_first: bool = True, + norm_bias: bool = True, + norm_eps: float = 1e-5, + attn: Optional[Dict] = None, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + initializer_range: float = 0.02, + fuse_norm: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + value_dim: Optional[Union[int, List[int]]] = None, + **kwargs + ): + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.norm_first = norm_first + self.num_hidden_layers = num_hidden_layers + + if head_dim is None and num_heads is not None: + head_dim = int(hidden_size // num_heads) + elif head_dim is not None and num_heads is None: + num_heads = int(hidden_size // head_dim) + + if value_dim is None: + value_dim = [hidden_size] * num_hidden_layers + elif isinstance(value_dim, int): + assert value_dim >= hidden_size, "value_dim must be greater than hidden_size" + assert value_dim % hidden_size == 0, "value_dim must be divisible by hidden_size" + value_dim = [value_dim] * num_hidden_layers + else: + assert len(value_dim) == num_hidden_layers, "value_dim must have the same length as num_hidden_layers" + for v in value_dim: + assert v >= hidden_size, "value_dim must be greater than hidden_size" + assert v % hidden_size == 0, "value_dim must be divisible by hidden_size" + + self.head_dim = head_dim + self.num_heads = num_heads + self.value_dim = value_dim + + self.decay_low_rank_dim = decay_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + self.a_low_rank_dim = a_low_rank_dim + self.v_low_rank_dim = v_low_rank_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.norm_bias = norm_bias + self.norm_eps = norm_eps + self.attn = attn + self.use_cache = use_cache + self.initializer_range = initializer_range + self.fuse_norm = fuse_norm + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + if attn is not None: + if not isinstance(attn, Dict): + raise ValueError("attn must be a dictionary") + if 'layers' not in attn: + raise ValueError("Layer indices must be provided to initialize hybrid attention layers") + if 'num_heads' not in attn: + raise ValueError("Number of heads must be provided to initialize hybrid attention layers") + attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads']) + attn['qkv_bias'] = attn.get('qkv_bias', False) + attn['window_size'] = attn.get('window_size', None) + attn['rope_theta'] = attn.get('rope_theta', 10000.) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/rwkv7/modeling_rwkv7.py b/fla3/models/rwkv7/modeling_rwkv7.py new file mode 100644 index 0000000000000000000000000000000000000000..b0271249a5713f522166405db37d6b6fbcc36165 --- /dev/null +++ b/fla3/models/rwkv7/modeling_rwkv7.py @@ -0,0 +1,581 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.rwkv7 import RWKV7Attention +from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm +from fla.modules.activations import ACT2FN +from fla.modules.token_shift import token_shift + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class RWKV7FeedForward(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'sqrelu', + layer_idx: int = None, + num_hidden_layers: int = None, + ) -> RWKV7FeedForward: + super().__init__() + + self.hidden_size = hidden_size + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio) + intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.x_k = nn.Parameter(torch.zeros(hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + self.layer_idx = layer_idx + self.num_hidden_layers = num_hidden_layers + for name, module in self.named_modules(): + module._in_rwkv_module = True + + def _initialize_weights(self, module: nn.Module): + if isinstance(module, RWKV7FeedForward): + with torch.no_grad(): + ratio_1_to_almost0 = 1.0 - (module.layer_idx / module.num_hidden_layers) # 1 to ~0 + ddd = torch.ones(1, 1, module.hidden_size) + for i in range(module.hidden_size): + ddd[0, 0, i] = i / module.hidden_size + module.x_k.data = 1.0 - torch.pow(ddd, ratio_1_to_almost0**4).squeeze() + + # Initialize key and value weights as in CMix_x070 + module.key.weight.data.uniform_(-0.5/(module.hidden_size**0.5), 0.5/(module.hidden_size**0.5)) + module.value.weight.data.zero_() + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + state: Optional[Cache] = None, + **kwargs + ) -> torch.Tensor: + if attention_mask is not None: + x = x.mul(attention_mask[:, -x.shape[-2]:, None]) + if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1) + delta = shifted - x + elif state is not None and state[self.layer_idx]['ffn_state'] is not None: + shifted = self.time_shift(x) + shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1] + delta = shifted - x + else: + cu_seqlens = kwargs.get('cu_seqlens', None) + delta = token_shift(x, cu_seqlens) + if state is not None: + # no need to update the offset twice + state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0) + return self.value(self.act_fn(self.key(x.addcmul(delta, self.x_k)))), state + + +class RWKV7Block(nn.Module): + + def __init__( + self, + config: RWKV7Config, + layer_idx: int + ) -> RWKV7Block: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + if config.norm_first and layer_idx == 0: + self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + if config.attn is not None and layer_idx in config.attn['layers']: + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.attn = RWKV7Attention( + mode=config.attn_mode, + hidden_size=config.hidden_size, + head_dim=config.head_dim, + num_heads=config.num_heads, + decay_low_rank_dim=config.decay_low_rank_dim, + gate_low_rank_dim=config.gate_low_rank_dim, + a_low_rank_dim=config.a_low_rank_dim, + v_low_rank_dim=config.v_low_rank_dim, + norm_eps=config.norm_eps, + fuse_norm=config.fuse_norm, + layer_idx=layer_idx, + value_dim=config.value_dim[layer_idx], + num_hidden_layers=config.num_hidden_layers + ) + self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + self.ffn = RWKV7FeedForward( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_idx=layer_idx, + num_hidden_layers=config.num_hidden_layers + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + v_first: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states + hidden_states = self.attn_norm(residual) + hidden_states, attentions, past_key_values, v_first = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.ffn_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states, attentions, past_key_values, v_first) + + return outputs + + +class RWKV7PreTrainedModel(PreTrainedModel): + + config_class = RWKV7Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['RWKV7Block'] + _supports_cache_class = True + _skip_keys_device_placement = ["past_key_values"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = True, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, nn.Embedding): + # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/train_temp/src/model.py#L396C12-L399C58 + scale = -1e-4 + nn.init.uniform_(module.weight, a=scale, b=-scale) + elif isinstance(module, nn.Linear) and hasattr(self, 'lm_head') and module is self.lm_head: + # https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/train_temp/src/model.py#L403 + if self.config.vocab_size > self.config.hidden_size: + scale = 0.5 * math.sqrt(self.config.vocab_size / self.config.hidden_size) + else: + scale = 0.5 + nn.init.orthogonal_(module.weight, gain=scale) + # Init Attention parameters + elif isinstance(module, (nn.Linear, nn.Conv1d)) and getattr(module, '_in_rwkv_module', False) is False: + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Parameter): + nn.init.normal_(module, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters') and getattr(module, '_in_rwkv_module', False) is False: + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class RWKV7Model(RWKV7PreTrainedModel): + + def __init__(self, config: RWKV7Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)( + config.hidden_size, + bias=config.norm_bias, + eps=config.norm_eps + ) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def load_state_dict(self, state_dict, strict=True, assign=False): + """ + Override the load_state_dict method to handle migration from version 1 to version 2. + Handles hierarchical keys like 'model.layers.0.attn.x_x'. + """ + # Collect all layer indices from the state_dict keys + layer_indices = set() + for key in state_dict.keys(): + if key.startswith("model.layers."): + # Extract the layer index from the key + try: + layer_idx = int(key.split(".")[2]) # Extract the number after 'model.layers.' + layer_indices.add(layer_idx) + except ValueError: + # Skip keys that don't match the expected format + continue + + # Sort the layer indices to process them in order + sorted_layer_indices = sorted(layer_indices) + + # Migration logic for each layer + for layer_idx in sorted_layer_indices: + layer_prefix = f"model.layers.{layer_idx}" + attn_prefix = f"{layer_prefix}.attn" + + # Check if the layer contains the old 'x_x' parameter + if f"{attn_prefix}.x_x" in state_dict: + logger.info(f"Migrating weights for layer {layer_idx} from RWKV7Attention version 1 to version 2...") + # Extract the x_x parameter + x_x = state_dict[f"{attn_prefix}.x_x"] + with torch.no_grad(): + # Create new parameters for version 2 + state_dict[f"{attn_prefix}.x_r"] = x_x[0].unsqueeze(0).unsqueeze(0) + state_dict[f"{attn_prefix}.x_w"] = x_x[1].unsqueeze(0).unsqueeze(0) + state_dict[f"{attn_prefix}.x_k"] = x_x[2].unsqueeze(0).unsqueeze(0) + state_dict[f"{attn_prefix}.x_v"] = x_x[3].unsqueeze(0).unsqueeze(0) + state_dict[f"{attn_prefix}.x_a"] = x_x[4].unsqueeze(0).unsqueeze(0) + state_dict[f"{attn_prefix}.x_g"] = x_x[5].unsqueeze(0).unsqueeze(0) + + # Call the parent method to load the modified state_dict + try: + super().load_state_dict(state_dict, strict=strict, assign=assign) + except TypeError: + # If the parent method does not support `assign`, fall back to strict loading + logger.warning( + "`assign` parameter is not supported by the parent `load_state_dict` method. " + "Falling back to default behavior." + ) + super().load_state_dict(state_dict, strict=strict) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, BaseModelOutputWithPast]: + if output_attentions: + warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.") + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + hidden_states = inputs_embeds + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + + v_first = torch.zeros_like(hidden_states) + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + use_cache, + output_attentions, + v_first, + **kwargs + ) + else: + hidden_states, attentions, past_key_values, v_first = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + v_first=v_first, + **kwargs + ) + + if output_attentions: + all_attns += (attentions,) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = RWKV7Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def generate(self, *args, **kwargs): + try: + return super().generate(*args, **kwargs) + except AttributeError as exception: + if 'past_key_values' in str(exception): + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) + else: + raise exception + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + labels: Optional[torch.LongTensor] = None, + shift_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + has_labels = (labels is not None) or (shift_labels is not None) + if not (fuse_linear_and_cross_entropy and has_labels): + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if has_labels: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + + # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files. + if shift_labels is None: + shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + shift_labels = shift_labels.to(hidden_states.device) + + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/samba/__init__.py b/fla3/models/samba/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a27a4b4cac782eb4a3e6c35216405d320e2c6507 --- /dev/null +++ b/fla3/models/samba/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.samba.configuration_samba import SambaConfig +from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel + +AutoConfig.register(SambaConfig.model_type, SambaConfig, True) +AutoModel.register(SambaConfig, SambaModel, True) +AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True) + + +__all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock'] diff --git a/fla3/models/samba/__pycache__/__init__.cpython-310.pyc b/fla3/models/samba/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e4b386871a8718a904c124a823d2ce4a1fee978 Binary files /dev/null and b/fla3/models/samba/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/samba/__pycache__/configuration_samba.cpython-310.pyc b/fla3/models/samba/__pycache__/configuration_samba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6be7bce5d6de69ce3a5b23f22ef68771d3a3e76 Binary files /dev/null and b/fla3/models/samba/__pycache__/configuration_samba.cpython-310.pyc differ diff --git a/fla3/models/samba/__pycache__/modeling_samba.cpython-310.pyc b/fla3/models/samba/__pycache__/modeling_samba.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b7c59cda42213baa0d1e2d2f7baf29ef08aa25 Binary files /dev/null and b/fla3/models/samba/__pycache__/modeling_samba.cpython-310.pyc differ diff --git a/fla3/models/samba/configuration_samba.py b/fla3/models/samba/configuration_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..27311f06a81f0132a409b9dab10b63fc9e19333a --- /dev/null +++ b/fla3/models/samba/configuration_samba.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +import math +from typing import Dict, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class SambaConfig(PretrainedConfig): + + model_type = "samba" + + def __init__( + self, + hidden_size: int = 2304, + state_size: int = 16, + num_hidden_layers: int = 18, + norm_eps=1e-5, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + expand: int = 2, + conv_kernel: int = 4, + use_bias: bool = False, + use_conv_bias: bool = True, + hidden_act: str = "swish", + initializer_range: str = 0.02, + residual_in_fp32: bool = False, + time_step_rank: str = "auto", + time_step_scale: float = 1.0, + time_step_min: float = 0.001, + time_step_max: float = 0.1, + time_step_init_scheme: str = "random", + time_step_floor: float = 1e-4, + max_position_embeddings: int = 2048, + attn: Optional[Dict] = { + 'layers': (1, 3, 5, 7, 9, 11, 13, 15, 17), + 'num_heads': 18, + 'num_kv_heads': 18, + 'qkv_bias': False, + 'window_size': 2048, + 'rope_theta': 10000. + }, + hidden_ratio: Optional[int] = 4, + rescale_prenorm_residual: bool = False, + use_cache: bool = True, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + tie_word_embeddings: bool = False, + **kwargs, + ): + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.norm_eps = norm_eps + self.conv_kernel = conv_kernel + self.expand = expand + self.intermediate_size = int(expand * self.hidden_size) + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor + self.max_position_embeddings = max_position_embeddings + self.attn = attn + self.hidden_ratio = hidden_ratio + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/fla3/models/samba/modeling_samba.py b/fla3/models/samba/modeling_samba.py new file mode 100644 index 0000000000000000000000000000000000000000..9f61e6bfe04542b92614aedc8c274881b76ac114 --- /dev/null +++ b/fla3/models/samba/modeling_samba.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.generation import GenerationMixin +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.layers.mamba import Mamba +from fla.models.mamba.modeling_mamba import MambaCache +from fla.models.samba.configuration_samba import SambaConfig +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as SambaMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + +logger = logging.get_logger(__name__) + + +class SambaBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mixer_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) + if config.attn is not None and layer_idx in config.attn['layers']: + self.mixer = Attention( + hidden_size=config.hidden_size, + num_heads=config.attn['num_heads'], + num_kv_heads=config.attn['num_kv_heads'], + qkv_bias=config.attn['qkv_bias'], + window_size=config.attn['window_size'], + rope_theta=config.attn['rope_theta'], + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + else: + self.mixer = Mamba( + hidden_size=config.hidden_size, + state_size=config.state_size, + conv_kernel=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_bias=config.use_bias, + layer_idx=layer_idx + ) + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = SambaMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Tuple[torch.Tensor]] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.mixer_norm(hidden_states) + if isinstance(self.mixer, Mamba): + hidden_states = self.mixer(hidden_states, cache_params=cache_params, **kwargs) + else: + hidden_states, _, cache_params = self.mixer(hidden_states=hidden_states, past_key_values=cache_params, **kwargs) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + return hidden_states + + +class SambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SambaConfig + base_model_prefix = "backbone" + _no_split_modules = ["SambaBlock"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, Mamba): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_proj.bias.data = nn.Parameter(inv_dt.to(module.dt_proj.bias.device)) + module.dt_proj.bias._no_reinit = True + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) + + +@dataclass +class SambaOutput(ModelOutput): + """ + Class for the Samba model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`MambaCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, + returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[MambaCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class SambaModel(SambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([SambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = RMSNorm(config.hidden_size, eps=config.norm_eps) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + cache_params: Optional[MambaCache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, SambaOutput]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, + hidden_states, + cache_params, + **kwargs + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + **kwargs + ) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return SambaOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + + +class SambaForCausalLM(SambaPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = SambaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def _update_model_kwargs_for_generation( + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs + ) -> Dict[str, Any]: + model_kwargs["cache_params"] = outputs.get("cache_params", None) + return model_kwargs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids, + cache_params: + Optional[MambaCache] = None, + inputs_embeds=None, + attention_mask=None, + use_cache: Optional[bool] = True, + logits_to_keep: Optional[int] = None, + **kwargs: Unpack[Dict] + ): + # only last token for inputs_ids if the state is passed along. + if cache_params is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'cache_params': cache_params, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + 'logits_to_keep': logits_to_keep, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[MambaCache] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Dict] + ) -> Union[Tuple, SambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + **kwargs + ) + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + + loss, logits = None, None + if not fuse_linear_and_cross_entropy or labels is None: + logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:]) + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return SambaCausalLMOutput( + loss=loss, + logits=logits, + cache_params=outputs.cache_params, + hidden_states=outputs.hidden_states, + ) diff --git a/fla3/models/transformer/__init__.py b/fla3/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1a82c4ffb298bda4baf05000ca057c3b5a458f --- /dev/null +++ b/fla3/models/transformer/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.models.transformer.modeling_transformer import TransformerForCausalLM, TransformerModel + +AutoConfig.register(TransformerConfig.model_type, TransformerConfig) +AutoModel.register(TransformerConfig, TransformerModel) +AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) + + +__all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] diff --git a/fla3/models/transformer/__pycache__/__init__.cpython-310.pyc b/fla3/models/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654e970bc804c3d8ca4f63f6b64061b78627ac3c Binary files /dev/null and b/fla3/models/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc b/fla3/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ddbcf0fb991d44b33514ae8b6bdd3068ded4e16 Binary files /dev/null and b/fla3/models/transformer/__pycache__/configuration_transformer.cpython-310.pyc differ diff --git a/fla3/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc b/fla3/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..270f358ee6e6b49b9b935b38001e37a5c8306940 Binary files /dev/null and b/fla3/models/transformer/__pycache__/modeling_transformer.cpython-310.pyc differ diff --git a/fla3/models/transformer/configuration_transformer.py b/fla3/models/transformer/configuration_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9f7de9e64eda540b2f8079f01601c66b19cdb5 --- /dev/null +++ b/fla3/models/transformer/configuration_transformer.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +from transformers.configuration_utils import PretrainedConfig + + +class TransformerConfig(PretrainedConfig): + + model_type = 'transformer' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + hidden_size: int = 2048, + num_hidden_layers: int = 24, + num_heads: int = 32, + num_kv_heads: int = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: int = 2048, + hidden_ratio: Optional[int] = 4, + intermediate_size: Optional[int] = None, + hidden_act: str = "swish", + initializer_range: float = 0.02, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = None, + bos_token_id: int = 1, + eos_token_id: int = 2, + tie_word_embeddings: bool = False, + fuse_norm: bool = True, + fuse_swiglu: bool = True, + fuse_cross_entropy: bool = True, + vocab_size: int = 32000, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + + self.initializer_range = initializer_range + self.elementwise_affine = elementwise_affine + self.norm_eps = norm_eps + self.use_cache = use_cache + + self.fuse_norm = fuse_norm + self.fuse_swiglu = fuse_swiglu + self.fuse_cross_entropy = fuse_cross_entropy + self.vocab_size = vocab_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/fla3/models/transformer/modeling_transformer.py b/fla3/models/transformer/modeling_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..af76d7648adad49da611c4e1ffae8d45dca7160e --- /dev/null +++ b/fla3/models/transformer/modeling_transformer.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +import warnings +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from transformers.generation import GenerationMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.deprecation import deprecate_kwarg + +from fla.layers.attn import Attention +from fla.models.transformer.configuration_transformer import TransformerConfig +from fla.models.utils import Cache +from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss +from fla.modules import GatedMLP as TransformerMLP +from fla.modules import RMSNorm + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +logger = logging.get_logger(__name__) + + +class TransformerBlock(nn.Module): + + def __init__(self, config: TransformerConfig, layer_idx: int): + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.attn = Attention( + hidden_size=config.hidden_size, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + qkv_bias=config.qkv_bias, + qk_norm=config.qk_norm, + window_size=config.window_size, + rope_theta=config.rope_theta, + max_position_embeddings=config.max_position_embeddings, + layer_idx=layer_idx + ) + + self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + self.mlp = TransformerMLP( + hidden_size=config.hidden_size, + hidden_ratio=config.hidden_ratio, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + fuse_swiglu=config.fuse_swiglu + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs: Unpack[Any] + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + hidden_states = self.attn_norm(hidden_states) + hidden_states, attentions, past_key_values = self.attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs + ) + if self.config.fuse_norm: + hidden_states, residual = self.mlp_norm(hidden_states, residual, True) + else: + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.mlp_norm(hidden_states) + hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attentions,) + + if use_cache: + outputs += (past_key_values,) + + return outputs + + +class TransformerPreTrainedModel(PreTrainedModel): + + config_class = TransformerConfig + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['TransformerBlock'] + _supports_cache_class = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights( + self, + module: nn.Module, + rescale_prenorm_residual: bool = False, + num_residuals_per_layer: int = 2, + ): + if isinstance(module, (nn.Linear, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif hasattr(module, 'reset_parameters'): + module.reset_parameters() + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + p = None + if hasattr(module, 'o_proj'): + p = module.o_proj.weight + elif hasattr(module, 'down_proj'): + p = module.down_proj.weight + if p is not None: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) + + +class TransformerModel(TransformerPreTrainedModel): + + def __init__( + self, + config: TransformerConfig + ) -> TransformerModel: + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps) + + self.gradient_checkpointing = False + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + if output_attentions: + warnings.warn( + "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`." + ) + output_attentions = False + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = Cache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_attns = () if output_attentions else None + next_cache = None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + past_key_values, + output_attentions, + use_cache, + **kwargs + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attns + ) + + +class TransformerForCausalLM(TransformerPreTrainedModel, GenerationMixin): + + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = TransformerModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.criterion = None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embeddings + + def set_input_embeddings(self, value): + self.model.embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: bool = True, + logits_to_keep: Optional[int] = None, + **kwargs + ): + # only last token for `inputs_ids` if the `past_key_values` is not empty. + if past_key_values is not None and len(past_key_values) > 0: + input_ids = input_ids[:, -1:] + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(past_key_values) == 0: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. + # Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {'input_ids': input_ids.contiguous()} + + if logits_to_keep is not None: + model_inputs['logits_to_keep'] = logits_to_keep + + model_inputs.update({ + 'past_key_values': past_key_values, + 'use_cache': use_cache, + 'attention_mask': attention_mask, + }) + return model_inputs + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Optional[int] = 0, + **kwargs: Unpack[Any] + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs + ) + + hidden_states = outputs[0] + fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training + logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -logits_to_keep:]) + + loss = None + if labels is not None: + if getattr(self, 'criterion', None) is None: + if fuse_linear_and_cross_entropy: + criterion = FusedLinearCrossEntropyLoss() + elif self.config.fuse_cross_entropy: + criterion = FusedCrossEntropyLoss(inplace_backward=True) + else: + criterion = nn.CrossEntropyLoss() + else: + criterion = self.criterion + # Enable model parallelism + labels = labels.to(hidden_states.device) + labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1) + if fuse_linear_and_cross_entropy: + loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias) + else: + loss = criterion(logits.view(labels.numel(), -1), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fla3/models/utils.py b/fla3/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19f7cd0a34c3f3cb024c62622c3d0ffa7d81d8d4 --- /dev/null +++ b/fla3/models/utils.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import transformers + + +class Cache(transformers.cache_utils.Cache): + """ + A cache used for storing hidden states produced by flash linear attention models. + + It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. + """ + + is_compileable = True + + def __init__( + self, + seen_tokens: int = 0 + ) -> Cache: + super().__init__() + + self.states: List[Dict[str, Any]] = [] + + self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> Dict[str, Any]: + if layer_idx < len(self): + return self.states[layer_idx] + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + for state in self.states: + yield state + + def __len__(self): + return len(self.states) + + def update( + self, + recurrent_state: torch.Tensor = None, + attn_state: Tuple[torch.Tensor] = None, + conv_state: Tuple[torch.Tensor] = None, + ffn_state: torch.Tensor = None, + layer_idx: int = 0, + offset: Optional[int] = 1, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """ + Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`. + + Args: + recurrent_state (`torch.Tensor`, `optional`): + The new recurrent state to cache. + attn_state (`Tuple[torch.Tensor]`, `optional`): + The new attention key/value states to cache. + conv_state (`Tuple[torch.Tensor]`, `optional`): + The new convolution state to cache. + layer_idx (`int`, defaults to 0): + The index of the layer to cache the states for. + offset (`int`, `optional`, defaults to 1): + The number of new tokens being processed. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. + + Return: + Dictionary of the updated state. + """ + + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += offset + + if cache_kwargs is None: + cache_kwargs = {} + if attn_state is not None: + input_size = attn_state[0].shape[1] + window_size = cache_kwargs.get('window_size', None) + if not isinstance(attn_state, Tuple): + raise ValueError("`attn_state` must be a tuple of tensors for key/value states") + if len(self.states) <= layer_idx: + if attn_state is not None: + if window_size is not None and input_size > window_size: + attn_state = [state[:, -window_size:].contiguous() for state in attn_state] + state = dict( + recurrent_state=recurrent_state, + attn_state=attn_state, + conv_state=conv_state, + ffn_state=ffn_state + ) + self.states.append(state) + else: + state = self.states[layer_idx] + if recurrent_state is not None: + state['recurrent_state'] = recurrent_state + if attn_state is not None: + if window_size is not None and state['attn_state'][0].shape[1] == window_size: + for i, (old_state, new_state) in enumerate(zip(state['attn_state'], attn_state)): + # DO NOT allocate new memory if the cache is full + # roll the key/value states to the left by `input_size` + old_state = old_state.roll(-input_size, 1) + # replace the last `input_size` tokens with the new key/value states + old_state[:, -input_size:] = new_state + state['attn_state'][i] = old_state + else: + attn_state = [ + torch.cat([old_state, new_state], 1) + for old_state, new_state in zip(state['attn_state'], attn_state) + ] + state['attn_state'] = attn_state + if conv_state is not None: + state['conv_state'] = conv_state + if ffn_state is not None: + state['ffn_state'] = ffn_state + + return state + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.states) <= layer_idx: + return 0 + return self._seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. Cache does not have a maximum length.""" + return None + + def to_legacy_cache(self) -> Tuple: + return tuple(self.states) + + @classmethod + @torch.compiler.disable + def from_legacy_cache( + cls, + past_key_values: Optional[Tuple] = None, + seen_tokens: int = 0 + ) -> Cache: + """Converts a cache in the legacy cache format into an equivalent `Cache`.""" + + cache = cls(seen_tokens) + if isinstance(past_key_values, list): + for layer_idx in range(len(past_key_values)): + cache.states.append(past_key_values[layer_idx]) + return cache diff --git a/fla3/modules/__init__.py b/fla3/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f06f2b911a1ac9328c26594a5d415470f29fad --- /dev/null +++ b/fla3/modules/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution +from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear +from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss +from fla.modules.fused_kl_div import FusedKLDivLoss +from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss +from fla.modules.fused_norm_gate import ( + FusedLayerNormGated, + FusedLayerNormSwishGate, + FusedLayerNormSwishGateLinear, + FusedRMSNormGated, + FusedRMSNormSwishGate, + FusedRMSNormSwishGateLinear +) +from fla.modules.l2norm import L2Norm +from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear +from fla.modules.mlp import GatedMLP +from fla.modules.rotary import RotaryEmbedding +from fla.modules.token_shift import TokenShift + +__all__ = [ + 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', + 'BitLinear', 'FusedBitLinear', + 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss', + 'L2Norm', + 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', + 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', + 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', + 'GatedMLP', + 'RotaryEmbedding', + 'TokenShift' +] diff --git a/fla3/modules/__pycache__/__init__.cpython-310.pyc b/fla3/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1b83ae99f8349ab54beb5b28fe2eae194b33b53 Binary files /dev/null and b/fla3/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/__init__.cpython-312.pyc b/fla3/modules/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884d29fe3842c800815bdfdc7fac4ccd3f4f9672 Binary files /dev/null and b/fla3/modules/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/modules/__pycache__/activations.cpython-310.pyc b/fla3/modules/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf6fa00c5a8819fd4d5cf48096442620e5949cb Binary files /dev/null and b/fla3/modules/__pycache__/activations.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/convolution.cpython-310.pyc b/fla3/modules/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd18fd0fcb595a79d1897b241ba7916cbacf46f Binary files /dev/null and b/fla3/modules/__pycache__/convolution.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/feature_map.cpython-310.pyc b/fla3/modules/__pycache__/feature_map.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ca80c813a83a9247cbbcb833bb638b008949e9 Binary files /dev/null and b/fla3/modules/__pycache__/feature_map.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/fused_bitlinear.cpython-310.pyc b/fla3/modules/__pycache__/fused_bitlinear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38bb1d480ca7de1d194656c6934f0e1b27981dc6 Binary files /dev/null and b/fla3/modules/__pycache__/fused_bitlinear.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/fused_cross_entropy.cpython-310.pyc b/fla3/modules/__pycache__/fused_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7baf24c03cd94326941e0d3f8010f3ff3e9318f4 Binary files /dev/null and b/fla3/modules/__pycache__/fused_cross_entropy.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/fused_kl_div.cpython-310.pyc b/fla3/modules/__pycache__/fused_kl_div.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeaa6e817a295b9ea8358c05cc8ebed752e8ef56 Binary files /dev/null and b/fla3/modules/__pycache__/fused_kl_div.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/fused_linear_cross_entropy.cpython-310.pyc b/fla3/modules/__pycache__/fused_linear_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06899b7744fa2deb61d9fb2be78e2fb68d0b5272 Binary files /dev/null and b/fla3/modules/__pycache__/fused_linear_cross_entropy.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/fused_norm_gate.cpython-310.pyc b/fla3/modules/__pycache__/fused_norm_gate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48d7e7b800085e4e1e3144b9e8dfe8b0f7a7511 Binary files /dev/null and b/fla3/modules/__pycache__/fused_norm_gate.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/l2norm.cpython-310.pyc b/fla3/modules/__pycache__/l2norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e89a5a4a39dbd06adfe9a16f929238e8c4d2319e Binary files /dev/null and b/fla3/modules/__pycache__/l2norm.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/l2norm.cpython-312.pyc b/fla3/modules/__pycache__/l2norm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79bc249f96b6c9d53c2c05e43a6df98104328a65 Binary files /dev/null and b/fla3/modules/__pycache__/l2norm.cpython-312.pyc differ diff --git a/fla3/modules/__pycache__/layernorm.cpython-310.pyc b/fla3/modules/__pycache__/layernorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1c268381a367d893bbe770fe9232cb40ffaabcd Binary files /dev/null and b/fla3/modules/__pycache__/layernorm.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/layernorm_gated.cpython-310.pyc b/fla3/modules/__pycache__/layernorm_gated.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8505f705ad6173612aa92c5662c3db213470085 Binary files /dev/null and b/fla3/modules/__pycache__/layernorm_gated.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/mlp.cpython-310.pyc b/fla3/modules/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f8120164e37bbfb0cb50f4fb32233c00854527a Binary files /dev/null and b/fla3/modules/__pycache__/mlp.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/rotary.cpython-310.pyc b/fla3/modules/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fa8eef850c94996e692b30c33bbcf64a499ebd8 Binary files /dev/null and b/fla3/modules/__pycache__/rotary.cpython-310.pyc differ diff --git a/fla3/modules/__pycache__/token_shift.cpython-310.pyc b/fla3/modules/__pycache__/token_shift.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26abf2622ea28f78b3c107ed43eea55f42c884c0 Binary files /dev/null and b/fla3/modules/__pycache__/token_shift.cpython-310.pyc differ diff --git a/fla3/modules/activations.py b/fla3/modules/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..74a73a3a3595f0338a6771875f0120266675ae1a --- /dev/null +++ b/fla3/modules/activations.py @@ -0,0 +1,471 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Tri Dao, Yu Zhang, Songlin Yang. + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, get_multiprocessor_count, input_guard + +sigmoid_fwd_codestring = """ +template T sigmoid_fwd(T x) { + return 1.0f / (1.0f + ::exp(-float(x))); +} +""" +sigmoid_bwd_codestring = """ +template T sigmoid_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - x_sigmoid); +} +""" + +sigmoid_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_fwd_codestring) +sigmoid_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(sigmoid_bwd_codestring) + + +@torch.compiler.disable +def sigmoid_fwd(x): + return sigmoid_fwd_jit_fn(x) + + +@torch.compiler.disable +def sigmoid_bwd(x, g): + return sigmoid_bwd_jit_fn(x, g) + + +class SigmoidFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return sigmoid_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return sigmoid_bwd(x, dout) + + +sigmoid = SigmoidFunction.apply + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def logsigmoid_fwd_kernel( + x, + y, + temperature, + T, + D: tl.constexpr, + B: tl.constexpr +): + i = tl.program_id(0) + o_i = i * B + tl.arange(0, B) + m_i = o_i < T + + b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32) + b_m = tl.minimum(0., b_x) + b_z = 1. + exp(-tl.abs(b_x)) + b_y = (b_m - log(b_z)) / temperature + tl.store(y + o_i, b_y.to(y.dtype.element_ty), mask=m_i) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def logsigmoid_bwd_kernel( + x, + dx, + dy, + temperature, + T, + D: tl.constexpr, + B: tl.constexpr +): + i = tl.program_id(0) + o_i = i * B + tl.arange(0, B) + m_i = o_i < T + + b_x = tl.load(x + o_i, mask=m_i, other=0.).to(tl.float32) + b_dy = tl.load(dy + o_i, mask=m_i, other=0.).to(tl.float32) + b_dx = b_dy * (1. - tl.sigmoid(b_x)) / temperature + tl.store(dx + o_i, b_dx.to(dx.dtype.element_ty), mask=m_i) + + +def logsigmoid_fwd(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor: + T, D = x.numel(), x.shape[-1] + B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index))) + y = torch.empty_like(x) + logsigmoid_fwd_kernel[(triton.cdiv(T, B),)]( + x=x, + y=y, + temperature=temperature, + T=T, + D=D, + B=B + ) + return y + + +def logsigmoid_bwd(x: torch.Tensor, dy: torch.Tensor, temperature: float = 1.) -> torch.Tensor: + T, D = x.numel(), x.shape[-1] + B = triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index))) + dx = torch.empty_like(x) + logsigmoid_bwd_kernel[(triton.cdiv(T, B),)]( + x=x, + dx=dx, + dy=dy, + temperature=temperature, + T=T, + D=D, + B=B + ) + return dx + + +class LogSigmoidFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x, temperature): + ctx.save_for_backward(x,) + ctx.temperature = temperature + return logsigmoid_fwd(x, temperature) + + @staticmethod + @input_guard + def backward(ctx, dy): + x, = ctx.saved_tensors + return logsigmoid_bwd(x, dy, ctx.temperature), None + + +def logsigmoid(x: torch.Tensor, temperature: float = 1.) -> torch.Tensor: + return LogSigmoidFunction.apply(x, temperature) + + +swish_fwd_codestring = """ +template T swish_fwd(T x) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(x) * x_sigmoid; +} +""" +swish_bwd_codestring = """ +template T swish_bwd(T x, T g) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + return float(g) * x_sigmoid * (1.0f - float(x) * x_sigmoid + float(x)); +} +""" + +swish_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_fwd_codestring) +swish_bwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swish_bwd_codestring) + + +@torch.compiler.disable +def swish_fwd(x): + return swish_fwd_jit_fn(x) + + +@torch.compiler.disable +def swish_bwd(x, g): + return swish_bwd_jit_fn(x, g) + + +class SwishFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_fwd(x) + + @staticmethod + def backward(ctx, dout): + x, = ctx.saved_tensors + return swish_bwd(x, dout) + + +swish = SwishFunction.apply + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.compile +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.compile +def bias_gelu_bwd(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D)""" + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_bwd(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.compile +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.compile +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( + 1 + tanh_out + ) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + (input,) = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + + +fast_gelu_impl = FastGeLUFunction.apply + + +@torch.compile +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.compile +def sqrelu_fwd(x): + r = F.relu(x.float()) + return (r * r).to(dtype=x.dtype) + + +@torch.compile +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x.float())).to(dtype=x.dtype) + + +class SquaredReLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return sqrelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return sqrelu_bwd(grad_output, input) + + +sqrelu = SquaredReLUFunction.apply + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" + +swiglu_fwdbwd_codestring = """ +template T swiglu_fwdbwd(T x, T y, T g, T& dx, T& dy, T& z) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + float x_swish = float(x) * x_sigmoid; + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = x_swish * float(g); + z = x_swish * float(y); +} +""" + + +swiglu_fwd_jit_fn = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) +swiglu_fwdbwd_jit_fn = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_fwdbwd_codestring, num_outputs=3) + + +@torch.compiler.disable +def swiglu_fwd(x, y): + return swiglu_fwd_jit_fn(x, y) + + +@torch.compiler.disable +def swiglu_bwd(x, y, g): + return swiglu_bwd_jit_fn(x, y, g) + + +@torch.compiler.disable +def swiglu_fwdbwd(x, y, g): + return swiglu_fwdbwd_jit_fn(x, y, g) + + +@torch.compile +def swiglu_fwd_torch(x, y): + return (F.silu(x.float()) * y).to(x.dtype) + + +@torch.compile +def swiglu_bwd_torch(x, y, g): + dtype = x.dtype + x, y, g = x.float(), y.float(), g.float() + x_sigmoid = x.sigmoid() + x_swish = x * x_sigmoid + dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y + dy = x_swish * g + return dx.to(dtype), dy.to(dtype) + + +@torch.compile +def swiglu_fwdbwd_torch(x, y, g): + dtype = x.dtype + x, y, g = x.float(), y.float(), g.float() + x_sigmoid = x.sigmoid() + x_swish = x * x_sigmoid + dx = x_sigmoid * (1 + x * (1.0 - x_sigmoid)) * g * y + dy = x_swish * g + z = x_swish * y + return dx.to(dtype), dy.to(dtype), z.to(dtype) + + +class SwiGLUFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function. + + .. math:: + \text{SwiGLU}(x, y) = swish(x) * y = \frac{x}{1 + \exp(-x)} * y + """ + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor): + return swiglu_fwd_torch(x, y) + else: + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor): + return swiglu_bwd_torch(x, y, dout) + else: + return swiglu_bwd(x, y, dout) + + +class SwiGLULinearFunction(torch.autograd.Function): + r""" + Swish-Gated Linear Unit (SwiGLU) function followed by a linear transformation. + + .. math:: + \text{SwiGLULinear}(x, y, W, b) = (swish(x) * y) W + b + + This simple wrap discards the intermediate results of SwiGLU(x, y) to save memory. + """ + + @staticmethod + @autocast_custom_fwd + def forward(ctx, x, y, weight, bias): + with torch.no_grad(): + if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor): + z = swiglu_fwd_torch(x, y) + else: + z = swiglu_fwd(x, y) + out = F.linear(z, weight, bias) + # We don't store z, will be recomputed in the backward pass to save memory + ctx.save_for_backward(x, y, weight) + ctx.linear_bias_is_none = bias is None + return out + + @staticmethod + @autocast_custom_bwd + def backward(ctx, dout, *args): + x, y, weight = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dz = F.linear(dout, weight.t()).view_as(x) + with torch.no_grad(): + if torch.compiler.is_compiling() or isinstance(x, torch.distributed.tensor.DTensor): + dx, dy, z = swiglu_fwdbwd_torch(x, y, dz) + else: + dx, dy, z = swiglu_fwdbwd(x, y, dz) + dlinear_weight = torch.einsum("bo,bi->oi", dout, z.reshape(-1, z.shape[-1])) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + return dx, dy, dlinear_weight, dlinear_bias + + +swiglu = SwiGLUFunction.apply + + +swiglu_linear = SwiGLULinearFunction.apply + + +ACT2FN = { + 'relu': F.relu, + 'sigmoid': sigmoid, + 'logsigmoid': logsigmoid, + 'silu': swish, + 'swish': swish, + 'sqrelu': sqrelu, + 'gelu': fast_gelu_impl, + 'bias_gelu': bias_gelu_impl, +} diff --git a/fla3/modules/convolution.py b/fla3/modules/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..11c1a06a37cc7a6f416a7924a724c1da92a59c98 --- /dev/null +++ b/fla3/modules/convolution.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py + +import math +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + +from fla.modules.activations import ACT2FN +from fla.ops.utils import prepare_sequence_ids +from fla.utils import checkpoint, input_guard + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] + + out = y + u + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +@checkpoint +def proj_then_conv1d( + x: torch.Tensor, + proj_weight: torch.Tensor, + conv1d_weight: torch.Tensor, + conv1d_bias: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None +) -> torch.Tensor: + # We do matmul and transpose BLH -> HBL at the same time + x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2]) + + if causal_conv1d_fn is None: + raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.") + if cache is None: + x = causal_conv1d_fn( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + activation="silu", + ).transpose(1, 2) + else: + assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now" + x = x.squeeze(-1) + x = causal_conv1d_update( + x=x, + weight=rearrange(conv1d_weight, "d 1 w -> d w"), + bias=conv1d_bias, + cache=cache, + activation="silu", + ) + return x + + +@triton.jit +def causal_conv1d_varlen_states_fwd_kernel( + x, + cache, + offsets, + D, + W, + BD: tl.constexpr, + BW: tl.constexpr +): + i_d, i_w, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2) + eos = tl.load(offsets + i_n + 1) + bos = tl.maximum(tl.load(offsets + i_n), eos - W) + o_t = eos - (i_w + 1) * BW + tl.arange(0, BW) + o_d = i_d * BD + tl.arange(0, BD) + o_w = W - (i_w + 1) * BW + tl.arange(0, BW) + + b_x = tl.load(x + o_t * D + o_d[:, None], mask=(o_t >= bos) & (o_d[:, None] < D), other=0) + tl.store(cache + i_n * D*W + o_d[:, None] * W + o_w, b_x, mask=(o_d[:, None] < D) & (o_w >= 0)) + + +@input_guard +def causal_conv1d_varlen_states_fwd( + x: torch.Tensor, + cache: torch.Tensor, + cu_seqlens: torch.Tensor, + state_len: int +) -> torch.Tensor: + N, D, W = len(cu_seqlens) - 1, x.shape[-1], state_len + cache = torch.empty(N, D, W, dtype=x.dtype, device=x.device) if cache is None else cache + BD = min(triton.next_power_of_2(D), 256) + BW = min(triton.next_power_of_2(state_len), 16) + grid = (triton.cdiv(D, BD), triton.cdiv(W, BW), N) + with torch.cuda.device(x.device.index): + causal_conv1d_varlen_states_fwd_kernel[grid]( + x=x, + cache=cache, + offsets=cu_seqlens, + D=D, + W=W, + BW=BW, + BD=BD + ) + return cache + + +class ShortConvolution(nn.Conv1d): + """ + Simple wrapper around `nn.Conv1d` that accepts dimension last. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + bias: bool = False, + activation: Optional[str] = 'silu', + use_fast_conv1d: Optional[bool] = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, + bias=bias, + padding=kernel_size - 1, + device=device, + dtype=dtype, + ) + + self.hidden_size = hidden_size + self.activation = None + if activation is not None: + assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet." + self.activation = activation + + if causal_conv1d_fn is None: + if use_fast_conv1d: + raise RuntimeError( + "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel " + "or set `use_fast_conv1d` to False" + ) + else: + warnings.warn( + "The naive Pytorch verison is very slow in practice, " + "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel", + category=ImportWarning + ) + self.use_fast_conv1d = use_fast_conv1d + + def extra_repr(self): + s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' + ', stride={stride}') + if self.padding != (0,) * len(self.padding): + s += ', padding={padding}' + if self.dilation != (1,) * len(self.dilation): + s += ', dilation={dilation}' + if self.output_padding != (0,) * len(self.output_padding): + s += ', output_padding={output_padding}' + if self.groups != 1: + s += ', groups={groups}' + if self.bias is None: + s += ', bias=False' + if self.padding_mode != 'zeros': + s += ', padding_mode={padding_mode}' + if self.activation is not None: + s += ', activation={activation}' + if not self.use_fast_conv1d: + s += ', use_fast_conv1d={use_fast_conv1d}' + return s.format(**self.__dict__) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + cache: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x (`torch.Tensor`): + Tensor of shape `[B, T, D]`. + If `seq_idx` is provided, `B` must be 1. + mask (`Optional[torch.Tensor]`): + Attention mask dealing with padded positions. + cache (`Optional[torch.Tensor]`): + Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size. + If provided, the cache is updated **inplace**. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, D, W]`. Default: `False`. + cu_seqlens (Optional[torch.LongTensor]): + Cumulative sequence lengths for each batch. Used for varlen. Default: `None`. + Shape: [B+1] + + Returns: + Tensor of shape `[B, T, D]`. + """ + + B, T, D, W = *x.shape, self.kernel_size[0] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + if mask is not None: + if cu_seqlens is not None: + raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time") + x = x.mul_(mask.unsqueeze(-1)) + if output_final_state and cache is None: + cache = x.new_zeros(N, D, W) + # during the decoding phase, we assume the batch is composed of sequences of length 1 + if cache is not None and B * T == N: + return self.step(x, cache, cu_seqlens) + + if cache is not None: + if cu_seqlens is not None: + cache = causal_conv1d_varlen_states_fwd(x, cache, cu_seqlens, W) + else: + cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w')) + + x = rearrange(x, 'b t d -> b d t') + if self.use_fast_conv1d: + # Sequence index for each token. Used for varlen. + # Suppose a batch consists of two sequences with lengths 3 and 4, + # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch. + # NOTE: No need to provide this arg if `cu_seqlens` is passed. + # This arg is just for BC, and will be removed in the future. + # [B, T] + seq_idx = kwargs.get('seq_idx', None) + if cu_seqlens is not None and seq_idx is None: + seq_idx = prepare_sequence_ids(cu_seqlens).to(torch.int32).unsqueeze(0) + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + seq_idx=seq_idx, + ) + else: + if cu_seqlens is not None: + raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version") + x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]] + if self.activation is not None: + x = ACT2FN[self.activation](x) + return rearrange(x, "b d t -> b t d"), cache + + def step( + self, + x: torch.Tensor, + cache: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None + ): + shape = x.shape + x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1) + if self.use_fast_conv1d: + x = causal_conv1d_update( + x=x, + conv_state=cache, + weight=rearrange(self.weight, "d 1 w -> d w"), + bias=self.bias, + activation=self.activation, + ) + else: + dtype = x.dtype + # we follow the fast mode that updates the cache in-place + cache.copy_(cache.roll(shifts=-1, dims=-1)) + cache[:, :, -1] = x + x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1) + if self.bias is not None: + x = x + self.bias + if self.activation is not None: + x = ACT2FN[self.activation](x).to(dtype=dtype) + return x.view(shape), cache + + @property + def state_size(self) -> int: + return self.hidden_size * self.kernel_size + + +class LongConvolution(nn.Module): + """ + LongConvolution applies a convolution operation on the input tensor using a fixed + filter of length max_len. + The filter is learned during training and is applied using FFT convolution. + Args: + hidden_size (int): The number of expected features in the input and output. + max_len (int): The maximum sequence length. + Returns: + y: [batch_size, seq_len, hidden_size] tensor + """ + + def __init__( + self, + hidden_size: int, + max_len: int, + **kwargs, + ): + """ + Initializes the LongConvolution module. + Args: + hidden_size (int): The number of expected features in the input and output. + max_len (int): The maximum sequence length. + """ + super().__init__() + self.hidden_size = hidden_size + self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Applies the LongConvolution operation on the input tensor. + Args: + x: [batch_size, seq_len, hidden_size] tensor + Returns: + y: [batch_size, seq_len, hidden_size] tensor + """ + x = x.transpose(1, 2) + y = fft_conv(x, self.filter, dropout_mask=None, gelu=False) + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) + + +class PositionalEmbedding(nn.Module): + def __init__(self, emb_dim: int, seq_len: int, **kwargs): + """Complex exponential positional embeddings for implicit long convolution filters.""" + super().__init__() + + self.seq_len = seq_len + # The time embedding fed to the filteres is normalized so that t_f = 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + # To compute the right embeddings we use the "proper" linspace + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + self.z = nn.Parameter(z, requires_grad=False) + + def forward(self, L): + return self.z[:, :L] + + +class ImplicitLongConvolution(nn.Module): + """ + Long convolution with implicit filter parameterized by an MLP. + + Args: + hidden_size (int): + The number of expected features in the input and output. + max_len (int): + The maximum sequence length. + d_emb (Optional[int]): + The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine). + Defaults to 3. + d_hidden (Optional[int]): + The number of features in the hidden layer of the MLP. Defaults to 16. + + Attributes: + pos_emb (`PositionalEmbedding`): The positional embedding layer. + mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter. + + """ + + def __init__( + self, + hidden_size: int, + max_len: int, + d_emb: int = 3, + d_hidden: int = 16, + **kwargs, + ): + """ + Long convolution with implicit filter parameterized by an MLP. + + + """ + super().__init__() + self.hidden_size = hidden_size + self.d_emb = d_emb + + assert ( + d_emb % 2 != 0 and d_emb >= 3 + ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)" + self.pos_emb = PositionalEmbedding(d_emb, max_len) + + # final linear layer + self.mlp = nn.Sequential( + nn.Linear(d_emb, d_hidden), + torch.nn.ReLU(), + nn.Linear(d_hidden, hidden_size), + ) + + def filter(self, seq_len: int, *args, **kwargs): + k = self.mlp(self.pos_emb(seq_len)) + + return k.transpose(1, 2) + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + Args: + x: [batch_size, seq_len, hidden_size] tensor + Returns: + y: [batch_size, seq_len, hidden_size] tensor + """ + x = x.transpose(1, 2) + k = self.filter(x.shape[-1]) + y = fft_conv(x, k, dropout_mask=None, gelu=False) + + y = y.transpose(1, 2) + return y.to(dtype=x.dtype) diff --git a/fla3/modules/feature_map.py b/fla3/modules/feature_map.py new file mode 100644 index 0000000000000000000000000000000000000000..6af81e74d3975f67b8df23c1dfa60cd01b5a4950 --- /dev/null +++ b/fla3/modules/feature_map.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from fla.modules.activations import fast_gelu_impl, sigmoid, sqrelu, swish +from fla.modules.layernorm import layer_norm +from fla.utils import checkpoint + + +@checkpoint +def flatten_diag_outer_product(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N) + return z[..., indicies[0], indicies[1]] + + +@checkpoint +def flatten_diag_outer_product_off1(x, y): + z = torch.einsum("...i,...j->...ij", x, y) + N = z.size(-1) + indicies = torch.triu_indices(N, N, 1) + indices2 = torch.arange(0, N) + return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] + + +def is_power_of_2(n): + return (n & (n - 1) == 0) and n != 0 + + +class HedgehogFeatureMap(nn.Module): + + r""" + Hedgehog feature map as introduced in + `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry `_ + """ + + def __init__( + self, + head_dim: int + ) -> HedgehogFeatureMap: + super().__init__() + # Trainable map + self.layer = nn.Linear(head_dim, head_dim) + self.init_weights_() + + def init_weights_(self): + """Initialize trainable map as identity""" + with torch.no_grad(): + identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float) + self.layer.weight.copy_(identity.to(self.layer.weight)) + nn.init.zeros_(self.layer.bias) + + def forward(self, x: torch.Tensor): + x = self.layer(x) # shape b, h, l, d + return torch.cat([2*x, -2*x], dim=-1).softmax(-1) + + +class T2RFeatureMap(nn.Module): + + r""" + Simple linear mapping feature map as in + `Finetuning Pretrained Transformers into RNNs `_ + """ + + def __init__( + self, + head_dim: int, + dot_dim: int = None, + bias: Optional[bool] = False + ) -> T2RFeatureMap: + super().__init__() + # Trainable map + if dot_dim is None: + dot_dim = head_dim + + self.head_dim = head_dim + self.dot_dim = dot_dim + self.bias = bias + + self.layer = nn.Linear(head_dim, dot_dim, bias=bias) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(head_dim={self.head_dim}, dot_dim={self.dot_dim}, bias={self.bias})" + + def forward(self, x: torch.Tensor): + return self.layer(x).relu() + + +class DPFPFeatureMap(nn.Module): + + r""" + Deterministic Parameter-Free Projection (DPFP) feature map in + `Linear Transformers Are Secretly Fast Weight Programmers `_ + """ + + def __init__( + self, + head_dim: int, + nu: int = 4 + ) -> DPFPFeatureMap: + super().__init__() + self.nu = nu + + def forward(self, x: torch.Tensor): + x = torch.cat([x.relu(), -x.relu()], dim=-1) + x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1) + x_repeat = torch.cat([x] * self.nu, dim=-1) + return x_repeat * x_rolled + + +class HadamardFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> HadamardFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, head_dim) + self.layer2 = nn.Linear(head_dim, head_dim) + + def forward(self, x: torch.Tensor): + return self.layer1(x) * self.layer2(x) + + +class LearnableOuterProductFeatureMap(nn.Module): + def __init__( + self, + head_dim: int, + feature_dim: int + ) -> LearnableOuterProductFeatureMap: + super().__init__() + # Trainable map + self.layer1 = nn.Linear(head_dim, feature_dim, bias=False) + self.layer2 = nn.Linear(head_dim, feature_dim, bias=False) + self.normalizer = feature_dim ** -0.5 + + def forward(self, x: torch.Tensor): + return flatten_diag_outer_product(self.layer1(x), self.layer2(x)) + + +class LearnablePolySketchNonNegativeFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + sketch_size: Optional[int] = None, + degree: Optional[int] = 2 + ) -> LearnablePolySketchNonNegativeFeatureMap: + super().__init__() + + assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2" + + self.head_dim = head_dim + self.sketch_size = sketch_size if sketch_size is not None else head_dim + self.degree = degree + + self.gamma = nn.Parameter(torch.ones(head_dim)) + self.beta = nn.Parameter(torch.zeros(head_dim)) + # NOTE: the sketch layers defined here are quite different from the original paper + # currently we simply use linear layers without any non-linear activations + self.sketches1 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + self.sketches2 = nn.ModuleList([ + nn.Linear(head_dim, sketch_size, bias=False), + *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] + ]) + + def forward(self, x: torch.Tensor): + # Section 2.1 + x = layer_norm(x, self.gamma, self.beta) + # first map the input to sketch size with learnable parameters + x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5 + for i in range(1, int(math.log2(self.degree)) - 1): + x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5 + # do sketch mapping for log2(p) - 1 times in total + # do p=2 mapping to ensure non-negativity + return flatten_diag_outer_product(x, x) + + +class TaylorFeatureMap(nn.Module): + def __init__( + self, + head_dim: int + ) -> TaylorFeatureMap: + super().__init__() + self.head_dim = head_dim + self.r2 = math.sqrt(2) + self.rd = math.sqrt(self.head_dim) + self.rrd = math.sqrt(self.rd) + + def forward(self, x: torch.Tensor): + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1) + + +class RebasedFeatureMap(nn.Module): + + def __init__( + self, + head_dim: int, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True + ) -> RebasedFeatureMap: + super().__init__() + + self.head_dim = head_dim + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + + self.gamma = None + self.beta = None + if use_gamma: + self.gamma = nn.Parameter(torch.ones(head_dim)) + if use_beta: + self.beta = nn.Parameter(torch.zeros(head_dim)) + + def forward(self, x: torch.Tensor, flatten: Optional[bool] = True): + if self.use_beta and self.use_gamma and self.normalize: + x = layer_norm(x, self.gamma, self.beta) + elif self.normalize: + x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta) + elif self.use_gamma and self.use_beta: + x = torch.addcmul(self.beta, x, self.gamma) + elif self.use_gamma: + x = x.mul(self.gamma) + else: + raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, " + f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)") + if not flatten: + return x + x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) + # rebased use learnable parameters to approximate any quadratic function + return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1) + + +class ReLUFeatureMap(nn.Module): + + def __init__( + self, + ) -> ReLUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return F.relu(x) + + +class SquaredReLUFeatureMap(nn.Module): + + def __init__( + self, + ) -> SquaredReLUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return sqrelu(x) + + +class GELUFeatureMap(nn.Module): + + def __init__( + self, + ) -> GELUFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return fast_gelu_impl(x) + + +class SwishFeatureMap(nn.Module): + + def __init__( + self, + ) -> SwishFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return swish(x) + + +class SigmoidFeatureMap(nn.Module): + + def __init__( + self, + ) -> SigmoidFeatureMap: + super().__init__() + + def forward(self, x: torch.Tensor): + return sigmoid(x) diff --git a/fla3/modules/fused_bitlinear.py b/fla3/modules/fused_bitlinear.py new file mode 100644 index 0000000000000000000000000000000000000000..d05928eaf6b721f29a4d15967cae3a8e014e7c9c --- /dev/null +++ b/fla3/modules/fused_bitlinear.py @@ -0,0 +1,638 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Implementations of BitLinear layer with fused LayerNorm and quantized Linear layer. +# [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) +# [Scalable MatMul-free Language Modeling](https://arxiv.org/abs/2406.02528) + +# Code adapted from https://github.com/ridgerchu/matmulfreellm/ + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.modules.layernorm import RMSNorm +from fla.utils import get_multiprocessor_count, input_guard, require_version + + +def activation_quant(x): + """ + Per-token quantization to 8 bits. No grouping is needed for quantization. + + Args: + x: An activation tensor with shape [n, d]. + + Returns: + A quantized activation tensor with shape [n, d]. + """ + # Compute the scale factor + scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) + # Quantize and then de-quantize the tensor + y = (x * scale).round().clamp_(-128, 127) / scale + return y + + +def weight_quant(w): + """ + Per-tensor quantization to 1.58 bits. No grouping is needed for quantization. + + Args: + w: A weight tensor with shape [d, k]. + + Returns: + A quantized weight tensor with shape [d, k]. + """ + # Compute the scale factor + scale = 1.0 / w.abs().mean().clamp_(min=1e-5) + # Quantize and then de-quantize the tensor + u = (w * scale).round().clamp_(-1, 1) / scale + return u + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +@triton.jit +def layer_norm_fwd_kernel_quant( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + + y = x_hat * w if HAS_WEIGHT else x_hat + if HAS_BIAS: + y = y + b + + # Aply quantization to the output + scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5) + # Quantize and then de-quantize the tensor + y = tl.extra.cuda.libdevice.round(y * scale) + y = tl.maximum(tl.minimum(y, 127), -128) / scale + + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd_quant( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + residual: torch.Tensor = None, + out_dtype: torch.dtype = None, + residual_dtype: torch.dtype = None, + is_rms_norm: bool = False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + layer_norm_fwd_kernel_quant[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.heuristics({ + "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +@triton.jit +def layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w if HAS_WEIGHT else xhat + if HAS_BIAS: + y = y + b + + # Aply quantization to the output + scale = 127.0 / tl.maximum(tl.max(tl.abs(y), 0), 1e-5) + # Quantize and then de-quantize the tensor + y = tl.extra.cuda.libdevice.round(y * scale) + y = tl.maximum(tl.minimum(y, 127), -128) / scale + + tl.store(Y + cols, y, mask=mask) + wdy = dy + if HAS_WEIGHT: + wdy = dy * w + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_WEIGHT: + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + mean: torch.Tensor, + rstd: torch.Tensor, + dresidual: torch.Tensor = None, + has_residual: bool = False, + is_rms_norm: bool = False, + x_dtype: torch.dtype = None, + recompute_output: bool = False, +): + M, N = x.shape + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = get_multiprocessor_count(x.device.index) + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) if weight is not None else None + _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + weight is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) if weight is not None else None + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormLinearQuantFn(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) + y, mean, rstd, residual_out = layer_norm_fwd_quant( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = weight_quant(linear_weight).to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @input_guard + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_quant_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearQuantFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) + + +def rms_norm_linear_quant( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False +): + return layer_norm_linear_quant_fn( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + linear_weight=linear_weight, + linear_bias=linear_bias, + residual=residual, + eps=eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True + ) + + +@require_version("triton>=3.0", "Triton >= 3.0 is required to do online quantization.") +def bit_linear(x, weight, bias=None, norm_weight=None, norm_bias=None, eps=1e-8): + """ + A functional version of BitLinear that applies quantization to activations and weights. + + Args: + x: Input tensor with shape [n, d]. + weight: Weight tensor with shape [out_features, in_features]. + bias: Bias tensor with shape [out_features] (optional). + norm_weight: Weight tensor for RMS normalization with shape [in_features]. + norm_bias: Bias tensor for RMS normalization with shape [in_features]. + eps: A small constant for numerical stability in normalization. + + Returns: + Output tensor with shape [n, out_features]. + """ + return layer_norm_linear_quant_fn( + x, + norm_weight, + norm_bias, + weight, + bias, + is_rms_norm=True + ) + + +class BitLinear(nn.Linear): + """ + A custom linear layer that applies quantization on both activations and weights. + This is primarily for training; kernel optimization is needed for efficiency in deployment. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + norm_eps: float = 1e-8 + ): + """ + Initializes the BitLinear layer. + + Args: + in_features: Size of each input sample. + out_features: Size of each output sample. + bias: If set to False, the layer will not learn an additive bias. Default: True. + """ + # Initialize the superclass nn.Linear with the given parameters + super(BitLinear, self).__init__(in_features, out_features, bias=bias) + + self.norm = RMSNorm(in_features, eps=norm_eps) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({super().extra_repr()}, norm_eps={self.norm.eps})" + + def forward(self, x): + """ + Overrides the forward pass to include quantization. + + Args: + x: An input tensor with shape [n, d]. + + Returns: + An output tensor with shape [n, d]. + """ + # Weight tensor + w = self.weight + + # Apply RMS normalization to the input + x_norm = self.norm(x) + + # Apply quantization to both activations and weights + # Uses Straight-Through Estimator (STE) trick with .detach() for gradient flow + x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() + w_quant = w + (weight_quant(w) - w).detach() + # Perform linear operation with quantized values + y = F.linear(x_quant, w_quant) + + return y + + +class FusedBitLinear(BitLinear): + """ + A custom linear layer that applies quantization on both activations and weights. + This is primarily for training; kernel optimization is needed for efficiency in deployment. + """ + + def __init__(self, in_features, out_features, bias=False): + """ + Initializes the BitLinear layer. + + Args: + in_features: Size of each input sample. + out_features: Size of each output sample. + bias: If set to False, the layer will not learn an additive bias. Default: True. + """ + # Initialize the superclass nn.Linear with the given parameters + super(FusedBitLinear, self).__init__(in_features, out_features, bias=bias) + + def forward(self, x): + return layer_norm_linear_quant_fn( + x, + self.norm.weight, + self.norm.bias, + self.weight, + self.bias, + is_rms_norm=True + ) diff --git a/fla3/modules/fused_cross_entropy.py b/fla3/modules/fused_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..f85091f66fe5539d4d6c68ca801b3b51ac8b94e4 --- /dev/null +++ b/fla3/modules/fused_cross_entropy.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Tri Dao. + +from typing import Any, Tuple + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log +from fla.utils import input_guard + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +@triton.heuristics({ + "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0, +}) +@triton.jit +def cross_entropy_fwd_kernel( + loss_ptr, # data ptrs + lse_ptr, + z_loss_ptr, + logits_ptr, + labels_ptr, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + n_rows, + logits_row_stride, # strides + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, + # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE + SPLIT: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")) + logits = logits.to(tl.float32) * logit_scale + max_logits = tl.max(logits, 0) + if HAS_SMOOTHING: + sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) + lse = log(tl.sum(exp(logits - max_logits), 0)) + max_logits + tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) + if label_idx == ignore_index: + loss = 0.0 + z_loss = 0.0 + else: + label_idx -= class_start_idx + if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( + n_cols, (col_block_idx + 1) * BLOCK_SIZE + ): + logits_label = tl.load(logits_ptr + label_idx) * logit_scale + if HAS_SMOOTHING: + loss = ( + (lse if not SPLIT else 0.0) + - label_smoothing * sum_logits / total_classes + - (1 - label_smoothing) * logits_label + ) + else: + loss = (lse if not SPLIT else 0.0) - logits_label + else: + # If label is out of bounds, we set the CE loss to 0.0. But we still want the label_smoothing loss + if HAS_SMOOTHING: + loss = label_smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) + else: + loss = 0.0 + if not SPLIT: + z_loss = lse_square_scale * lse * lse + loss += z_loss + else: + z_loss = 0.0 + tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) + if not SPLIT: + tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) + + +@triton.heuristics({ + "HAS_SMOOTHING": lambda args: args["label_smoothing"] > 0.0, +}) +@triton.jit +def cross_entropy_bwd_kernel( + dlogits_ptr, # data ptrs + dloss_ptr, + logits_ptr, + lse_ptr, + labels_ptr, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes + n_cols, # shapes + logits_row_stride, # strides + dlogits_row_stride, + dloss_row_stride, + BLOCK_SIZE: tl.constexpr, + HAS_SMOOTHING: tl.constexpr, +): + row_idx = tl.program_id(0) + col_block_idx = tl.program_id(1) + logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) + dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) + col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + label_idx = tl.load(labels_ptr + row_idx) + if label_idx != ignore_index: + dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) + else: + dloss = 0.0 + logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( + tl.float32 + ) * logit_scale + lse = tl.load(lse_ptr + row_idx) + probs = exp(logits - lse) + probs += 2.0 * lse_square_scale * lse * probs + label_idx -= class_start_idx + if HAS_SMOOTHING: + smooth_negative = label_smoothing / total_classes + probs = tl.where(col_offsets == label_idx, probs - (1 - label_smoothing), probs) - smooth_negative + else: + probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) + tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) + + +def fused_cross_entropy_forward( + logits: torch.Tensor, + target: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index: int = -100, + process_group=None, +): + n_rows, n_cols = logits.shape + assert target.shape == (n_rows,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + total_classes = world_size * n_cols + rank = 0 if process_group is None else torch.distributed.get_rank(process_group) + class_start_idx = rank * n_cols + + if logits.stride(-1) != 1: + logits = logits.contiguous() + # Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py + MAX_BLOCK_SIZE = 64 * 1024 + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) + num_warps = ( + 4 + if BLOCK_SIZE < 2048 + else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) + ) + # We may split the lse computation across multiple blocks, then do a reduction + # lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) + # where having just one thread block processing more than 64k elements is slow. + split = world_size > 1 or n_cols > MAX_BLOCK_SIZE + n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE + loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) + losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) + + cross_entropy_fwd_kernel[(n_rows, n_splits)]( + losses, # data ptrs + lse, + z_losses, + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + total_classes, + class_start_idx, + n_cols, # shapes + n_rows, + logits.stride(0), # strides + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + SPLIT=split + ) + + if split: + # If there's no label_smoothing, if target are in the vocab of this partition, losses contains + # - predicted logit, and 0 otherwise. + # If there's label_smoothing=0.1, for target in the vocab of this partition, losses contains + # -0.9 * predicted logit - 0.1 * sum logit / total_classes. + # For target not in the vocab of this partition, losses contains + # -0.1 * sum logit / total_classes. + if n_splits > 1: + lse = torch.logsumexp(lse, dim=0) + losses = losses.sum(dim=0) + if world_size > 1: + lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + handle_losses.wait() + # After the allreduce, if there's no label_smoothing, the total losses are - predicted_logit, + # we just have to add the (global) lse. + # If there's label_smoothing=0.1, the total losses are + # -0.9 * predicted_logit - 0.1 * sum logit / total_classes. + # Again, we just have to add the (global) lse. + losses += lse + if lse_square_scale != 0.0: + z_losses = lse_square_scale * lse.square() + z_losses.masked_fill_(target == ignore_index, 0.0) + losses += z_losses + else: + z_losses = torch.zeros_like(losses) + losses.masked_fill_(target == ignore_index, 0.0) + + return losses, z_losses, lse, total_classes, class_start_idx + + +class CrossEntropyLossFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + logits, + target, + label_smoothing=0.0, + logit_scale=1.0, + lse_square_scale=0.0, + ignore_index=-100, + inplace_backward=False, + process_group=None, + ): + losses, z_losses, lse, total_classes, class_start_idx = fused_cross_entropy_forward( + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + process_group, + ) + ctx.save_for_backward(logits, lse, target) + ctx.mark_non_differentiable(z_losses) + ctx.label_smoothing = label_smoothing + ctx.logit_scale = logit_scale + ctx.lse_square_scale = lse_square_scale + ctx.ignore_index = ignore_index + ctx.total_classes = total_classes + ctx.class_start_idx = class_start_idx + ctx.inplace_backward = inplace_backward + + return losses, z_losses + + @staticmethod + @input_guard + def backward(ctx, grad_losses, grad_z_losses): + del grad_z_losses # z_losses are only for logging. + + logits, lse, target = ctx.saved_tensors + dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) + n_rows, n_cols = logits.shape + BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) + num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) + def grid(META): return (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa + cross_entropy_bwd_kernel[grid]( + dlogits, # data ptrs + grad_losses, + logits, + lse, + target, + ctx.label_smoothing, + ctx.logit_scale, + ctx.lse_square_scale, + ctx.ignore_index, + ctx.total_classes, + ctx.class_start_idx, + n_cols, # shapes + logits.stride(0), # strides + dlogits.stride(0), + grad_losses.stride(0), + BLOCK_SIZE=BLOCK_SIZE, # constants + num_warps=num_warps, + ) + return dlogits, None, None, None, None, None, None, None, None + + +def cross_entropy_loss( + logits: torch.Tensor, + target: torch.Tensor, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + ignore_index=-100, + inplace_backward: bool = False, + process_group=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + logits: [batch, vocab_size] + target: [batch,] + label_smoothing: float + logit_scale: float. + Multiply logits by this scale before calculating the loss. + lse_square_scale: float. + If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + inplace_backward: bool. + If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: + if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + Returns: + losses: [batch,], float + z_losses: [batch,], float + """ + return CrossEntropyLossFunction.apply( + logits, + target, + label_smoothing, + logit_scale, + lse_square_scale, + ignore_index, + inplace_backward, + process_group, + ) + + +class FusedCrossEntropyLoss(nn.Module): + def __init__( + self, + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + lse_square_scale: float = 0.0, + inplace_backward: bool = False, + process_group: Any = None, + return_z_loss: bool = False, + ): + """ + Arguments: + ignore_index: int. If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. + This is also referred to as "z-loss". + inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. + This saves memory. + process_group: if not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss will be aggregated across processes. + return_z_loss: bool. If True, we return the component of the loss contributed by + the lse_square_scale value. This value is only for logging and does not support + backprop. + """ + super().__init__() + if reduction not in ["mean", "none", "sum"]: + raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.lse_square_scale = lse_square_scale + self.inplace_backward = inplace_backward + self.process_group = process_group + self.return_z_loss = return_z_loss + + def forward(self, input, target): + """ + Arguments: + input: (batch, vocab_size) + target: (batch,) + Returns: + losses: (batch,) if reduction is 'none', else (1,), dtype float + z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) + """ + assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" + loss, z_loss = cross_entropy_loss( + input, + target, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + lse_square_scale=self.lse_square_scale, + ignore_index=self.ignore_index, + inplace_backward=self.inplace_backward, + process_group=self.process_group, + ) + if self.reduction == "mean": + loss = loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + loss = loss.sum() + else: + loss = loss + + if not self.return_z_loss: + return loss + + if self.reduction == "mean": + z_loss = z_loss.sum() / (target != self.ignore_index).sum() + elif self.reduction == "sum": + z_loss = z_loss.sum() + else: + z_loss = z_loss + + return loss, z_loss diff --git a/fla3/modules/fused_kl_div.py b/fla3/modules/fused_kl_div.py new file mode 100644 index 0000000000000000000000000000000000000000..5e49269dec9e4c09d058c0ac0d5e6e059c6240b8 --- /dev/null +++ b/fla3/modules/fused_kl_div.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log +from fla.utils import input_guard + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 +# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +@triton.jit +def kl_div_kernel( + logits, + target_logits, + loss, + s_logits, + s_loss, + reduction: tl.constexpr, + N: tl.constexpr, + V: tl.constexpr, + BV: tl.constexpr +): + # https://github.com/triton-lang/triton/issues/1058 + # If N*V is too large, i_n * stride will overflow out of int32, so we convert to int64 + i_n = tl.program_id(0).to(tl.int64) + + logits += i_n * s_logits + target_logits += i_n * s_logits + + # m is the max value. use the notation from the paper + sm = float('-inf') + tm = float('-inf') + # d is the sum. use the notation from the paper + sd, td = 0.0, 0.0 + + NV = tl.cdiv(V, BV) + for iv in range(0, NV): + o_x = iv * BV + tl.arange(0, BV) + # for student + b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf')) + b_sm = tl.max(b_sl) + m_new = tl.maximum(sm, b_sm) + sd = sd * exp(sm - m_new) + tl.sum(exp(b_sl - m_new)) + sm = m_new + # for teacher + b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf')) + b_tm = tl.max(b_tl) + m_new = tl.maximum(tm, b_tm) + td = td * exp(tm - m_new) + tl.sum(exp(b_tl - m_new)) + tm = m_new + + b_loss = 0. + # KL(y_true || y) = exp(y_true) * (log(y_true) - log(y)) + for iv in range(0, NV): + o_x = iv * BV + tl.arange(0, BV) + b_sl = tl.load(logits + o_x, mask=o_x < V, other=float('-inf')) + b_tl = tl.load(target_logits + o_x, mask=o_x < V, other=float('-inf')) + b_sp_log = b_sl - sm - log(sd) + b_tp_log = b_tl - tm - log(td) + b_sp = exp(b_sp_log) + b_tp = exp(b_tp_log) + b_kl = tl.where(o_x < V, b_tp * (b_tp_log - b_sp_log), 0) + b_dl = -b_tp + b_sp + b_loss += tl.sum(b_kl) + if reduction == 'batchmean': + b_dl = b_dl / N + tl.store(logits + o_x, b_dl, mask=o_x < V) + + # Normalize the loss by the number of elements if reduction is 'batchmean' + if reduction == 'batchmean': + b_loss = b_loss / N + + tl.store(loss + i_n * s_loss, b_loss) + + +@triton.jit +def elementwise_mul_kernel( + x, + g, + N: tl.constexpr, + B: tl.constexpr +): + """ + This function multiplies each element of the tensor pointed by x with the value pointed by g. + The multiplication is performed in-place on the tensor pointed by x. + + Parameters: + x: + Pointer to the input tensor. + g: + Pointer to the gradient output value. + N (int): + The number of columns in the input tensor. + B (int): + The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + i_x = tl.program_id(0).to(tl.int64) + o_x = i_x * B + tl.arange(0, B) + + # Load the gradient output value + b_g = tl.load(g) + b_x = tl.load(x + o_x, mask=o_x < N) + tl.store(x + o_x, b_x * b_g, mask=o_x < N) + + +def fused_kl_div_forward( + x: torch.Tensor, + target_x: torch.Tensor, + weight: torch.Tensor, + target_weight: torch.Tensor, + reduction: str = 'batchmean' +): + device = x.device + + # ideally, we would like to achieve the same memory consumption as [N, H], + # so the expected chunk size should be: + # NC = ceil(V / H) + # C = ceil(N / NC) + # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048 + N, H, V = *x.shape, weight.shape[0] + BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # TODO: in real cases, we may need to limit the number of chunks NC to + # ensure the precisions of accumulated gradients + NC = min(8, triton.cdiv(V, H)) + C = triton.next_power_of_2(triton.cdiv(N, NC)) + NC = triton.cdiv(N, C) + + dx = torch.zeros_like(x, device=device) + dw = torch.zeros_like(weight, device=device) if weight is not None else None + # we use fp32 for loss accumulator + loss = torch.zeros(N, dtype=torch.float32, device=device) + + for ic in range(NC): + start, end = ic * C, min((ic + 1) * C, N) + # [C, N] + c_sx = x[start:end] + c_tx = target_x[start:end] + # when doing matmul, use the original precision + # [C, V] + c_sl = F.linear(c_sx, weight) + c_tl = F.linear(c_tx, target_weight) + + # unreduced loss + c_loss = loss[start:end] + + # Here we calculate the gradient of c_sx in place so we can save memory. + kl_div_kernel[(c_sx.shape[0],)]( + logits=c_sl, + target_logits=c_tl, + loss=c_loss, + s_logits=c_sl.stride(-2), + s_loss=c_loss.stride(-1), + reduction=reduction, + N=N, + V=V, + BV=BV, + num_warps=32 + ) + + # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V + # thus dx[start: end] should be of shape: C x H + # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only + # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. + # Thus, we need an additional scaling factor of (n_non_ignore/total) to scale the gradients. + # [C, H] + + dx[start:end] = torch.mm(c_sl, weight) + + if weight is not None: + torch.addmm(input=dw, mat1=c_sl.t(), mat2=c_sx, out=dw) + + loss = loss.sum() + return loss, dx, dw + + +def fused_kl_div_backward( + do: torch.Tensor, + dx: torch.Tensor, + dw: torch.Tensor +): + # If cross entropy is the last layer, do is 1.0. Skip the mul to save time + if torch.ne(do, torch.tensor(1.0, device=do.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + N, H = dx.shape + B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( + x=dx, + g=do, + N=N*H, + B=B, + num_warps=32, + ) + + # handle dw + if dw is not None: + V, H = dw.shape + elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( + x=dw, + g=do, + N=V*H, + B=B, + num_warps=32, + ) + + return dx, dw + + +class FusedKLDivLossFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + target_x: torch.Tensor, + weight: torch.Tensor, + target_weight: torch.Tensor, + reduction: str + ): + loss, dx, dw = fused_kl_div_forward( + x=x, + target_x=target_x, + weight=weight, + target_weight=target_weight, + reduction=reduction + ) + ctx.save_for_backward(dx, dw) + return loss + + @staticmethod + @input_guard + def backward(ctx, do): + dx, dw = ctx.saved_tensors + dx, dw = fused_kl_div_backward(do, dx, dw) + return dx, None, dw, None, None + + +def fused_kl_div_loss( + x: torch.Tensor, + target_x: torch.Tensor, + weight: torch.Tensor, + target_weight: torch.Tensor, + reduction: str = 'batchmean' +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target_x (torch.Tensor): [batch_size * seq_len, hidden_size] + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + target_weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + reduction: + Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'. + Returns: + loss + """ + return FusedKLDivLossFunction.apply( + x, + target_x, + weight, + target_weight, + reduction + ) + + +class FusedKLDivLoss(nn.Module): + + def __init__( + self, + reduction: str = 'batchmean' + ): + """ + Args: + reduction: + Specifies the reduction to apply to the output: 'batchmean'. Default: 'batchmean'. + """ + super().__init__() + + assert reduction in ['batchmean'], f"reduction: {reduction} is not supported" + + self.reduction = reduction + + def forward( + self, + x: torch.Tensor, + target_x: torch.Tensor, + weight: torch.Tensor, + target_weight: torch.Tensor + ): + """ + Args: + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target_x (torch.Tensor): [batch_size * seq_len, hidden_size] + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + target_weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + Returns: + loss + """ + loss = fused_kl_div_loss( + x=x, + target_x=target_x, + weight=weight, + target_weight=target_weight, + reduction=self.reduction + ) + return loss diff --git a/fla3/modules/fused_linear_cross_entropy.py b/fla3/modules/fused_linear_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..a18566fdbffccdd0d2b3bbe3586fdb38f18720e4 --- /dev/null +++ b/fla3/modules/fused_linear_cross_entropy.py @@ -0,0 +1,570 @@ +# -*- coding: utf-8 -*- + +# Code adapted from +# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py + +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module +from torch.distributed.tensor.parallel import ParallelStyle + +from fla.ops.utils import logsumexp_fwd +from fla.ops.utils.op import exp +from fla.utils import input_guard + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 +# https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +@triton.jit +def cross_entropy_kernel( + logits, + lse, + target, + loss, + total, + ignore_index, + label_smoothing: tl.constexpr, + logit_scale: tl.constexpr, + reduction: tl.constexpr, + V: tl.constexpr, + BV: tl.constexpr +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. + Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Args: + logits: + Pointer to logits tensor. + lse: + Pointer to logsumexp tensor. + target: Pointer to target tensor. + loss: + Pointer to tensor to store the loss. + V (int): + The number of columns in the input tensor. + total (int): + The number of non-ignored classes. + ignore_index (int): + The index to ignore in the target. + label_smoothing (float): + The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): + The string for the reduction to apply + BV (int): + The block size for vocab. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, i_n * stride will overflow out of int32, so we convert to int64 + i_n = tl.program_id(0).to(tl.int64) + NV = tl.cdiv(V, BV) + + # 1. Load target first because if the target is ignore_index, we can return right away + b_y = tl.load(target + i_n) + + # 2. locate the start index + logits += i_n * V + + if b_y == ignore_index: + # set all x as 0 + for i in range(0, V, BV): + o_v = i + tl.arange(0, BV) + tl.store(logits + o_v, 0.0, mask=o_v < V) + return + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: compute logsumexp + # we did this in anouter kernel + b_l = tl.load(logits + b_y) * logit_scale + b_lse = tl.load(lse + i_n) + + # 4. Calculate the loss + # loss = lse - logits_l + b_loss = b_lse - b_l + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + b_z = 0.0 + eps = label_smoothing / V + + # We need tl.debug_barrier() as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for iv in range(0, NV): + o_v = iv * BV + tl.arange(0, BV) + b_logits = tl.load(logits + o_v, mask=o_v < V, other=float('-inf')) * logit_scale + if label_smoothing > 0: + # scale X beforehand to avoid overflow + b_z += tl.sum(tl.where(o_v < V, -eps * b_logits, 0.0)) + b_p = (exp(b_logits - b_lse) - eps) * logit_scale + if reduction == "mean": + b_p = b_p / total + tl.store(logits + o_v, b_p, mask=o_v < V) + + tl.debug_barrier() + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: + # https://arxiv.org/pdf/1512.00567 + # pytorch: + # https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + b_loss = b_loss * (1 - label_smoothing) + (b_z + label_smoothing * b_lse) + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + b_l = tl.load(logits + b_y) + + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == 'mean': + b_loss = b_loss / total + b_l += (label_smoothing - 1) / total * logit_scale + else: + b_l += (label_smoothing - 1) * logit_scale + + tl.store(loss + i_n, b_loss) + tl.store(logits + b_y, b_l) + + +@triton.jit +def elementwise_mul_kernel( + x, + g, + N: tl.constexpr, + B: tl.constexpr +): + """ + This function multiplies each element of the tensor pointed by x with the value pointed by g. + The multiplication is performed in-place on the tensor pointed by x. + + Parameters: + x: + Pointer to the input tensor. + g: + Pointer to the gradient output value. + N (int): + The number of columns in the input tensor. + B (int): + The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + i_x = tl.program_id(0).to(tl.int64) + o_x = i_x * B + tl.arange(0, B) + + # Load the gradient output value + b_g = tl.load(g) + b_x = tl.load(x + o_x, mask=o_x < N) + tl.store(x + o_x, b_x * b_g, mask=o_x < N) + + +def fused_linear_cross_entropy_forward( + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" +): + device = x.device + # inputs have shape: [N, H] + # materialized activations will have shape: [N, V] + # the increase in memory = [N, V] + # reduction can be achieved by partitioning the number of tokens N into smaller chunks. + + # ideally, we would like to achieve the same memory consumption as [N, H], + # so the expected chunk size should be: + # NC = ceil(V / H) + # C = ceil(N / NC) + # for ex: N = 4096*4, V = 32000, H = 4096 ==> NC = 8, C = ceil(N / NC) = 2048 + N, H, V = *x.shape, weight.shape[0] + BV = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + # TODO: in real cases, we may need to limit the number of chunks NC to + # ensure the precisions of accumulated gradients + NC = min(num_chunks, triton.cdiv(V, H)) + C = triton.next_power_of_2(triton.cdiv(N, NC)) + NC = triton.cdiv(N, C) + + # [N, H] + dx = torch.zeros_like(x, device=device) + # [V, H] + dw = torch.zeros_like(weight, device=device, dtype=torch.float) if weight is not None else None + # [V] + db = torch.zeros_like(bias, device=device, dtype=torch.float) if bias is not None else None + # [N] + loss = torch.zeros(N, device=device, dtype=torch.float) + + total = target.ne(ignore_index).sum().item() + + for ic in range(NC): + start, end = ic * C, min((ic + 1) * C, N) + # [C, N] + c_x = x[start:end] + # when doing matmul, use the original precision + # [C, V] + c_logits = F.linear(c_x, weight, bias) + c_target = target[start:end] + # [C] + # keep lse in fp32 to maintain precision + c_lse = logsumexp_fwd(c_logits, scale=logit_scale, dtype=torch.float) + + # unreduced loss + c_loss = loss[start:end] + + # Here we calculate the gradient of c_logits in place so we can save memory. + cross_entropy_kernel[(c_logits.shape[0],)]( + logits=c_logits, + lse=c_lse, + target=c_target, + loss=c_loss, + total=total, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + logit_scale=logit_scale, + reduction=reduction, + V=V, + BV=BV, + num_warps=32 + ) + + # gradient of logits is computed in-place by the above triton kernel and is of shape: C x V + # thus dx should be of shape: C x H + dx[start:end] = torch.mm(c_logits, weight) + + # keep dw in fp32 to maintain precision + if weight is not None: + dw += c_logits.t() @ c_x + + if bias is not None: + torch.add(input=db, other=c_logits.sum(0), out=db) + + loss = loss.sum() + if dw is not None: + dw = dw.to(weight) + if db is not None: + db = db.to(bias) + return loss, dx, dw, db + + +def fused_linear_cross_entropy_backward( + do: torch.Tensor, + dx: torch.Tensor, + dw: torch.Tensor, + db: torch.Tensor +): + # If cross entropy is the last layer, do is 1.0. Skip the mul to save time + if torch.ne(do, torch.tensor(1.0, device=do.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + N, H = dx.shape + B = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + elementwise_mul_kernel[(triton.cdiv(N * H, B),)]( + x=dx, + g=do, + N=N*H, + B=B, + num_warps=32, + ) + + # handle dw + if dw is not None: + V, H = dw.shape + elementwise_mul_kernel[(triton.cdiv(V * H, B),)]( + x=dw, + g=do, + N=V*H, + B=B, + num_warps=32, + ) + + if db is not None: + V = db.shape[0] + elementwise_mul_kernel[(triton.cdiv(V, B),)]( + x=db, + g=do, + N=V, + B=B, + num_warps=32, + ) + return dx, dw, db + + +class FusedLinearCrossEntropyFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the x and target + for the backward pass. + + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target (torch.LongTensor): [batch_size * seq_len] + where each value is in [0, vocab_size). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + ignore_index: + the index to ignore in the target. + label_smoothing: + the amount of smoothing when computing the loss, where 0.0 means no smoothing. + logit_scale: float = 1.0, + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + """ + loss, dx, dw, db = fused_linear_cross_entropy_forward( + x, + target, + weight, + bias, + ignore_index, + label_smoothing, + logit_scale, + num_chunks, + reduction + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + dx.detach(), + dw.detach() if weight is not None else None, + db.detach() if bias is not None else None, + ) + return loss + + @staticmethod + @input_guard + def backward(ctx, do): + dx, dw, db = ctx.saved_tensors + dx, dw, db = fused_linear_cross_entropy_backward(do, dx, dw, db) + return dx, None, dw, db, None, None, None, None, None + + +def fused_linear_cross_entropy_loss( + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: torch.Tensor = None, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x (torch.Tensor): [batch_size * seq_len, hidden_size] + target (torch.LongTensor): [batch_size * seq_len] + where each value is in [0, vocab_size). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + logit_scale: float + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + Returns: + losses: [batch,], float + """ + return FusedLinearCrossEntropyFunction.apply( + x, + target, + weight, + bias, + ignore_index, + label_smoothing, + logit_scale, + num_chunks, + reduction + ) + + +class FusedLinearCrossEntropyLoss(nn.Module): + + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + logit_scale: float = 1.0, + num_chunks: int = 8, + reduction: str = "mean" + ): + """ + Args: + ignore_index: int. + If target == ignore_index, the loss is set to 0.0. + label_smoothing: float + logit_scale: float + A scaling factor applied to the logits. Default: 1.0 + num_chunks: int + The number of chunks to split the input tensor into for processing. + This can help optimize memory usage and computation speed. + Default: 8 + reduction: + Specifies the reduction to apply to the output: 'mean' | 'sum'. + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + Default: 'mean'. + """ + super().__init__() + + assert reduction in ["mean", "sum"], f"reduction: {reduction} is not supported" + + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.logit_scale = logit_scale + self.num_chunks = num_chunks + self.reduction = reduction + + @torch.compiler.disable + def forward( + self, + x: torch.Tensor, + target: torch.LongTensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None + ): + """ + Args: + x (torch.Tensor): [batch_size, seq_len, hidden_size] + target (torch.LongTensor): [batch_size, seq_len] + where each value is in [0, V). + weight (torch.Tensor): [vocab_size, hidden_size] + where `vocab_size` is the number of classes. + bias (Optional[torch.Tensor]): [vocab_size] + where `vocab_size` is the number of classes. + Returns: + loss + """ + loss = fused_linear_cross_entropy_loss( + x.view(-1, x.shape[-1]), + target.view(-1), + weight=weight, + bias=bias, + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + logit_scale=self.logit_scale, + num_chunks=self.num_chunks, + reduction=self.reduction + ) + return loss + + +class LinearLossParallel(ParallelStyle): + def __init__( + self, + *, + sequence_dim: int = 1, + use_local_output: bool = False, + ): + super().__init__() + + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + x, target, weight, bias = inputs + + if not isinstance(x, DTensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + x = DTensor.from_local(x, device_mesh, sequence_sharding) + if x.placements != sequence_sharding: + x = x.redistribute(placements=sequence_sharding, async_op=True) + if not isinstance(target, DTensor): + target = DTensor.from_local(target, device_mesh, [Replicate()]) + if target.placements != sequence_sharding: + target = target.redistribute(placements=sequence_sharding, async_op=True) + + if not isinstance(weight, DTensor): + weight = DTensor.from_local(weight, device_mesh, [Replicate()]) + if weight.placements != [Replicate()]: + # we replicate the weight/bias in FLCE + weight = weight.redistribute(placements=[Replicate()], async_op=True) + + if bias is not None and not isinstance(bias, DTensor): + bias = DTensor.from_local(bias, device_mesh, [Replicate()]) + if bias is not None and bias.placements != [Replicate()]: + bias = bias.redistribute(placements=[Replicate()], async_op=True) + + return x.to_local(), target.to_local(), weight.to_local(), bias.to_local() if bias is not None else bias + + @staticmethod + def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=None, + input_fn=partial(self._prepare_input_fn, self.sequence_sharding), + output_fn=partial(self._prepare_output_fn, self.use_local_output) + ) diff --git a/fla3/modules/fused_norm_gate.py b/fla3/modules/fused_norm_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..2704af6b5d7e0c5ef9a3b7468d9fe84da044f2f5 --- /dev/null +++ b/fla3/modules/fused_norm_gate.py @@ -0,0 +1,1257 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.utils import get_multiprocessor_count, input_guard + + +@triton.heuristics({ + 'STORE_RESIDUAL_OUT': lambda args: args['residual_out'] is not None, + 'HAS_RESIDUAL': lambda args: args['residual'] is not None, + 'HAS_WEIGHT': lambda args: args['w'] is not None, + 'HAS_BIAS': lambda args: args['b'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [8, 16, 32, 64] + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'NB', 'IS_RMS_NORM', 'STORE_RESIDUAL_OUT', 'HAS_RESIDUAL', 'HAS_WEIGHT'], +) +@triton.jit +def layer_norm_gated_fwd_kernel( + x, # pointer to the input + g, # pointer to the gate + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + residual, # pointer to the residual + residual_out, # pointer to the residual + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + T, # number of rows in x + D: tl.constexpr, # number of columns in x + BT: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + i_t = tl.program_id(0) + + o_d = tl.arange(0, BD) + m_d = o_d < D + + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + if HAS_RESIDUAL: + p_res = tl.make_block_ptr(residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32) + if STORE_RESIDUAL_OUT: + p_res_out = tl.make_block_ptr(residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1)) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=1) / D + p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,)) + b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + else: + b_xbar = tl.where(m_d[None, :], b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) + + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) + b_x_hat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] + b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b[None, :] + + # swish/sigmoid output gate + p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + if ACTIVATION == 'swish': + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == 'silu': + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == 'sigmoid': + b_y = b_y * tl.sigmoid(b_g) + + # Write output + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_RESIDUAL_OUT': lambda args: args['residual_out'] is not None, + 'HAS_RESIDUAL': lambda args: args['residual'] is not None, + 'HAS_WEIGHT': lambda args: args['w'] is not None, + 'HAS_BIAS': lambda args: args['b'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'IS_RMS_NORM', 'STORE_RESIDUAL_OUT', 'HAS_RESIDUAL', 'HAS_WEIGHT'], +) +@triton.jit +def layer_norm_gated_fwd_kernel1( + x, # pointer to the input + g, # pointer to the gate + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + residual, # pointer to the residual + residual_out, # pointer to the residual + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + D: tl.constexpr, # number of columns in x + BD: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + g += i_t * D + if HAS_RESIDUAL: + residual += i_t * D + if STORE_RESIDUAL_OUT: + residual_out += i_t * D + + o_d = tl.arange(0, BD) + m_d = o_d < D + b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32) + if STORE_RESIDUAL_OUT: + tl.store(residual_out + o_d, b_x, mask=m_d) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=0) / D + tl.store(mean + i_t, b_mean) + b_xbar = tl.where(m_d, b_x - b_mean, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + else: + b_xbar = tl.where(m_d, b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + tl.store(rstd + i_t, b_rstd) + + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) + b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd + b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b + + # swish/sigmoid output gate + b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32) + if ACTIVATION == 'swish': + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == 'silu': + b_y = b_y * b_g * tl.sigmoid(b_g) + elif ACTIVATION == 'sigmoid': + b_y = b_y * tl.sigmoid(b_g) + + # Write output + tl.store(y + o_d, b_y, mask=m_d) + + +@triton.heuristics({ + 'HAS_DRESIDUAL': lambda args: args['dresidual'] is not None, + 'HAS_WEIGHT': lambda args: args['w'] is not None, + 'HAS_BIAS': lambda args: args['b'] is not None, + 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [8, 16, 32, 64] + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'NB', 'IS_RMS_NORM', 'HAS_DRESIDUAL', 'HAS_WEIGHT'], +) +@triton.jit +def layer_norm_gated_bwd_kernel( + x, # pointer to the input + g, # pointer to the gate + w, # pointer to the weights + b, # pointer to the biases + y, # pointer to the output to be recomputed + dy, # pointer to the output gradient + dx, # pointer to the input gradient + dg, # pointer to the gate gradient + dw, # pointer to the partial sum of weights gradient + db, # pointer to the partial sum of biases gradient + dresidual, + dresidual_in, + mean, + rstd, + T, + D: tl.constexpr, + BS: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + i_s = tl.program_id(0) + o_d = tl.arange(0, BD) + m_d = o_d < D + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) + b_dw = tl.zeros((BT, BD), dtype=tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=m_d, other=0.0).to(tl.float32) + b_db = tl.zeros((BT, BD), dtype=tl.float32) + + T = min(i_s * BS + BS, T) + for i_t in range(i_s * BS, T, BT): + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + + if not IS_RMS_NORM: + p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t,), (BT,), (0,)) + b_mean = tl.load(p_mean, boundary_check=(0,)) + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t,), (BT,), (0,)) + b_rstd = tl.load(p_rstd, boundary_check=(0,)) + # Compute dx + b_xhat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] + b_xhat = tl.where(m_d[None, :], b_xhat, 0.0) + + b_y = b_xhat * b_w[None, :] if HAS_WEIGHT else b_xhat + if HAS_BIAS: + b_y = b_y + b_b[None, :] + if RECOMPUTE_OUTPUT: + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + b_sigmoid_g = tl.sigmoid(b_g) + if ACTIVATION == 'swish': + b_dg = b_dy * b_y * (b_sigmoid_g + b_g * b_sigmoid_g * (1 - b_sigmoid_g)) + b_dy = b_dy * b_g * b_sigmoid_g + elif ACTIVATION == 'silu': + b_dg = b_dy * b_y * (b_sigmoid_g + b_g * b_sigmoid_g * (1 - b_sigmoid_g)) + b_dy = b_dy * b_g * b_sigmoid_g + elif ACTIVATION == 'sigmoid': + b_dg = b_dy * b_y * b_sigmoid_g * (1 - b_sigmoid_g) + b_dy = b_dy * b_sigmoid_g + b_wdy = b_dy + + if HAS_WEIGHT or HAS_BIAS: + m_t = (i_t + tl.arange(0, BT)) < T + if HAS_WEIGHT: + b_wdy = b_dy * b_w + b_dw += tl.where(m_t[:, None], b_dy * b_xhat, 0.0) + if HAS_BIAS: + b_db += tl.where(m_t[:, None], b_dy, 0.0) + if not IS_RMS_NORM: + b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D + b_c2 = tl.sum(b_wdy, axis=1) / D + b_dx = (b_wdy - (b_xhat * b_c1[:, None] + b_c2[:, None])) * b_rstd[:, None] + else: + b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D + b_dx = (b_wdy - b_xhat * b_c1[:, None]) * b_rstd[:, None] + if HAS_DRESIDUAL: + p_dres = tl.make_block_ptr(dresidual, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + b_dres = tl.load(p_dres, boundary_check=(0, 1)).to(tl.float32) + b_dx += b_dres + # Write dx + if STORE_DRESIDUAL: + p_dres_in = tl.make_block_ptr(dresidual_in, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0)) + tl.store(p_dres_in, b_dx.to(p_dres_in.dtype.element_ty), boundary_check=(0, 1)) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + if HAS_WEIGHT: + tl.store(dw + i_s * D + o_d, tl.sum(b_dw, axis=0), mask=m_d) + if HAS_BIAS: + tl.store(db + i_s * D + o_d, tl.sum(b_db, axis=0), mask=m_d) + + +@triton.heuristics({ + 'HAS_DRESIDUAL': lambda args: args['dresidual'] is not None, + 'HAS_WEIGHT': lambda args: args['w'] is not None, + 'HAS_BIAS': lambda args: args['b'] is not None, + 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'IS_RMS_NORM', 'STORE_DRESIDUAL', 'HAS_DRESIDUAL', 'HAS_WEIGHT'], +) +@triton.jit +def layer_norm_gated_bwd_kernel1( + x, # pointer to the input + g, # pointer to the gate + w, # pointer to the weights + b, # pointer to the biases + y, # pointer to the output to be recomputed + dy, # pointer to the output gradient + dx, # pointer to the input gradient + dg, # pointer to the gate gradient + dw, # pointer to the partial sum of weights gradient + db, # pointer to the partial sum of biases gradient + dresidual, + dresidual_in, + mean, + rstd, + T, + D: tl.constexpr, + BS: tl.constexpr, + BD: tl.constexpr, + ACTIVATION: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + i_s = tl.program_id(0) + o_d = tl.arange(0, BD) + mask = o_d < D + x += i_s * BS * D + g += i_s * BS * D + if HAS_DRESIDUAL: + dresidual += i_s * BS * D + if STORE_DRESIDUAL: + dresidual_in += i_s * BS * D + dy += i_s * BS * D + dx += i_s * BS * D + dg += i_s * BS * D + if RECOMPUTE_OUTPUT: + y += i_s * BS * D + if HAS_WEIGHT: + b_w = tl.load(w + o_d, mask=mask).to(tl.float32) + b_dw = tl.zeros((BD,), dtype=tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_d, mask=mask, other=0.0).to(tl.float32) + b_db = tl.zeros((BD,), dtype=tl.float32) + + for i_t in range(i_s * BS, min(i_s * BS + BS, T)): + # Load data to SRAM + b_x = tl.load(x + o_d, mask=mask, other=0).to(tl.float32) + b_g = tl.load(g + o_d, mask=mask, other=0).to(tl.float32) + b_dy = tl.load(dy + o_d, mask=mask, other=0).to(tl.float32) + + if not IS_RMS_NORM: + b_mean = tl.load(mean + i_t) + b_rstd = tl.load(rstd + i_t) + # Compute dx + b_xhat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd + b_xhat = tl.where(mask, b_xhat, 0.0) + + b_y = b_xhat * b_w if HAS_WEIGHT else b_xhat + if HAS_BIAS: + b_y = b_y + b_b + if RECOMPUTE_OUTPUT: + tl.store(y + o_d, b_y, mask=mask) + + b_sigmoid_g = tl.sigmoid(b_g) + if ACTIVATION == 'swish': + b_dg = b_dy * b_y * (b_sigmoid_g + b_g * b_sigmoid_g * (1 - b_sigmoid_g)) + b_dy = b_dy * b_g * b_sigmoid_g + elif ACTIVATION == 'silu': + b_dg = b_dy * b_y * (b_sigmoid_g + b_g * b_sigmoid_g * (1 - b_sigmoid_g)) + b_dy = b_dy * b_g * b_sigmoid_g + elif ACTIVATION == 'sigmoid': + b_dg = b_dy * b_y * b_sigmoid_g * (1 - b_sigmoid_g) + b_dy = b_dy * b_sigmoid_g + b_wdy = b_dy + if HAS_WEIGHT: + b_wdy = b_dy * b_w + b_dw += b_dy * b_xhat + if HAS_BIAS: + b_db += b_dy + if not IS_RMS_NORM: + b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D + b_c2 = tl.sum(b_wdy, axis=0) / D + b_dx = (b_wdy - (b_xhat * b_c1 + b_c2)) * b_rstd + else: + b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D + b_dx = (b_wdy - b_xhat * b_c1) * b_rstd + if HAS_DRESIDUAL: + b_dres = tl.load(dresidual + o_d, mask=mask, other=0).to(tl.float32) + b_dx += b_dres + # Write dx + if STORE_DRESIDUAL: + tl.store(dresidual_in + o_d, b_dx, mask=mask) + tl.store(dx + o_d, b_dx, mask=mask) + tl.store(dg + o_d, b_dg, mask=mask) + + x += D + g += D + if HAS_DRESIDUAL: + dresidual += D + if STORE_DRESIDUAL: + dresidual_in += D + if RECOMPUTE_OUTPUT: + y += D + dy += D + dx += D + dg += D + if HAS_WEIGHT: + tl.store(dw + i_s * D + o_d, b_dw, mask=mask) + if HAS_BIAS: + tl.store(db + i_s * D + o_d, b_db, mask=mask) + + +def layer_norm_gated_fwd( + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = 'swish', + eps: float = 1e-5, + residual: torch.Tensor = None, + out_dtype: torch.dtype = None, + residual_dtype: torch.dtype = None, + is_rms_norm: bool = False +): + if residual is not None: + residual_dtype = residual.dtype + T, D = x.shape + if residual is not None: + assert residual.shape == (T, D) + if weight is not None: + assert weight.shape == (D,) + if bias is not None: + assert bias.shape == (D,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype) + else: + residual_out = None + mean = torch.empty((T,), dtype=torch.float, device=x.device) if not is_rms_norm else None + rstd = torch.empty((T,), dtype=torch.float, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + + if D <= 512: + NB = triton.cdiv(T, 2048) + def grid(meta): return (triton.cdiv(T, meta['BT']),) + layer_norm_gated_fwd_kernel[grid]( + x=x, + g=g, + y=y, + w=weight, + b=bias, + residual=residual, + residual_out=residual_out, + mean=mean, + rstd=rstd, + eps=eps, + T=T, + D=D, + BD=BD, + NB=NB, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + ) + else: + layer_norm_gated_fwd_kernel1[(T,)]( + x=x, + g=g, + y=y, + w=weight, + b=bias, + residual=residual, + residual_out=residual_out, + mean=mean, + rstd=rstd, + eps=eps, + D=D, + BD=BD, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +def layer_norm_gated_bwd( + dy: torch.Tensor, + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = 'swish', + eps: float = 1e-5, + mean: torch.Tensor = None, + rstd: torch.Tensor = None, + dresidual: torch.Tensor = None, + has_residual: bool = False, + is_rms_norm: bool = False, + x_dtype: torch.dtype = None, + recompute_output: bool = False, +): + T, D = x.shape + assert dy.shape == (T, D) + if dresidual is not None: + assert dresidual.shape == (T, D) + if weight is not None: + assert weight.shape == (D,) + if bias is not None: + assert bias.shape == (D,) + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device) + dg = torch.empty_like(g) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(T, D, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + NS = get_multiprocessor_count(x.device.index) + BS = math.ceil(T / NS) + + dw = torch.empty((NS, D), dtype=torch.float, device=weight.device) if weight is not None else None + db = torch.empty((NS, D), dtype=torch.float, device=bias.device) if bias is not None else None + grid = (NS,) + + if D <= 512: + NB = triton.cdiv(T, 2048) + layer_norm_gated_bwd_kernel[grid]( + x=x, + g=g, + w=weight, + b=bias, + y=y, + dy=dy, + dx=dx, + dg=dg, + dw=dw, + db=db, + dresidual=dresidual, + dresidual_in=dresidual_in, + mean=mean, + rstd=rstd, + T=T, + D=D, + BS=BS, + BD=BD, + NB=NB, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + STORE_DRESIDUAL=dresidual_in is not None, + ) + else: + layer_norm_gated_bwd_kernel1[grid]( + x=x, + g=g, + w=weight, + b=bias, + y=y, + dy=dy, + dx=dx, + dg=dg, + dw=dw, + db=db, + dresidual=dresidual, + dresidual_in=dresidual_in, + mean=mean, + rstd=rstd, + T=T, + D=D, + BS=BS, + BD=BD, + ACTIVATION=activation, + IS_RMS_NORM=is_rms_norm, + STORE_DRESIDUAL=dresidual_in is not None, + ) + dw = dw.sum(0).to(weight.dtype) if weight is not None else None + db = db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dg, dw, db, dresidual_in) if not recompute_output else (dx, dg, dw, db, dresidual_in, y) + + +class LayerNormGatedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + ): + x_shape_og = x.shape + g_shape_og = g.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = layer_norm_gated_fwd( + x=x, + g=g, + weight=weight, + bias=bias, + activation=activation, + eps=eps, + residual=residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, g, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.g_shape_og = g_shape_og + ctx.activation = activation + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + @input_guard + def backward(ctx, dy, *args): + x, g, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dg, dw, db, dres_in = layer_norm_gated_bwd( + dy=dy, + x=x, + g=g, + weight=weight, + bias=bias, + activation=ctx.activation, + eps=ctx.eps, + mean=mean, + rstd=rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dg.reshape(ctx.g_shape_og), + dw, + db, + None, + dres_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +class LayerNormGatedLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + g: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + ): + x_shape_og = x.shape + g_shape_og = g.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + g = g.reshape(-1, g.shape[-1]) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = layer_norm_gated_fwd( + x=x, + g=g, + weight=norm_weight, + bias=norm_bias, + eps=eps, + residual=residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, g, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.g_shape_og = g_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @input_guard + def backward(ctx, dout, *args): + x, g, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dg, dnorm_weight, dnorm_bias, dres_in, y = layer_norm_gated_bwd( + dy=dy, + x=x, + g=g, + weight=norm_weight, + bias=norm_bias, + eps=ctx.eps, + mean=mean, + rstd=rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dg.reshape(ctx.g_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dres_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_gated( + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = 'swish', + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + eps: float = 1e-6 +): + return LayerNormGatedFunction.apply( + x, + g, + weight, + bias, + activation, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_gated( + x: torch.Tensor, + g: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + activation: str = 'swish', + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + eps: float = 1e-6 +): + return LayerNormGatedFunction.apply( + x, + g, + weight, + bias, + activation, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +def layer_norm_swish_gate_linear( + x: torch.Tensor, + g: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + eps: float = 1e-6 +): + return LayerNormGatedLinearFunction.apply( + x, + g, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + False + ) + + +def rms_norm_swish_gate_linear( + x, + g: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False, + eps: float = 1e-6 +): + return LayerNormGatedLinearFunction.apply( + x, + g, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +class FusedLayerNormGated(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + activation: str = 'swish', + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedLayerNormGated: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.activation = activation + + if self.activation not in ['swish', 'silu', 'sigmoid']: + raise ValueError(f"Unsupported activation: {self.activation}") + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += f", activation={self.activation}" + s += ")" + return s + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False + ) -> torch.Tensor: + return layer_norm_gated( + x, + g, + self.weight, + self.bias, + self.activation, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormGated(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + activation: str = 'swish', + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedRMSNormGated: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.activation = activation + + if self.activation not in ['swish', 'silu', 'sigmoid']: + raise ValueError(f"Unsupported activation: {self.activation}") + + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += f", activation={self.activation}" + s += ")" + return s + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False + ) -> torch.Tensor: + return rms_norm_gated( + x, + g, + self.weight, + self.bias, + self.activation, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedLayerNormSwishGate(FusedLayerNormGated): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedLayerNormSwishGate: + super().__init__( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + bias=bias, + eps=eps, + device=device, + dtype=dtype + ) + + +class FusedRMSNormSwishGate(FusedRMSNormGated): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedRMSNormSwishGate: + super().__init__( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + eps=eps, + device=device, + dtype=dtype + ) + + +class FusedLayerNormGatedLinear(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedLayerNormGatedLinear: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False + ) -> torch.Tensor: + return layer_norm_swish_gate_linear( + x, + g, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedLayerNormSwishGateLinear(FusedLayerNormGatedLinear): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedLayerNormSwishGateLinear: + super().__init__( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + eps=eps, + device=device, + dtype=dtype + ) + + +class FusedRMSNormGatedLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedRMSNormGatedLinear: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward( + self, + x: torch.Tensor, + g: torch.Tensor, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None, + prenorm: bool = False, + residual_in_fp32: bool = False + ) -> torch.Tensor: + return rms_norm_swish_gate_linear( + x, + g, + self.weight, + self.bias, + weight, + bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class FusedRMSNormSwishGateLinear(FusedRMSNormGatedLinear): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> FusedRMSNormSwishGateLinear: + super().__init__( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + eps=eps, + device=device, + dtype=dtype + ) diff --git a/fla3/modules/grpo.py b/fla3/modules/grpo.py new file mode 100644 index 0000000000000000000000000000000000000000..4f12bb8e4fff346bee6e028193ec8e20709ece40 --- /dev/null +++ b/fla3/modules/grpo.py @@ -0,0 +1,409 @@ +# -*- coding: utf-8 -*- +# modified from https://github.com/mdy666/mdy_triton/blob/e0a856347bd988e05e0152332bba35f1d33c5b1f/others/grpo/grpo_loss.ipynb +# XHS ID: blueeeee + +# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py +""" +# Get the per-token log probabilities for the completions for the model and the reference model + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + # Compute the per-token log probabilities for the model + + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + # x - x.detach() allows for preserving gradients from x + advantages = inputs["advantages"] + per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * per_token_kl) + loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + + # Log the metrics + completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics["completion_length"].append(completion_length) + + mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() + self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + + return loss +""" + + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log +from fla.utils import input_guard + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': BLOCK_SIZE}, num_warps=NUM_WARPS, num_stages=NUM_STAGES) + for BLOCK_SIZE in [1024, 2048, 4096, 8192] + for NUM_WARPS in [8, 16, 32] + for NUM_STAGES in [1, 2, 4] + ], + key=['B', 'N'] +) +@triton.jit +def grpo_fwd_kernel( + logits_ptr, + ref_logp_ptr, + input_ids_ptr, + advantages_ptr, + completion_mask_ptr, + loss_ptr, + lse_ptr, + beta, + save_kl: tl.constexpr, + B, + M, + N, + L, + start_idx, + BLOCK_SIZE: tl.constexpr +): + row_idx = tl.program_id(0) + + off_b = row_idx // L + N = tl.cast(N, tl.int64) + + loss_ptr += row_idx + + completion_mask_ptr += row_idx + not_skip = tl.load(completion_mask_ptr).to(tl.int1) + if not_skip == 1: + ref_logp_ptr += row_idx + lse_ptr += row_idx + advantages_ptr += off_b + logits_ptr += N * (row_idx + off_b) + input_ids_ptr += row_idx + (off_b+1) * start_idx + base_cols = tl.arange(0, BLOCK_SIZE) + + m_i = -float("inf") + l_i = 0.0 + for start_n in tl.range(0, N, BLOCK_SIZE): + cols = start_n + base_cols + mask = cols < N + logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32) + m_ij = tl.max(logits) + new_m_i = tl.maximum(m_i, m_ij) + l_i = l_i * exp(m_i - new_m_i) + tl.sum(exp(logits - new_m_i)) + m_i = new_m_i + lse = log(l_i) + m_i + + idx = tl.load(input_ids_ptr) + x = tl.load(logits_ptr+idx).to(tl.float32) + advantage = tl.load(advantages_ptr).to(tl.float32) + ref_logp = tl.load(ref_logp_ptr) + logp = x - lse + diff = ref_logp - logp + kl = exp(diff) - diff - 1 + loss = kl * beta - advantage + + tl.store(loss_ptr, loss.to(loss_ptr.dtype.element_ty)) + tl.store(lse_ptr, lse.to(lse_ptr.dtype.element_ty)) + if save_kl: + tl.store(loss_ptr+M, kl.to(loss_ptr.dtype.element_ty)) + else: + # store 0 + tl.store(loss_ptr, 0.0) + if save_kl: + tl.store(loss_ptr+M, 0.0) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=NUM_WARPS, num_stages=NUM_STAGES) + for NUM_WARPS in [32] + for NUM_STAGES in [4] + ], + key=['B', 'N'] +) +@triton.jit +def grpo_bwd_kernel( + dloss_ptr, + dlogits_ptr, + logits_ptr, + ref_logp_ptr, + input_ids_ptr, + advantages_ptr, + completion_mask_ptr, + lse_ptr, + beta, + B, + N, + L, + start_idx, + BLOCK_SIZE: tl.constexpr +): + + row_idx = tl.program_id(0) # B*L + off_b = row_idx // L + + N = tl.cast(N, tl.int64) + + dlogits_ptr += N * (row_idx + off_b) + base_cols = tl.arange(0, BLOCK_SIZE) + completion_mask_ptr += row_idx + not_skip = tl.load(completion_mask_ptr).to(tl.int1) + + if not_skip == 1: + lse_ptr += row_idx + dloss_ptr += row_idx + advantages_ptr += off_b + ref_logp_ptr += row_idx + logits_ptr += N * (row_idx + off_b) + input_ids_ptr += row_idx + (off_b+1) * start_idx + dloss = tl.load(dloss_ptr).to(tl.float32) + lse = tl.load(lse_ptr).to(tl.float32) + idx = tl.load(input_ids_ptr) + x = tl.load(logits_ptr+idx).to(tl.float32) + advantage = tl.load(advantages_ptr).to(tl.float32) + ref_logp = tl.load(ref_logp_ptr) + # Need for in-place grad. + tl.debug_barrier() + logp = x - lse + + dlogp = (beta * (-1.0 * exp(ref_logp - logp) + 1) + - advantage) * dloss + + for start_n in tl.range(0, N, BLOCK_SIZE): + cols = start_n + base_cols + mask = cols < N + logits = tl.load(logits_ptr+cols, mask=mask, other=-float('inf')).to(tl.float32) + probs = exp(logits - lse) + dlogits = tl.where(cols == idx, 1-probs, -probs) * dlogp + + tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask) + else: + dlogits = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for start_n in tl.range(0, N, BLOCK_SIZE): + cols = start_n + base_cols + mask = cols < N + + tl.store(dlogits_ptr+cols, dlogits.to(dlogits_ptr.dtype.element_ty), mask=mask) + + +class GrpoLoss(torch.autograd.Function): + + @input_guard + @staticmethod + def forward(ctx, logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace=True): + ctx.input_shape = logits.shape + B, L_ADD_1, N = ctx.input_shape + L = L_ADD_1 - 1 + M = B * L + input_ids_start_index = input_ids.size(1) - L + + if not save_kl: + loss = torch.empty(B, L, device=logits.device, dtype=torch.float32) + else: + loss = torch.empty(B*2, L, device=logits.device, dtype=torch.float32) + + lse = torch.empty(B, L, device=logits.device, dtype=torch.float32) + + if completion_mask is None: + completion_mask = torch.ones(B, L, device=logits.device, dtype=torch.int32) + else: + loss[:B].masked_fill_(completion_mask.logical_not(), 0.0) + + grpo_fwd_kernel[(M,)]( + logits_ptr=logits, + ref_logp_ptr=ref_logp, + input_ids_ptr=input_ids, + advantages_ptr=advantages, + completion_mask_ptr=completion_mask, + loss_ptr=loss, + lse_ptr=lse, + beta=beta, + save_kl=save_kl, + B=B, M=M, N=N, L=L, + start_idx=input_ids_start_index, + ) + ctx.beta = beta + ctx.save_for_backward(lse, logits, input_ids, advantages, completion_mask) + ctx.ref_logp = ref_logp + ctx.inplace = inplace + return loss + + @input_guard + @staticmethod + def backward(ctx, dloss): + # The grad of logits comes from two parts, the reward part and the kl part + lse, logits, input_ids, advantages, completion_mask = ctx.saved_tensors + inplace = ctx.inplace + B, L_ADD_1, N = ctx.input_shape + L = L_ADD_1 - 1 + M = B * L + + input_ids_start_index = input_ids.size(1) - L + + # B, L_ADD_1, N + dlogits = logits if inplace else torch.empty_like(logits) + BN = min(65536, triton.next_power_of_2(N)) + + grpo_bwd_kernel[(M,)]( + dloss_ptr=dloss, + dlogits_ptr=dlogits, + logits_ptr=logits, + ref_logp_ptr=ctx.ref_logp, + input_ids_ptr=input_ids, + advantages_ptr=advantages, + completion_mask_ptr=completion_mask, + lse_ptr=lse, + beta=ctx.beta, + B=B, N=N, L=L, + BLOCK_SIZE=BN, + start_idx=input_ids_start_index, + ) + # The last token in the completion is not used in the loss computation + # and therefore its gradient should be set to 0 + dlogits[:, -1, :].fill_(0.0) + return dlogits.view(*ctx.input_shape), None, None, None, None, None, None, None + + +def fused_grpo_loss(logits, ref_logp, input_ids, advantages, + beta=0.1, completion_mask=None, save_kl=False, inplace=False) -> torch.Tensor: + ''' + compute grpo loss, save memory(no addition usage) and fast speed(6X for A800) + + Args: + logtits: Tensor, [B, L+1, vocab_size], the origin output of model, it's not logits[:, :-1] + ref_logp: Tensor, [B, L], the origin output of model, it's not ref_logits[:, :-1] + input_ids: Tensor, [B, K+L], it's prompt_completion_id, it contains the prompt ids and output ids + advantages: Tensor, [B], the advantages of each prompt + beta: float, the weight of kl loss + completion_mask: Tensor, loss mask + save_kl: bool, if true will save kl + + Retutn: + loss: Tensor, [B, L], the loss of grpo, it contains the advantage part and kl part + + NOTE: logits(ref_logits) is computed by these steps + logits_to_keep = completion_ids.size(1) + + def get_per_token_logits(model, input_ids, attention_mask, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1 + ).logits + return logits + + logits = get_per_token_logits(model, prompt_completion_ids, attention_mask, logits_to_keep) + ''' + out = GrpoLoss.apply(logits, ref_logp, input_ids, advantages, beta, completion_mask, save_kl, inplace) + if not save_kl: + return out + else: + return out.chunk(2, axis=0) + + +def grpo_loss_torch(logits, ref_logp, input_ids, advantages, beta=0.1, completion_mask=None, save_kl=False): + def get_log_probs(logits, input_ids): + per_token_logps = [] + for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1):]): + log_probs = logits_row.log_softmax(dim=-1) + token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) + per_token_logps.append(token_log_prob) + return torch.stack(per_token_logps) + + logits = logits[:, :-1] + per_token_logps = get_log_probs(logits, input_ids) + ref_per_token_logps = ref_logp + per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + + per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - beta * per_token_kl) + if completion_mask is not None: + per_token_loss *= completion_mask + if save_kl: + per_token_kl *= completion_mask + return per_token_loss if not save_kl else (per_token_loss, per_token_kl) + + +@torch.compile(fullgraph=True) +def grpo_loss_with_old_logps( + logps: torch.Tensor, + ref_logps: torch.Tensor, + old_logps: torch.Tensor, + pad_mask: torch.Tensor, + logits_to_keep: int, + rewards: torch.Tensor, + beta: float = 0.2, + epsilon: float = 0.2 +): + """ + Compute the GRPO (Group Relative Policy Optimization) loss. + + Args: + logps (torch.Tensor): [Batch, Token_length] Log probabilities of the current policy. + ref_logps (torch.Tensor):[Batch, Token_length] Log probabilities of the reference policy. + old_logps (torch.Tensor): [Batch, Token_length] Log probabilities of the old policy. + completion_ids (torch.Tensor): [Batch, Token_length] Completion token IDs (bool). + pad_token_id: Pad token ID. + logits_to_keep (int): Number of logits to keep for masking. + rewards (torch.Tensor): [Batch] Rewards for each generation. + beta (float) = 0.2: A hyperparameter for weighting the KL divergence term. + epsilon (float) = 0.2: An float hyperparameter for clipping the importance weights. + + Returns: + torch.Tensor: The computed GRPO loss. + """ + B = logps.shape[0] + assert B > 1, "Batch * Num generations should be greater than 1" + + rewards_shaped = rewards.view(-1, B) # B,num_generations + advantages = (rewards_shaped - rewards_shaped.mean(dim=1, keepdim=True)) / \ + (rewards_shaped.std(dim=1, keepdim=True) + 1e-8) + advantages = advantages.view(-1) # B*num_generations + # Calculate the per - token KL divergence + per_token_kl = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1 + + # Calculate the ratio of probabilities (importance weights) + # Importance weights are calculated as exp(log_pi_theta - log_pi_theta_old) + importance_weights = torch.exp(logps - old_logps) + + # Clip the importance weights to the range [1 - epsilon, 1 + epsilon] + importance_weights_clipped = torch.clamp(importance_weights, 1 - epsilon, 1 + epsilon) + + # Create a completion mask. It checks which positions are valid based on logits_to_keep + completion_mask = torch.arange(logits_to_keep, device=logps.device)[None, :] >= 0 + + # Combine the completion mask and padding mask + completion_mask = completion_mask & pad_mask # Ensure matching shape + + # Add an extra dimension to advantages to match the shape for element - wise multiplication + advantages = advantages.unsqueeze(1) + + # Calculate the per - token loss. It takes the minimum of the unclipped and clipped importance weights + # and subtracts the KL divergence term weighted by beta, then multiplies by the completion mask + token_loss = -(torch.min(advantages * importance_weights, advantages * + importance_weights_clipped) - beta * per_token_kl) * completion_mask + + # Calculate the final loss by summing the token losses and normalizing by the number of valid tokens + loss = -token_loss.sum() / completion_mask.sum() + + return loss diff --git a/fla3/modules/l2norm.py b/fla3/modules/l2norm.py new file mode 100644 index 0000000000000000000000000000000000000000..04e670eab4abf180ec649cf50de7d7551550c3bc --- /dev/null +++ b/fla3/modules/l2norm.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from fla.utils import input_guard + +BT_LIST = [8, 16, 32, 64, 128] + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit +def l2norm_bwd_kernel1( + x, + dy, + dx, + eps, + D, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + x += i_t * D + dx += i_t * D + dy += i_t * D + + # Y += i_t * stride_y_row + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x) + b_rstd = 1 / tl.sqrt(b_var + eps) + b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32) + b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x) * (1 / (b_var+eps)) * b_rstd * b_x + tl.store(dx + cols, b_dx, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in BT_LIST + ], + key=['D', 'NB'] +) +@triton.jit +def l2norm_fwd_kernel( + x, + y, + eps, + NB: tl.constexpr, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in BT_LIST + ], + key=['D', 'NB'] +) +@triton.jit +def l2norm_bwd_kernel( + x, + dy, + dx, + eps, + NB: tl.constexpr, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1)[:, None] + b_rstd = 1 / tl.sqrt(b_var + eps) + b_dx = b_dy * b_rstd - tl.sum(b_dy * b_x, axis=1)[:, None] / (b_var+eps) * b_rstd * b_x + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def l2norm_fwd( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None +): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if D <= 512: + NB = triton.cdiv(T, 2048) + def grid(meta): return (triton.cdiv(T, meta['BT']), ) + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) + + +def l2norm_bwd( + x: torch.Tensor, + dy: torch.Tensor, + eps: float = 1e-5 +): + x_shape_og = x.shape + x = x.view(-1, dy.shape[-1]) + dy = dy.view(-1, dy.shape[-1]) + assert dy.shape == x.shape + # allocate output + dx = torch.empty_like(x) + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + + if D <= 512: + NB = triton.cdiv(T, 2048) + def grid(meta): return (triton.cdiv(T, meta['BT']), ) + l2norm_bwd_kernel[grid]( + x, + dy, + dx, + eps=eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_bwd_kernel1[(T,)]( + x, + dy, + dx, + eps=eps, + D=D, + BD=BD, + ) + + return dx.view(x_shape_og) + + +class L2NormFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + eps=1e-6, + output_dtype=None + ): + y = l2norm_fwd(x, eps, output_dtype) + ctx.eps = eps + ctx.x_dtype = x.dtype + ctx.save_for_backward(x) + return y + + @staticmethod + @input_guard + def backward(ctx, dy): + x, = ctx.saved_tensors + dx = l2norm_bwd(x, dy, ctx.eps) + return dx, None, None + + +def l2norm( + x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + return L2NormFunction.apply(x, eps, output_dtype) + + +l2_norm = l2norm + + +class L2Norm(nn.Module): + + def __init__( + self, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None + ): + super().__init__() + self.eps = eps + self.output_dtype = output_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return l2norm(x, self.eps, self.output_dtype) diff --git a/fla3/modules/layernorm.py b/fla3/modules/layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..f3aa3bf2b523e64a4976cec8ea85bfd9daeca990 --- /dev/null +++ b/fla3/modules/layernorm.py @@ -0,0 +1,1432 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Copyright (c) 2023, Tri Dao +# https://github.com/state-spaces/mamba/blob/fb7b5310fa865dbd62aa059b1e26f2b431363e2a/mamba_ssm/ops/triton/layernorm.py +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +from __future__ import annotations + +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module +from torch.distributed.tensor.parallel import ParallelStyle + +from fla.utils import get_multiprocessor_count, input_guard + + +def layer_norm_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + upcast: bool = False +): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + upcast: bool = False +): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +def group_norm_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + num_groups: int, + residual: torch.Tensor = None, + eps: float = 1e-5, + is_rms_norm: bool = False, + prenorm: bool = False, + upcast: bool = False +): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + residual = x + x, weight = [ + rearrange(data, "... (g d) -> ... g d", g=num_groups) for data in (x, weight) + ] + if bias is not None: + bias = rearrange(bias, '... (g d) -> ... g d', g=num_groups) + if not is_rms_norm: + mean = x.mean(dim=-1, keepdim=True) + x = x - mean + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = rearrange(out, "... g d -> ... (g d)") + out = out.to(dtype) + return out if not prenorm else (out, residual) + + +class GroupNormRef(nn.Module): + + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5, + is_rms_norm: bool = False + ) -> GroupNormRef: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.is_rms_norm = is_rms_norm + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + if self.is_rms_norm: + s += f", is_rms_norm={self.is_rms_norm}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False): + return group_norm_ref( + x, + self.weight, + self.bias, + num_groups=self.num_groups, + residual=residual, + eps=self.eps, + is_rms_norm=self.is_rms_norm, + prenorm=prenorm, + upcast=True + ) + + +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [8, 16, 32, 64, 128] + for num_warps in [1, 2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'NB', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM'], +) +@triton.jit +def layer_norm_fwd_kernel( + x, # pointer to the input + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + res, # pointer to the res + res_out, # pointer to the res + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + T, + G: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + i_t = tl.program_id(0) + + o_t = i_t * BT + tl.arange(0, BT) + o_g = o_t % G + o_d = tl.arange(0, BD) + m_d = o_d < D + + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + if HAS_RESIDUAL: + p_res = tl.make_block_ptr(res, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32) + if STORE_RESIDUAL_OUT: + p_res_out = tl.make_block_ptr(res_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1)) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=1) / D + p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,)) + b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + else: + b_xbar = tl.where(m_d[None, :], b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=1) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + + p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) + + if HAS_WEIGHT: + b_w = tl.load(w + o_g[:, None] * D + o_d[None, :], mask=m_d[None, :]).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + o_g[:, None] * D + o_d[None, :], mask=m_d[None, :]).to(tl.float32) + b_x_hat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] + b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b + + # Write output + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM'], +) +@triton.jit +def layer_norm_fwd_kernel1( + x, # pointer to the input + y, # pointer to the output + w, # pointer to the weights + b, # pointer to the biases + res, # pointer to the res + res_out, # pointer to the res + mean, # pointer to the mean + rstd, # pointer to the 1/std + eps, # epsilon to avoid division by zero + G: tl.constexpr, + D: tl.constexpr, + BD: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr +): + i_t = tl.program_id(0) + i_g = i_t % G + + x += i_t * D + y += i_t * D + if HAS_RESIDUAL: + res += i_t * D + if STORE_RESIDUAL_OUT: + res_out += i_t * D + + o_d = tl.arange(0, BD) + m_d = o_d < D + b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + b_x += tl.load(res + o_d, mask=m_d, other=0.0).to(tl.float32) + if STORE_RESIDUAL_OUT: + tl.store(res_out + o_d, b_x, mask=m_d) + if not IS_RMS_NORM: + b_mean = tl.sum(b_x, axis=0) / D + tl.store(mean + i_t, b_mean) + b_xbar = tl.where(m_d, b_x - b_mean, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + else: + b_xbar = tl.where(m_d, b_x, 0.0) + b_var = tl.sum(b_xbar * b_xbar, axis=0) / D + b_rstd = 1 / tl.sqrt(b_var + eps) + tl.store(rstd + i_t, b_rstd) + + if HAS_WEIGHT: + b_w = tl.load(w + i_g * D + o_d, mask=m_d).to(tl.float32) + if HAS_BIAS: + b_b = tl.load(b + i_g * D + o_d, mask=m_d).to(tl.float32) + b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd + b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat + if HAS_BIAS: + b_y = b_y + b_b + + # Write output + tl.store(y + o_d, b_y, mask=m_d) + + +@triton.heuristics({ + 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [8, 16, 32, 64] + for num_warps in [1, 2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'NB', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM'], +) +@triton.jit +def layer_norm_bwd_kernel( + x, # pointer to the input + w, # pointer to the weights + b, # pointer to the biases + y, # pointer to the output to be recomputed + dy, # pointer to the output gradient + dx, # pointer to the input gradient + dw, # pointer to the partial sum of weights gradient + db, # pointer to the partial sum of biases gradient + dres, + dres_in, + mean, + rstd, + T, + G: tl.constexpr, + D: tl.constexpr, + BS: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NB: tl.constexpr, + GS: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + i_s = tl.program_id(0) + i_g, i_sg = i_s // GS, i_s % GS + + o_d = tl.arange(0, BD) + m_d = o_d < D + if HAS_WEIGHT: + b_w = tl.load(w + i_g * D + o_d, mask=m_d).to(tl.float32) + b_dw = tl.zeros((BT, BD), dtype=tl.float32) + if HAS_BIAS: + b_b = tl.load(b + i_g * D + o_d, mask=m_d, other=0.0).to(tl.float32) + b_db = tl.zeros((BT, BD), dtype=tl.float32) + + T = min(i_sg * BS + BS, T // G) + for i_t in range(i_sg * BS, T, BT): + p_x = tl.make_block_ptr(x + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_dy = tl.make_block_ptr(dy + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) + + if not IS_RMS_NORM: + p_mean = tl.make_block_ptr(mean + i_g, (T,), (G,), (i_t,), (BT,), (0,)) + b_mean = tl.load(p_mean, boundary_check=(0,)) + p_rstd = tl.make_block_ptr(rstd + i_g, (T,), (G,), (i_t,), (BT,), (0,)) + b_rstd = tl.load(p_rstd, boundary_check=(0,)) + # Compute dx + b_xhat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] + b_xhat = tl.where(m_d[None, :], b_xhat, 0.0) + + b_y = b_xhat * b_w[None, :] if HAS_WEIGHT else b_xhat + if HAS_BIAS: + b_y = b_y + b_b[None, :] + if RECOMPUTE_OUTPUT: + p_y = tl.make_block_ptr(y + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + b_wdy = b_dy + + if HAS_WEIGHT or HAS_BIAS: + m_t = (i_t + tl.arange(0, BT)) < T + if HAS_WEIGHT: + b_wdy = b_dy * b_w + b_dw += tl.where(m_t[:, None], b_dy * b_xhat, 0.0) + if HAS_BIAS: + b_db += tl.where(m_t[:, None], b_dy, 0.0) + if not IS_RMS_NORM: + b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D + b_c2 = tl.sum(b_wdy, axis=1) / D + b_dx = (b_wdy - (b_xhat * b_c1[:, None] + b_c2[:, None])) * b_rstd[:, None] + else: + b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D + b_dx = (b_wdy - b_xhat * b_c1[:, None]) * b_rstd[:, None] + if HAS_DRESIDUAL: + p_dres = tl.make_block_ptr(dres + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + b_dres = tl.load(p_dres, boundary_check=(0, 1)).to(tl.float32) + b_dx += b_dres + # Write dx + if STORE_DRESIDUAL: + p_dres_in = tl.make_block_ptr(dres_in + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) + tl.store(p_dres_in, b_dx.to(p_dres_in.dtype.element_ty), boundary_check=(0, 1)) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + if HAS_WEIGHT: + tl.store(dw + i_s * D + o_d, tl.sum(b_dw, axis=0), mask=m_d) + if HAS_BIAS: + tl.store(db + i_s * D + o_d, tl.sum(b_db, axis=0), mask=m_d) + + +@triton.heuristics({ + 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['D', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM'], +) +@triton.jit +def layer_norm_bwd_kernel1( + x, # pointer to the input + w, # pointer to the weights + b, # pointer to the biases + y, # pointer to the output to be recomputed + dy, # pointer to the output gradient + dx, # pointer to the input gradient + dw, # pointer to the partial sum of weights gradient + db, # pointer to the partial sum of biases gradient + dres, + dres_in, + mean, + rstd, + T, + G: tl.constexpr, + D: tl.constexpr, + BS: tl.constexpr, + BD: tl.constexpr, + GS: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + i_s = tl.program_id(0) + i_g, i_sg = i_s // GS, i_s % GS + + o_d = tl.arange(0, BD) + mask = o_d < D + + if HAS_WEIGHT: + b_w = tl.load(w + i_g * D + o_d, mask=mask).to(tl.float32) + b_dw = tl.zeros((BD,), dtype=tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b_b = tl.load(b + i_g * D + o_d, mask=mask, other=0.0).to(tl.float32) + if HAS_BIAS: + b_db = tl.zeros((BD,), dtype=tl.float32) + + for i_t in range(i_sg * BS * G + i_g, min((i_sg * BS + BS) * G + i_g, T), G): + b_x = tl.load(x + i_t * D + o_d, mask=mask, other=0).to(tl.float32) + b_dy = tl.load(dy + i_t * D + o_d, mask=mask, other=0).to(tl.float32) + + if not IS_RMS_NORM: + b_mean = tl.load(mean + i_t) + b_rstd = tl.load(rstd + i_t) + # Compute dx + b_xhat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd + b_xhat = tl.where(mask, b_xhat, 0.0) + if RECOMPUTE_OUTPUT: + b_y = b_xhat * b_w if HAS_WEIGHT else b_xhat + if HAS_BIAS: + b_y = b_y + b_b + tl.store(y + i_t * D + o_d, b_y, mask=mask) + b_wdy = b_dy + if HAS_WEIGHT: + b_wdy = b_dy * b_w + b_dw += b_dy * b_xhat + if HAS_BIAS: + b_db += b_dy + if not IS_RMS_NORM: + b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D + b_c2 = tl.sum(b_wdy, axis=0) / D + b_dx = (b_wdy - (b_xhat * b_c1 + b_c2)) * b_rstd + else: + b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D + b_dx = (b_wdy - b_xhat * b_c1) * b_rstd + if HAS_DRESIDUAL: + b_dres = tl.load(dres + i_t * D + o_d, mask=mask, other=0).to(tl.float32) + b_dx += b_dres + # Write dx + b_dx = tl.cast(b_dx, dtype=dx.dtype.element_ty, fp_downcast_rounding='rtne') + if STORE_DRESIDUAL: + tl.store(dres_in + i_t * D + o_d, b_dx, mask=mask) + tl.store(dx + i_t * D + o_d, b_dx, mask=mask) + + if HAS_WEIGHT: + tl.store(dw + i_s * D + o_d, b_dw, mask=mask) + if HAS_BIAS: + tl.store(db + i_s * D + o_d, b_db, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, + residual: torch.Tensor = None, + out_dtype: torch.dtype = None, + residual_dtype: torch.dtype = None, + is_rms_norm: bool = False, + num_groups: int = 1, +): + if residual is not None: + residual_dtype = residual.dtype + T, D, G = *x.shape, num_groups + if residual is not None: + assert residual.shape == (T, D) + if weight is not None: + assert weight.shape == (G * D,) + if bias is not None: + assert bias.shape == (G * D,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + res_out = torch.empty(T, D, device=x.device, dtype=residual_dtype) + else: + res_out = None + mean = torch.empty((T,), dtype=torch.float, device=x.device) if not is_rms_norm else None + rstd = torch.empty((T,), dtype=torch.float, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + + if D <= 512: + NB = triton.cdiv(T, 2048) + def grid(meta): return (triton.cdiv(T, meta['BT']), ) + layer_norm_fwd_kernel[grid]( + x, + y, + weight, + bias, + residual, + res_out, + mean, + rstd, + eps, + T=T, + G=G, + D=D, + BD=BD, + NB=NB, + IS_RMS_NORM=is_rms_norm, + HAS_RESIDUAL=residual is not None, + STORE_RESIDUAL_OUT=res_out is not None, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + ) + else: + layer_norm_fwd_kernel1[(T,)]( + x, + y, + weight, + bias, + residual, + res_out, + mean, + rstd, + eps, + G=G, + D=D, + BD=BD, + IS_RMS_NORM=is_rms_norm, + HAS_RESIDUAL=residual is not None, + STORE_RESIDUAL_OUT=res_out is not None, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + ) + # res_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, res_out if res_out is not None else x + + +def layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor = None, + rstd: torch.Tensor = None, + dres: torch.Tensor = None, + has_residual: bool = False, + is_rms_norm: bool = False, + x_dtype: torch.dtype = None, + recompute_output: bool = False, + num_groups: int = 1, +): + T, D, G = *x.shape, num_groups + assert dy.shape == (T, D) + if dres is not None: + assert dres.shape == (T, D) + if weight is not None: + assert weight.shape == (G * D,) + if bias is not None: + assert bias.shape == (G * D,) + # allocate output + dx = torch.empty_like(x) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device) + dres_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(T, D, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # each program handles one group only + NS = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G + BS = triton.cdiv(T, NS) + GS = NS // G + + dw = torch.empty((NS, D), dtype=torch.float, device=weight.device) if weight is not None else None + db = torch.empty((NS, D), dtype=torch.float, device=bias.device) if bias is not None else None + grid = (NS,) + + if D <= 512: + NB = triton.cdiv(T, 2048) + layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + dw, + db, + dres, + dres_in, + mean, + rstd, + T=T, + G=G, + D=D, + BS=BS, + BD=BD, + NB=NB, + GS=GS, + IS_RMS_NORM=is_rms_norm, + HAS_DRESIDUAL=dres is not None, + STORE_DRESIDUAL=dres_in is not None, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + ) + else: + layer_norm_bwd_kernel1[grid]( + x, + weight, + bias, + y, + dy, + dx, + dw, + db, + dres, + dres_in, + mean, + rstd, + T=T, + G=G, + D=D, + BS=BS, + BD=BD, + GS=GS, + IS_RMS_NORM=is_rms_norm, + HAS_DRESIDUAL=dres is not None, + STORE_DRESIDUAL=dres_in is not None, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + ) + dw = dw.view(G, -1, D).sum(1).to(weight).view_as(weight) if weight is not None else None + db = db.view(G, -1, D).sum(1).to(bias).view_as(bias) if bias is not None else None + # Don't need to compute dres_in separately in this case + if has_residual and dx.dtype == x.dtype: + dres_in = dx + return (dx, dw, db, dres_in) if not recompute_output else (dx, dw, db, dres_in, y) + + +class LayerNormFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + weight, + bias, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + num_groups: int = 1 + ): + x_shape_og = x.shape + + if x.shape[-1] % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + # reshape input data into 2D tensor + x = x.reshape(-1, (x.shape[-1] // num_groups)) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape_as(x) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, res_out = layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + num_groups=num_groups + ) + ctx.save_for_backward(res_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.num_groups = num_groups + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, res_out.reshape(x_shape_og)) + + @staticmethod + @input_guard + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, x.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = layer_norm_bwd( + dy, + x, + weight, + bias, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + num_groups=ctx.num_groups + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + None + ) + + +def layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False +): + return LayerNormFunction.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm + ) + + +def group_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + num_groups: int = 1 +): + return LayerNormFunction.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + num_groups + ) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False +): + return LayerNormFunction.apply( + x, + weight, + bias, + residual, + eps, + prenorm, + residual_in_fp32, + True + ) + + +def layer_norm_linear( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + num_groups: int = 1 +): + return LayerNormLinearFunction.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + num_groups + ) + + +def rms_norm_linear( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False +): + return layer_norm_linear( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + linear_weight=linear_weight, + linear_bias=linear_bias, + residual=residual, + eps=eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True + ) + + +def group_norm_linear( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + linear_weight: torch.Tensor, + linear_bias: torch.Tensor, + residual: torch.Tensor = None, + eps: float = 1e-5, + prenorm: bool = False, + residual_in_fp32: bool = False, + is_rms_norm: bool = False, + num_groups: int = 1 +): + return layer_norm_linear( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + linear_weight=linear_weight, + linear_bias=linear_bias, + residual=residual, + eps=eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=is_rms_norm, + num_groups=num_groups + ) + + +class LayerNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> LayerNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32 + ) + + +class GroupNorm(nn.Module): + + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5, + is_rms_norm: bool = False + ) -> GroupNorm: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.is_rms_norm = is_rms_norm + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + if self.is_rms_norm: + s += f", is_rms_norm={self.is_rms_norm}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return group_norm( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=self.is_rms_norm, + num_groups=self.num_groups + ) + + +class RMSNorm(nn.Module): + + def __init__( + self, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> RMSNorm: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-5, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + num_groups=1 + ): + x_shape_og = x.shape + + if x.shape[-1] % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + # reshape input data into 2D tensor + x = x.reshape(-1, (x.shape[-1] // num_groups)) + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape_as(x) + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, res_out = layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + num_groups=num_groups + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(res_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.num_groups = num_groups + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, res_out.reshape(x_shape_og)) + + @staticmethod + @input_guard + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, x.shape[-1]) + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + num_groups=ctx.num_groups + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y.view(-1, linear_weight.shape[-1])) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + None + ) + + +class LayerNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> LayerNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear( + x=x, + norm_weight=self.weight, + norm_bias=self.bias, + linear_weight=weight, + linear_bias=bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=False + ) + + +class GroupNormLinear(nn.Module): + + def __init__( + self, + num_groups: int, + hidden_size: int, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5, + is_rms_norm: bool = False + ) -> GroupNormLinear: + super().__init__() + + if hidden_size % num_groups != 0: + raise ValueError('num_channels must be divisible by num_groups') + + self.num_groups = num_groups + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + self.is_rms_norm = is_rms_norm + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + if self.is_rms_norm: + s += f", is_rms_norm={self.is_rms_norm}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear( + x=x, + norm_weight=self.weight, + norm_bias=self.bias, + linear_weight=weight, + linear_bias=bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=self.is_rms_norm, + num_groups=self.num_groups + ) + + +class RMSNormLinear(nn.Module): + + def __init__( + self, + hidden_size, + elementwise_affine: bool = True, + bias: bool = False, + eps: float = 1e-5 + ) -> RMSNormLinear: + super().__init__() + + self.hidden_size = hidden_size + self.elementwise_affine = elementwise_affine + self.eps = eps + + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if elementwise_affine: + self.weight = nn.Parameter(torch.empty(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.empty(hidden_size)) + + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + nn.init.ones_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.hidden_size}" + if not self.elementwise_affine: + s += f", elementwise_affine={self.elementwise_affine}" + s += f", eps={self.eps}" + s += ")" + return s + + def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): + return layer_norm_linear( + x=x, + norm_weight=self.weight, + norm_bias=self.bias, + linear_weight=weight, + linear_bias=bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + is_rms_norm=True + ) + + +class NormParallel(ParallelStyle): + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): + super().__init__() + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): + for p_name, param in module.named_parameters(): + # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow + # us to simply just use from_local + replicated_param = torch.nn.Parameter( + DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) + ) + module.register_parameter(p_name, replicated_param) + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True + ) + return input_tensor + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local( + input_tensor, device_mesh, sequence_sharding, run_check=False + ) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) + + @staticmethod + def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._replicate_module_fn, + partial(self._prepare_input_fn, self.sequence_sharding), + partial(self._prepare_output_fn, self.use_local_output), + ) diff --git a/fla3/modules/layernorm_gated.py b/fla3/modules/layernorm_gated.py new file mode 100644 index 0000000000000000000000000000000000000000..1a72ff839dc021e484d55486cc11a9c9f85863fe --- /dev/null +++ b/fla3/modules/layernorm_gated.py @@ -0,0 +1,528 @@ +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + +from fla.utils import get_multiprocessor_count, input_guard + + +def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True): + dtype = x.dtype + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + z = z.float() if z is not None else z + if z is not None and not norm_before_gate: + x = x * F.silu(z) + if group_size is None: + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + else: + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) + out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + if z is not None and norm_before_gate: + out *= F.silu(z) + return out.to(dtype) + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def layer_norm_fwd( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor = None, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps + ) + return out, mean, rstd + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + "RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None, +}) +@triton.jit +def layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DZ, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_z_row, + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dz_row, + stride_dw_row, + stride_db_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + group = tl.program_id(1) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + group * N + if HAS_Z: + Z += row_start * stride_z_row + group * N + DZ += row_start * stride_dz_row + group * N + DY += row_start * stride_dy_row + group * N + DX += row_start * stride_dx_row + group * N + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: + B += group * N + b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + x_og = x + x = x_og * z * tl.sigmoid(z) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.) + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + y = xhat * w + b if HAS_BIAS else xhat * w + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y * z * z_sigmoid, mask=mask) + dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dy *= z * z_sigmoid + else: + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + c1 = tl.sum(xhat * wdy, axis=0) / N + if not IS_RMS_NORM: + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + dx = (wdy - xhat * c1) * rstd + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_Z and not NORM_BEFORE_GATE: + z_sigmoid = tl.sigmoid(z) + dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dx *= z * z_sigmoid + # Write dx + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_Z: + Z += stride_z_row + DZ += stride_dz_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) + + +def layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + mean: torch.Tensor, + rstd: torch.Tensor, + z: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + is_rms_norm: bool = False, + recompute_output: bool = False, + dz: torch.Tensor = None, + out: torch.Tensor = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + assert dz.stride(-1) == 1 + else: + dz = torch.empty_like(z) if z is not None else None + if recompute_output: + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + sm_count = get_multiprocessor_count(x.device.index) + # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs + # would limit the occupancy. + nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) + _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / nrow_groups) + grid = (nrow_groups, ngroups) + layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + z, + out if recompute_output else None, + dy, + dx, + _dw, + _db, + dz, + mean, + rstd, + x.stride(0), + z.stride(0) if z is not None else 0, + 0 if not recompute_output else out.stride(0), + dy.stride(0), + dx.stride(0), + dz.stride(0) if dz is not None else 0, + _dw.stride(0), + _db.stride(0) if _db is not None else 0, + M, group_size, eps, + rows_per_program, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) + + +class LayerNormFn(torch.autograd.Function): + + @input_guard + @staticmethod + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + @input_guard + @staticmethod + def backward(ctx, dy): + x, weight, bias, mean, rstd, z = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + dx, dw, db, dz = layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + z, + ctx.group_size, + ctx.norm_before_gate, + ctx.is_rms_norm + ) + dx = dx.reshape(ctx.x_shape_og) + dz = dz.reshape(ctx.x_shape_og) if dz is not None else None + return dx, dw, db, dz, None, None, None, None + + +def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm) + + +def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True): + return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True) + + +class LayerNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps, + norm_before_gate=self.norm_before_gate) + + +class RMSNormGated(nn.Module): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size, + norm_before_gate=self.norm_before_gate) diff --git a/fla3/modules/mlp.py b/fla3/modules/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f35aa6910ad143eb632fc28684043704e8741a --- /dev/null +++ b/fla3/modules/mlp.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING, Any, Optional + +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, Placement, Replicate, Shard, distribute_module +from torch.distributed.tensor.parallel import ParallelStyle + +from fla.modules.activations import swiglu, swiglu_linear + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + +class GatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + hidden_ratio: Optional[int] = None, + intermediate_size: Optional[int] = None, + hidden_act: str = 'swish', + fuse_swiglu: bool = True + ) -> GatedMLP: + super().__init__() + + self.hidden_size = hidden_size + # the final number of params is `hidden_ratio * hidden_size^2` + # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio` + if hidden_ratio is None: + hidden_ratio = 4 + if intermediate_size is None: + intermediate_size = int(hidden_size * hidden_ratio * 2 / 3) + intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256) + self.hidden_ratio = hidden_ratio + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.fuse_swiglu = fuse_swiglu + + if hidden_act != 'swish': + raise ValueError(f'Unsupported hidden_act: {hidden_act}') + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if self.fuse_swiglu: + self.swiglu_linear = SwiGLULinear() + + def forward( + self, + x: torch.Tensor, + **kwargs: Unpack[Any] + ) -> torch.Tensor: + gate, y = self.gate_proj(x), self.up_proj(x) + if self.fuse_swiglu: + return self.swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias) + else: + return self.down_proj(swiglu(gate, y)) + + +class SwiGLULinear(nn.Module): + + def forward(self, x, y, weight, bias): + return swiglu_linear(x, y, weight, bias) + + +class SwiGLULinearParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) + self.desired_input_layouts = (Shard(-1),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + x, y, weight, bias = inputs + if not isinstance(x, DTensor): + x = DTensor.from_local(x, device_mesh, input_layouts, run_check=False) + if x.placements != desired_input_layouts: + x = x.redistribute(placements=desired_input_layouts, async_op=True) + + if not isinstance(y, DTensor): + y = DTensor.from_local(y, device_mesh, input_layouts, run_check=False) + if y.placements != desired_input_layouts: + y = y.redistribute(placements=desired_input_layouts, async_op=True) + + if not isinstance(weight, DTensor): + weight = DTensor.from_local(weight, device_mesh, (Shard(1),)) + + if bias is not None and not isinstance(bias, DTensor): + bias = DTensor.from_local(bias, device_mesh, (Replicate(),)) + + return x, y, weight, bias + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=None, + input_fn=partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), + output_fn=partial(self._prepare_output_fn, self.output_layouts, self.use_local_output) + ) diff --git a/fla3/modules/parallel.py b/fla3/modules/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3561476adeed63be51edbee270106291d6a646ac --- /dev/null +++ b/fla3/modules/parallel.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor, distribute_module +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +class PrepareModuleWeight(ParallelStyle): + def __init__(self, *, layouts: Optional[Placement] = None): + super().__init__() + self.layouts = layouts + + def _replicate_module_fn( + self, + name: str, + module: nn.Module, + device_mesh: DeviceMesh + ): + for p_name, param in module.named_parameters(): + replicated_param = nn.Parameter( + DTensor.from_local(param, device_mesh, [self.layouts], run_check=False) + ) + module.register_parameter(p_name, replicated_param) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._replicate_module_fn, + input_fn=None, + output_fn=None + ) diff --git a/fla3/modules/rotary.py b/fla3/modules/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..64dbfe487176fbe5101b85ccdc8716126bda1672 --- /dev/null +++ b/fla3/modules/rotary.py @@ -0,0 +1,497 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import triton +import triton.language as tl +from einops import rearrange, repeat + +from fla.ops.utils import prepare_chunk_indices +from fla.utils import get_multiprocessor_count, input_guard + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) + + +def rotary_embedding_ref(x, cos, sin, interleaved=False): + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') + sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)') + return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'D', 'INTERLEAVED'], +) +@triton.jit(do_not_specialize=['T']) +def rotary_embedding_kernel( + x, + cos, + sin, + y, + cu_seqlens, + chunk_indices, + seq_offsets, + T, + B: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + R: tl.constexpr, + TR: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr +): + i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n), tl.load(cu_seqlens + i_n + 1) + T = eos - bos + x = x + bos * H*D + i_h * D + y = y + bos * H*D + i_h * D + else: + i_n = i_b + x = x + i_n * T*H*D + i_h * D + y = y + i_n * T*H*D + i_h * D + + if i_t * BT >= T: + return + + o_t = i_t * BT + tl.arange(0, BT) + if not IS_SEQLEN_OFFSETS_TENSOR: + o_cs = o_t + seq_offsets + else: + o_cs = o_t + tl.load(seq_offsets + i_n) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out + o_r = tl.arange(0, BD // 2) + p_x = x + o_t[:, None] * H*D + o_r[None, :] + p_cos = cos + (o_cs[:, None] * R + o_r[None, :]) + p_sin = sin + (o_cs[:, None] * R + o_r[None, :]) + mask = (o_t[:, None] >= 0) & (o_t[:, None] < T) & (o_r[None, :] < R) + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x + R, mask=mask, other=0.0).to(tl.float32) + if CONJUGATE: + b_sin = -b_sin + b_o0 = b_x0 * b_cos - b_x1 * b_sin + b_o1 = b_x0 * b_sin + b_x1 * b_cos + # write back result + p_y = y + (o_t[:, None] * H*D + o_r[None, :]) + tl.store(p_y, b_o0, mask=mask) + tl.store(p_y + R, b_o1, mask=mask) + else: + # We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + o_d = tl.arange(0, BD) + o_d_swap = o_d + ((o_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + o_d_repeat = tl.arange(0, BD) // 2 + p_x0 = x + o_t[:, None] * H*D + o_d[None, :] + p_x1 = x + o_t[:, None] * H*D + o_d_swap[None, :] + p_cos = cos + (o_cs[:, None] * R + o_d_repeat[None, :]) + p_sin = sin + (o_cs[:, None] * R + o_d_repeat[None, :]) + mask = (o_cs[:, None] >= 0) & (o_cs[:, None] < TR) & (o_d_repeat[None, :] < R) + + b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32) + b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32) + b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32) + b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32) + if CONJUGATE: + b_sin = -b_sin + b_o0 = b_x0 * b_cos + b_o1 = b_x1 * b_sin + b_y = tl.where(o_d[None, :] % 2 == 0, b_o0 - b_o1, b_o0 + b_o1) + p_y = y + (o_t[:, None] * H*D + o_d[None, :]) + tl.store(p_y, b_y, mask=mask) + + +def rotary_embedding_fwdbwd( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False +) -> torch.Tensor: + """ + Args: + x: [B, T, H, D]. + cos: [TR, R / 2] + sin: [TR, R / 2] + seqlen_offsets: integer or integer tensor of size [N] + cu_seqlens: [N + 1,] or None + + Returns: + y: [B, T, H, D] + """ + is_varlen = cu_seqlens is not None + + B, T, H, D = x.shape + N = B if not is_varlen else cu_seqlens.shape[0] - 1 + TR, R = cos.shape + R2 = R * 2 + + assert D <= 256, "Only support D <= 256" + assert TR >= T, "TR must be >= T" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (N,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + else: + assert seqlen_offsets + T <= TR + + y = torch.empty_like(x) if not inplace else x + if R2 < D and not inplace: + y[..., R2:].copy_(x[..., R2:]) + + BD = triton.next_power_of_2(R2) + BT = min(128, triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index)))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if is_varlen else None + NT = len(chunk_indices) if is_varlen else triton.cdiv(T, BT) + + grid = (NT, B, H) + rotary_embedding_kernel[grid]( + x, + cos, + sin, + y, + cu_seqlens, + chunk_indices, + seqlen_offsets, + B=B, + T=T, + H=H, + D=D, + R=R, + TR=TR, + BT=BT, + BD=BD, + IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor), + IS_VARLEN=is_varlen, + INTERLEAVED=interleaved, + CONJUGATE=conjugate + ) + return y + + +class RotaryEmbeddingFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + ): + y = rotary_embedding_fwdbwd( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + # Can't save int with save_for_backward + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + return y if not inplace else x + + @staticmethod + @input_guard + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = rotary_embedding_fwdbwd( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def rotary_embedding( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None +): + """ + Args: + x: [B, T, H, D] + cos, sin: [TR, R//2] + interleaved: + If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). + inplace: + If True, apply rotary embedding in-place. + seqlen_offsets: [N,] or int. + Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: [N + 1,] or None + + Returns: + out: [B, T, H, D] + """ + return RotaryEmbeddingFunction.apply( + x, + cos, + sin, + interleaved, + inplace, + seqlen_offsets, + cu_seqlens + ) + + +class RotaryEmbedding(nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base: float = 10000.0, + scale_base: Optional[float] = None, + interleaved: bool = False, + pos_idx_in_fp32: bool = True, + device: Optional[torch.device] = None, + ): + """ + interleaved: + If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: + If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. + In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, we add this option. + """ + super().__init__() + + self.dim = dim + self.base = float(base) + self.scale_base = scale_base + self.interleaved = interleaved + self.pos_idx_in_fp32 = pos_idx_in_fp32 + self.device = device + + # Generate and save the inverse frequency buffer (non trainable) + self.register_buffer("inv_freq", torch.empty(-(dim // -2), dtype=torch.float32, device=device), persistent=False) + + scale = None + if scale_base is not None: + scale = torch.empty(-(dim // -2), dtype=torch.float32, device=device) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + self.inv_freq.copy_(self._compute_inv_freq(device=self.inv_freq.device)) + if self.scale_base is not None: + self.scale.copy_(self._compute_scale(device=self.scale.device)) + + def __repr__(self): + s = f"{self.__class__.__name__}(" + s += f"dim={self.dim}, " + s += f"base={self.base}, " + s += f"interleaved={self.interleaved}, " + if self.scale_base is not None: + s += f"scale_base={self.scale_base}, " + s += f"pos_idx_in_fp32={self.pos_idx_in_fp32})" + return s + + def _compute_inv_freq(self, device=None): + return 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) + ) + + def _compute_scale(self, device=None): + return (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + 0.4 * self.dim) / (1.4 * self.dim) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + seqlen_offset: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + q: [B, T, H, D] + k: [B, T, H, D] + seqlen_offset: + [N] or int. + Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: [N + 1] or None + max_seqlen: int + """ + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype) + if self.scale is None: + q = rotary_embedding( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + ) + k = rotary_embedding( + k, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + ) + + else: + q = rotary_embedding( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + ) + k = rotary_embedding( + k, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + cu_seqlens=cu_seqlens, + ) + + return q, k diff --git a/fla3/modules/token_shift.py b/fla3/modules/token_shift.py new file mode 100644 index 0000000000000000000000000000000000000000..07eb9c59dc77339cb9299fe2f468055d350a0af6 --- /dev/null +++ b/fla3/modules/token_shift.py @@ -0,0 +1,243 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import input_guard + + +def token_shift_ref( + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None +) -> torch.Tensor: + if cu_seqlens is not None: + # Variable length mode with cu_seqlens + assert x.dim() == 3, "Input must be [B, T, D]" + B, T, D = x.shape + assert B == 1, "Batch size must be 1 when using cu_seqlens" + + result = torch.zeros_like(x) + N = cu_seqlens.shape[0] - 1 + + for i in range(N): + start = cu_seqlens[i].item() + end = cu_seqlens[i+1].item() + seq_len = end - start + + if seq_len <= 1: + # For sequences of length 1 or 0, delta is simply -x + result[0, start:end] = -x[0, start:end] + else: + # For longer sequences, handle padding manually + shifted = torch.zeros_like(x[0, start:end]) + shifted[1:] = x[0, start:end-1] + delta = shifted - x[0, start:end] + result[0, start:end] = delta + + return result + else: + time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) + shifted = time_shift(x) + delta = shifted - x + return delta + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [1, 2, 3, 4] + ], + key=['BD'], +) +@triton.jit +def token_shift_fwd_kernel( + x, + y, + cu_seqlens, + T, + D: tl.constexpr, + BD: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_b, i_t = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = i_b + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + + if i_t < bos or i_t >= eos: + return + + is_first_pos = (i_t - bos == 0) + else: + is_first_pos = (i_t == 0) + + o_d = tl.arange(0, BD) + m_d = o_d < D + + if IS_VARLEN: + base_offset = i_t * D + o_d + else: + base_offset = i_b * T*D + i_t * D + o_d + + b_x = tl.load(x + base_offset, mask=m_d) + + if is_first_pos: + # First position in sequence: delta = -hidden_states + tl.store(y + base_offset, -b_x, mask=m_d) + else: + # Other positions: delta = prev - curr + if IS_VARLEN: + prev_offset = (i_t - 1) * D + o_d + else: + prev_offset = i_b * T*D + (i_t-1) * D + o_d + + prev_values = tl.load(x + prev_offset, mask=m_d) + delta = prev_values - b_x + tl.store(y + base_offset, delta, mask=m_d) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [1, 2, 3, 4] + ], + key=['D'], +) +@triton.jit +def token_shift_bwd_kernel( + dx, + dy, + cu_seqlens, + T, + D: tl.constexpr, + BD: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_b, i_t = tl.program_id(0), tl.program_id(1) + if IS_VARLEN: + i_n = i_b + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + + if i_t < bos or i_t >= eos: + return + + local_pos = i_t - bos + is_last_pos = (local_pos == eos - bos - 1) + else: + is_last_pos = (i_t == T - 1) + + o_d = tl.arange(0, BD) + m_d = o_d < D + + if IS_VARLEN: + base_offset = i_t * D + o_d + else: + base_offset = i_b * T*D + i_t * D + o_d + + b_dy = tl.load(dy + base_offset, mask=m_d) + + if is_last_pos: + # Last position: b_dx = -grad_delta[t] + b_dx = -b_dy + else: + # Other positions: b_dx = -grad_delta[t] + grad_delta[t+1] + if IS_VARLEN: + next_offset = (i_t+1) * D + o_d + else: + next_offset = i_b * T*D + (i_t+1) * D + o_d + + b_dx = -b_dy + tl.load(dy + next_offset, mask=m_d) + + tl.store(dx + base_offset, b_dx, mask=m_d) + + +def token_shift_fwd( + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None +) -> torch.Tensor: + B, T, D = x.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BD = triton.next_power_of_2(D) + + y = torch.empty_like(x) + + grid = (N, T) + token_shift_fwd_kernel[grid]( + x=x, + y=y, + cu_seqlens=cu_seqlens, + T=T, + D=D, + BD=BD, + ) + + return y + + +def token_shift_bwd( + dy: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None +) -> torch.Tensor: + B, T, D = dy.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BD = triton.next_power_of_2(D) + + dx = torch.empty_like(dy) + + grid = (N, T) + token_shift_bwd_kernel[grid]( + dy=dy, + dx=dx, + cu_seqlens=cu_seqlens, + T=T, + D=D, + BD=BD, + ) + return dx + + +class TokenShift(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None): + ctx.cu_seqlens = cu_seqlens + return token_shift_fwd(x, cu_seqlens) + + @staticmethod + @input_guard + def backward(ctx, dy: torch.Tensor): + cu_seqlens = ctx.cu_seqlens + dx = token_shift_bwd(dy, cu_seqlens) + return dx, None + + +def token_shift( + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None +): + """ + Implementation of token shift using Triton kernels + Args: + x: Input tensor of shape [B, T, D] + cu_seqlens: Cumulative sequence lengths (optional) + Returns: + Tensor of same shape as input with token shift applied + """ + if cu_seqlens is not None: + assert x.dim() == 3, "Input must be [B, T, D]" + assert x.shape[0] == 1, "Batch size must be 1 when using cu_seqlens" + + return TokenShift.apply(x, cu_seqlens) diff --git a/fla3/ops/__init__.py b/fla3/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53c1e84c1881d01e529665ca50b2f233a62802df --- /dev/null +++ b/fla3/ops/__init__.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +from .abc import chunk_abc +from .attn import parallel_attn +from .based import fused_chunk_based, parallel_based +from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule +from .forgetting_attn import parallel_forgetting_attn +from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from .generalized_delta_rule import ( + chunk_dplr_delta_rule, + chunk_iplr_delta_rule, + fused_recurrent_dplr_delta_rule, + fused_recurrent_iplr_delta_rule +) +from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla +from .gsa import chunk_gsa, fused_recurrent_gsa +from .hgrn import fused_recurrent_hgrn +from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn +from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn +from .nsa import parallel_nsa +from .path_attn import parallel_path_attention +from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention +from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 +from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 +from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla + +__all__ = [ + 'chunk_abc', + 'parallel_attn', + 'fused_chunk_based', 'parallel_based', + 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule', + 'parallel_forgetting_attn', + 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule', + 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule', + 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla', + 'chunk_gsa', 'fused_recurrent_gsa', + 'fused_recurrent_hgrn', + 'chunk_lightning_attn', 'fused_recurrent_lightning_attn', + 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn', + 'parallel_nsa', + 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention', + 'chunk_rwkv6', 'fused_recurrent_rwkv6', + 'chunk_rwkv7', 'fused_recurrent_rwkv7', + 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', + 'parallel_path_attention', +] diff --git a/fla3/ops/__pycache__/__init__.cpython-310.pyc b/fla3/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf5e168c1a758f002bda687b0843e006160dc6e7 Binary files /dev/null and b/fla3/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/__pycache__/__init__.cpython-312.pyc b/fla3/ops/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2608095b08239dd603fad086ac5377d5c2854dec Binary files /dev/null and b/fla3/ops/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/abc/__init__.py b/fla3/ops/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdac8d900fc51485a55716443ee1f00424b522b9 --- /dev/null +++ b/fla3/ops/abc/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_abc + +__all__ = [ + 'chunk_abc' +] diff --git a/fla3/ops/abc/__pycache__/__init__.cpython-310.pyc b/fla3/ops/abc/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a7d841919b1783880483e97e92dbba6db001a3f Binary files /dev/null and b/fla3/ops/abc/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/abc/__pycache__/__init__.cpython-312.pyc b/fla3/ops/abc/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c392bb213d41ba86e06f348c22561d69dae858b1 Binary files /dev/null and b/fla3/ops/abc/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/abc/__pycache__/chunk.cpython-310.pyc b/fla3/ops/abc/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7077c839ffa4a6c1c68da3b5efe596f6005b4538 Binary files /dev/null and b/fla3/ops/abc/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/abc/__pycache__/chunk.cpython-312.pyc b/fla3/ops/abc/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba649da7f809cbf81684218fa416ce35439fd9d Binary files /dev/null and b/fla3/ops/abc/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/abc/chunk.py b/fla3/ops/abc/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bc45d1e7b3cc406e5d7503d1957c26a1dbd3dc --- /dev/null +++ b/fla3/ops/abc/chunk.py @@ -0,0 +1,1117 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import softmax_bwd, softmax_fwd +from fla.ops.utils.logcumsumexp import logcumsumexp_fwd_kernel +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_h( + k, + v, + z, + h, + h0, + ht, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if NORMK: + p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,)) + else: + p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_z0).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if NORMK: + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[:, None] + b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype) + else: + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[None, :] + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_K( + v, + z, + o, + A, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_o *= exp(b_zn[None, :] - b_z) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_K( + q, + k, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BV] + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_o = b_o * exp(b_zp[None, :] - b_z) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_V( + q, + k, + z, + A, + scale, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_V( + q, + v, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BK] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_q, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_dh( + q, + z, + do, + dh, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + i_p = tl.maximum(i_t * BT - 1, 0) + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if NORMK: + p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BK, BT] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype) + # [BK, BV] + b_dh = b_dh * b_r[:, None] + else: + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype) + # [BK, BV] + b_dh = b_dh * b_r[None, :] + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_V( + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BT, BK] + b_dq = b_dq * b_z + b_dk = b_dk * b_k + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_V( + q, + k, + z, + dA, + dq, + dk, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_zq = exp(b_zn[None, :] - b_z) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kz, allow_tf32=False) + b_dq *= b_zq + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False) + b_dk *= b_kz + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_K( + v, + z, + do, + dA, + scale, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_v, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_K( + q, + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * b_z * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_KV( + v, + z, + A, + do, + dv, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= exp(b_v - b_zn[None, :]) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_inter( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + + b_sp = tl.zeros([BS,], dtype=tl.float32) + b_zp = tl.full([BS,], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + # [BS,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + + b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :] + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + # [BS,] + b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0) + b_zp = b_zc + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_intra( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + NC: tl.constexpr +): + i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_i = tl.arange(0, BC) + m_o = tl.full([BC, BC], 1., dtype=tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BS,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + + b_doo = tl.zeros([BC, BS], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + # [BC, BS] + b_doo += b_ss * exp(b_zn[None, :] - b_z) + b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False) + + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + # [BS,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_ss = tl.load(p_ss, boundary_check=(0,)) + # [BC, BS] + m_i = o_i[:, None] <= j + b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.) + b_doo += tl.load(p_doo, boundary_check=(0, 1)) + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkABCFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, s, initial_state, output_final_state): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_pre(s, B, H, T, S): + # keep cummulative normalizer in fp32 + z = torch.empty_like(s, dtype=torch.float) + grid = (B * H,) + logcumsumexp_fwd_kernel[grid]( + s, z, + T=T, S=S + ) + return z + + def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_abc_fwd_kernel_h[grid]( + k, v, z, h, h0, ht, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + z = fwd_pre(s, B, H, T, M) + scale = K ** -0.5 + hk = fwd_inner( + q=q, k=k, v=s, z=z, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + normk=False, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None + ) + ok1 = torch.empty_like(s) + Ak = q.new_empty(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_fwd_kernel_K[grid]( + q, k, z, hk, ok1, Ak, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ok0 = torch.empty_like(s) + grid = (NM, NT * NC, B * H) + chunk_abc_fwd_kernel_intra_K[grid]( + s, z, ok0, Ak, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ok = ok0.add_(ok1) + + scale = 1. + # p is kept in fp32 for safe softmax backward + p = softmax_fwd(ok, dtype=torch.float) + qv = p.to(q.dtype) + + scale = 1. + hv = fwd_inner( + q=qv, k=s, v=v, z=z, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + normk=True, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None + ) + Av = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_fwd_kernel_intra_V[grid]( + qv, s, z, Av, + scale=scale, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + Av = Av.sum(0) + ov = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_abc_fwd_kernel_V[grid]( + qv, v, z, hv, ov, Av, + scale=scale, + T=T, + K=M, + V=V, + BT=BT, + BK=BM, + BV=BV, + NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av) + ctx.BT = BT + return ov, final_state + + @staticmethod + @input_guard + def backward(ctx, dov, dht=None): + q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_abc_bwd_kernel_dh[grid]( + q, z, do, dh, + scale=scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS): + doo = torch.empty_like(s) + grid = (NS, B * H) + chunk_abc_bwd_kernel_rcum_inter[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BS=BS, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NS, NT * NC, B * H) + chunk_abc_bwd_kernel_rcum_intra[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + return doo + + scale = 1. + qv = p.to(q.dtype) + dhv = bwd_inner( + qv, z, dov, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + scale=scale, + normk=True + ) + dp1 = torch.empty_like(p) + dsv1 = torch.empty_like(s, dtype=torch.float) + dv = v.new_empty(NM, *v.shape) + dAv = q.new_zeros(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_bwd_kernel_V[grid]( + s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv, + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dp0 = torch.empty_like(p) + dsv0 = s.new_zeros(s.shape, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_V[grid]( + qv, s, z, dAv, dp0, dsv0, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dp = dp1.add_(dp0) + dsv = dsv1.add_(dsv0) + + # softmax gradient, equivalent to: + # dok = p * (dp - (p * dp).sum(-1, True)) + dok = softmax_bwd(p, dp, dtype=ok.dtype) + + scale = K ** -0.5 + dhk = bwd_inner( + q, z, dok, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + scale=scale, + normk=False + ) + dAk = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_bwd_kernel_intra_K[grid]( + s, z, dok, dAk, + scale=scale, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dAk = dAk.sum(0) + + Ak = q.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float) + grid = (NK, NT, B * H) + chunk_abc_bwd_kernel_K[grid]( + q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + Ak = Ak.sum(0) + dsk1 = dsk1.sum(0) + dsk0 = torch.empty_like(s, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_KV[grid]( + s, z, Ak, dok, dsk0, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ds = dsv.add_(dsk1.add_(dsk0)) + ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM) + ds = ds.to(s.dtype) + return dq, dk, dv, ds, None, None + + +@torch.compiler.disable +def chunk_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: bool = False, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]` + s (torch.Tensor): + slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]` + initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]): + Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`. + """ + if not head_first: + q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s)) + o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla3/ops/abc/naive.py b/fla3/ops/abc/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f25c40db73bcf33d1599761be0008cc5be7c59 --- /dev/null +++ b/fla3/ops/abc/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import repeat + + +def naive_recurrent_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + + NG = q.shape[1]//k.shape[1] + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) + k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) + if initial_state is not None: + initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) + + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) + return ov.to(dtype), final_state + + +def naive_cumsum_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor +) -> torch.Tensor: + """ + A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. + This is just for demonstration purposes, with no numerical stabilities guaranteed. + """ + + dtype = q.dtype + q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) + + scale = q.shape[-1] ** -0.5 + # [batch_size, n_heads, seq_len, n_slots] + s = (s - s.max(2, True)[0]).exp() + z = s.cumsum(2) + # [batch_size, n_heads, seq_len, n_slots, d_head] + K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + # [batch_size, n_heads, seq_len, n_slots] + p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) + # [batch_size, n_heads, seq_len, d_head] + o = torch.einsum('...m,...md->...d', p, V) + return o.to(dtype), None diff --git a/fla3/ops/attn/__init__.py b/fla3/ops/attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fbccbb5843dd331359e6c4d1c51c2f6b72701c9e --- /dev/null +++ b/fla3/ops/attn/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_attn + +__all__ = [ + 'parallel_attn' +] diff --git a/fla3/ops/attn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fd9bbd20438563b149810c7db8c972b49c36444 Binary files /dev/null and b/fla3/ops/attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/attn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c157ce9d6a5bc05ddd0033228a8e29d5187ebde0 Binary files /dev/null and b/fla3/ops/attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/attn/__pycache__/decoding.cpython-310.pyc b/fla3/ops/attn/__pycache__/decoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f6b448a566814e04f239e08a2cd0fe67e62caa7 Binary files /dev/null and b/fla3/ops/attn/__pycache__/decoding.cpython-310.pyc differ diff --git a/fla3/ops/attn/__pycache__/parallel.cpython-310.pyc b/fla3/ops/attn/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60d0bd7e935e4ec463dda020457b7783d578c13e Binary files /dev/null and b/fla3/ops/attn/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/attn/__pycache__/parallel.cpython-312.pyc b/fla3/ops/attn/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5926794f5e45e17c21ec0c484fd2bdd9ea980fd3 Binary files /dev/null and b/fla3/ops/attn/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla3/ops/attn/decoding.py b/fla3/ops/attn/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..ce30df3da6e3e44ac6a3985e918e62aff4192225 --- /dev/null +++ b/fla3/ops/attn/decoding.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.cumsum import chunk_global_cumsum +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem + + +@triton.heuristics({ + 'USE_G': lambda args: args['g_cumsum'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([] if check_shared_mem('hopper') else [8]) + for num_stages in [2, 3, 4, 5] + ], + key=['H', 'G', 'K', 'V', 'BK', 'BV', 'USE_G'], +) +@triton.jit +def naive_attn_decoding_kernel( + q, + k, + v, + o, + g_cumsum, + scale, + gate_scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr +): + i_v, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + bos, eos = tl.load(cu_seqlens + i_b).to(tl.int32), tl.load(cu_seqlens + i_b + 1).to(tl.int32) + T = eos - bos + + p_q = tl.make_block_ptr(q + i_bh * K, (K,), (1, ), (0, ), (BK,), (0,)) + p_o = tl.make_block_ptr(o + i_bh * V, (V,), (1, ), (0, ), (BV,), (0,)) + + b_q = tl.load(p_q, boundary_check=(0,)) + b_q = (b_q * scale).to(b_q.dtype) + + b_o = tl.zeros([BV, ], dtype=tl.float32) + + b_m = tl.full([1,], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([1,], dtype=tl.float32) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (T-1,), (1,), (0,)) + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + else: + b_gq = None + + for i_s in range(0, T, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.sum(b_q[None, :] * b_k, 1) + + mask = i_s + tl.arange(0, BS) < T + b_s = tl.where(mask, b_s, float('-inf')) + + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_s += (b_gq - b_gk) * gate_scale + # [BT, BS] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m) + + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 0) + # [BT, BV] + b_o = b_o * b_r + tl.sum(b_p[:, None] * b_v, 0) + b_mp = b_m + b_o = b_o / b_acc + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, )) + + +def attn_decoding_one_step( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: torch.LongTensor = None, + do_gate_scale: bool = False, +): + r""" + Args: + q (torch.Tensor): + query of shape `[1, B, HQ, K]`. + k (torch.Tensor): + keys of shape `[1, T, H, K]`. + GQA will be applied if HQ is divisible by H. T is the cumulative length for all batch. + v (torch.Tensor): + values of shape `[1, T, H, V]`. + g (Optional[torch.Tensor]): + log decay factors of shape `[1, T, H]`. Default: `None`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + do_gate_scale (bool): + Whether to apply gate scale. Default: `False`. If `True`, the attention scale will also be applied + to the gating bias term in Forgetting Transformer or PaTH-FoX. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, 1, HQ, V]`. + """ + assert cu_seqlens is not None, "The cu_seqlens must be provided for varlen decoding" + B, T, H, K, V = *k.shape, v.shape[-1] + N = len(cu_seqlens) - 1 + HQ = q.shape[2] + G = HQ // H + if scale is None: + scale = K ** -0.5 + + BK = triton.next_power_of_2(K) + if check_shared_mem('hopper', q.device.index): + BS = min(64, max(16, triton.next_power_of_2(T))) + BV = min(256, max(16, triton.next_power_of_2(V))) + elif check_shared_mem('ampere', q.device.index): + BS = min(32, max(16, triton.next_power_of_2(T))) + BV = min(128, max(16, triton.next_power_of_2(V))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BV = min(64, max(16, triton.next_power_of_2(V))) + g_cumsum = chunk_global_cumsum(g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) if g is not None else None + NV = triton.cdiv(V, BV) + o = torch.empty(*q.shape[:-1], V, dtype=v.dtype, device=q.device) + gate_scale = 1.0 if not do_gate_scale else scale + + grid = (NV, N * HQ) + naive_attn_decoding_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + g_cumsum=g_cumsum, + scale=scale, + gate_scale=gate_scale, + cu_seqlens=cu_seqlens, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BS=BS, + BK=BK, + BV=BV, + ) + return o diff --git a/fla3/ops/attn/parallel.py b/fla3/ops/attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c367f8dd0da6b563a784e644ee0f477eba4546d4 --- /dev/null +++ b/fla3/ops/attn/parallel.py @@ -0,0 +1,738 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.cumsum import chunk_global_cumsum +from fla.ops.utils.op import exp, log, safe_exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + + +@triton.heuristics({ + 'USE_G': lambda args: args['g_cumsum'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], +) +@triton.jit +def parallel_attn_fwd_kernel( + q, + k, + v, + o, + g_cumsum, + lse, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + else: + b_gq = None + + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[:, None] - b_gk[None, :] + + # [BT, BS] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = safe_exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[:, None] - b_gk[None, :] + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = safe_exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + b_mp = b_m + + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g_cumsum'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_attn_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + dg_cumsum, + g_cumsum, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + if USE_G: + b_dg = tl.zeros([BT, ], dtype=tl.float32) + p_gq = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + else: + b_gq = None + b_dg = None + + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[:, None] - b_gk[None, :] + + b_p = safe_exp(b_s - b_lse[:, None]) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + if USE_G: + b_dg += tl.sum(b_ds, 1) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[:, None] - b_gk[None, :] + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, -float('inf')) + + b_p = safe_exp(b_s - b_lse[:, None]) # SY: important to use safe_exp here to avoid NaN. + b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + if USE_G: + b_dg += tl.sum(b_ds, 1) + + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_dg = tl.make_block_ptr(dg_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g_cumsum'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'HQ', 'G', 'K', 'V', 'BK', 'BV', 'USE_G', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_attn_bwd_kernel_dkv( + q, + k, + v, + g_cumsum, + lse, + delta, + do, + dk, + dv, + dg_cumsum, + cu_seqlens, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + + if USE_G: + p_gk = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_dg = tl.zeros([BT,], dtype=tl.float32) + else: + b_gk = None + b_dg = None + + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[None, :] - b_gk[:, None] + b_s = tl.where(o_k[:, None] <= o_q[None, :], b_s, -float('inf')) + b_p = safe_exp(b_s - b_lse[None, :]) + b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + if USE_G: + b_dg -= tl.sum(b_ds, 1) + + for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + b_s += b_gq[None, :] - b_gk[:, None] + b_p = safe_exp(b_s - b_lse[None, :]) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + if USE_G: + b_dg -= tl.sum(b_ds, 1) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_dg = tl.make_block_ptr(dg_cumsum + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +def parallel_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cumsum: torch.Tensor, + scale: float, + chunk_size: int = 128, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + if check_shared_mem('hopper', q.device.index): + BS = min(64, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(256, max(16, triton.next_power_of_2(V))) + elif check_shared_mem('ampere', q.device.index): + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(128, max(16, triton.next_power_of_2(V))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(64, max(16, triton.next_power_of_2(V))) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert NK == 1, "The key dimension can not be larger than 256" + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + grid = (NV, NT, B * HQ) + parallel_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + g_cumsum=g_cumsum, + lse=lse, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float) + parallel_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + g_cumsum: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = max(16, triton.next_power_of_2(T)) + BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS) # SY:H100 should at least use BS=64 to use WGMMA + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NV = triton.cdiv(V, BV) + + delta = parallel_attn_bwd_preprocess(o, do) + + dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device) + grid = (NV, NT, B * HQ) + + dg_cumsum, dg_cumsum_k = None, None + if g_cumsum is not None: + dg_cumsum = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + dg_cumsum_k = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + g_cumsum=g_cumsum, + lse=lse, + delta=delta, + do=do, + dq=dq, + dg_cumsum=dg_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + g_cumsum=g_cumsum, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + dg_cumsum=dg_cumsum_k, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + if g_cumsum is not None: + dg_cumsum.add_(dg_cumsum_k) + return dq, dk, dv, dg_cumsum + + +@torch.compile +class ParallelAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, cu_seqlens): + ctx.dtype = q.dtype + + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + + g_cumsum = chunk_global_cumsum(g, cu_seqlens=cu_seqlens) if g is not None else None + o, lse = parallel_attn_fwd( + q=q, + k=k, + v=v, + g_cumsum=g_cumsum, + scale=scale, + chunk_size=chunk_size, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, o, g_cumsum, lse) + ctx.chunk_size = chunk_size + ctx.cu_seqlens = cu_seqlens + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, g_cumsum, lse = ctx.saved_tensors + dq, dk, dv, dg = parallel_attn_bwd( + q=q, + k=k, + v=v, + o=o, + g_cumsum=g_cumsum, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + cu_seqlens=ctx.cu_seqlens, + ) + if dg is not None: + dg = chunk_global_cumsum(dg, cu_seqlens=ctx.cu_seqlens, reverse=True) + + return dq.to(q), dk.to(k), dv.to(v), dg, None, None + + +def parallel_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (Optional[torch.Tensor]): + log decay factors of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v)) + if g is not None: + g = rearrange(g, 'b h t ... -> b t h ...') + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + + o = ParallelAttentionFunction.apply(q, k, v, g, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o diff --git a/fla3/ops/based/__init__.py b/fla3/ops/based/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f20b31ba0ea4c7d345761fbd6ab5f6ced5136236 --- /dev/null +++ b/fla3/ops/based/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .fused_chunk import fused_chunk_based +from .parallel import parallel_based + +__all__ = [ + 'fused_chunk_based', + 'parallel_based' +] diff --git a/fla3/ops/based/__pycache__/__init__.cpython-310.pyc b/fla3/ops/based/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8995673274ff79ac9fec8bb404448245255d98c Binary files /dev/null and b/fla3/ops/based/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/based/__pycache__/__init__.cpython-312.pyc b/fla3/ops/based/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93719d76f21c1924cdf853014f20e812a997152a Binary files /dev/null and b/fla3/ops/based/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/based/__pycache__/fused_chunk.cpython-310.pyc b/fla3/ops/based/__pycache__/fused_chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9967e69f73d12596fe5f0143be39b4358d4f6ecf Binary files /dev/null and b/fla3/ops/based/__pycache__/fused_chunk.cpython-310.pyc differ diff --git a/fla3/ops/based/__pycache__/fused_chunk.cpython-312.pyc b/fla3/ops/based/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e3f2953134dbf9364df5dd1d3ef32786bfc315b Binary files /dev/null and b/fla3/ops/based/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla3/ops/based/__pycache__/parallel.cpython-310.pyc b/fla3/ops/based/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5b6640669afb76e83afb5c0bbccf8b1ee9b6131 Binary files /dev/null and b/fla3/ops/based/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/based/__pycache__/parallel.cpython-312.pyc b/fla3/ops/based/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca2e5dff55ee2792c776ccd92759c3f99bc1629 Binary files /dev/null and b/fla3/ops/based/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla3/ops/based/fused_chunk.py b/fla3/ops/based/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..1bba63a72d9468e03d22708fbd1114f2c55a1283 --- /dev/null +++ b/fla3/ops/based/fused_chunk.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale=1): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + z = q.new_empty(NK, B, H, T, dtype=torch.float32) + + grid = (NV, NK, B * H) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +def fused_chunk_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = False +): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = FusedChunkBasedFunction.apply(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla3/ops/based/naive.py b/fla3/ops/based/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4de614137ed28567ebb1df39c0892f498b91fb5a --- /dev/null +++ b/fla3/ops/based/naive.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import rearrange + + +def naive_parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + if scale is None: + scale = q.shape[-1] ** -0.5 + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) diff --git a/fla3/ops/based/parallel.py b/fla3/ops/based/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..573c4844df215fbd33d1cf42c2b1350e4475fe69 --- /dev/null +++ b/fla3/ops/based/parallel.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dq, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % NV + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dq, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dk, dv, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = False +): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = triton_parallel_based(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla3/ops/common/__init__.py b/fla3/ops/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/fla3/ops/common/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/fla3/ops/common/__pycache__/__init__.cpython-310.pyc b/fla3/ops/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6698cf852cdcc776bdc32efafbca6a22d9e12d06 Binary files /dev/null and b/fla3/ops/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/common/__pycache__/__init__.cpython-312.pyc b/fla3/ops/common/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543dd159db98ba2729b8264d029a92dc7eaefdbd Binary files /dev/null and b/fla3/ops/common/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_delta_h.cpython-310.pyc b/fla3/ops/common/__pycache__/chunk_delta_h.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3e576c560bffd7c53d2c41dcebb9ee077db0bd2 Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_delta_h.cpython-310.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc b/fla3/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97557fb81f6131cf896dc4613f224ba2207110ae Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_h.cpython-310.pyc b/fla3/ops/common/__pycache__/chunk_h.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64b934ff64f6970907bb0724e19c88c1477aa126 Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_h.cpython-310.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_o.cpython-310.pyc b/fla3/ops/common/__pycache__/chunk_o.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6517ac029b6a3fa48dad980c9018421380dedec Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_o.cpython-310.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_o.cpython-312.pyc b/fla3/ops/common/__pycache__/chunk_o.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..778a8cf0c5488b15790a3535d3b748a1e110d883 Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_o.cpython-312.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-310.pyc b/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..146ac2660b87a2e823ad0df572ad666d2c237c81 Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-310.pyc differ diff --git a/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc b/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c39e9e38fe356b7e7d71660d2c1bc5ef8624a5ca Binary files /dev/null and b/fla3/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc differ diff --git a/fla3/ops/common/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/common/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd7f023372232a6f7d82db871e8e30f7ca18db29 Binary files /dev/null and b/fla3/ops/common/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/common/chunk_delta_h.py b/fla3/ops/common/chunk_delta_h.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe2021fd59f2c14b9f97e72c66de380fcf247b4 --- /dev/null +++ b/fla3/ops/common/chunk_delta_h.py @@ -0,0 +1,601 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ...ops.utils.op import exp +from ...utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [16, 32, 64] + ], + key=['H', 'K', 'V', 'BT', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + d, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K*V + v += (bos * H + i_h) * V + k += (bos * H + i_h) * K + d += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H*V + stride_h = H*K*V + stride_k = H*K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K*V + if STORE_FINAL_STATE: + ht = ht + i_nh * K*V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) if SAVE_NEW_VALUE else None + b_intermediate = tl.zeros([BT, BV], dtype=tl.float32) + p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_intermediate += tl.dot(b_d, b_h1.to(b_d.dtype)) + if K > 64: + p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_intermediate += tl.dot(b_d, b_h2.to(b_d.dtype)) + if K > 128: + p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_intermediate += tl.dot(b_d, b_h3.to(b_d.dtype)) + if K > 192: + p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_intermediate += tl.dot(b_d, b_h4.to(b_d.dtype)) + b_intermediate = -b_intermediate + tl.load(p_v, boundary_check=(0, 1)) + b_intermediate = b_intermediate.to(k.dtype.element_ty) + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, b_intermediate, boundary_check=(0, 1)) + + if USE_G: + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_intermediate) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_intermediate) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_intermediate) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_intermediate) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] # SY: do not change this line + for num_stages in [4, 3, 2, 1] + for BV in [64, 32, 16] + ], + key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + d, + g, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + dh += (boh * H + i_h) * K*V + dv += (bos * H + i_h) * V + dv2 += (bos * H + i_h) * V + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + d += (bos * H + i_h) * K + do += (bos * H + i_h) * V + stride_v = H*V + stride_h = H*K*V + stride_k = H*K + if USE_INITIAL_STATE: + dh0 += i_nh * K*V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K*V + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh2 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh3 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh4 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + # b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + if USE_G: + last_idx = min((i_t + 1) * BT, T) - 1 + bg_last = tl.load(g + (bos + last_idx) * H + i_h) + bg_last = exp(bg_last) + else: + bg_last = None + last_idx = None + + p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # Update dh + p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + if USE_G: + b_dh1 *= bg_last + b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last + b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last + b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last + b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_Q': lambda args: args['q'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + for BK in [16, 32, 64, 128] + ], + key=['H', 'K', 'BT', 'BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def preprocess_qkw( + q, + k, + w, + g, + q_new, + k_new, + w_new, + cu_seqlens, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + USE_Q: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_nh, i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + # calculateoffset + k += (bos * H + i_h) * K + w += (bos * H + i_h) * K + k_new += (bos * H + i_h) * K + w_new += (bos * H + i_h) * K + if USE_Q: + q += (bos * H + i_h) * K + q_new += (bos * H + i_h) * K + g += bos * H + i_h + stride_k = H * K + stride_g = H + + # Get gate values + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + last_idx * stride_g).to(tl.float32) + + p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k_new = tl.make_block_ptr(k_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w_new = tl.make_block_ptr(w_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_d_last = exp(b_g_last - b_g) + b_d_begin = exp(b_g) + + b_k = b_k * b_d_last[:, None] + b_w = b_w * b_d_begin[:, None] + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_Q: + p_q = tl.make_block_ptr(q, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q_new = tl.make_block_ptr(q_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_q = b_q * b_d_begin[:, None] + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + if g is not None: + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) + preprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + cu_seqlens=cu_seqlens, + T=T, + H=H, + K=K, + BT=BT, + ) + + v_new = torch.empty_like(u) if save_new_value else None + def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)#仅允许BV并行 + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k if g is None else k_new, + v=u, + d=w if g is None else w_new, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + g: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + if g is not None: + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) + preprocess_qkw[grid]( + q=q, + k=k, + w=w, + g=g, + q_new=q_new, + k_new=k_new, + w_new=w_new, + cu_seqlens=cu_seqlens, + T=T, + H=H, + K=K, + BT=BT, + ) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( + q=q if g is None else q_new, + k=k if g is None else k_new, + d=w if g is None else w_new, + g=g, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dh, dh0, dv2 diff --git a/fla3/ops/common/chunk_h.py b/fla3/ops/common/chunk_h.py new file mode 100644 index 0000000000000000000000000000000000000000..e2307ee2732fecb0c9c4b90abf9589c4c61441e3 --- /dev/null +++ b/fla3/ops/common/chunk_h.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_offsets +from ...ops.utils.op import exp +from ...utils import check_shared_mem + +BKV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + cu_seqlens, + split_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + i_s = i_t // (BS // BT) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 + + # scalar decay + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_h *= exp(b_g_last) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + cu_seqlens, + split_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + i_s = i_t // (BS // BT) + o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + b_dh *= exp(b_g_last) + + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.Tensor] = None, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NS, split_offsets = B, triton.cdiv(T, BS), None + else: + split_offsets = prepare_chunk_offsets(cu_seqlens, BS) + N, NS = len(cu_seqlens) - 1, split_offsets[-1].item() + + h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + cu_seqlens=cu_seqlens, + split_offsets=split_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.Tensor] = None, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if cu_seqlens is None: + N, NS, split_offsets = B, triton.cdiv(T, BS), None + else: + split_offsets = prepare_chunk_offsets(cu_seqlens, BS) + N, NS = len(cu_seqlens) - 1, split_offsets[-1].item() + NG = HQ // H + + dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_bwd_kernel_dh[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + split_offsets=split_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + return dh, dh0 diff --git a/fla3/ops/common/chunk_h_parallel.py b/fla3/ops/common/chunk_h_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..86cf943a34affd197a4d36f0ea57f9b525071736 --- /dev/null +++ b/fla3/ops/common/chunk_h_parallel.py @@ -0,0 +1,551 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +Fully parallelized state passing. +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ...ops.utils.op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_parallel( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + # i_b: batch index + # i_h: head index + # i_n: sequence index + # i_t: chunk index within current sequence + # i_tg: (global) chunk index across all sequences + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * H + i_h + + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == 0: + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + else: + b_h = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h = tl.dot(b_k, b_v) + if i_t < NT - 1: + p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_reduction( + h, + g, + gk, + gv, + kvt, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT): + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if i_t > 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_h *= exp(b_g_last) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + if STORE_FINAL_STATE: + p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_parallel( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + cu_seqlens, + chunk_indices, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // NG + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * HQ + i_hq + + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == NT - 1: + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + else: + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_dh = tl.dot(b_q, b_do) + if i_t > 0: + p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_reduction( + g, + gk, + gv, + dh, + doq0, + dh0, + cu_seqlens, + chunk_offsets, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32) + if i_t < NT - 1: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_dh *= exp(b_g_last) + + if USE_GK: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + if STORE_INITIAL_STATE_GRADIENT: + p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + states_in_fp32: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + h = k.new_empty(B, NT, H, K, V, dtype=torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_h_parallel[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h_reduction[grid]( + h=h, + g=g, + gk=gk, + gv=gv, + kvt=kvt, + ht=ht, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + h = h.to(k.dtype) if not states_in_fp32 else h + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + states_in_fp32: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + NG = HQ // H + + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ) + chunk_bwd_kernel_dh_parallel[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + + doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) + chunk_bwd_kernel_dh_reduction[grid]( + g=g, + gk=gk, + gv=gv, + dh=dh, + doq0=doq0, + dh0=dh0, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None + ) + dh = dh.to(q.dtype) if not states_in_fp32 else dh + return dh, dh0 diff --git a/fla3/ops/common/chunk_h_split.py b/fla3/ops/common/chunk_h_split.py new file mode 100644 index 0000000000000000000000000000000000000000..cda3e45784d9c6428128a0d1a6850d58f6e3361c --- /dev/null +++ b/fla3/ops/common/chunk_h_split.py @@ -0,0 +1,596 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils.op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_split( + k, + v, + g, + gk, + gv, + hs, + hr, + h0, + ht, + cu_seqlens, + split_indices, + T, + S: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # handle one split at a time + # i_h: head index + # i_n: sequence index + # i_s: local split index inside a sequence + i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_ss, i_h = i_sh // H, i_sh % H + if IS_VARLEN: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + i_nh = i_n * H + i_h + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # for the first split, we directly store the state as the final result + if i_s == 0: + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1)) + for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + + # scalar decay + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_h *= exp(b_g_last) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v) + + # if there are more than one splits, we store the result to (unreduced) hs + # otherwise, we store the result to ht as the final state + if NS > 1: + p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_reduction( + g, + gk, + gv, + hs, + hr, + ht, + cu_seqlens, + split_offsets, + T, + S: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NS = tl.cdiv(T, S) + boh = i_n * NS + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # skip the first split + for i_s in range(1, NS): + p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)): + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_h *= exp(b_g_last) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + # vector decay, h = h @ Diag(gv) + if USE_GV: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + if NS > 1: + if STORE_FINAL_STATE: + p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_split( + q, + g, + gk, + gv, + do, + dht, + dhs, + dhr, + dh0, + cu_seqlens, + split_indices, + scale, + T, + S: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # handle one split at a time + # i_h: head index + # i_n: sequence index + # i_s: local split index inside a sequence + i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_ss, i_hq = i_sh // HQ, i_sh % HQ + if IS_VARLEN: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + i_nh = i_n * HQ + i_hq + i_h = i_hq // NG + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if i_s == NS - 1: + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1): + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + if USE_G: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + b_dh *= exp(b_g_last) + + if USE_GK: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do) + + if NS > 1: + p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_reduction( + g, + gk, + gv, + dhs, + dhr, + dh0, + cu_seqlens, + split_offsets, + T, + S: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NS = tl.cdiv(T, S) + boh = i_n * NS + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_s in range(NS - 2, -1, -1): + p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1): + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_dh *= exp(b_g_last) + + if USE_GK: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + if NS > 1: + if STORE_INITIAL_STATE_GRADIENT: + p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + split_offsets: Optional[torch.LongTensor] = None, + split_indices: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + split_size: int = 256, + states_in_fp32: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + # B: batch size + # N: the actual number of sequences in the batch + # H: number of heads + # T: sequence length, can be variable across sequences + # S: split size, a multiple of chunk size + # BT: chunk size + S, BT = split_size, chunk_size + assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}" + if cu_seqlens is None: + N = B + NS = N * triton.cdiv(T, S) + else: + N = len(cu_seqlens) - 1 + NS = split_offsets[-1] + + # unreduced kv states per split + hs = k.new_empty(NS, H, K, V, dtype=torch.float) + # reduced states per split + hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + # parallelized over splits + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H) + chunk_fwd_kernel_h_split[grid]( + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + hs=hs, + hr=hr, + h0=h0, + ht=ht, + cu_seqlens=cu_seqlens, + split_indices=split_indices, + T=T, + S=S, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h_reduction[grid]( + g=g, + gk=gk, + gv=gv, + hs=hs, + hr=hr, + ht=ht, + cu_seqlens=cu_seqlens, + split_offsets=split_offsets, + T=T, + S=S, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + return hr, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.Tensor] = None, + split_offsets: Optional[torch.Tensor] = None, + split_indices: Optional[torch.Tensor] = None, + chunk_size: int = 64, + split_size: int = 256, + states_in_fp32: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + # B: batch size + # N: the actual number of sequences in the batch + # H: number of heads + # T: sequence length, can be variable across sequences + # S: split size, a multiple of chunk size + # BT: chunk size + S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size + assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}" + if cu_seqlens is None: + N = B + NS = N * triton.cdiv(T, S) + else: + N = len(cu_seqlens) - 1 + NS = split_offsets[-1] + # number of groups in GQA + NG = HQ // H + + dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float) + dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + # parallelized over splits + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ) + chunk_bwd_kernel_dh_split[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dht=dht, + dhs=dhs, + dhr=dhr, + dh0=dh0, + cu_seqlens=cu_seqlens, + split_indices=split_indices, + scale=scale, + T=T, + S=S, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) + chunk_bwd_kernel_dh_reduction[grid]( + g=g, + gk=gk, + gv=gv, + dhs=dhs, + dhr=dhr, + dh0=dh0, + cu_seqlens=cu_seqlens, + split_offsets=split_offsets, + T=T, + S=S, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + ) + return dhr, dh0 diff --git a/fla3/ops/common/chunk_o.py b/fla3/ops/common/chunk_o.py new file mode 100644 index 0000000000000000000000000000000000000000..8174d2328990805d3c9fe43c56b09015d5e74f09 --- /dev/null +++ b/fla3/ops/common/chunk_o.py @@ -0,0 +1,632 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_indices +from ...ops.utils.op import exp, safe_exp +from ...utils import check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K*V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g'] is not None, + 'USE_DW': lambda args: args['dw'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqkwg( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dg, + w, + dv, + dw, + cu_seqlens, + chunk_indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_DW: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_G: + dg += i_k * B * H * T + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += (bos * H + i_h) * V + do += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K*V + dh += (i_tg * H + i_h).to(tl.int64) * K*V + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + + # for delta rule only + if USE_DW: + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + w += (bos * H + i_h) * K + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None + b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + if USE_G: + b_dg_last += (tl.sum(b_h * b_dh)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + if USE_DW: + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + if USE_DW and not USE_G: + p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + if USE_G: + b_dg = tl.zeros([BT,], dtype=tl.float32) + g += bos * H + i_h + dg += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * H) + b_dg_last *= exp(b_g_last) + + if USE_DW: + p_w = tl.make_block_ptr(w, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_dw = b_dw * exp(b_g)[:, None] + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + b_dg -= tl.sum(b_w * b_dw, axis=1) + + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dg += tl.sum(b_dq * b_q, axis=1) + + b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None] + b_dg -= tl.sum(b_k * b_dk, axis=1) + b_dg_last += tl.sum(b_dk * b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) + b_dg += tl.sum(b_ds2, axis=1) + b_dg -= tl.sum(b_ds2, axis=0) + + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + p_dg = tl.make_block_ptr(dg, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue + # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last) + b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + else: + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) * scale + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv( + q, + k, + g, + do, + dv, + dh, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + dh += (i_tg * H + i_h).to(tl.int64) * K*V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * H) + b_dv *= safe_exp(-b_g + b_g_last)[:, None] + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *q.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o + + +def chunk_bwd_dv( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NV = triton.cdiv(V, BV) + + dv = torch.empty_like(do) + grid = (NV, NT, B * H) + chunk_bwd_kernel_dv[grid]( + q, + k, + g, + do, + dv, + dh, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q, + k, + g, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +def chunk_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: Optional[torch.Tensor] = None, + w: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None + dw = torch.empty_like(w) if w is not None else None + + grid = (NK, NT, B * H) + chunk_bwd_kernel_dqkwg[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + do=do, + dh=dh, + dv=dv, + w=w, + dw=dw, + dq=dq, + dk=dk, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + + if dg is not None: + dg = dg.sum(0) + return dq, dk, dw, dg diff --git a/fla3/ops/common/chunk_scaled_dot_kkt.py b/fla3/ops/common/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc53e102404563d03ad86e82da7cf3cdc711737 --- /dev/null +++ b/fla3/ops/common/chunk_scaled_dot_kkt.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_indices +from ...ops.utils.op import safe_exp + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g_cumsum'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + Ag, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = tl.arange(0, BT) + + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_Ag = b_A * safe_exp(b_g_diff) + p_Ag = tl.make_block_ptr(Ag + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_Ag, b_Ag.to(p_Ag.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, H, K = k.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + Ag = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) if g_cumsum is not None else None + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + Ag=Ag, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + ) + return A, Ag diff --git a/fla3/ops/common/fused_recurrent.py b/fla3/ops/common/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..1625d3d6f40e8a51b63326903347da2ed317dd1e --- /dev/null +++ b/fla3/ops/common/fused_recurrent.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ...ops.utils import chunk_global_cumsum +from ...ops.utils.op import exp +from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_fwd_kernel( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[None, :]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[:, None]) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_o += (-1 if REVERSE else 1) * H*V + if USE_GK: + p_gk += (-1 if REVERSE else 1) * H*K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * H*V + if USE_G: + p_g += (-1 if REVERSE else 1) * H + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_bwd_kernel( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = i_k * BK + tl.arange(0, BK) < K + mask_v = i_v * BV + tl.arange(0, BV) < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[None, :]) + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + b_dq = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_do += (-1 if REVERSE else 1) * H*V + p_dq += (-1 if REVERSE else 1) * H*K + if USE_G: + p_g += (-1 if REVERSE else 1) * H + if USE_GK: + p_gk += (-1 if REVERSE else 1) * H*K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * H*V + + # sync threads + tl.debug_barrier() + + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_dh *= exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_dh *= exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_dh *= exp(b_gv)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + + p_q += (1 if REVERSE else -1) * H*K + p_k += (1 if REVERSE else -1) * H*K + p_v += (1 if REVERSE else -1) * H*V + p_do += (1 if REVERSE else -1) * H*V + p_dk += (1 if REVERSE else -1) * H*K + p_dv += (1 if REVERSE else -1) * H*V + if USE_G: + p_g += (1 if REVERSE else -1) * H + if USE_GK: + p_gk += (1 if REVERSE else -1) * H*K + if USE_GV: + p_gv += (1 if REVERSE else -1) * H*V + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + ht = q.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + o = q.new_empty(NK, *v.shape, dtype=torch.float32) + + grid = (NV, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + ) + o = o.sum(0) + return o, ht + + +def fused_recurrent_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + o: Optional[torch.Tensor] = None, + do: Optional[torch.Tensor] = None, + dht: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape, dtype=torch.float32) + dk = q.new_empty(NV, *k.shape, dtype=torch.float32) + dv = q.new_empty(NK, *v.shape, dtype=torch.float32) + h0 = initial_state + dh0 = torch.empty_like(initial_state) if initial_state is not None else None + + grid = (NV, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + cu_seqlens, + scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg, dgk, dgv = None, None, None + if g is not None: + dg = chunk_global_cumsum((dq * q.float() - dk * k.float()).sum(-1), reverse=not reverse, cu_seqlens=cu_seqlens) + if gk is not None: + dgk = chunk_global_cumsum(dq * q.float() - dk * k.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + if gv is not None: + dgv = chunk_global_cumsum(do.float() * o.float() - dv * v.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + + return dq, dk, dv, dg, dgk, dgv, dh0 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None + ): + o, ht = fused_recurrent_fwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + ctx.cu_seqlens = cu_seqlens + return o.to(q.dtype), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors + # not supported yet. + if dht is not None: + if not dht.eq(0).all(): + if g is not None: + assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gk is not None: + assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gv is not None: + assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + o=o, + do=do, + dht=dht, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + cu_seqlens=ctx.cu_seqlens, + ) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None + + +def fused_recurrent( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +): + if scale is None: + scale = k.shape[-1] ** -0.5 + return FusedRecurrentFunction.apply( + q, + k, + v, + g, + gk, + gv, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + ) diff --git a/fla3/ops/delta_rule/README.md b/fla3/ops/delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..607b0d583c7ec2904c18c0f1d86fb0ec2dfdf583 --- /dev/null +++ b/fla3/ops/delta_rule/README.md @@ -0,0 +1,90 @@ +# Chunkwise-form Parallelism of DeltaNet + +This section expands on the formulation presented in Appendix B of the DeltaNet paper.[^1] + +To reduce notational clutter, we focus on the first chunk, denoting $\mathbf{S}^r=\mathbf{S}_{[1]}^r$. By partially expanding the recurrence, we have: +```math +\begin{equation} +\begin{aligned} +\mathbf{S}^r &= \underbrace{\left(\prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \right)}_{:= \mathbf{P}^r} \cdot\mathbf{S}^{0} + \overbrace{\sum_{i=1}^{r} \underbrace{\left(\prod_{j=i+1}^r \mathbf{I} - \beta^j \boldsymbol{k}^j \boldsymbol{k}^{j\top} \right)}_{:= \mathbf{P}_{i+1}^r}\beta^i \boldsymbol{k}^i\boldsymbol{v}^{i\top}}^{:=\mathbf{H}^r} \\ +&=\mathbf{P}^r \cdot \mathbf{S}^{0} + \mathbf{H}^r +\end{aligned} +\end{equation} +``` + +where $\mathbf{P}_i^r$ involves cumulative products of generalized Householder matrices. +We abbreviate $\mathbf{P}_1^r$ as $\mathbf{P}^r$. +This can be optimized using the classical WY representation: +```math +\begin{equation} +\mathbf{P}^{r} = \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} \in \mathbb{R}^{d_k \times d_k};\qquad +\boldsymbol{w}^r = \beta^r \left(\boldsymbol{k}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i \right)\boldsymbol{w}^i \right) \in \mathbb{R}^{d_k} +\end{equation} +``` + +We prove this by induction: +```math +\begin{align*} +\mathbf{P}^{r} &= \prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\mathbf{P}^{r-1} \\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\left(\mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\ +&= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} + \beta^r\boldsymbol{k}^r \boldsymbol{k}^{r\top} \left(\sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\ +&= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \left(\boldsymbol{k}^{r} - \left(\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top} \boldsymbol{k}^i\right)\boldsymbol{w}^{i}\right) \right)^\top \\ +&= \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} +\end{align*} +``` + +Similarly, $\mathbf{H}^r$ can be represented as: +```math +\begin{equation} +\mathbf{H}^{r} = \sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} \in \mathbb{R}^{d_k \times d_v};\qquad \boldsymbol{u}^r = \beta^r \left(\boldsymbol{v}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i\right) \boldsymbol{u}^i \right)\in \mathbb{R}^{d_v} +\end{equation} +``` + +This can also be proven by induction: +```math +\begin{align*} +\mathbf{H}^{r} &= \sum_{i=1}^{r} \mathbf{P}_{i+1}^r \beta^i \boldsymbol{k}^i \boldsymbol{v}^{i\top}\\ +&= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right) \mathbf{H}^{r-1} + \beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} +\beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \left(\beta^r \boldsymbol{v}^{r\top}-\beta^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top}\right) \\ +&= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \beta^r\left(\boldsymbol{v}^{r}-\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top}\boldsymbol{k}^{i}\right)\boldsymbol{u}^{i} \right)^\top \\ +&=\sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} +\end{align*} +``` + +In matrix form, $\mathbf{P}$ and $\mathbf{H}$ can be written as: +```math +\begin{equation} +\mathbf{P}=\mathbf{I}-\mathbf{K}^\top\mathbf{W} \in \mathbb{R}^{d_k \times d_k}, \qquad\mathbf{H}=\mathbf{K}^\top\mathbf{U} \in \mathbb{R}^{d_k\times d_v} +\end{equation} +``` + +Now we can derive the matrix form of $\mathbf{W}$ and $\mathbf{U}$: +```math +\begin{align*} +\mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} - \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\mathbf{W}\\ +\left(\mathbf{I} + \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\right) \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} +\end{align*} +``` +A similar process holds for $\mathbf{U}$. We can further write $\mathbf{W}$ and $\mathbf{U}$ in matrix form: +```math +\begin{align*} +\mathbf{T} &= \left(\mathbf{I} + \mathrm{tril}\left(\mathrm{diag}(\beta)\mathbf{K} \mathbf{K}^\top,-1\right)\right)^{-1}\mathrm{diag}\left(\beta\right)\in \mathbb{R}^{C \times C}\\ +\mathbf{W} &= \mathbf{T} \mathbf{K}\in \mathbb{R}^{C \times d_k}\\ +\mathbf{U} &= \mathbf{T}\mathbf{V}\in \mathbb{R}^{C \times d_v} +\end{align*} +``` + +Substituting these back into the original equations yields a hardware-efficient chunkwise algorithm for DeltaNet that leverages matrix multiplications, enabling tensor core based GPU optimization: +```math +\begin{equation} +\begin{aligned} +\mathbf{S} &= \mathbf{P}\cdot\mathbf{S}^0 + \mathbf{H} \\ +&= \mathbf{S}^0 + \mathbf{K}^\top (\mathbf{U} -\mathbf{W} \mathbf{S}^0) \in \mathbb{R}^{d_k \times d_v}\\ +\mathbf{O} &= \mathbf{Q} \mathbf{S}^0 + (\mathbf{Q} \mathbf{K}^{\top} \odot \mathbf{M}) \left(\mathbf{U} - \mathbf{W} \mathbf{S}^0\right) \in \mathbb{R}^{C \times d_v} +\end{aligned} +\end{equation} +``` + +[^1]: https://arxiv.org/abs/2406.06484 diff --git a/fla3/ops/delta_rule/__init__.py b/fla3/ops/delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0acb6a7d0e4eec9a8dc697615604783b8858d13 --- /dev/null +++ b/fla3/ops/delta_rule/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_delta_rule +from .fused_chunk import fused_chunk_delta_rule +from .fused_recurrent import fused_recurrent_delta_rule + +__all__ = [ + 'fused_chunk_delta_rule', + 'fused_recurrent_delta_rule', + 'chunk_delta_rule' +] diff --git a/fla3/ops/delta_rule/__pycache__/__init__.cpython-310.pyc b/fla3/ops/delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24b0ea485300d214fd2d088fb313e9ad7ca6b8db Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/__init__.cpython-312.pyc b/fla3/ops/delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c07750e88eb9fdccc1591bff53c043259a6c1124 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/chunk.cpython-310.pyc b/fla3/ops/delta_rule/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ce7276b3affafd22a1ec52d1ad521cbd7c89d03 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/chunk.cpython-312.pyc b/fla3/ops/delta_rule/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17a2b474976ca21e3aa50d30a130aff5477ba9be Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-310.pyc b/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cbefaa1823cb6ed6e89137fd313f11a886a7c66 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-310.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc b/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56e84c255e597c6eba6db991e4c88987c309c02 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/fused_chunk.cpython-312.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26fbaa661aafdce2504d116913c28f7286b13d78 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6125266194d29a7c5d0edd5569d380318a001a3 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc b/fla3/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbd92f5dfa3451532022a92a0db36faac38fda39 Binary files /dev/null and b/fla3/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla3/ops/delta_rule/chunk.py b/fla3/ops/delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..05b1c5263bce9b52037d2415fc2510022288d15c --- /dev/null +++ b/fla3/ops/delta_rule/chunk.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from fla.ops.delta_rule.wy_fast import prepare_wy_repr_bwd, prepare_wy_repr_fwd, recompute_w_u_fwd +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + # obtain WY representation. u is actually the new v. + w, u, A = prepare_wy_repr_fwd( + k=k, + v=v, + beta=beta, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=None, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens + ) + + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=None, + scale=scale, + cu_seqlens=cu_seqlens + ) + return o, A, final_state + + +def chunk_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + cu_seqlens=cu_seqlens + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=None, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + do=do, + g=None, + scale=scale, + cu_seqlens=cu_seqlens + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=None, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens + ) + dq, dk, dw, _ = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + h=h, + w=w, + dv=dv, + do=do, + dh=dh, + g=None, + scale=scale, + cu_seqlens=cu_seqlens + ) + dk2, dv, db = prepare_wy_repr_bwd( + k=k, + v=v, + beta=beta, + A=A, + dw=dw, + du=dv, + cu_seqlens=cu_seqlens + ) + dk.add_(dk2) + return dq, dk, dv, db, dh0 + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = True + ): + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + o, A, final_state = chunk_delta_rule_fwd( + q=q, + k=k, + v=v, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, beta, A, initial_state = ctx.saved_tensors + use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel + if use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + + dq, dk, dv, db, dh0 = chunk_delta_rule_bwd( + q=q, + k=k, + v=v, + beta=beta, + A=A, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=ctx.cu_seqlens + ) + if use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q_orig, dq) + dk = l2norm_bwd(k_orig, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None + + +@torch.compiler.disable +def chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + use_qk_l2norm_in_kernel (Optional[bool]): + Whether to use qk l2norm within the kernel for saving GPU memory. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.delta_rule import chunk_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkDeltaRuleFunction.apply( + q, + k, + v, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla3/ops/delta_rule/fused_chunk.py b/fla3/ops/delta_rule/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6347fb9af47d3d9f82c03ea9aedbfa09fc1bfbc1 --- /dev/null +++ b/fla3/ops/delta_rule/fused_chunk.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +def fused_chunk_delta_rule( + **kwargs +): + raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.") diff --git a/fla3/ops/delta_rule/fused_recurrent.py b/fla3/ops/delta_rule/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..20742730a21f19e0802890ce0238e7fb46e6b126 --- /dev/null +++ b/fla3/ops/delta_rule/fused_recurrent.py @@ -0,0 +1,538 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_delta_rule_fwd_kernel( + q, + k, + v, + u, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + bos * H + i_h + p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_v_minus = tl.sum(b_h * b_k[None, :], axis=1) + b_v -= b_v_minus + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v) + b_v *= b_beta + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_o += H*V + p_v += H*V + p_u += H*V + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_delta_rule_bwd_kernel( + q, + k, + v, + beta, + h0, + dh0, + dht, + do, + dq, + dk, + dv, + db, + cu_seqlens, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NK: tl.constexpr, + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar + USE_INITIAL_STATE: tl.constexpr, # whether to use dh0 + USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht + IS_VARLEN: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + mask_k = i_k * BK + tl.arange(0, BK) < K + mask_v = i_v * BV + tl.arange(0, BV) < V + + p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K + p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V + p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V + p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K + p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V + if IS_BETA_HEADWISE: + p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV) + else: + p_beta = beta + (bos + T - 1) * H + i_h + p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + + b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v) + b_dv = b_dv * b_beta + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + if IS_BETA_HEADWISE: + tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v) + else: + tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty)) + + b_dh -= b_k[:, None] * b_dv[None, :] + + p_q -= H*K + p_k -= H*K + p_v -= H*V + p_do -= H*V + p_dk -= H*K + p_dv -= H*V + p_dbeta -= H * (V if IS_BETA_HEADWISE else 1) + p_beta -= H * (V if IS_BETA_HEADWISE else 1) + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) + + tl.debug_barrier() + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + bos * H + i_h + p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + if USE_INITIAL_STATE: + mask_h = mask_k[:, None] & mask_v[None, :] + p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32) + b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32) + b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + d_q = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += H*K + p_v += H*V + p_do += H*V + p_dq += H*K + p_dk += H*K + p_dv += H*V + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NV, NK, N * H) + u = torch.empty_like(v) + fused_recurrent_delta_rule_fwd_kernel[grid]( + q, + k, + v, + u, + beta, + o, + initial_state, + final_state, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, u, final_state + + +def fused_recurrent_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + dht: torch.Tensor, + do: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, *q.shape) + dk = q.new_empty(NV, *k.shape) + dv = q.new_empty(NK, *v.shape) + if beta_vector: + db = q.new_empty(NV, NK, B, T, H, V) + else: + db = q.new_empty(NV, B, T, H) + grid = (NV, NK, N * H) + + if initial_state is not None and initial_state.requires_grad: + dh0 = torch.empty_like(initial_state, dtype=torch.float32) + else: + dh0 = None + + fused_recurrent_delta_rule_bwd_kernel[grid]( + q, + k, + v, + beta, + initial_state, + dh0, + dht, + do, + dq, + dk, + dv, + db, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + NK=NK, + IS_BETA_HEADWISE=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + db = db.sum((0, 1)) if beta_vector else db.sum(0) + + return dq, dk, dv, db, dh0 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False + ): + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + o, u, final_state = fused_recurrent_delta_rule_fwd( + q=q, + k=k, + v=v, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + + ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, beta, initial_state = ctx.saved_tensors + if ctx.use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd( + q=q, + k=k, + v=v, + beta=beta, + dht=dht, + do=do, + scale=ctx.scale, + initial_state=initial_state, + cu_seqlens=ctx.cu_seqlens, + ) + if ctx.use_qk_l2norm_in_kernel: + dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk) + return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None + + +@torch.compiler.disable +def fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.delta_rule import fused_recurrent_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> beta = torch.rand(B, T, H, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_delta_rule( + q, k, v, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + return o, final_state diff --git a/fla3/ops/delta_rule/naive.py b/fla3/ops/delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..6752aa89a3eac8c00726e09a730152af44343de4 --- /dev/null +++ b/fla3/ops/delta_rule/naive.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if beta.ndim < v.ndim: + beta = beta[..., None] + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + + assert l % chunk_size == 0 + + # compute (I - tri(diag(beta) KK^T))^{-1} + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) + attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + + u = attn @ v + w = attn @ k_beta + S = k.new_zeros(b, h, d_k, d_v) + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i = q[:, :, i], k[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) + u_i = u[:, :, i] - w[:, :, i] @ S + o_inter = q_i @ S + o[:, :, i] = o_inter + attn @ u_i + S = S + k_i.transpose(-1, -2) @ u_i + + return rearrange(o, 'b h n c d -> b h (n c) d'), S + + +def delta_rule_parallel(q, k, v, beta, BM=128, BN=32): + b, h, l, d_k = q.shape + # d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + # compute (I - tri(diag(beta) KK^T))^{-1} + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) + mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) + T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, BN): + T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) + T = T + torch.eye(BN, dtype=torch.float, device=q.device) + + mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) + A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T + o_intra = A_local @ v + + # apply cumprod transition matrices on k to the last position within the chunk + k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta + # apply cumprod transition matrices on q to the first position within the chunk + q = q - A_local @ k_beta + o_intra = A_local @ v + + A = torch.zeros(b, h, l, l, device=q.device) + + q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) + o = torch.empty_like(v) + for i in range(0, l, BM): + q_i = q[:, :, i:i+BM] + o_i = o_intra[:, :, i:i+BM] + # intra block + for j in range(i + BM - 2 * BN, i-BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + mask = torch.arange(i, i+BM) >= (j + BN) + A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + # inter block + for j in range(i - BN, -BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + o[:, :, i:i+BM] = o_i + + for i in range(0, l//BN): + A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] + + return o, A diff --git a/fla3/ops/delta_rule/parallel.py b/fla3/ops/delta_rule/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..765f82a4033c03b17de766250ecebb553709a223 --- /dev/null +++ b/fla3/ops/delta_rule/parallel.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.delta_rule.wy_fast import fwd_prepare_T +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BT', 'K', 'V'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_transform_qk_fwd_kernel( + q, + k, + v, + beta, + o, + A, + q_new, + k_new, + A_local, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + BT: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_T = tl.load(p_T, boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_t = o_i[:, None] >= o_i[None, :] + b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype) + m_t = o_i[:, None] > o_i[None, :] + b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype) + + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q_new = tl.make_block_ptr(q_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + p_k_new = tl.make_block_ptr(k_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + b_k_new = b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False) + tl.store(p_k_new, b_k_new.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_transform_qk_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + chunk_size: int, + output_attentions: bool +): + B, H, T, K = k.shape + BT = chunk_size + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + o = torch.empty_like(v) + grid = (triton.cdiv(T, BT), B*H) + V = v.shape[-1] + A_local = torch.empty_like(A) if output_attentions else None + chunk_transform_qk_fwd_kernel[grid]( + q, + k, + v, + beta, + o, + A, + q_new, + k_new, + A_local, + scale=scale, + T=T, + K=K, + V=V, + BT=BT, + BK=triton.next_power_of_2(K), + BV=triton.next_power_of_2(V), + OUTPUT_ATTENTIONS=output_attentions + ) + return q_new, k_new, o, A_local + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def save_intra_chunk_attn( + A, + A_local, + T, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0)) + p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_local = tl.load(p_A_local, boundary_check=(0, 1)) + tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def parallel_delta_rule_fwd_kernel( + q, + k, + k2, # original k + v, + beta, + o, + o_new, + attn, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.zeros([BT, BK], dtype=tl.float32) + b_q += tl.load(p_q, boundary_check=(0, 1)) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0)) + b_o += tl.load(p_o, boundary_check=(0, 1)) + + # As opposed to Flashattention, this kernel requires scanning the KV blocks from right to left + # Q block and K block have overlap. + # masks required + for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1)) + p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_beta = tl.load(p_beta, boundary_check=(0,)) + # [BT, BS] + m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS) + b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False) + b_s = tl.where(m_s[:, None], b_s, 0) + + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) + b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for offset in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,)) + p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS] + b_beta = tl.load(p_beta, boundary_check=(0,)) + # [BT, BS] + b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)) + # [BT, BV] + b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype) + b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype) + + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + + p_o_new = tl.make_block_ptr(o_new + i_bh * T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +class ParallelDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, scale, output_attentions): + B, H, T, K, V = *k.shape, v.shape[-1] + assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.' + BT, BS = 128, 32 + BK = triton.next_power_of_2(k.shape[-1]) + BV = triton.next_power_of_2(v.shape[-1]) + assert BT % BS == 0 + + A = fwd_prepare_T(k, beta, BS) + attn = q.new_zeros(B, H, T, T) if output_attentions else None + q_new, k_new, o, A_local = chunk_transform_qk_fwd( + q, + k, + v, + beta, + A, + scale, + BS, + output_attentions + ) + + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + grid = (triton.cdiv(T, BT), B * H) + o_new = torch.empty_like(o) + + parallel_delta_rule_fwd_kernel[grid]( + q=q_new, + k=k_new, + k2=k, + v=v, + beta=beta, + o=o, + o_new=o_new, + attn=attn, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + num_stages=num_stages, + num_warps=num_warps + ) + + if output_attentions: + grid = (triton.cdiv(T, BS), B * H) + save_intra_chunk_attn[grid]( + A=attn, + A_local=A_local, + T=T, + BT=BS + ) + return o_new.to(q.dtype), attn + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, d_attn=None): + raise NotImplementedError('Backward pass is not implemented. Stay tuned!') + + +def parallel_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + output_attentions: bool = False, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`. + """ + if not head_first: + q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta)) + o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions) + if not head_first: + o = o.transpose(1, 2) + return o, attn + + +def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32): + b, h, l, d_k = q.shape + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + # compute (I - tri(diag(beta) KK^T))^{-1} + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta]) + mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0) + T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, BN): + T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2) + T = T + torch.eye(BN, dtype=q.dtype, device=q.device) + + mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1) + A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T + o_intra = A_local @ v + + # apply cumprod transition matrices on k to the last position within the chunk + k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta + # apply cumprod transition matrices on q to the first position within the chunk + q = q - A_local @ k_beta + o_intra = A_local @ v + + A = torch.zeros(b, h, l, l, device=q.device) + + q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra]) + o = torch.empty_like(v) + for i in range(0, l, BM): + q_i = q[:, :, i:i+BM] + o_i = o_intra[:, :, i:i+BM] + # intra block + for j in range(i + BM - 2 * BN, i-BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + mask = torch.arange(i, i+BM) >= (j + BN) + A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + # inter block + for j in range(i - BN, -BN, -BN): + k_j = k[:, :, j:j+BN] + A_ij = q_i @ k_j.transpose(-1, -2) + A[:, :, i:i+BM, j:j+BN] = A_ij + q_i = q_i - A_ij @ k_beta[:, :, j:j+BN] + o_i += A_ij @ v[:, :, j:j+BN] + o[:, :, i:i+BM] = o_i + + for i in range(0, l//BN): + A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i] + + return o, A diff --git a/fla3/ops/delta_rule/wy_fast.py b/fla3/ops/delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd48806e5fadafbef6ecf7c246d7afefca5be44 --- /dev/null +++ b/fla3/ops/delta_rule/wy_fast.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from fla.ops.utils import prepare_chunk_indices +from fla.ops.utils.solve_tril import solve_tril +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def prepare_wy_repr_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A, _ = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + g_cumsum=None, + cu_seqlens=cu_seqlens, + chunk_size=64, + output_dtype=torch.float32, + ) + A = solve_tril( + A=A, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + ) + return w, u, A + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + u = torch.empty_like(v) + w = torch.empty_like(k) + recompute_w_u_fwd_kernel[(NT, B*H)]( + k, + v, + beta, + w, + u, + A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k, + v, + beta, + A, + dw, + du, + dk, + dv, + dbeta, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dk, dv, dbeta + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +bwd_prepare_wy_repr = prepare_wy_repr_bwd + +fwd_recompute_w_u = recompute_w_u_fwd diff --git a/fla3/ops/forgetting_attn/__init__.py b/fla3/ops/forgetting_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e62c741d464f01b5c0c6707671061293b9d48644 --- /dev/null +++ b/fla3/ops/forgetting_attn/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_forgetting_attn + +__all__ = [ + 'parallel_forgetting_attn' +] diff --git a/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-310.pyc b/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b35ff451652de775d4f358ffbf2d13d353447d2c Binary files /dev/null and b/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-312.pyc b/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc68a668209030b8ec1cae4a5f7d7989df5acb6 Binary files /dev/null and b/fla3/ops/forgetting_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-310.pyc b/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1461d8a3d8c18f8c57f4b65f0cf6570d2a8fb8e Binary files /dev/null and b/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-310.pyc differ diff --git a/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc b/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8676a7806eedb7e2ff6f75b750d909a7258066df Binary files /dev/null and b/fla3/ops/forgetting_attn/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla3/ops/forgetting_attn/parallel.py b/fla3/ops/forgetting_attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..317929a3f0dce5b05d2dac64d4ec4d94e479a2be --- /dev/null +++ b/fla3/ops/forgetting_attn/parallel.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from fla.ops.attn.parallel import parallel_attn + + +def parallel_forgetting_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + Log decay at rach time step (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + assert (g <= 0).all(), "g_cumsum must be in log space" + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + o = parallel_attn(q, k, v, g, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o diff --git a/fla3/ops/gated_delta_rule/__init__.py b/fla3/ops/gated_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7f86639b3482c78768cf0511d2eb2650305e7f --- /dev/null +++ b/fla3/ops/gated_delta_rule/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule" +] diff --git a/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-310.pyc b/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b11fe5de839a9cc47ef66175d415982753f6e936 Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc b/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e52f6db2c1f244a2c66cad234531989e7532a809 Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-310.pyc b/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71d623d0ebc58ae8c5322b7b8bb0661d715d60e0 Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc b/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c8c0301dba25080a25914ef2b47926e283ee64f Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-310.pyc b/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e35a3901f3a4ff0488d13af260b10068ae39ec9 Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc b/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d160a7b7229a9592f8a68b732454b831dbce7abd Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc b/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee797de7bf842d76a3becf7cdc4a0243cc485860 Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc b/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a9f76a403dc9cf49b7b914cfae6168adc48993e Binary files /dev/null and b/fla3/ops/gated_delta_rule/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla3/ops/gated_delta_rule/chunk.py b/fla3/ops/gated_delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6df7fce0579c5a0d783839581caeb2ebe4fd09 --- /dev/null +++ b/fla3/ops/gated_delta_rule/chunk.py @@ -0,0 +1,355 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from ...modules.l2norm import l2norm_bwd, l2norm_fwd +from ...ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from ...ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from ...ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from ...ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd +from ...ops.utils import chunk_local_cumsum, solve_tril +from ...utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + Aw, Au = chunk_scaled_dot_kkt_fwd( + k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32 + ) + Aw = solve_tril( + A=Aw, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + Au = solve_tril( + A=Au, + cu_seqlens=cu_seqlens, + output_dtype=k.dtype + ) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + return g, o, Aw, Au, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, +): + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + cu_seqlens=cu_seqlens, + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + cu_seqlens=cu_seqlens, + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + ) + + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk2, dv, db, dg2 = prepare_wy_repr_bwd( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=dv, + cu_seqlens=cu_seqlens, + ) + dk.add_(dk2) + dg.add_(dg2) + assert dg.dtype == torch.float32, "dg should be fp32" + dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False + ): + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, cu_seqlens) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, g, beta, Aw, Au, initial_state, cu_seqlens = ctx.saved_tensors + if ctx.use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + cu_seqlens=cu_seqlens, + ) + if ctx.use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q_orig, dq) + dk = l2norm_bwd(k_orig, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the ..shAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from ...ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, beta, g = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please ..tten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla3/ops/gated_delta_rule/fused_recurrent.py b/fla3/ops/gated_delta_rule/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..4246484b10760d223a881385d37fd2fd7346a426 --- /dev/null +++ b/fla3/ops/gated_delta_rule/fused_recurrent.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils.op import exp +from ...utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * H + i_h) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + o_v + else: + p_beta = beta + bos * H + i_h + p_g = g + bos * H + i_h + p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_o += H*V + p_v += H*V + p_g += H + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NK, NV, N * H) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + cu_seqlens=cu_seqlens + ) + + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps." + ) + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + g (decays) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the ..shAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from ...ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda')) + >>> beta = torch.rand(B, T, H, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please ..tten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + return o, final_state diff --git a/fla3/ops/gated_delta_rule/wy_fast.py b/fla3/ops/gated_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..991db6658095c3581b280a4c88420b4cdbafe298 --- /dev/null +++ b/fla3/ops/gated_delta_rule/wy_fast.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import prepare_chunk_indices +from ...ops.utils.op import safe_exp +from ...utils import check_shared_mem + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + g, + Aw, + Au, + dw, + du, + dk, + dv, + dbeta, + dg, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA2 = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A) + b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype)) + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty) + + p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :]) + b_dA += b_dA2 + b_dA = b_dA.to(k.dtype.element_ty) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_k_beta, tl.trans(b_k)) + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + b_dA2 *= b_A + b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0) + p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], +) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + Aw, + Au, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_Au = None + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_Aw, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = Aw.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + + u = torch.empty_like(v) + w = torch.empty_like(k) + recompute_w_u_fwd_kernel[(NT, B*H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + Aw=Aw, + Au=Au, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + dg = torch.empty_like(g) + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=du, + dk=dk, + dv=dv, + dbeta=dbeta, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dk, dv, dbeta, dg + + +bwd_prepare_wy_repr = prepare_wy_repr_bwd + +fwd_recompute_w_u = recompute_w_u_fwd diff --git a/fla3/ops/generalized_delta_rule/README.md b/fla3/ops/generalized_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f96c22f44a51ad3e6fdeb824eb2aded660223600 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/README.md @@ -0,0 +1,37 @@ +# Generalized Delta Rule + +In delta rule we have the recurrence: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T +``` + +This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$. + +## IPLR (Identity Plus Low Rank) + +The first variant is IPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR. + +### Numerical Stability + +$\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix. + +## DPLR (Diagonal Plus Low Rank) + +The second variant is DPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7. + +## Efficient Chunkwise Implementation + +For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing). diff --git a/fla3/ops/generalized_delta_rule/__init__.py b/fla3/ops/generalized_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b4155a215ca8c44ea45d6b151b1e584872ed6c --- /dev/null +++ b/fla3/ops/generalized_delta_rule/__init__.py @@ -0,0 +1,9 @@ +from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc b/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5f0711cb3e9201546bd053d531e067d3b43f750 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc b/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3271cf36afee6c28e2640c14399c86fbc993d872 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__init__.py b/fla3/ops/generalized_delta_rule/dplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6de2928ca88abc25dc3156c4dc4fcb13ace180d --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_dplr_delta_rule +from .fused_recurrent import fused_recurrent_dplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule' +] diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb610c7d4b6be41e7f2ddd1ce9703c39df228e4d Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0ebe265ad98dd15f76219078e0165eebc73e93 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2668f9e2b01307da4dfcc0217eb799b18a007bd8 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8fb5e86bc8f45c1bbe1ce35d8f249a9851ad541 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df8f6f63fa20e37051f1a6d8ad411389c893e84 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_A_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8695b8a018bbc4317d7d093c6e51b585ee69d94b --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp, gather +from ....utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_intra( + q, + k, + a, + b, + gi, + ge, + qg, + kg, + ag, + bg, + Aqk, + Aqb, + Aab, + Aak, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + tl.arange(0, BC)) < T + last_idx = min((i_t+1) * BT, T) - 1 + o_A = (bos + i_t * BT + tl.arange(0, BC)) * H*BT + i_h * BT + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = b_q * scale + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32) + b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32) + + # deal with decay term. + g_exp = exp(b_gi) + g_exp_inv = exp(-b_gi + b_g_last[None, :]) + b_qg = b_q * g_exp + b_kg = b_k * g_exp_inv + b_bg = b_b * g_exp_inv + b_ag = b_a * exp(b_ge) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # tl.debug_barrier() + + b_q = b_q.to(b_k.dtype) + # inner attn + for j in range(0, min(BC, T - i_t * BT)): + # a trick to index the j-th row of b_k, b_g, b_b + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + # [1, BK] + b_k_j = gather(b_k, row_idx, axis=0) + b_gk_j = gather(b_gi, row_idx, axis=0) + b_b_j = gather(b_b, row_idx, axis=0) + else: + mask = tl.arange(0, BC) == j + b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :] + b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :] + b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :] + tmp = exp(b_gi - b_gk_j) + b_A_qk = tl.sum(b_q * b_k_j * tmp, 1) + m_i = (o_i >= j).to(tl.float32) + b_A_qk = b_A_qk * m_i + b_A_qb = tl.sum(b_q * b_b_j * tmp, 1) + b_A_qb = b_A_qb * m_i + tmp2 = exp(b_ge - b_gk_j) + b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1) + m_i2 = (o_i > j).to(tl.float32) + b_A_ak = b_A_ak * m_i2 + b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1) + b_A_ab = b_A_ab * m_i2 + + tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + + +def chunk_dplr_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + scale: float, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype) + Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype) + # involving matrix inverse and it'd be better to use float here. + Aab = q.new_empty(B, T, H, BT, dtype=torch.float) + Aak = q.new_empty(B, T, H, BT, dtype=torch.float) + + grid = (NT, B, H) + BK = triton.next_power_of_2(K) + qg = torch.empty_like(q) + kg = torch.empty_like(k, dtype=q.dtype) + ag = torch.empty_like(a, dtype=q.dtype) + bg = torch.empty_like(b, dtype=q.dtype) + chunk_dplr_fwd_A_kernel_intra_sub_intra[grid]( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + Aqk=Aqk, + Aqb=Aqb, + Aab=Aab, + Aak=Aak, + qg=qg, + kg=kg, + ag=ag, + bg=bg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BT, + BK=BK, + GATHER_SUPPORTED=is_gather_supported + ) + return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..86e8cec2d47980d2ff26f7e904bbe39f0697fa07 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV', "V"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dhu( + qg, + bg, + w, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)) + + mask_k = tl.arange(0, BK) < K + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + # [BT, BK] + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype)) + tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype)) + b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype)) + last_idx = min((i_t + 1) * BT, T) - 1 + bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k) + b_dh *= exp(bg_last)[:, None] + b_dh += b_dh_tmp + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dhu( + qg: torch.Tensor, + bg: torch.Tensor, + w: torch.Tensor, + gk: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *qg.shape, do.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + # H100 + if check_shared_mem('hopper', qg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', qg.device.index): # A100 + BV = 32 + BC = 32 + else: # Etc: 4090 + BV = 16 + BC = 16 + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + BC = min(BT, BC) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = qg.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.zeros_like(dv) + + grid = (NK, NV, N * H) + chunk_dplr_bwd_kernel_dhu[grid]( + qg=qg, + bg=bg, + w=w, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return dh, dh0, dv2 diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_h_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..eee76b11c6ee3b71b20ff35fe4b6dfc1d6225380 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_h( + kg, + v, + w, + bg, + u, + v_new, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + o_k = i_k * BK + tl.arange(0, BK) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_kg, b_v) + b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32) + b_h *= exp(b_g_last[:, None]) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_h( + kg: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + bg: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *kg.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', kg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', kg.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = kg.new_empty(B, NT, H, K, V) + final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + chunk_dplr_fwd_kernel_h[grid]( + kg=kg, + v=v, + w=w, + bg=bg, + u=u, + v_new=v_new, + h=h, + gk=gk, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return h, v_new, final_state diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6cc0874dc229a1d8a1252c81801f555402b840 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BV', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dAu( + v, + do, + v_new, + A_qb, + dA_qk, + dA_qb, + dv_new, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32) + b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32) + + p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1)) + # causal mask + b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_dA_qk += tl.dot(b_do, b_v) + b_dA_qb += tl.dot(b_do, b_v_new) + b_dv_new = tl.dot(tl.trans(b_A_qb), b_do) + # for recurrent + tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1)) + + p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.) + tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1)) + b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.) + tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_o_kernel( + v, + v_new, + h, + do, + dh, + dk, + db, + w, + dq, + dv, + dw, + gk, + dgk_last, + k, + b, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += (bos * H + i_h) * V + v_new += (bos * H + i_h) * V + do += (bos * H + i_h) * V + h += (i_tg * H + i_h) * K * V + dh += (i_tg * H + i_h) * K * V + dk += (bos * H + i_h) * K + k += (bos * H + i_h) * K + db += (bos * H + i_h) * K + b += (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dq += (bos * H + i_h) * K + w += (bos * H + i_h) * K + + dgk_last += (i_tg * H + i_h) * K + gk += (bos * H + i_h) * K + + stride_qk = H*K + stride_vo = H*V + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_db = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk_last = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0) + + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + m_k = (i_k*BK+tl.arange(0, BK)) < K + last_idx = min(i_t * BT + BT, T) - 1 + b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf')) + b_dgk_last *= exp(b_gk_last) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_dgk_last += tl.sum(b_k * b_dk, axis=0) + b_dgk_last += tl.sum(b_b * b_db, axis=0) + tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k) + + p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in BK_LIST + for BV in BK_LIST + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_kernel_dv( + A_qk, + kg, + do, + dv, + dh, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + A_qk += (bos * H + i_h) * BT + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + kg += (bos * H + i_h) * K + dh += (i_tg * H + i_h) * K*V + + stride_qk = H*K + stride_vo = H*V + stride_A = H*BT + + for i_k in range(tl.cdiv(K, BK)): + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype)) + + p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dv( + A_qk: torch.Tensor, + kg: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *kg.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_bwd_kernel_dv[grid]( + A_qk=A_qk, + kg=kg, + do=do, + dv=dv, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dv + + +def chunk_dplr_bwd_o( + k: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + gk: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + w: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, H, K, V = *w.shape, v.shape[-1] + + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(k) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + db = torch.empty_like(b) + grid = (NK, NT, B * H) + + dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device) + + chunk_dplr_bwd_o_kernel[grid]( + k=k, + b=b, + v=v, + v_new=v_new, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + db=db, + dgk_last=dgk_last, + w=w, + dv=dv, + dw=dw, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dq, dk, dw, db, dgk_last + + +def chunk_dplr_bwd_dAu( + v: torch.Tensor, + v_new: torch.Tensor, + do: torch.Tensor, + A_qb: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + if check_shared_mem('ampere'): # A100 + BV = min(triton.next_power_of_2(V), 128) + elif check_shared_mem('ada'): # 4090 + BV = min(triton.next_power_of_2(V), 64) + else: + BV = min(triton.next_power_of_2(V), 32) + + grid = (NT, B * H) + dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dv_new = torch.empty_like(v_new) + chunk_dplr_bwd_kernel_dAu[grid]( + v=v, + do=do, + v_new=v_new, + A_qb=A_qb, + dA_qk=dA_qk, + dA_qb=dA_qb, + dv_new=dv_new, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dv_new, dA_qk, dA_qb diff --git a/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..66f5e823be6c20bfe6683d489cabde3b3816be7e --- /dev/null +++ b/fla3/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BK_LIST + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_o( + qg, + v, + v_new, + A_qk, + A_qb, + h, + o, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_qg, b_h) + + p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) + b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) + b_Aqk = tl.where(m_s, b_Aqk, 0) + b_Aqb = tl.where(m_s, b_Aqb, 0) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_o( + qg: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + A_qk: torch.Tensor, + A_qb: torch.Tensor, + h: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *qg.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_fwd_kernel_o[grid]( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + o=o, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/fla3/ops/generalized_delta_rule/iplr/__init__.py b/fla3/ops/generalized_delta_rule/iplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e44d2a773b31f43fce68c5a9d1e67a3b33f42411 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/iplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_iplr_delta_rule +from .fused_recurrent import fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f406e6c1f7fcdd2185d355be34f8da1c9a35fe56 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f78d850e19fd73c9a9e32ba2c655f7b42af0532e Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc b/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f00fb347646d0687f042c5564cb00cd07f9ce3 Binary files /dev/null and b/fla3/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8bbc526e3c8a53c4abb1dc44fafec3847f6a81 --- /dev/null +++ b/fla3/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BK"], +) +@triton.jit +def fused_recurrent_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + a, # a [B, H, L, K] + b, # b [B, H, L, K] + o, # output [B, H, L, V] + ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0) + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + cu_seqlens, # varlen cu_seqlens + scale, # K ** -0.5 + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + p_q = q + (bos * H + i_h) * K + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + tl.arange(0, BK) + p_a = a + (bos * H + i_h) * K + tl.arange(0, BK) + p_b = b + (bos * H + i_h) * K + tl.arange(0, BK) + p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = tl.arange(0, BK) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + # to store + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v) + p_q += K*H + p_k += K*H + p_o += V*H + p_v += V*H + p_ha += V*H + p_a += K*H + p_b += K*H + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_DHT': lambda args: args['dht'] is not None, + 'USE_DH0': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=["BK", "BV"], +) +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: b_dhead + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + a, # a [B, H, L, K] + b, # b [B, H, L, K] + ha, # ha [B, H, L, V] + dht, # gradient of final state [B, H, K, V] + dh0, # gradient of initial state [B, H, K, V] + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + da, # gradient of a [NV, B, H, L, K] + db, # gradient of b [NV, B, H, L, K] + dha, # gradient of ha [NK, B, H, L, V] + h0, # initial state [B, H, K, V] + scale, # K ** -0.5 + cu_seqlens, # cu_seqlens + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0 + USE_DH0: tl.constexpr, # whether to use dh0 + USE_DHT: tl.constexpr, # whether to use dht + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + dk += i_v * B * H * K * T + db += i_v * B * H * K * T + dq += i_v * B * H * K * T + da += i_v * B * H * K * T + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + mask_k = tl.arange(0, BK) < K + mask_v = (tl.arange(0, BV) + i_v * BV) < V + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + i_v * BV + ha += (bos * H + i_h) * V + i_v * BV + a += (bos * H + i_h) * K + b += (bos * H + i_h) * K + do += (bos * H + i_h) * V + i_v * BV + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + i_v * BV + da += (bos * H + i_h) * K + db += (bos * H + i_h) * K + dha += (bos * H + i_h) * V + i_v * BV + + p_q = q + tl.arange(0, BK) + (T - 1) * H*K + p_k = k + tl.arange(0, BK) + (T - 1) * H*K + p_v = v + tl.arange(0, BV) + (T - 1) * H*V + p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V + p_a = a + tl.arange(0, BK) + (T - 1) * H*K + p_b = b + tl.arange(0, BK) + (T - 1) * H*K + p_do = do + tl.arange(0, BV) + (T - 1) * H*V + p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K + p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V + p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V + p_db = db + tl.arange(0, BK) + (T - 1) * H*K + p_da = da + tl.arange(0, BK) + (T - 1) * H*K + p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_DHT: + p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + + b_dh += b_q[:, None] * b_do[None, :] + d_k = tl.sum(b_dh * b_v[None, :], axis=1) + d_v = tl.sum(b_dh * b_k[:, None], axis=0) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v) + + b_dha = tl.sum(b_dh * b_b[:, None], axis=0) + tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v) + b_db = tl.sum(b_dh * b_ha[None, :], axis=1) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k) + + b_dh += b_dha[None, :] * b_a[:, None] + p_do -= H*V + p_q -= H*K + p_k -= H*K + p_v -= H*V + p_dk -= H*K + p_dv -= H*V + p_b -= H*K + p_db -= H*K + p_a -= H*K + p_dha -= H*V + p_ha -= H*V + + if USE_DH0: + p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) + + tl.debug_barrier() + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_k[:, None] & mask_v[None, :] + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + p_k = k + tl.arange(0, BK) + p_v = v + tl.arange(0, BV) + p_ha = ha + tl.arange(0, BV) + p_do = do + tl.arange(0, BV) + p_dha = dha + tl.arange(0, BV) + p_da = da + tl.arange(0, BK) + p_dq = dq + tl.arange(0, BK) + p_b = b + tl.arange(0, BK) + + for i in range(0, T): + b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32) + d_a = tl.sum(b_dha[None, :] * b_h, axis=1) + tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += H*K + p_do += H*V + p_v += H*V + p_da += H*K + p_dha += H*V + p_ha += H*V + p_dq += H*K + p_b += H*K + + +class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None + ): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK = triton.next_power_of_2(K) + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32) + else: + final_state = None + + ha = torch.empty_like(v, dtype=torch.float32) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + N * H + ) + o = torch.empty_like(v) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + o=o, + ha=ha, + h0=initial_state, + ht=final_state, + scale=scale, + cu_seqlens=cu_seqlens, + H=H, + T=T, + K=K, + V=V, + BK=BK, + ) + ctx.save_for_backward(q, k, v, a, b, ha, initial_state) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, a, b, ha, initial_state = ctx.saved_tensors + B, T, H, K, V = *q.shape, v.shape[-1] + N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + scale = ctx.scale + + dq = q.new_empty(NV, *q.shape) + dk = k.new_empty(NV, *k.shape) + da = a.new_empty(NV, *a.shape) + db = b.new_empty(NV, *b.shape) + dv = torch.empty_like(v) + dha = torch.empty_like(ha) + grid = (NV, N * H) + + if initial_state is not None and initial_state.requires_grad: + dh0 = torch.empty_like(initial_state, dtype=torch.float32) + else: + dh0 = None + + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + ha=ha, + dht=dht, + dh0=dh0, + do=do, + dq=dq, + dk=dk, + dv=dv, + da=da, + db=db, + dha=dha, + h0=initial_state, + scale=scale, + cu_seqlens=ctx.cu_seqlens, + B=B, + H=H, + T=T, + K=K, + V=V, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + dk = dk.sum(0) + da = da.sum(0) + db = db.sum(0) + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None + + +def fused_recurrent_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` + a (torch.Tensor): + as of shape `[B, T, H, K]` + b (torch.Tensor): + bs of shape `[B, T, H, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens + ) + return o, final_state diff --git a/fla3/ops/generalized_delta_rule/iplr/wy_fast.py b/fla3/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e895a8191b7ce6503db674c480ab7238b60ccc7b --- /dev/null +++ b/fla3/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,300 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk32( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk64( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def wu_fwd_kernel( + w, + u, + a, + k, + v, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def prepare_wy_repr_fwd( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype) + fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + ) + w, u = wu_fwd( + a=a, + v=v, + k=k, + A=A, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return w, u, A + + +def wu_fwd( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + wu_fwd_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +fwd_wu = wu_fwd diff --git a/fla3/ops/gla/__init__.py b/fla3/ops/gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..367c85442a26fe56516716622433f8b6f87afd2c --- /dev/null +++ b/fla3/ops/gla/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gla +from .fused_chunk import fused_chunk_gla +from .fused_recurrent import fused_recurrent_gla + +__all__ = [ + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla' +] diff --git a/fla3/ops/gla/__pycache__/__init__.cpython-312.pyc b/fla3/ops/gla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ffab0d487eabb5e89decfaaec9557198d752f68 Binary files /dev/null and b/fla3/ops/gla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc b/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13d955898952bd0b86bb9c786bfcfb37b578d43b Binary files /dev/null and b/fla3/ops/gla/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla3/ops/gla/fused_recurrent.py b/fla3/ops/gla/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..346c47b5d50acf4befefce88d85865e52a256ca2 --- /dev/null +++ b/fla3/ops/gla/fused_recurrent.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.common.fused_recurrent import fused_recurrent + + +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + gk (torch.Tensor): + Forget gates of shape `[B, T, H, K]`. + gv (torch.Tensor): + Forget gates of shape `[B, T, H, V]` applied to values. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import fused_recurrent_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_gla( + q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=None, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return o, final_state diff --git a/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc b/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecfefc966a3e01bb6b66fc76c8c11b515b5ecd22 Binary files /dev/null and b/fla3/ops/gsa/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla3/ops/gsa/fused_recurrent.py b/fla3/ops/gsa/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..6febf9932b7510bf106f6e8507b32e1519813daa --- /dev/null +++ b/fla3/ops/gsa/fused_recurrent.py @@ -0,0 +1,525 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel +from fla.ops.utils import chunk_global_cumsum +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit +def fused_recurrent_gsa_inference_kernel( + q, + k, + v, + s, + g, + o, + hk0, + hv0, + hkt, + hvt, + scale, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr +): + i_bh = tl.program_id(0) + i_bg = i_bh // NG + + b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32) + b_g = exp(b_g) + + b_ok = tl.zeros([M], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + + p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None] + # [BK,] + mask_k = o_k < K + # [M, BK] + mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :] + # [M, BK] + b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32) + # [BK,] + b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale + b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32) + b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None] + b_ok += tl.sum(b_hk * b_q[None, :], axis=1) + + if i_bh % NG == 0: + p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None] + tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk) + + b_qv = tl.softmax(b_ok) + for i_v in range(tl.cdiv(V, BV)): + o_v = i_v * BV + tl.arange(0, BV) + + p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + # [BV,] + mask_v = o_v < V + # [BV, M] + mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :] + # [BV, M] + b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32) + # [BV,] + b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32) + b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None] + b_ov = tl.sum(b_hv * b_qv[None, :], axis=1) + + tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v) + + if i_bh % NG == 0: + p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None] + tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv) + + +def fused_recurrent_gsa_inference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_final_state: bool = False, + scale: float = 1., +) -> torch.Tensor: + B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + HQ = q.shape[2] + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NG = HQ // H + + if initial_state != (None, None) and initial_state is not None: + hk0, hv0 = initial_state + else: + hk0, hv0 = q.new_zeros(B, H, K, M, dtype=torch.float), q.new_zeros(B, H, M, V, dtype=torch.float) + + hkt, hvt = None, None + if output_final_state: + if NG == 1: + hkt, hvt = hk0, hv0 + else: + hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float) + + o = v.new_empty(B, T, HQ, V) + grid = (B * HQ,) + fused_recurrent_gsa_inference_kernel[grid]( + q, + k, + v, + s, + g, + o, + hk0, + hv0, + hkt, + hvt, + scale=scale, + K=K, + V=V, + M=M, + BK=BK, + BV=BV, + NG=NG + ) + return o, (hkt, hvt) + + +def fused_recurrent_gsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + output_final_state: bool = False, + scale: float = 1., + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + HQ = q.shape[2] + if HQ != H: + raise ValueError("GQA not supported yet.") + + BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + + hk0, hv0 = None, None + if initial_state != (None, None) and initial_state is not None: + hk0, hv0 = initial_state + hkt, hvt = None, None + if output_final_state: + hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float) + + ok = q.new_empty(NK, *s.shape, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=s, + g=None, + gk=gk, + gv=gv, + o=ok, + h0=hk0, + ht=hkt, + cu_seqlens=cu_seqlens, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=M, + BK=BK, + BV=BM, + USE_G=False, + USE_GK=False, + USE_GV=True, + REVERSE=reverse + ) + ok = ok.sum(0) + + qv = ok.softmax(-1, dtype=torch.float) + ov = q.new_empty(NM, *v.shape, dtype=torch.float) + gk, gv = g, None + grid = (NV, NM, N * H) + fused_recurrent_fwd_kernel[grid]( + q=qv, + k=s, + v=v, + g=None, + gk=gk, + gv=gv, + o=ov, + h0=hv0, + ht=hvt, + cu_seqlens=cu_seqlens, + scale=1., + B=B, + T=T, + H=H, + K=M, + V=V, + BK=BM, + BV=BV, + USE_G=False, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + ) + ov = ov.sum(0) + return ok, hkt, qv, ov, hvt + + +def fused_recurrent_gsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + qv: torch.Tensor, + hk0: Optional[torch.Tensor] = None, + hv0: Optional[torch.Tensor] = None, + ok: Optional[torch.Tensor] = None, + do: Optional[torch.Tensor] = None, + dhkt: Optional[torch.Tensor] = None, + dhvt: Optional[torch.Tensor] = None, + scale: float = 1., + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor]: + B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK, BV, BM = min(K, 64), min(V, 64), min(M, 64) + NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM) + + dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float) + dv = q.new_empty(NM, B, T, H, V, dtype=torch.float) + dhk0 = torch.empty_like(hk0)if hk0 is not None else None + dhv0 = torch.empty_like(hv0)if hv0 is not None else None + + gk, gv = g, None + grid = (NV, NM, N * H) + fused_recurrent_bwd_kernel[grid]( + q=qv, + k=s, + v=v, + g=None, + gk=gk, + gv=gv, + h0=hv0, + do=do, + dq=dqv, + dk=dsv, + dv=dv, + dht=dhvt, + dh0=dhv0, + cu_seqlens=cu_seqlens, + scale=1., + B=B, + T=T, + H=H, + K=M, + V=V, + BK=BM, + BV=BV, + USE_G=False, + USE_GK=True, + USE_GV=False, + REVERSE=reverse, + ) + dqv = dqv.sum(0) + dsv = dsv.sum(0) + dv = dv.sum(0) + dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + + dok = qv * (dqv - (qv * dqv).sum(-1, True)) + dq = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dk = q.new_empty(NM, B, T, H, K, dtype=torch.float) + dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float) + gk, gv = None, g + grid = (NM, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=s, + g=None, + gk=gk, + gv=gv, + h0=hk0, + do=dok, + dq=dq, + dk=dk, + dv=dsk, + dht=dhkt, + dh0=dhk0, + cu_seqlens=cu_seqlens, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=M, + BK=BK, + BV=BM, + USE_G=False, + USE_GK=False, + USE_GV=True, + REVERSE=reverse, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dsk = dsk.sum(0) + + dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(), reverse=not reverse, cu_seqlens=cu_seqlens) + + ds = dsk.add_(dsv) + dg = dgk.add_(dgv) + + return dq, dk, dv, ds, dg, dhk0, dhv0 + + +class FusedRecurrentGSAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + hk0: Optional[torch.Tensor] = None, + hv0: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: + T = q.shape[1] + if T == 1 and not q.requires_grad: + o, (hkt, hvt) = fused_recurrent_gsa_inference( + q=q, + k=k, + v=v, + s=s, + g=g, + initial_state=(hk0, hv0), + output_final_state=output_final_state, + scale=scale, + ) + return o, hkt, hvt + ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd( + q=q, + k=k, + v=v, + s=s, + g=g, + initial_state=(hk0, hv0), + output_final_state=output_final_state, + scale=scale, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok) + ctx.scale = scale + ctx.reverse = reverse + ctx.cu_seqlens = cu_seqlens + return ov.to(q.dtype), hkt, hvt + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dhkt=None, dhvt=None): + q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors + scale = ctx.scale + reverse = ctx.reverse + cu_seqlens = ctx.cu_seqlens + + # not supported yet. + if dhkt is not None or dhvt is not None: + if g is not None: + assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd( + q=q, + k=k, + v=v, + s=s, + g=g, + qv=qv, + hk0=hk0, + hv0=hv0, + ok=ok, + do=do, + dhkt=dhkt, + dhvt=dhvt, + scale=scale, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None + + +def fused_recurrent_gsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: Optional[bool] = False, + reverse: Optional[bool] = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + s (torch.Tensor): + slot representations of shape `[B, T, H, M]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, M]` applied to keys. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[Tuple[torch.Tensor]]): + Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`. + Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (Tuple[torch.Tensor]): + Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gsa import fused_recurrent_gsa + # inputs with equal lengths + >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> s = torch.randn(B, T, H, M, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda')) + >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda')) + >>> o, (hk, hv) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa( + q, k, v, s, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert hk.allclose(hk_var) + >>> assert hv.allclose(hv_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + if initial_state is None: + initial_state = (None, None) + o, *final_state = FusedRecurrentGSAFunction.apply( + q, + k, + v, + s, + g, + scale, + *initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + return o, final_state diff --git a/fla3/utils.py b/fla3/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57834cd13064170c0841840e4aa05c51ecaeff2a --- /dev/null +++ b/fla3/utils.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- + +import contextlib +import functools +import logging +import os +from enum import Enum +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +import torch +import triton +from packaging import version + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" + + +def get_abs_err(x, y): + return (x.detach()-y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach()-y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + logger.info(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)): + if error_rate > ratio: + import warnings + warnings.warn(msg) + else: + assert error_rate < ratio, msg + + +def tensor_cache( + fn: Callable[..., torch.Tensor] +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: Optional[Tuple] = None + last_kwargs: Optional[Dict] = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if last_args is not None and last_kwargs is not None: + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + if all(a is b for a, b in zip(args, last_args)) and \ + all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +def input_guard( + fn: Callable[..., torch.Tensor] +) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +contiguous = input_guard + + +def require_version(version, hint): + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + """ + def decorator(fn): + @functools.wraps(fn) + def wrapper(ctx, *args, **kwargs): + from transformers.utils.versions import require_version + require_version(version, hint) + return fn(ctx, + *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), + **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) + return wrapper + return decorator + + +def checkpoint(fn): + def wrapper(*args, **kwargs): + return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) + return wrapper + + +@lru_cache(maxsize=None) +def check_pytorch_version(version_s: str = '2.4') -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +def _cpu_device_warning(): + import warnings + warnings.warn(('Triton is not supported on current platform, roll back to CPU.'), stacklevel=1) + + +@lru_cache(maxsize=None) +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + try: + return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] + except BaseException: + _cpu_device_warning() + return -1 + + +@lru_cache(maxsize=None) +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + _cpu_device_warning() + return 'cpu' + + +@lru_cache(maxsize=None) +def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: + device = get_available_device() + if device == 'cuda': + return 'nvidia' + elif device == 'hip': + return 'amd' + elif device == 'xpu': + return 'intel' + else: + return device + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != 'hip' else 'cuda' +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = (device_platform == 'amd') +is_intel = (device_platform == 'intel') +is_nvidia = (device_platform == 'nvidia') +is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) +is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) +use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') + +# Nvidia Ampere or newer, haven't check AMD and intel yet. +is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) +is_gather_supported = hasattr(triton.language, 'gather') + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)['max_shared_mem'] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + _cpu_device_warning() + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@lru_cache(maxsize=None) +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False + + +if check_pytorch_version('2.4'): + device = 'cuda' if device == 'cpu' else device + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == 'cuda', 'Only cuda device is supported for PyTorch version < 2.4.0.' + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index)