diff --git a/fla2/modules/__pycache__/fused_cross_entropy.cpython-312.pyc b/fla2/modules/__pycache__/fused_cross_entropy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b385b93f20f3078fc7d928b95af974516452f439 Binary files /dev/null and b/fla2/modules/__pycache__/fused_cross_entropy.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/fused_cross_entropy.cpython-38.pyc b/fla2/modules/__pycache__/fused_cross_entropy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b15f5ae39f193b081094f9db354145af726db443 Binary files /dev/null and b/fla2/modules/__pycache__/fused_cross_entropy.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/fused_cross_entropy.cpython-39.pyc b/fla2/modules/__pycache__/fused_cross_entropy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f60ef190229e2da3c3ce60e1a4a0783aff7bc4 Binary files /dev/null and b/fla2/modules/__pycache__/fused_cross_entropy.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/fused_norm_gate.cpython-310.pyc b/fla2/modules/__pycache__/fused_norm_gate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf0e3caaafea0897109d87f9b06247afb26efa5c Binary files /dev/null and b/fla2/modules/__pycache__/fused_norm_gate.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/fused_norm_gate.cpython-312.pyc b/fla2/modules/__pycache__/fused_norm_gate.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4342c94e92c7290aeb690278b1096f87f54db8c8 Binary files /dev/null and b/fla2/modules/__pycache__/fused_norm_gate.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/fused_norm_gate.cpython-38.pyc b/fla2/modules/__pycache__/fused_norm_gate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54c26221b8cdc0403ece1b46f119589737452f8d Binary files /dev/null and b/fla2/modules/__pycache__/fused_norm_gate.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/fused_norm_gate.cpython-39.pyc b/fla2/modules/__pycache__/fused_norm_gate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a3986f73934b06f3b9be5b5f34947a475044b9 Binary files /dev/null and b/fla2/modules/__pycache__/fused_norm_gate.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/l2norm.cpython-310.pyc b/fla2/modules/__pycache__/l2norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4de5d11af1348ff7b5ff29da2dad55976ff7f352 Binary files /dev/null and b/fla2/modules/__pycache__/l2norm.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/l2norm.cpython-312.pyc b/fla2/modules/__pycache__/l2norm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a8e3a3db567515b1c80510471797b3a3d0897d Binary files /dev/null and b/fla2/modules/__pycache__/l2norm.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/l2norm.cpython-38.pyc b/fla2/modules/__pycache__/l2norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ae03a74ca13d35be25fe4798bc35046001a219f Binary files /dev/null and b/fla2/modules/__pycache__/l2norm.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/l2norm.cpython-39.pyc b/fla2/modules/__pycache__/l2norm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6b50505bf9418e43a32162c85f7252fc4e90f7 Binary files /dev/null and b/fla2/modules/__pycache__/l2norm.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/layernorm.cpython-310.pyc b/fla2/modules/__pycache__/layernorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0124dbe82f4427d6bf00ef58b111329abc16ef4 Binary files /dev/null and b/fla2/modules/__pycache__/layernorm.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/layernorm.cpython-312.pyc b/fla2/modules/__pycache__/layernorm.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf64675cf72c093b1d97e8d48306398097864289 Binary files /dev/null and b/fla2/modules/__pycache__/layernorm.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/layernorm.cpython-38.pyc b/fla2/modules/__pycache__/layernorm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb282aae2dbd1585954b0cdfea9a7cf6702fa30 Binary files /dev/null and b/fla2/modules/__pycache__/layernorm.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/layernorm.cpython-39.pyc b/fla2/modules/__pycache__/layernorm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d9284692918d55e629a056186a7bf7527ca5717 Binary files /dev/null and b/fla2/modules/__pycache__/layernorm.cpython-39.pyc differ diff --git a/fla2/modules/__pycache__/rotary.cpython-310.pyc b/fla2/modules/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e9372e1bdd45a460262bd22865d503a1600d66d Binary files /dev/null and b/fla2/modules/__pycache__/rotary.cpython-310.pyc differ diff --git a/fla2/modules/__pycache__/rotary.cpython-312.pyc b/fla2/modules/__pycache__/rotary.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6d1407ed2ac6c4bcf2e1e48099fa992995100a3 Binary files /dev/null and b/fla2/modules/__pycache__/rotary.cpython-312.pyc differ diff --git a/fla2/modules/__pycache__/rotary.cpython-38.pyc b/fla2/modules/__pycache__/rotary.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1150cf0f1695070bdacdf0f767da47f909b7f6e Binary files /dev/null and b/fla2/modules/__pycache__/rotary.cpython-38.pyc differ diff --git a/fla2/modules/__pycache__/rotary.cpython-39.pyc b/fla2/modules/__pycache__/rotary.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bef4dcfd34b8f1ae19f854948d29cedf4cdc68 Binary files /dev/null and b/fla2/modules/__pycache__/rotary.cpython-39.pyc differ diff --git a/fla2/ops/__pycache__/utils.cpython-38.pyc b/fla2/ops/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce053461fe43f958208e1c2bbdc2ed26685a934 Binary files /dev/null and b/fla2/ops/__pycache__/utils.cpython-38.pyc differ diff --git a/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc b/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e34760e741fc6fa0b81adc7645de601f8a5ebd20 Binary files /dev/null and b/fla2/ops/abc/__pycache__/chunk_gate.cpython-38.pyc differ diff --git a/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18915be07f6ac3cb7584482c28196481899b4ab7 Binary files /dev/null and b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85bb3f9ebe5d6d061d6d6435c4332479dbf272eb Binary files /dev/null and b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ada68ad7dcb41d6ec29473d1afb4e53f8fc9b4 Binary files /dev/null and b/fla2/ops/abc/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..107cf7726be2685b8f1d9ac8441df7ebf88cb21c Binary files /dev/null and b/fla2/ops/based/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/based/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/based/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eecd2a1253788ad129e95ed9d6c504e2b69ac464 Binary files /dev/null and b/fla2/ops/based/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/based/__pycache__/parallel.cpython-312.pyc b/fla2/ops/based/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5336df9b77e9e0b5c4d62cc7d8c7edbca557ae2b Binary files /dev/null and b/fla2/ops/based/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla2/ops/based/__pycache__/parallel.cpython-38.pyc b/fla2/ops/based/__pycache__/parallel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d6830016a1cf0a3c4840130f1a157988f3c382 Binary files /dev/null and b/fla2/ops/based/__pycache__/parallel.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/__init__.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b00ff9d7724d80be01f2ba7dbd854969ebcae7 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/__init__.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a948f0ec52246f14b2ebf074fab6aca2eaac443c Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk.cpython-310.pyc b/fla2/ops/delta_rule/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c3e6e2dfd0241326bcf81c5665896b1f28e7f37 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk.cpython-312.pyc b/fla2/ops/delta_rule/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78a4eb601ad645ffd8ff1e20ef59812fe9e88c5c Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cff0a4c76ae4d2740b47cd72b0628dd5bc871f41 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f41739967c0262d0abcbdc71f48b649160f6fd Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-310.pyc b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba3ee36b74788ca2f9072758ae8828bf9901312 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-310.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afbf37de0fb1fb245711caafcdca75c1b10c2196 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c378431dcd84ae23fe1eb19b3a5d44d6928d95a4 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686fba514ad2d79567e8dd2272d909e9960d8093 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fbd641a7ef1a16ef0fec4479cccbf5a21e6af4e Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e7f6a1d1bdebc7ec2fe7b9c6b0385a1c05e5da8 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b44600d477eb80e539b2ac1c1cf04d9e62e9a38c Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d51499ae48dfdaf08ca2ce10f9bfd10798df118 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/utils.cpython-310.pyc b/fla2/ops/delta_rule/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..162c53bfff5e687d3b35374ed9dfd9ac7438373d Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/utils.cpython-310.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/utils.cpython-312.pyc b/fla2/ops/delta_rule/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a75b5b88c252e5c0ded72e13d081eae930657a Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/utils.cpython-312.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/utils.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2772e24703a9b30e4dbd4b1a68091d0e697531f3 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/utils.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/utils.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2221c886354106639fd7dce216b716de2fcdae7 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/utils.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eafafe91225307aad5ef3b305e06407eab75948 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-312.pyc b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d26fa6de595904e522d373cc8afa56cf63a0046 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-38.pyc b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69b536a03a72a5479f17259d76a0ef178de938f3 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-38.pyc differ diff --git a/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-39.pyc b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64d893465447c326a1330a509a95b70e84820903 Binary files /dev/null and b/fla2/ops/delta_rule/__pycache__/wy_fast.cpython-39.pyc differ diff --git a/fla2/ops/delta_rule/chunk.py b/fla2/ops/delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..41ee3fee316199a6106d2740356da7944aca8ba5 --- /dev/null +++ b/fla2/ops/delta_rule/chunk.py @@ -0,0 +1,543 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.wy_fast import (bwd_prepare_wy_repr, + fwd_prepare_wy_repr, fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_dv(q, k, do, BT): + dv = torch.empty_like(do) + B, H, T, K, V = *k.shape, do.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV + ) + return dv + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v, + d, + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(BT, BC)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BK] + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False) + # [BK, BV] + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h_cumsum += tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + b_h += b_h_cumsum + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + p_dv2 = tl.make_block_ptr(dv2 + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d, b_dv.to(b_q.dtype), allow_tf32=False) + b_dh += b_dh_tmp + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + + # [BT, BT] + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype) + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape, u.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty_like(u) + chunk_delta_rule_fwd_kernel_h[grid]( + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B, H, T, K, V = *q.shape, do.shape[-1] + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K, V) + # dv_new = torch.empty_like(do) + grid = (NK, NV, B * H) + dv2 = torch.empty_like(dv) + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT, + ) + return dh, dv2 + + +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + o = torch.empty_like(v_new) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + ) + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + + BK = triton.next_power_of_2(K) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=1): + # obtain WY representation. u is actually the new v. + w, u, A = fwd_prepare_wy_repr(k, v, beta, BT) + # ### forward_h + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False) + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state) + # obtain output + o = chunk_fwd_o_fn(q, k, v_new, h, BT) + # save memory + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + w, u = fwd_recompute_w_u(k, v, beta, A, BT) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + dv = fwd_prepare_dv(q, k, do, BT) + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT) + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/delta_rule/chunk_fuse.py b/fla2/ops/delta_rule/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..57b46b0b6b663d16d232607d8e2f1e60dac40cb2 --- /dev/null +++ b/fla2/ops/delta_rule/chunk_fuse.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +import torch.nn.functional as F + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v = map(lambda x: pad(x), [q, k, v]) + beta = pad_b(beta) + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + return o, final_state + + +def delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + seq_len = 128 + b = 2 + h = 4 + q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + beta = torch.rand(b, h, seq_len).sigmoid() + q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + do = torch.rand_like(v) + o2 = delta_rule_recurrence(q, k, v.clone(), beta) + o2.backward(do, retain_graph=True) + q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + o.backward(do, retain_graph=True) + q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + q.grad = k.grad = v.grad = beta.grad = None + print((o - o2).abs().max()) + print((q_grad - q_grad2).abs().max()) + print((k_grad - k_grad2).abs().max()) + print((v_grad - v_grad2).abs().max()) + print((beta_grad - beta_grad2).abs().max()) \ No newline at end of file diff --git a/fla2/ops/delta_rule/naive.py b/fla2/ops/delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4c628f0472d00081386a121655df146b018bb0 --- /dev/null +++ b/fla2/ops/delta_rule/naive.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if beta.ndim < v.ndim: + beta = beta[..., None] + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + + return o + + +def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v * beta[..., None] + k_beta = k * beta[..., None] + + assert l % chunk_size == 0 + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) + attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) + + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + # u + k_cumsum = attn @ v + # w + k_cumdecay = attn @ k_beta + + v = k_cumsum + S = k.new_zeros(b, h, d_k, d_v) + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) + v_prime = k_cumdecay[:, :, i] @ S + v_new = v_i - v_prime + o_inter = q_i @ S + o[:, :, i] = o_inter + attn @ v_new + # chunk state update + S = S + k_i.transpose(-1, -2) @ v_new + + return rearrange(o, 'b h n c d -> b h (n c) d') + + +if __name__ == '__main__': + B = 2 + H = 4 + L = 256 + DK = 128 + DV = 128 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + + o = delta_rule_recurrence(q, k, v, beta) + do = torch.randn(B, H, L, DV).cuda() + o.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + + o2 = delta_rule_chunkwise(q, k, v, beta) + o2.backward(do) + assert torch.allclose(o, o2, atol=1e-4), breakpoint() + assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint() + assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint() + assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint() + assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint() + print("All passed!") diff --git a/fla2/ops/delta_rule/recurrent_fuse.py b/fla2/ops/delta_rule/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..675cede2a2f363422b97b803b3820e9a150e809c --- /dev/null +++ b/fla2/ops/delta_rule/recurrent_fuse.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + beta, # beta [B, H, L] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, # batch size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_qk_h, # stride size: L * K + + s_vo_h, # stride size: L * V + + NK, # NK block size + scale, # K ** -0.5 + + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= b_k[:, None] * d_v[None, :] + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @contiguous + @staticmethod + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale + return o, final_state + + @contiguous + @staticmethod + def backward(ctx, do, dht=None): + q, k, v, beta, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None + + +def fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/delta_rule/utils.py b/fla2/ops/delta_rule/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..173d6629c628bb6b5860a005cbc8ea85d7cf9b5e --- /dev/null +++ b/fla2/ops/delta_rule/utils.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @contiguous + @autocast_custom_fwd + @staticmethod + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @contiguous + @autocast_custom_bwd + @staticmethod + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) diff --git a/fla2/ops/delta_rule/wy_fast.py b/fla2/ops/delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..c56345de4e5bedd5684bfe79a688c1e60fb24327 --- /dev/null +++ b/fla2/ops/delta_rule/wy_fast.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty), boundary_check=(0, 1)) + b_A = b_A.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + p_w = tl.make_block_ptr(w + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + # store + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + # store + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def fwd_prepare_wy_repr(k, v, beta, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + u = torch.empty_like(v) + w = torch.empty_like(k) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B, H, T, BT, device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return w, u, A + + +def fwd_recompute_w_u(k, v, beta, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + u = torch.empty_like(v) + w = torch.empty_like(k) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return w, u + + +def bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, A, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, BT, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta, chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v, beta, ctx.BT) + ctx.save_for_backward(k, v, beta, A) + return w, u + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, A, dw, du, BT) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 1024 + b = 4 + h = 4 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + # beta = torch.ones(b, h, seq_len) + require_grad = True + + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + o1, o2 = naive(k.clone(), v.clone(), beta.clone(), 64) + if require_grad: + o1.backward(do, retain_graph=True) + o2.backward(do2, retain_graph=True) + + k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad + k.grad = v.grad = beta.grad = None + o3, o4 = prepare_wy_repr(k.clone(), v.clone(), beta.clone(), 64) + print((o1-o3).abs().max()) + print((o2-o4).abs().max()) + + if require_grad: + o3.backward(do, retain_graph=True) + o4.backward(do2, retain_graph=True) + k_grad, v_grad, beta_grad = k.grad, v.grad, beta.grad + print((k_grad2-k_grad).abs().max()) + print((v_grad2-v_grad).abs().max()) + print((beta_grad2-beta_grad).abs().max()) + breakpoint() diff --git a/fla2/ops/generalized_delta_rule/README.md b/fla2/ops/generalized_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f96c22f44a51ad3e6fdeb824eb2aded660223600 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/README.md @@ -0,0 +1,37 @@ +# Generalized Delta Rule + +In delta rule we have the recurrence: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T +``` + +This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$. + +## IPLR (Identity Plus Low Rank) + +The first variant is IPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR. + +### Numerical Stability + +$\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix. + +## DPLR (Diagonal Plus Low Rank) + +The second variant is DPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7. + +## Efficient Chunkwise Implementation + +For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing). diff --git a/fla2/ops/generalized_delta_rule/__init__.py b/fla2/ops/generalized_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b4155a215ca8c44ea45d6b151b1e584872ed6c --- /dev/null +++ b/fla2/ops/generalized_delta_rule/__init__.py @@ -0,0 +1,9 @@ +from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla2/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc b/fla2/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e05e80941a83eaef3736f2ffa6a96962204917 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__init__.py b/fla2/ops/generalized_delta_rule/dplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6de2928ca88abc25dc3156c4dc4fcb13ace180d --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_dplr_delta_rule +from .fused_recurrent import fused_recurrent_dplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule' +] diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e0f12a0676ce31f9a63a74823fb7209e80b42d2 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633f62ca610abd62c54ab7ad4afa8a03a2d960e8 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a964d6060c842068377d28f64796ce279bc22ad Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a613ef9e2aee2fed3410a0b53f34d4cda8e1648f Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61dbb55d85c7d8003ae740bf9088dd089b1b598f Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f94e4616646ba3782db2604f631f8bb59fbfb1f Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..387ac8b41e3792e9dcb4872eff09e5b36ea54524 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fffb5640b8f18a69332fd423c25321701ba2f98 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..121aceb879432d9bd6a17e28494d5f9d2101392f Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17dd0a74d63f0e3f560118358c13067beec6db85 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc b/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b440a7eaeac666269ede87c9b4aa5ea141c2d1e7 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk.py b/fla2/ops/generalized_delta_rule/dplr/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6804fbbee115b57470e4c7ba0a1c6d12d272e5eb --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional + +import torch +import triton +from einops import rearrange + +from ....ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra +from ....ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_dplr_fwd_intra +from ....ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu +from ....ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h +from ....ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o +from ....ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o +from ....ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy +from ....ops.generalized_delta_rule.dplr.wy_fast_fwd import prepare_wy_repr_fwd +from ....ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum +from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_dplr_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + T = q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens) + + A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + ) + del ge + + # A_ab, A_ak, gi, ge torch.float32 + # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16 + w, u, _ = prepare_wy_repr_fwd( + ag=ag, + A_ab=A_ab, + A_ak=A_ak, + v=v, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del A_ab, A_ak + h, v_new, final_state = chunk_dplr_fwd_h( + kg=kg, + bg=bg, + v=v, + w=w, + u=u, + gk=gi, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del u, kg, bg, gi + + o = chunk_dplr_fwd_o( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del v_new, h, A_qk, A_qb + + return o, final_state + + +class ChunkDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + chunk_size = 16 + o, final_state = chunk_dplr_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + ctx.save_for_backward(q, k, v, a, b, gk, initial_state) + ctx.cu_seqlens = cu_seqlens + ctx.scale = scale + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, a, b, gk, initial_state = ctx.saved_tensors + BT = ctx.chunk_size + cu_seqlens = ctx.cu_seqlens + scale = ctx.scale + + # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted ******* + gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, cu_seqlens=cu_seqlens) + + A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT, + ) + w, u, A_ab_inv = prepare_wy_repr_fwd( + ag=ag, + A_ab=A_ab, + A_ak=A_ak, + v=v, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del A_ab + h, v_new, _ = chunk_dplr_fwd_h( + kg=kg, + bg=bg, + v=v, + w=w, + u=u, + gk=gi, + initial_state=initial_state, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del u + # ******* end of recomputation ******* + # A_ak, A_ab_inv, gi, ge torch.float32 + # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16 + + dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu( + v=v, + v_new=v_new, + do=do, + A_qb=A_qb, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + + dh, dh0, dv_new = chunk_dplr_bwd_dhu( + qg=qg, + bg=bg, + w=w, + gk=gi, + h0=initial_state, + dht=dht, + do=do, + dv=dv_new_intra, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + + dv = chunk_dplr_bwd_dv( + A_qk=A_qk, + kg=kg, + do=do, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del A_qk + + dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o( + k=kg, + b=bg, + v=v, + v_new=v_new, + do=do, + h=h, + dh=dh, + dv=dv_new, + w=w, + gk=gi, + cu_seqlens=cu_seqlens, + chunk_size=BT, + scale=scale, + ) + del v_new + + dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + v=v, + ag=ag, + dw=dw, + du=dv_new, + dv0=dv, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + del A_ak + + dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + dAqk=dA_qk, + dAqb=dA_qb, + dAak=dA_ak, + dAab=dA_ab, + dgk_last=dgk_last, + dqg=dqg, + dkg=dkg, + dag=dag, + dbg=dbg, + chunk_size=BT, + scale=scale, + cu_seqlens=cu_seqlens, + ) + + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None + + +@torch.compiler.disable +def chunk_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + a (torch.Tensor): + activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + b (torch.Tensor): + betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + gk (torch.Tensor): + gk of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. decay term in log space! + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, a, b, gk = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b, gk)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if q.dtype == torch.float32: + raise DeprecationWarning( + """ChunkDeltaRuleFunction does not support float32. Please use bfloat16. + If you want to use float32, please solve the issue by yourself.""" + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + cu_seqlens, + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_A_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..0e2fc6773053cb204df033bd9c19a51080f6fb69 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp, gather +from ....utils import check_shared_mem, is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_intra( + q, + k, + a, + b, + gi, + ge, + dAqk, + dAqb, + dAak, + dAab, + dq, + dk, + da, + db, + dqg, + dkg, + dag, + dbg, + dgk, + dgk_offset, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32) + + if i_t * BT >= T: + return + + # offset calculation + ge += (bos*H + i_h) * K + gi += (bos*H + i_h) * K + q += (bos*H + i_h) * K + a += (bos*H + i_h) * K + b += (bos*H + i_h) * K + k += (bos*H + i_h) * K + dq += (bos*H + i_h) * K + dk += (bos*H + i_h) * K + da += (bos*H + i_h) * K + db += (bos*H + i_h) * K + dqg += (bos*H + i_h) * K + dag += (bos*H + i_h) * K + dkg += (bos*H + i_h) * K + dbg += (bos*H + i_h) * K + dgk += (bos*H + i_h) * K + dgk_offset += (bos*H + i_h) * K + dAqk += (bos*H + i_h) * BT + dAqb += (bos*H + i_h) * BT + dAak += (bos*H + i_h) * BT + dAab += (bos*H + i_h) * BT + + stride_qk = H*K + stride_A = H*BT + + p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + b_da = tl.zeros([BC, BK], dtype=tl.float32) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + b_db = tl.zeros([BC, BK], dtype=tl.float32) + # intra chunk gradient calculation + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT, 0), (BC, BC), (1, 0)) + o_i = tl.arange(0, BC) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT, i_k*BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAab = tl.load(p_dAab, boundary_check=(0, 1)) + b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)) + b_dAak = tl.load(p_dAak, boundary_check=(0, 1)) + + # inter chunk gradient calculation + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + # intra chunk gradient calculation + for j in range(0, min(BC, T - i_t * BT)): + # trick to index the block + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + col_idx = tl.full([BC, 1], j, dtype=tl.int16) + row_idx_bc = tl.full([1, BC], j, dtype=tl.int16) + # [1, BK] + b_kj = gather(b_k, row_idx, axis=0) + b_bj = gather(b_b, row_idx, axis=0) + b_gij = gather(b_gi, row_idx, axis=0) + b_gej = gather(b_ge, row_idx, axis=0) + b_qj = gather(b_q, row_idx, axis=0) + b_aj = gather(b_a, row_idx, axis=0) + # [BC, 1] + b_dAqk_j = gather(b_dAqk, col_idx, axis=1) + b_dAab_j = gather(b_dAab, col_idx, axis=1) + b_dAqb_j = gather(b_dAqb, col_idx, axis=1) + b_dAak_j = gather(b_dAak, col_idx, axis=1) + # [1, BC] -> [BC, 1] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None] + b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None] + b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None] + else: + mask_idx = tl.arange(0, BC) == j + b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :] + b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :] + b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :] + b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :] + b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None] + b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None] + b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None] + b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None] + b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None] + b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None] + b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None] + b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None] + # [1, BK] b_qj, b_aj + b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :] + b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :] + + m_e = o_i[:, None] > j + m_i = o_i[:, None] >= j + tmp1 = exp(b_gi - b_gij) + tmp2 = exp(b_ge - b_gij) + b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.) + b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.) + b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.) + b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.) + + m_i = o_i[:, None] <= j + m_e = o_i[:, None] < j + tmp1 = exp(b_gij - b_gi) + tmp2 = exp(b_gej - b_gi) + b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.) + b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.) + b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.) + b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.) + + # post processing + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge) + b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale + tmp = exp(b_gn[None, :] - b_gi) + b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp + b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp + tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32) + b_dgk_offset = b_da * b_a + tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in [32, 64] + ], + key=['BK', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_dgk_kernel( + dgk, + dgk_offset, + dgk_last, + dgk_output, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = (i_b * NT + i_t).to(tl.int32) + bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32) + + stride_qk = H * K + dgk += (bos * H + i_h) * K + dgk_offset += (bos * H + i_h) * K + dgk_last += (i_tg * H + i_h) * K + dgk_output += (bos * H + i_h) * K + p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK + m_k = tl.arange(0, BK) + i_k * BK < K + b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) + b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1)) + # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32) + # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False) + b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True) + b_dgk_cumsum += b_dgk_last[None, :] + b_dgk_cumsum -= b_dgk_offset + p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dAqk: torch.Tensor, + dAqb: torch.Tensor, + dAak: torch.Tensor, + dAab: torch.Tensor, + dqg: torch.Tensor, + dkg: torch.Tensor, + dag: torch.Tensor, + dbg: torch.Tensor, + dgk_last: torch.Tensor, + scale: float = 1.0, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +): + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + da = torch.empty_like(a) + db = torch.empty_like(b) + dgk = torch.empty_like(gi, dtype=torch.float) + dgk_offset = torch.empty_like(gi, dtype=torch.float) + + grid = (NK, NT, B * H) + chunk_dplr_bwd_kernel_intra[grid]( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + dAqk=dAqk, + dAqb=dAqb, + dAak=dAak, + dAab=dAab, + dq=dq, + dk=dk, + dgk=dgk, + dgk_offset=dgk_offset, + dqg=dqg, + dkg=dkg, + dag=dag, + dbg=dbg, + da=da, + db=db, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BT, + BK=BK, + GATHER_SUPPORTED=is_gather_supported + ) + + dgk_output = torch.empty_like(dgk) + + def grid(meta): return (NT, triton.cdiv(K, meta['BK']), B * H) + chunk_dplr_bwd_dgk_kernel[grid]( + dgk=dgk, + dgk_offset=dgk_offset, + dgk_last=dgk_last, + dgk_output=dgk_output, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + ) + return dq, dk, da, db, dgk_output diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_A_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8695b8a018bbc4317d7d093c6e51b585ee69d94b --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp, gather +from ....utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_intra( + q, + k, + a, + b, + gi, + ge, + qg, + kg, + ag, + bg, + Aqk, + Aqb, + Aab, + Aak, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + tl.arange(0, BC)) < T + last_idx = min((i_t+1) * BT, T) - 1 + o_A = (bos + i_t * BT + tl.arange(0, BC)) * H*BT + i_h * BT + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = b_q * scale + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32) + b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32) + + # deal with decay term. + g_exp = exp(b_gi) + g_exp_inv = exp(-b_gi + b_g_last[None, :]) + b_qg = b_q * g_exp + b_kg = b_k * g_exp_inv + b_bg = b_b * g_exp_inv + b_ag = b_a * exp(b_ge) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # tl.debug_barrier() + + b_q = b_q.to(b_k.dtype) + # inner attn + for j in range(0, min(BC, T - i_t * BT)): + # a trick to index the j-th row of b_k, b_g, b_b + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + # [1, BK] + b_k_j = gather(b_k, row_idx, axis=0) + b_gk_j = gather(b_gi, row_idx, axis=0) + b_b_j = gather(b_b, row_idx, axis=0) + else: + mask = tl.arange(0, BC) == j + b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :] + b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :] + b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :] + tmp = exp(b_gi - b_gk_j) + b_A_qk = tl.sum(b_q * b_k_j * tmp, 1) + m_i = (o_i >= j).to(tl.float32) + b_A_qk = b_A_qk * m_i + b_A_qb = tl.sum(b_q * b_b_j * tmp, 1) + b_A_qb = b_A_qb * m_i + tmp2 = exp(b_ge - b_gk_j) + b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1) + m_i2 = (o_i > j).to(tl.float32) + b_A_ak = b_A_ak * m_i2 + b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1) + b_A_ab = b_A_ab * m_i2 + + tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + + +def chunk_dplr_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + scale: float, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype) + Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype) + # involving matrix inverse and it'd be better to use float here. + Aab = q.new_empty(B, T, H, BT, dtype=torch.float) + Aak = q.new_empty(B, T, H, BT, dtype=torch.float) + + grid = (NT, B, H) + BK = triton.next_power_of_2(K) + qg = torch.empty_like(q) + kg = torch.empty_like(k, dtype=q.dtype) + ag = torch.empty_like(a, dtype=q.dtype) + bg = torch.empty_like(b, dtype=q.dtype) + chunk_dplr_fwd_A_kernel_intra_sub_intra[grid]( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + Aqk=Aqk, + Aqb=Aqb, + Aab=Aab, + Aak=Aak, + qg=qg, + kg=kg, + ag=ag, + bg=bg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BT, + BK=BK, + GATHER_SUPPORTED=is_gather_supported + ) + return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..86e8cec2d47980d2ff26f7e904bbe39f0697fa07 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV', "V"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dhu( + qg, + bg, + w, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)) + + mask_k = tl.arange(0, BK) < K + for i_t in range(NT - 1, -1, -1): + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + # [BT, BK] + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype)) + tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype)) + b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype)) + last_idx = min((i_t + 1) * BT, T) - 1 + bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k) + b_dh *= exp(bg_last)[:, None] + b_dh += b_dh_tmp + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dhu( + qg: torch.Tensor, + bg: torch.Tensor, + w: torch.Tensor, + gk: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *qg.shape, do.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + # H100 + if check_shared_mem('hopper', qg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', qg.device.index): # A100 + BV = 32 + BC = 32 + else: # Etc: 4090 + BV = 16 + BC = 16 + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + BC = min(BT, BC) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = qg.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.zeros_like(dv) + + grid = (NK, NV, N * H) + chunk_dplr_bwd_kernel_dhu[grid]( + qg=qg, + bg=bg, + w=w, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return dh, dh0, dv2 diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..eee76b11c6ee3b71b20ff35fe4b6dfc1d6225380 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_h( + kg, + v, + w, + bg, + u, + v_new, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + o_k = i_k * BK + tl.arange(0, BK) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_kg, b_v) + b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(gk + (bos + last_idx) * H*K + i_h * K + o_k, mask=o_k < K).to(tl.float32) + b_h *= exp(b_g_last[:, None]) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_h( + kg: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + bg: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *kg.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', kg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', kg.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = kg.new_empty(B, NT, H, K, V) + final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + chunk_dplr_fwd_kernel_h[grid]( + kg=kg, + v=v, + w=w, + bg=bg, + u=u, + v_new=v_new, + h=h, + gk=gk, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return h, v_new, final_state diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..bf6cc0874dc229a1d8a1252c81801f555402b840 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import exp +from ....utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BV', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dAu( + v, + do, + v_new, + A_qb, + dA_qk, + dA_qb, + dv_new, + cu_seqlens, + chunk_indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32) + b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32) + + p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1)) + # causal mask + b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_dA_qk += tl.dot(b_do, b_v) + b_dA_qb += tl.dot(b_do, b_v_new) + b_dv_new = tl.dot(tl.trans(b_A_qb), b_do) + # for recurrent + tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1)) + + p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.) + tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1)) + b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.) + tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_o_kernel( + v, + v_new, + h, + do, + dh, + dk, + db, + w, + dq, + dv, + dw, + gk, + dgk_last, + k, + b, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += (bos * H + i_h) * V + v_new += (bos * H + i_h) * V + do += (bos * H + i_h) * V + h += (i_tg * H + i_h) * K * V + dh += (i_tg * H + i_h) * K * V + dk += (bos * H + i_h) * K + k += (bos * H + i_h) * K + db += (bos * H + i_h) * K + b += (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dq += (bos * H + i_h) * K + w += (bos * H + i_h) * K + + dgk_last += (i_tg * H + i_h) * K + gk += (bos * H + i_h) * K + + stride_qk = H*K + stride_vo = H*V + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_db = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk_last = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0) + + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + m_k = (i_k*BK+tl.arange(0, BK)) < K + last_idx = min(i_t * BT + BT, T) - 1 + b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf')) + b_dgk_last *= exp(b_gk_last) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_dgk_last += tl.sum(b_k * b_dk, axis=0) + b_dgk_last += tl.sum(b_b * b_db, axis=0) + tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k) + + p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in BK_LIST + for BV in BK_LIST + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_kernel_dv( + A_qk, + kg, + do, + dv, + dh, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + A_qk += (bos * H + i_h) * BT + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + kg += (bos * H + i_h) * K + dh += (i_tg * H + i_h) * K*V + + stride_qk = H*K + stride_vo = H*V + stride_A = H*BT + + for i_k in range(tl.cdiv(K, BK)): + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype)) + + p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dv( + A_qk: torch.Tensor, + kg: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *kg.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_bwd_kernel_dv[grid]( + A_qk=A_qk, + kg=kg, + do=do, + dv=dv, + dh=dh, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dv + + +def chunk_dplr_bwd_o( + k: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + gk: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + w: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + B, T, H, K, V = *w.shape, v.shape[-1] + + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(k) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + db = torch.empty_like(b) + grid = (NK, NT, B * H) + + dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device) + + chunk_dplr_bwd_o_kernel[grid]( + k=k, + b=b, + v=v, + v_new=v_new, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + db=db, + dgk_last=dgk_last, + w=w, + dv=dv, + dw=dw, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dq, dk, dw, db, dgk_last + + +def chunk_dplr_bwd_dAu( + v: torch.Tensor, + v_new: torch.Tensor, + do: torch.Tensor, + A_qb: torch.Tensor, + scale: float, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + if check_shared_mem('ampere'): # A100 + BV = min(triton.next_power_of_2(V), 128) + elif check_shared_mem('ada'): # 4090 + BV = min(triton.next_power_of_2(V), 64) + else: + BV = min(triton.next_power_of_2(V), 32) + + grid = (NT, B * H) + dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dv_new = torch.empty_like(v_new) + chunk_dplr_bwd_kernel_dAu[grid]( + v=v, + do=do, + v_new=v_new, + A_qb=A_qb, + dA_qk=dA_qk, + dA_qb=dA_qb, + dv_new=dv_new, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dv_new, dA_qk, dA_qb diff --git a/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..66f5e823be6c20bfe6683d489cabde3b3816be7e --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BK_LIST + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_o( + qg, + v, + v_new, + A_qk, + A_qb, + h, + o, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_qg, b_h) + + p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) + b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) + b_Aqk = tl.where(m_s, b_Aqk, 0) + b_Aqb = tl.where(m_s, b_Aqb, 0) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_o( + qg: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + A_qk: torch.Tensor, + A_qb: torch.Tensor, + h: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *qg.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_fwd_kernel_o[grid]( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + o=o, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..49400c1f7f0f6880ef98022e01dc156c00a6d0bf --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -0,0 +1,273 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils.op import exp +from ....utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [16, 32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_dplr_delta_rule_fwd_kernel( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + cu_seqlens, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = tl.sum(b_h * b_q[None, :], axis=1) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * H*K + p_k += (-1 if REVERSE else 1) * H*K + p_a += (-1 if REVERSE else 1) * H*K + p_b += (-1 if REVERSE else 1) * H*K + p_gk += (-1 if REVERSE else 1) * H*K + p_v += (-1 if REVERSE else 1) * H*V + p_o += (-1 if REVERSE else 1) * H*V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_dplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N * H) + fused_recurrent_dplr_delta_rule_fwd_kernel[grid]( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + cu_seqlens, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + REVERSE=reverse, + ) + return o, ht + + +class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + o, ht = fused_recurrent_dplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + ) + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. " + "This kernel is only for inference. " + "For training, please use `chunk_dplr_delta_rule`." + ) + + +def fused_recurrent_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, T, H, K]`. + gk (torch.Tensor): + gk of shape `[B, T, H, K]`. decay term in log space! + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: 1. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (Optional[torch.Tensor]): + Cumulative sequence lengths of shape `[N + 1]` used for variable-length training, + consistent with the FlashAttention API. + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + ) + return o, final_state diff --git a/fla2/ops/generalized_delta_rule/dplr/naive.py b/fla2/ops/generalized_delta_rule/dplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ac253673e5361a375286347253f7d4e6f7a2f3 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] + + +def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i].clone() + _beta = beta[:, :, i].clone() + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S.clone() * gk[:, :, i].exp()[..., None] + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v).to(q) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', + c=chunk_size).float(), [q, k, v, alpha, beta, gk]) + + gk_cumsum = gk.cumsum(-2) + + # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + + for i in range(chunk_size): + alpha_i = alpha[:, :, :, i, None] + q_i = q[:, :, :, i, None] + gk_i = gk_cumsum[:, :, :, i, None] + mask = (torch.arange(chunk_size) <= i).to(q.device) + attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone() + A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone() + mask = (torch.arange(chunk_size) < i).to(q.device) + # shift by one. + attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone() + A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone() + + A_ab = A_ab + for i in range(1, chunk_size): + A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2) + + A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = A_ab @ (A_ak @ v) + w = A_ab @ ((gk_cumsum-gk).exp() * alpha) + + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + v2_i = u_i + w_i @ S + + o_1 = A_qk[:, :, i] @ v_i + o_2 = A_qb[:, :, i] @ v2_i + o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S + o[:, :, i] = o_1 + o_2 + o_3 + decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp() + S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \ + (beta_i * decay).transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..6855e7bfdac154365e2faf3a91d204caf3c6f647 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, is_intel_alchemist, use_cuda_graph + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_bwd_kernel( + A_ab_inv, + A_ak, + ag, + v, + dw, + du, + dv, + dv0, + dag, + dAak, + dAab, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1)) + b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1)) + b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0) + b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0) + b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty) + b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v)) + b_dv0 = tl.load(p_dv0, boundary_check=(0, 1)) + b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :] + b_dA_tmp = tl.where(m_i, b_dA_tmp, 0) + b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp) + b_dA_ak = tl.where(m_i, b_dA_ak, 0) + tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1)) + b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t) + + for i_k in range(tl.cdiv(K, BK)): + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag)) + b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw) + tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1)) + + # if we know dL/dA^(-1), for dL/dA, we can use the following formula: + # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T + # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1. + # denote A = I - lower(A_ab), B = A^-1 + # in the backward pass. + # dL/dA = -(B)^T @ (dL/dB) @ B^T + # dL/dA_ab = lower(B^T @ dL/dB @ B^T) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv) + b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t) + b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0) + tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1)) + + +def chunk_dplr_bwd_wy( + A_ab_inv: torch.Tensor, + A_ak: torch.Tensor, + v: torch.Tensor, + ag: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dv0: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]) + B, T, H, K, V = *dw.shape, du.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32) + + dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float) + dA_ak = torch.empty_like(A_ak, dtype=torch.float) + dv = torch.empty_like(v) + dag = torch.empty_like(ag) + + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + ag=ag, + v=v, + dw=dw, + du=du, + dv=dv, + dv0=dv0, + dag=dag, + dAak=dA_ak, + dAab=dA_ab, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dA_ab, dA_ak, dv, dag diff --git a/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf14bd34ebde04d9e1a46784aa80dc6d72bd4fd --- /dev/null +++ b/fla2/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....ops.utils.op import gather +from ....utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk32( + A_ab, + A_ab_inv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, # placeholder, do not delete + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_ab = tl.load(p_Aab, boundary_check=(0, 1)) + b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i) + b_A_ab = tl.where(mask[:, None], b_a, b_A_ab) + b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk64( + A_ab, + A_ab_inv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + IS_VARLEN: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr = is_gather_supported +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + b_A = tl.load(p_A1, boundary_check=(0, 1)) + b_A2 = tl.load(p_A2, boundary_check=(0, 1)) + b_A3 = tl.load(p_A3, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + if GATHER_SUPPORTED: + row_idx = tl.full([1, BC], i, dtype=tl.int16) + # [1, BK] -> [BK] + b_a = tl.sum(gather(b_A, row_idx, axis=0), 0) + b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0) + else: + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + mask = tl.arange(0, BC) == i + # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A) + # tl.debug_barrier() + tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # causal mask + tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def wu_fwd_kernel( + w, + u, + ag, + v, + A_ab_inv, + A_ak, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_s = tl.arange(0, BT) + + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1)) + b_Aak = tl.load(p_A_ak, boundary_check=(0, 1)) + b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0) + b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0) + # let's use tf32 here + b_Aak = tl.dot(b_Aab_inv, b_Aak) + # (SY 01/04) should be bf16 or tf32? To verify. + b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne") + b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne") + + for i_k in range(tl.cdiv(K, BK)): + p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16 + tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16 + tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def wu_fwd( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab_inv: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *ag.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + w = torch.empty_like(ag) + u = torch.empty_like(v) + wu_fwd_kernel[(NT, B * H)]( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + w=w, + u=u, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +def prepare_wy_repr_fwd( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, _ = ag.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(BT, 32) + fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32 + A_ab_inv = torch.empty_like(A_ab) + fwd_fn[(NT, B * H)]( + A_ab=A_ab, + A_ab_inv=A_ab_inv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + BC=BC, + ) + w, u = wu_fwd( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + return w, u, A_ab_inv + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +fwd_wu = wu_fwd diff --git a/fla2/ops/generalized_delta_rule/iplr/__init__.py b/fla2/ops/generalized_delta_rule/iplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e44d2a773b31f43fce68c5a9d1e67a3b33f42411 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/iplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_iplr_delta_rule +from .fused_recurrent import fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla2/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc b/fla2/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b9276237e24c7eee7fec99dd6fde2e8cb0e3a02 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-310.pyc b/fla2/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598619ad17de0041ead7138ef304053334609c3c Binary files /dev/null and b/fla2/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc b/fla2/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64f377cf73cb7a9754ce01088b7121306d06cb09 Binary files /dev/null and b/fla2/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4dabe4705c124b9b728ac31f6148f035d290c6e Binary files /dev/null and b/fla2/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/generalized_delta_rule/iplr/chunk.py b/fla2/ops/generalized_delta_rule/iplr/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..806c246303265a9aebd2b57e5efbb30b4a7c508b --- /dev/null +++ b/fla2/ops/generalized_delta_rule/iplr/chunk.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ....ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd +from ....ops.utils import prepare_chunk_indices, prepare_chunk_offsets +from ....utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_h( + k, + v, + d, + b, + u, + v_new, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_k, b_v) + b_hc += tl.dot(b_b, b_v2.to(b_k.dtype)) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_o( + q, + k, + v, + u, + b, + h, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + b += (bos * H + i_h) * K + v += (bos * H + i_h) * V + u += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h) * K * V + stride_qk = H*K + stride_vo = H*V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_Aqk = tl.zeros([BT, BT], dtype=tl.float32) + b_Aqb = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqk += tl.dot(b_q, b_k) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqb += tl.dot(b_q, b_b) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_Aqk = tl.where(m_A, b_Aqk, 0) + b_Aqb = tl.where(m_A, b_Aqb, 0) + + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.load(p_u, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_generalized_iplr_delta_rule_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> torch.Tensor: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + NT, + B * H + ) + chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + u=v_new, + b=b, + h=h, + o=o, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return o + + +def chunk_generalized_iplr_delta_rule_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + b: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', k.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', k.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + + chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid]( + k=k, + v=v, + d=w, + b=b, + u=u, + v_new=v_new, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + ) + return h, v_new, final_state + + +def chunk_generalized_iplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64 +): + T = q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u, _ = prepare_wy_repr_fwd( + a=a, + b=b, + k=k, + v=v, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + + h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h( + k=k, + v=v, + b=b, + w=w, + u=u, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + o = chunk_generalized_iplr_delta_rule_fwd_o( + q=q, + k=k, + v=v, + v_new=v_new, + b=b, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=BT + ) + return o, final_state + + +class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + ): + chunk_size = 64 + + o, final_state = chunk_generalized_iplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + raise NotImplementedError( + "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. " + "Stay tuned!" + ) + + +@torch.compiler.disable +def chunk_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + a (torch.Tensor): + activations of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + b (torch.Tensor): + betas of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead." + ) + q, k, v, a, b = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, a, b)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please ...tten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens, + ) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state diff --git a/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8bbc526e3c8a53c4abb1dc44fafec3847f6a81 --- /dev/null +++ b/fla2/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BK"], +) +@triton.jit +def fused_recurrent_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + a, # a [B, H, L, K] + b, # b [B, H, L, K] + o, # output [B, H, L, V] + ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0) + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + cu_seqlens, # varlen cu_seqlens + scale, # K ** -0.5 + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + p_q = q + (bos * H + i_h) * K + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + tl.arange(0, BK) + p_a = a + (bos * H + i_h) * K + tl.arange(0, BK) + p_b = b + (bos * H + i_h) * K + tl.arange(0, BK) + p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = tl.arange(0, BK) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + # to store + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v) + p_q += K*H + p_k += K*H + p_o += V*H + p_v += V*H + p_ha += V*H + p_a += K*H + p_b += K*H + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_DHT': lambda args: args['dht'] is not None, + 'USE_DH0': lambda args: args['dh0'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=["BK", "BV"], +) +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: b_dhead + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + a, # a [B, H, L, K] + b, # b [B, H, L, K] + ha, # ha [B, H, L, V] + dht, # gradient of final state [B, H, K, V] + dh0, # gradient of initial state [B, H, K, V] + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + da, # gradient of a [NV, B, H, L, K] + db, # gradient of b [NV, B, H, L, K] + dha, # gradient of ha [NK, B, H, L, V] + h0, # initial state [B, H, K, V] + scale, # K ** -0.5 + cu_seqlens, # cu_seqlens + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0 + USE_DH0: tl.constexpr, # whether to use dh0 + USE_DHT: tl.constexpr, # whether to use dht + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + dk += i_v * B * H * K * T + db += i_v * B * H * K * T + dq += i_v * B * H * K * T + da += i_v * B * H * K * T + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + mask_k = tl.arange(0, BK) < K + mask_v = (tl.arange(0, BV) + i_v * BV) < V + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + i_v * BV + ha += (bos * H + i_h) * V + i_v * BV + a += (bos * H + i_h) * K + b += (bos * H + i_h) * K + do += (bos * H + i_h) * V + i_v * BV + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + i_v * BV + da += (bos * H + i_h) * K + db += (bos * H + i_h) * K + dha += (bos * H + i_h) * V + i_v * BV + + p_q = q + tl.arange(0, BK) + (T - 1) * H*K + p_k = k + tl.arange(0, BK) + (T - 1) * H*K + p_v = v + tl.arange(0, BV) + (T - 1) * H*V + p_ha = ha + tl.arange(0, BV) + (T - 1) * H*V + p_a = a + tl.arange(0, BK) + (T - 1) * H*K + p_b = b + tl.arange(0, BK) + (T - 1) * H*K + p_do = do + tl.arange(0, BV) + (T - 1) * H*V + p_dk = dk + tl.arange(0, BK) + (T - 1) * H*K + p_dv = dv + tl.arange(0, BV) + (T - 1) * H*V + p_dha = dha + tl.arange(0, BV) + (T - 1) * H*V + p_db = db + tl.arange(0, BK) + (T - 1) * H*K + p_da = da + tl.arange(0, BK) + (T - 1) * H*K + p_dq = dq + tl.arange(0, BK) + (T - 1) * H*K + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_DHT: + p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + + b_dh += b_q[:, None] * b_do[None, :] + d_k = tl.sum(b_dh * b_v[None, :], axis=1) + d_v = tl.sum(b_dh * b_k[:, None], axis=0) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v) + + b_dha = tl.sum(b_dh * b_b[:, None], axis=0) + tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v) + b_db = tl.sum(b_dh * b_ha[None, :], axis=1) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k) + + b_dh += b_dha[None, :] * b_a[:, None] + p_do -= H*V + p_q -= H*K + p_k -= H*K + p_v -= H*V + p_dk -= H*K + p_dv -= H*V + p_b -= H*K + p_db -= H*K + p_a -= H*K + p_dha -= H*V + p_ha -= H*V + + if USE_DH0: + p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) + + tl.debug_barrier() + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_k[:, None] & mask_v[None, :] + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + p_k = k + tl.arange(0, BK) + p_v = v + tl.arange(0, BV) + p_ha = ha + tl.arange(0, BV) + p_do = do + tl.arange(0, BV) + p_dha = dha + tl.arange(0, BV) + p_da = da + tl.arange(0, BK) + p_dq = dq + tl.arange(0, BK) + p_b = b + tl.arange(0, BK) + + for i in range(0, T): + b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32) + d_a = tl.sum(b_dha[None, :] * b_h, axis=1) + tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += H*K + p_do += H*V + p_v += H*V + p_da += H*K + p_dha += H*V + p_ha += H*V + p_dq += H*K + p_b += H*K + + +class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None + ): + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + + BK = triton.next_power_of_2(K) + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32) + else: + final_state = None + + ha = torch.empty_like(v, dtype=torch.float32) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + N * H + ) + o = torch.empty_like(v) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + o=o, + ha=ha, + h0=initial_state, + ht=final_state, + scale=scale, + cu_seqlens=cu_seqlens, + H=H, + T=T, + K=K, + V=V, + BK=BK, + ) + ctx.save_for_backward(q, k, v, a, b, ha, initial_state) + ctx.scale = scale + ctx.cu_seqlens = cu_seqlens + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, a, b, ha, initial_state = ctx.saved_tensors + B, T, H, K, V = *q.shape, v.shape[-1] + N = B if ctx.cu_seqlens is None else len(ctx.cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + scale = ctx.scale + + dq = q.new_empty(NV, *q.shape) + dk = k.new_empty(NV, *k.shape) + da = a.new_empty(NV, *a.shape) + db = b.new_empty(NV, *b.shape) + dv = torch.empty_like(v) + dha = torch.empty_like(ha) + grid = (NV, N * H) + + if initial_state is not None and initial_state.requires_grad: + dh0 = torch.empty_like(initial_state, dtype=torch.float32) + else: + dh0 = None + + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + ha=ha, + dht=dht, + dh0=dh0, + do=do, + dq=dq, + dk=dk, + dv=dv, + da=da, + db=db, + dha=dha, + h0=initial_state, + scale=scale, + cu_seqlens=ctx.cu_seqlens, + B=B, + H=H, + T=T, + K=K, + V=V, + BK=BK, + BV=BV, + ) + dq = dq.sum(0) + dk = dk.sum(0) + da = da.sum(0) + db = db.sum(0) + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None + + +def fused_recurrent_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, T, H, V]` + a (torch.Tensor): + as of shape `[B, T, H, K]` + b (torch.Tensor): + bs of shape `[B, T, H, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens + ) + return o, final_state diff --git a/fla2/ops/generalized_delta_rule/iplr/naive.py b/fla2/ops/generalized_delta_rule/iplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..9da977011e943f7432be09b144c115d7661911ac --- /dev/null +++ b/fla2/ops/generalized_delta_rule/iplr/naive.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] +def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i] + _beta = beta[:, :, i] + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta]) + + v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = attn @ v2 + w = attn @ alpha + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i + v2_i = u_i + w_i @ S + o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i) + o_3 = q_i @ S + o[:, :, i] = o_1 + o_2 + o_3 + S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla2/ops/generalized_delta_rule/iplr/wy_fast.py b/fla2/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..e895a8191b7ce6503db674c480ab7238b60ccc7b --- /dev/null +++ b/fla2/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,300 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ....ops.utils import prepare_chunk_indices +from ....utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk32( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def prepare_wy_repr_fwd_kernel_chunk64( + a, + b, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def wu_fwd_kernel( + w, + u, + a, + k, + v, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def prepare_wy_repr_fwd( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, T, H, BT, device=a.device, dtype=a.dtype) + fwd_fn = prepare_wy_repr_fwd_kernel_chunk64 if BT == 64 else prepare_wy_repr_fwd_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + ) + w, u = wu_fwd( + a=a, + v=v, + k=k, + A=A, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size + ) + return w, u, A + + +def wu_fwd( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + wu_fwd_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u + + +fwd_prepare_wy_repr = prepare_wy_repr_fwd + +fwd_wu = wu_fwd diff --git a/fla2/ops/gla/__init__.py b/fla2/ops/gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fdb9563ac716719cbe2cda45197d756f10f435 --- /dev/null +++ b/fla2/ops/gla/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gla +from .chunk_fuse import fused_chunk_gla +from .recurrent_fuse import fused_recurrent_gla + +__all__ = [ + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla' +] diff --git a/fla2/ops/gla/__pycache__/__init__.cpython-310.pyc b/fla2/ops/gla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb4846d7bbe6826e0a6037dfe84f72ed4601d644 Binary files /dev/null and b/fla2/ops/gla/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/gla/__pycache__/__init__.cpython-312.pyc b/fla2/ops/gla/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d12946e637625da7fa172d693518ad4c4fe6070 Binary files /dev/null and b/fla2/ops/gla/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/gla/__pycache__/__init__.cpython-38.pyc b/fla2/ops/gla/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbf38e5e1c0533f4364f3cf121754567cab154a4 Binary files /dev/null and b/fla2/ops/gla/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/gla/__pycache__/__init__.cpython-39.pyc b/fla2/ops/gla/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b0830d5fda5c1695d9ab07148084d35664f058a Binary files /dev/null and b/fla2/ops/gla/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk.cpython-310.pyc b/fla2/ops/gla/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d887cadafe188809cc918b8e7106a00df814d286 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk.cpython-312.pyc b/fla2/ops/gla/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d9cd56a0d277fe7b5af8c990100a754cfc0c0f8 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk.cpython-38.pyc b/fla2/ops/gla/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..293e7b83f2ce3fcd2a3b442f295f8016b849546c Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk.cpython-39.pyc b/fla2/ops/gla/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd06251369951e3c2996212560f959b6b2cca790 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_fuse.cpython-310.pyc b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e254ee941dcb2a04b272a832f3cd56d7af66312 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-310.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2a378f3c0d1ec9e0725a0ff655ad9082ab8ded Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c45662426a6de37f055cd7231da094a46745b2c Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63890fcdc9d35b0e26ef4bdcac70d6d09437e0b9 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_util.cpython-310.pyc b/fla2/ops/gla/__pycache__/chunk_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff6f04deca9f429ddc0ec1d75316f11c6f2cfb93 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_util.cpython-310.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_util.cpython-312.pyc b/fla2/ops/gla/__pycache__/chunk_util.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7daef75300e1001361636bd6ffe4fe10d14ab334 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_util.cpython-312.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_util.cpython-38.pyc b/fla2/ops/gla/__pycache__/chunk_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f622ec0dfee7568172b9627da1e130000e03bd Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_util.cpython-38.pyc differ diff --git a/fla2/ops/gla/__pycache__/chunk_util.cpython-39.pyc b/fla2/ops/gla/__pycache__/chunk_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2e1003a9780b72c2588f175546bc2499cd2d839 Binary files /dev/null and b/fla2/ops/gla/__pycache__/chunk_util.cpython-39.pyc differ diff --git a/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-310.pyc b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fac0531a4d6c5cc3c0ee297ce1da4921aaf60638 Binary files /dev/null and b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-310.pyc differ diff --git a/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa7d6314c0cfdc89b596c4ee39fe81e12894ed1 Binary files /dev/null and b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4b0b3961b0c301da46164c3d01e720a9161587e Binary files /dev/null and b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d628e582c3dbdbc6793f5e5d6e3272bc7908826 Binary files /dev/null and b/fla2/ops/gla/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/gla/chunk.py b/fla2/ops/gla/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2abbd69a73d06bf5e7295ce3e8c3cc258b1206 --- /dev/null +++ b/fla2/ops/gla/chunk.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023-2024, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.utils import chunk_global_reversed_cumsum, chunk_local_cumsum +from ...ops.common.chunk_h import chunk_fwd_h_fn, chunk_bwd_dh_fn +from ...utils import contiguous + + +@triton.jit +def chunk_gla_fwd_kernel_intra( + q, + k, + g, + A, + s_k_h, + s_k_t, + s_k_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :]) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[:, None] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_qg, b_kg, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + p_gk = tl.advance(p_gk, (K,)) + + +@triton.jit +def chunk_gla_fwd_kernel_inter( + q, + v, + g, + h, + o, + A, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_bwd_kernel_intra( + q, + k, + g, + dA, + dq, + dk, + dg, + s_k_h, + s_k_t, + s_k_d, + T: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * tl.exp(b_gn[None, :] - b_gk)).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg, allow_tf32=False) + b_dq *= tl.exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + p_gkj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + b_gkj = tl.load(p_gkj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * tl.exp(b_g - b_gkj[None, :]), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_dq = b_dq + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T*K,), (s_k_d,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_gn = tl.load(p_gn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = (b_q * tl.exp(b_g - b_gn[None, :])).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qg, allow_tf32=False) + b_dk *= tl.exp(b_gn[None, :] - b_gk) + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_gqj = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_gqj = tl.load(p_gqj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * tl.exp(b_gqj[None, :] - b_gk), 0.) + + p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_dk = b_dk + tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_gla_bwd_kernel_inter( + k, + v, + h, + g, + A, + do, + dh, + dq, + dk, + dv, + dA, + s_k_h, + s_k_t, + s_k_d, + s_v_h, + s_v_t, + s_v_d, + s_h_h, + s_h_t, + s_h_d, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = tl.exp(tl.load(p_gn, boundary_check=(0,))[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * V * K, (V, K), (s_h_d, s_h_t), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K*V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + b_dq = b_dq * tl.exp(b_gk) + b_dk = b_dk * b_gn + + p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, s_k_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + +class ChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state, checkpoint_level): + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + g_cumsum = chunk_local_cumsum(g, BT=BT) + g_org, g = g, g_cumsum + + h, ht = chunk_fwd_h_fn( + k=k, v=v, g=None, gk=g, gv=None, BT=BT, h0=initial_state, output_final_state=output_final_state + ) + A = q.new_zeros(NK, B, H, T, BT) + grid = (NK, NT * NC * NC, B * H) + chunk_gla_fwd_kernel_intra[grid]( + q, k, g, A, + k.stride(1), k.stride(2), k.stride(3), + scale, + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + A = A.sum(0, dtype=A.dtype) + o = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_gla_fwd_kernel_inter[grid]( + q, v, g, h, o, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + if checkpoint_level >= 1: + del g + g = g_org + if checkpoint_level > 1: + del h + h = None + + ctx.save_for_backward(q, k, v, g, h, initial_state, A) + ctx.BT = BT + ctx.scale = scale + ctx.checkpoint_level = checkpoint_level + return o, ht + + @staticmethod + @contiguous + def backward(ctx, do, dht): + q, k, v, g, h, initial_state, A = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + if ctx.checkpoint_level >= 1: + g_cumsum = chunk_local_cumsum(g, BT=BT) + g_org, g = g, g_cumsum + + if h is None: + h, _ = chunk_fwd_h_fn( + k=k, v=v, g=None, gk=g, gv=None, BT=BT, h0=initial_state, output_final_state=False + ) + + scale = ctx.scale + dh, dh0 = chunk_bwd_dh_fn(q=q, k=k, v=v, g=None, gk=g, gv=None, do=do, h0=initial_state, dht=dht, BT=BT, scale=scale) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dg = torch.empty_like(k, dtype=torch.float) + dv = v.new_empty(NK, *v.shape) + dA = q.new_zeros(B, H, T, BT) + grid = (NK, NT, B * H) + chunk_gla_bwd_kernel_inter[grid]( + k, v, h, g, A, do, dh, dq, dk, dv, dA, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), h.stride(3), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0, dtype=v.dtype) + grid = (NK, NT * NC, B * H) + chunk_gla_bwd_kernel_intra[grid]( + q, k, g, dA, dq, dk, dg, + k.stride(1), k.stride(2), k.stride(3), + T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + dg = chunk_global_reversed_cumsum(dg).to(k.dtype) + return dq, dk, dv, dg, None, dh0, None, None + + +def chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + checkpoint_level: Optional[int] = 2 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + g (torch.Tensor): + Forget gates of shape `(B, H, T, K)` applied to keys. + scale (Optional[int]): + Scale factor for the GLA attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `0`: + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the fp32 cumulative values during backward. + - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward. + """ + assert checkpoint_level in [0, 1, 2] + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, checkpoint_level) + return o, final_state diff --git a/fla2/ops/gla/chunk_fuse.py b/fla2/ops/gla/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..397c131390e0f66d3fc8340edfb27ba51c3c158a --- /dev/null +++ b/fla2/ops/gla/chunk_fuse.py @@ -0,0 +1,575 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang +# Gated Linear Attention Transformers with Hardware-Efficient Training: https://arxiv.org/abs/2312.06635 +# on-the-fly computation without materializing hidden statets into HBMs + +from typing import Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from packaging import version + +from .chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum, + prepare_qg_kg) +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def fused_chunk_gla_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, K] + v, # value [B, H, L, V] + g, # cumulative sum of log decay [B, H, L, K] + o, # output [B, H, L, V] + + h0, # initial state of the chunk [B, H, K, V] + ht, # final state of the chunk [B, H, K, V] + + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + + B: tl.constexpr, # batch size + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) + if CHECK and i == 0: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * tl.exp(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + else: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * tl.exp(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_db += BT * K + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_gla_bwd_kernel( + q, k, v, g, + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + + h0, # initial state of the chunk [B, H, K, V] + + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) + + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [V, K] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * tl.exp(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * tl.exp(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + + # cum = tl.zeros([BK], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, V), + (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + # [K, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, V] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32) + + # inter-chunk + # [K, V] + if CHECK and i == 1: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * tl.exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + else: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * tl.exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fwd_inner_chunk( + q, k, g, A, + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + B, # B + H, # H + T, # T + scale, # K ** -0.5 + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + K: tl.constexpr, # K +): + + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) * scale + gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + s = _q[None, :] * b_k * tl.exp(gq[None, :] - b_g) + score = tl.sum(s, axis=1) + score = tl.where(o_i <= i, score, 0) + tl.store(p_A, score.to(p_A.dtype.element_ty)) + p_q += K + p_gq += K + p_A += BT + + +@triton.jit +def bwd_inner_chunk( + q, + k, + g, + dA, + dq, + dk, + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + T: tl.constexpr, # T + K: tl.constexpr, # K + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) + gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + score = tl.exp(gq[None, :] - b_g) + score = tl.where(o_i[:, None] <= i, score, 0) + _dA = tl.load(p_dA) + _dA = tl.where(o_i <= i, _dA, 0) + b_dk += (_dA[:, None] * score * _q[None, :]) + b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0) + tl.store(p_dq, b_dq, mask=mask) + p_q += K + p_dq += K + p_gq += K + p_dA += BT + + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): + ctx.g_dtype = g.dtype + g_original = g + # cumulative decay should be in float32, otherwise the err will be accumulated and amplified. + g = torch.empty_like(g, dtype=torch.float32) + B, H, T, K, V = *k.shape, v.shape[-1] + ctx.scale = scale + + # inter-chunk + BT = 16 # chunk_size + BK, BV = min(K, 64), min(V, 64) + num_stages = 1 + num_warps = 2 + + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + o = q.new_empty(NK, B, H, T, V) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(T, BT), B * H) + + + + fwd_decay_cumsum[grid]( + g_original, + g, + #q.stride(1), + T*K, + K=K, + BT=BT, BK=BK, num_warps=1 + ) + # print(g) + # print('gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg') + prepare_qg_kg[grid]( + q, k, g, q_g, k_g, + #q.stride(1), + T*K, + scale, + K=K, BT=BT, BK=BK, num_warps=1 + ) + + # data = { + # 'q': q, + # 'k': k, + # 'g': g, + # 'q_g': q_g, + # 'k_g': k_g, + # } + + # 保存到文件 + # save_path = '/raid/ligq/msj/lra_test/lra_new_test/tensors.pth' + # torch.save(data, save_path) + # print(f"Tensors saved to {save_path}") + + # print(q_g) + # print('qgqgqgqgqgqgqggqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgqgq') + # print(g.min()) + # print('minminminminminminminminminminminminminminminminminminminmin') + # print(k_g) + # print('kgkgkgkgkgkgkgkgkkkgkgkgkgkgkgkgkgkgkkgkgkgkgkgkgkgkgkkgkgkgkgkgkgkgk') + + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_gla_fwd_kernel[grid]( + q_g, k_g, v, g, o, initial_state, final_state, + T*K,K,1, + T*V,V,1, + # q.stride(1), q.stride(2), q.stride(3), + # v.stride(1), v.stride(2), v.stride(3), + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0)#沿着nk维度求和 + # print(o) + # print('oooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo') + #intra-chunk + chunk_size = 16 + num_chunk = T // chunk_size + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + BK = min(K, 64) + NK = triton.cdiv(K, BK) + A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT) + grid = (NK, triton.cdiv(T, BT), B * H) + fwd_inner_chunk[grid]( + q, k, g, A, + T*K,K,1, + #q.stride(1), q.stride(2), q.stride(3), + B, H, T, scale, BT=BT, BK=BK, K=K, num_stages=3, + num_warps=4 + ) + A = A.sum(0) + o2 = A @ v2 + # print(o2) + # print('ooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo') + o2 = rearrange(o2, 'b h n c d -> b h (n c) d') + # combine inner and inter + o.add_(o2) + ctx.save_for_backward(q, k, v, g_original, A, initial_state) + ctx.CHECK = CHECK + return o.to(v), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, g_origin, A, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + # recomputation + # inter-chunk + BT = 16 # chunk_size + g = torch.empty_like(g_origin, dtype=torch.float32)#仍旧相当于全部参与了运算 + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(T, BT), B * H) + fwd_decay_cumsum[grid]( + g_origin, + g, + #q.stride(1), + T*K, + K=K, + BT=BT, BK=BK, num_warps=1 + ) + prepare_qg_kg[grid]( + q, k, g, q_g, k_g, + #q.stride(1), + T*K, + scale, + K=K, BT=BT, BK=BK, num_warps=1 + ) + + #这部分读取是否导致出错,还是有很大的计算结果在 + # inter-chunk + BT = 16 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 2 + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + + grid = (NV, NK, B * H) + + fused_chunk_gla_bwd_kernel[grid]( + q_g, k_g, v, g, do, dq, dk, dv, initial_state, + T*K,K,1, + T*V,V,1, + # q.stride(1), q.stride(2), q.stride(3), + # v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + + # intra chunk + num_chunk = T // BT + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk) + dA2 = (do2 @ v2.transpose(-2, -1)) * scale + dv2 = A.transpose(-1, -2) @ do2 + dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk) + + BK = min(triton.next_power_of_2(K), 16) + NK = triton.cdiv(K, BK) + dk2 = torch.empty_like(k) + dq2 = torch.empty_like(q) + + grid = (NK, triton.cdiv(T, BT), B * H) + bwd_inner_chunk[grid]( + q, k, g, + dA2, dq2, dk2, + T*K,K,1, + # q.stride(1), q.stride(2), q.stride(3), + T=T, K=K, BT=BT, BK=BK, + num_warps=1, + num_stages=3 + ) + + BK = min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dg = torch.empty_like(g, dtype=torch.float32) + grid = (NK, triton.cdiv(T, BT), B * H) + bwd_decay_global_cumsum[grid]( + dq2, dq, dk2, dk, q, k, g, dg, + T*K,K,1, + #q.stride(1), q.stride(2), q.stride(3), + B, H, T, scale, + BT=BT, K=K, BK=BK, + num_warps=1, + num_stages=1 + ) + dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT) + + def rev_cumsum_exclusive(x): + cumsum_x = x.cumsum(-2) + rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x + return rev_cumsum_x + + rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :]) + dg.add_(rev_cumsum_dg.unsqueeze(-2)) + dv.add_(dv2) + dg = rearrange(dg, 'b h n c d -> b h (n c) d') + + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None + + +def pad(x, chunk_size=16): + T = x.shape[-2] + padded_seq_len = ceildiv(T, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - T)) + return x + + +def ceildiv(a, b): + return -(a // -b) + +#默认head_first +def fused_chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = q.shape[-2] + q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) + o, final_state = FusedChunkGLAFunction.apply( + q, k, v, g, scale, initial_state, output_final_state) + o = o[..., :seq_len, :] + return o, final_state diff --git a/fla2/ops/gla/chunk_util.py b/fla2/ops/gla/chunk_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8dbc2835497e1b57b3e327fcfffcd797530f9b55 --- /dev/null +++ b/fla2/ops/gla/chunk_util.py @@ -0,0 +1,125 @@ +import triton +import triton.language as tl + + +@triton.jit +def fwd_decay_cumsum( + g, + g_o, + s_qk_h, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_g = g + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_go = g_o + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + cum_decay = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < K + + for i in range(BT): + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + cum_decay += _g + tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask) + p_g += K + p_go += K + + +@triton.jit +def prepare_qg_kg( + q, + k, + g, + qg, + kg, + s_qk_h, + scale, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_g = g + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_qg = qg + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_kg = kg + i_bh * s_qk_h + i_c * BT * K + i_k * BK + tl.arange(0, BK) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK)) + + + for i in range(BT): + _q = tl.load(p_q, mask=mask, other=0) + _k = tl.load(p_k, mask=mask, other=0) + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + _q *= tl.exp(_g) * scale + _k *= tl.exp(last_decay - _g) + tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask) + tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask) + p_q += K + p_g += K + p_k += K + p_kg += K + p_qg += K + + +@triton.jit +def bwd_decay_global_cumsum( + dq_inner, + dq_inter, + dk_inner, + dk_inter, + q, k, g, dg, + s_qk_h, + s_qk_t, + s_qk_d, + B, + H, + T, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + K: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + cum_grad_dg = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < K + last_g = tl.zeros([BK], dtype=tl.float32) + for j in range(BT-1, -1, -1): + _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + if j == (BT-1): + last_g = _g + _dq1 = tl.load(p_dq_inner, mask=mask, other=0) + _dq2 = tl.load(p_dq_inter, mask=mask, other=0) + _dq2 *= tl.exp(_g) + _dq = _dq1 + _dq2 + tl.store(p_dq_inter, _dq, mask=mask) + _dk1 = tl.load(p_dk_inner, mask=mask, other=0) + _dk2 = tl.load(p_dk_inter, mask=mask, other=0) + _dk2 *= tl.exp(last_g - _g) + _dk = _dk1 + _dk2 + tl.store(p_dk_inter, _dk, mask=mask) + _q = tl.load(p_q, mask=mask, other=0) + _k = tl.load(p_k, mask=mask, other=0) + _dg = _dq * _q - _dk * _k + cum_grad_dg += _dg + tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) + p_g -= K + p_k -= K + p_q -= K + p_dq_inner -= K + p_dk_inner -= K + p_dq_inter -= K + p_dk_inter -= K + p_dg -= K diff --git a/fla2/ops/gla/naive.py b/fla2/ops/gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..8b203c433be63ca83c93ea33d2f3f5c9496df283 --- /dev/null +++ b/fla2/ops/gla/naive.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn.functional as F + +from ...ops.gla.recurrent_fuse import fused_recurrent_gla + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q, + k, + v, + gk, + initial_state=None, + output_final_state=False, + causal=True +): + orig_dtype = q.dtype + q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + batch_size, n_heads, seq_len, d_head_k = q.shape + _, _, _, d_head_v = v.shape + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + scale = d_head_k ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(seq_len): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o[:, :, i] = o_i + + if causal: + return o.to(orig_dtype), h + else: + o_reverse = torch.zeros_like(v) + h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) + for i in range(seq_len-1, -1, -1): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o_i = (q_i[..., None] * h).sum(-2) + o_reverse[:, :, i] = o_i + + return o, o_reverse + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 512 + D = 128 + dtype = torch.float32 + q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True) + g = F.logsigmoid(torch.rand(B, H, L, D)).cuda( + ).clamp_min(-1).to(torch.float32).requires_grad_(True) + + do = torch.rand_like(v).cuda() + do2 = torch.rand_like(v).cuda() + intial_state = torch.rand(B, H, D, D).cuda() + + ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False) + + ref.backward(do, retain_graph=True) + ref_rev.backward(do2, retain_graph=True) + + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + ref_dg, g.grad = g.grad.clone(), None + + tri, tri_rev = fused_recurrent_gla( + q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False) + tri.backward(do, retain_graph=True) + tri_rev.backward(do2, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + tri_dg, g.grad = g.grad.clone(), None + + assert ref.allclose(tri, 0, 1e-5), breakpoint() + assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint() + assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + + # tri = fused_chunk_gla(q, k, v, g) + # tri.backward(do, retain_graph=True) + # tri_dq, q.grad = q.grad.clone(), None + # tri_dk, k.grad = k.grad.clone(), None + # tri_dv, v.grad = v.grad.clone(), None + # tri_dg, g.grad = g.grad.clone(), None + + # assert ref.allclose(tri, 0, 1e-5), breakpoint() + # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() + # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() + # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() + # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() + # breakpoint() + print("Pass") diff --git a/fla2/ops/gla/recurrent_fuse.py b/fla2/ops/gla/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3553b8c18322ff1be30b78d29e6fd12bc6e115 --- /dev/null +++ b/fla2/ops/gla/recurrent_fuse.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from ...ops.common.fused_recurrent import fused_recurrent + +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor = None, + gv: torch.Tensor = None, + scale: int = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + reverse: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = fused_recurrent(q, k, v, None, gk, gv, scale, initial_state, output_final_state, reverse) + return o, final_state \ No newline at end of file diff --git a/fla2/ops/hgrn/__init__.py b/fla2/ops/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96f24b1d286315351d41d4df104d1d9ba65c2d16 --- /dev/null +++ b/fla2/ops/hgrn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_hgrn +from .recurrent_fuse import fused_recurrent_hgrn + +__all__ = [ + 'chunk_hgrn', + 'fused_recurrent_hgrn' +] diff --git a/fla2/ops/hgrn/__pycache__/__init__.cpython-312.pyc b/fla2/ops/hgrn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce571b11894da1540920f5512067095b80d0b3d2 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/__init__.cpython-38.pyc b/fla2/ops/hgrn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4022382bac94f17b24c29d2aefb13171cb7ab063 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/__init__.cpython-39.pyc b/fla2/ops/hgrn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f0a5677262128486175f657c50f705c4f071018 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/chunk.cpython-312.pyc b/fla2/ops/hgrn/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d813ab1d9869a492b2702e73804092c45e382712 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/chunk.cpython-38.pyc b/fla2/ops/hgrn/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..328df06b131652fd6344b63159b127cee7b70532 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/chunk.cpython-39.pyc b/fla2/ops/hgrn/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cddd9b4dfbf19f054c6350e9dac1e41791bcd4d Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1ce299636b6f33c86db552d23d08fbe8241073 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e5edacff37172b578c68b7140a84655912dc71 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04460cb1c6bab56c967a60d4cf3e220d2d7f3a49 Binary files /dev/null and b/fla2/ops/hgrn/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/hgrn/chunk.py b/fla2/ops/hgrn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..01ab344d4ec10e8fd2ea41d97e27dd90732f6ca7 --- /dev/null +++ b/fla2/ops/hgrn/chunk.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2024, Yu Zhang, Songlin Yang + +# this function implements the chunkwise form of HGRN, inspired by +# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html) +# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan + +# from tests on H800, with B, H, D = 16, 4, 128, we see that the chunk can be greatly faster than the recurrent: +# +# Performance: +# seq_len chunk recurrent chunk_bwd recurrent_bwd +# 0 128.0 0.039360 0.061056 0.312160 0.205008 +# 1 256.0 0.045824 0.123712 0.308784 0.297696 +# 2 512.0 0.058688 0.241952 0.310720 0.626528 +# 3 1024.0 0.088288 0.476992 0.313184 1.333152 +# 4 2048.0 0.169472 0.943264 0.452464 2.724864 +# 5 4096.0 0.329920 1.886144 0.881600 5.551520 +# 6 8192.0 0.647872 3.755040 1.740496 11.117184 +# 7 16384.0 1.272064 7.520576 3.446608 22.362528 + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def chunk_hgrn_fwd_kernel_h( + x, + g, + gc, + o, + h0, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_bh * T * D + i_t * BT * D + o_d + p_g = g + i_bh * T * D + i_t * BT * D + o_d + p_gc = gc + i_bh * T * D + i_t * BT * D + o_d + p_o = o + i_bh * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + b_h = tl.exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit +def chunk_hgrn_fwd_kernel_o( + gc, + o, + s_h, + s_t, + s_d, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + b_h0 = tl.load(o + i_bh * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # [BT, BD] + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_o = b_o + tl.exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def chunk_hgrn_bwd_kernel_h( + g, + gc, + dx, + do, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_bh * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_bh * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_bh * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * tl.exp(b_g) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit +def chunk_hgrn_bwd_kernel_o( + g, + gc, + o, + dx, + dg, + s_h, + s_t, + s_d, + T: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * s_h, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + mask_t = mask & ((i_t + 1) * BT < T) + b_ht = tl.load(dx + i_bh * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + + b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * tl.exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkHGRNFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, H, T, D = x.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + o = torch.empty_like(x, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H) + chunk_hgrn_fwd_kernel_h[grid]( + x, g, gc, o, initial_state, + T=T, D=D, BT=BT, + USE_INITIAL_STATE=initial_state is not None + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + chunk_hgrn_fwd_kernel_o[grid]( + gc, o, + o.stride(1), o.stride(2), o.stride(3), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + final_state = None + if output_final_state: + final_state = o[:, :, -1].clone() + o = o.to(x.dtype) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, H, T, D = do.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + dx = torch.empty_like(o, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B * H) + chunk_hgrn_bwd_kernel_h[grid]( + g, gc, dx, do, + T=T, D=D, BT=BT + ) + + dg = torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + chunk_hgrn_bwd_kernel_o[grid]( + g, gc, o, dx, dg, + o.stride(1), o.stride(2), o.stride(3), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + if initial_state is not None: + dg[:, :, 0] = (initial_state * dx[:, :, 0] * g[:, :, 0].float().exp()).to(dg.dtype) + + return dx.to(o.dtype), dg, None, None + + +def chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) diff --git a/fla2/ops/hgrn/naive.py b/fla2/ops/hgrn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..04385bed23337c682cf04e8a3073889789892919 --- /dev/null +++ b/fla2/ops/hgrn/naive.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, H, T, D = x.shape + + h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(T): + h = g[:, :, i].exp() * h + x[:, :, i] + o[:, :, i] = h + + if output_final_state: + final_state = h + return o.to(dtype), final_state + + +def naive_chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False, + chunk_size: int = 64 +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, H, T, D = x.shape + + gc = g.view(B, H, -1, chunk_size, D).cumsum(-2).view_as(g) + h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(0, T, chunk_size): + hp = h + h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) + for j in range(i, i + chunk_size): + h = g[:, :, j].exp() * h + x[:, :, j] + o[:, :, j] = hp * gc[:, :, j].exp() + h + h = o[:, :, j].clone() + + if output_final_state: + final_state = h + return o.to(dtype), final_state diff --git a/fla2/ops/hgrn/recurrent_fuse.py b/fla2/ops/hgrn/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..0ddab8f3cff752328819fdbedbdf930dd7f41c3c --- /dev/null +++ b/fla2/ops/hgrn/recurrent_fuse.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- + +# Copyright (c) 2023, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def fused_recurrent_hgrn_fwd_kernel( + x, + g, + o, + h0, + ht, + T: tl.constexpr, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_bh * T * D + o_d + p_g = g + i_bh * T * D + o_d + p_o = o + i_bh * T * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * D + o_d + b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) + for _ in range(0, T): + b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_h = tl.exp(b_g) * b_h + b_x + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) + + p_x += D + p_g += D + p_o += D + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * D + o_d + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit +def fused_recurrent_hgrn_bwd_kernel( + g, + o, + dx, + dg, + do, + h0, + T: tl.constexpr, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_bh = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_g = g + (i_bh * T + T - 1) * D + o_d + p_o = o + (i_bh * T + T - 2) * D + o_d + p_dx = dx + (i_bh * T + T - 1) * D + o_d + p_dg = dg + (i_bh * T + T - 1) * D + o_d + p_do = do + (i_bh * T + T - 1) * D + o_d + + b_dh = tl.zeros([BD], dtype=tl.float32) + for i in range(T - 1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + if i > 0: + b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) + elif USE_INITIAL_STATE: + b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) + else: + b_o = tl.zeros([BD], dtype=tl.float32) + + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * tl.exp(b_g) + b_dg = b_dh * b_o + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) + + p_g -= D + p_o -= D + p_dx -= D + p_dg -= D + p_do -= D + + +class FusedRecurrentHGRNFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, H, T, D = x.shape + + final_state = None + if output_final_state: + final_state = x.new_empty(B, H, D) + + o = torch.empty_like(x) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + fused_recurrent_hgrn_fwd_kernel[grid]( + x, g, o, initial_state, final_state, + T, D, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, H, T, D = do.shape + + dx = torch.empty_like(o, dtype=torch.float) + dg = torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) + fused_recurrent_hgrn_bwd_kernel[grid]( + g, o, dx, dg, do, initial_state, + T, D, + USE_INITIAL_STATE=initial_state is not None, + ) + + return dx, dg, None, None + + +def fused_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + return FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state) diff --git a/fla2/ops/linear_attn/__init__.py b/fla2/ops/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dbeab9acd39a05fd4a234ffaff87f19ddcff7cdf --- /dev/null +++ b/fla2/ops/linear_attn/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_linear_attn +from .chunk_fuse import fused_chunk_linear_attn +from .recurrent_fuse import fused_recurrent_linear_attn + +__all__ = [ + 'chunk_linear_attn', + 'fused_chunk_linear_attn', + 'fused_recurrent_linear_attn' +] diff --git a/fla2/ops/linear_attn/__pycache__/__init__.cpython-312.pyc b/fla2/ops/linear_attn/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3405d2f0acb8b10f6006dc47895d05ad62c9d3 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/__init__.cpython-38.pyc b/fla2/ops/linear_attn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d0c173db18eea68e8695556f78dc62b659ef18e Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/__init__.cpython-39.pyc b/fla2/ops/linear_attn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92a305146e4b9a5a231960d6d9c538e7c170b6ca Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk.cpython-312.pyc b/fla2/ops/linear_attn/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3390ccea1784e8d3a127aedc72194655f00768d Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk.cpython-38.pyc b/fla2/ops/linear_attn/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49ef115a668084b194cef291d3a94a072bf84c5 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk.cpython-39.pyc b/fla2/ops/linear_attn/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b97ea26672f31e202d88dd3d00877c6549c52e56 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3863168af81c61825da55726d556ff2ce991dc3d Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f82d9ef07fdaf0a3ab775570d3859575afa5cfd Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94328305bb0f5e459fb587cadf26a372233d17b Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cea3e9d9af72a26583d3d7d60f2ca76ec04b8a3 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-38.pyc b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdcb14578759b6a18e7a06d08dbf6290f1ed5907 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-38.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-39.pyc b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28f2cdc2f85d91e9315953ba46acfffa20bf9d01 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/recurrent_fuse.cpython-39.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/utils.cpython-312.pyc b/fla2/ops/linear_attn/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..decec81d57c265334cbdf57b4ecdce61951ba82c Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/utils.cpython-312.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/utils.cpython-38.pyc b/fla2/ops/linear_attn/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8991826fcc127811ba5c0a945eb4ee148fef9150 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/utils.cpython-38.pyc differ diff --git a/fla2/ops/linear_attn/__pycache__/utils.cpython-39.pyc b/fla2/ops/linear_attn/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18bed4b49bea73f415f5c193f6a45110957e5878 Binary files /dev/null and b/fla2/ops/linear_attn/__pycache__/utils.cpython-39.pyc differ diff --git a/fla2/ops/linear_attn/chunk.py b/fla2/ops/linear_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..be3727c8828cf01987608937ef2febbbd5e48e69 --- /dev/null +++ b/fla2/ops/linear_attn/chunk.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def chunk_linear_attn_fwd_kernel_h( + k, + v, + h, + h0, + ht, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def chunk_linear_attn_bwd_kernel_dh( + q, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + + +@triton.jit +def chunk_linear_attn_bwd_kernel_dqkv( + q, + k, + v, + h, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + n_bh = tl.num_programs(2) + o_i = tl.arange(0, BT) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + ctx.scale = scale + + final_state = None + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + + h = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_linear_attn_fwd_kernel_h[grid]( + k, v, h, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NV, NT, B * H) + o = torch.empty_like(v) + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v, h, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, h) + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, h = ctx.saved_tensors + + B, H, T, K, V = *q.shape, v.shape[-1] + BT = 64 + BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + scale = ctx.scale + + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_linear_attn_bwd_kernel_dh[grid]( + q, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + num_stages = 1 + num_warps = 4 if BK == 64 else 2 + chunk_linear_attn_bwd_kernel_dqkv[grid]( + q, k, v, h, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + scale (Optional[int]): + Scale factor for the linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/fla2/ops/linear_attn/chunk_fuse.py b/fla2/ops/linear_attn/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..040385e90ec802911d4810c7d5043fea906f903b --- /dev/null +++ b/fla2/ops/linear_attn/chunk_fuse.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def fused_chunk_linear_attn_fwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + o, # output [B, H, T, V] + h0, + ht, + s_qk_h, # stride size: T * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: T * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fused_chunk_linear_attn_bwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + do, # gradient of output [B, H, T, V] + dq, # gradient of query [NV, B, H, T, K] + dk, # gradient of key [NV, B, H, T, K] + dv, # gradient of value [NK, B, H, T, V] + + h0, # initial state of the chunk [B, H, K, V] + + s_qk_h, # stride size: T * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: T * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [BV, BK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + o = o.sum(0) if NK > 1 else o[0] + + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/fla2/ops/linear_attn/naive.py b/fla2/ops/linear_attn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ecf2718fcac8eef80f445ed02b95f36329f3c4 --- /dev/null +++ b/fla2/ops/linear_attn/naive.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.linear_attn.utils import normalize_output + + +def naive_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = (( + q @ k.transpose(-1, -2)).masked_fill_( + torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), + 0 + )) @ v + o = inter + intra + if normalize: + o = normalize_output(q * scale, k, o) + return rearrange(o, 'b h n c d -> b h (n c) d') diff --git a/fla2/ops/linear_attn/recurrent_fuse.py b/fla2/ops/linear_attn/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..84aef018a1b24beec2a0d533489e18eae491bf54 --- /dev/null +++ b/fla2/ops/linear_attn/recurrent_fuse.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import contiguous + + +@triton.jit +def fused_recurrent_linear_attn_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_linear_attn_bwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + h0, # initial hidden state initialization [B, H, K, V] + + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + b_h += b_k[:, None] * b_v[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dq += K + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * b_v[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + + +class FusedRecurrentLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False): + B, H, T, K = q.shape + V = v.shape[-1] + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V) if output_final_state else None + + grid = (NV, NK, B * H) + fused_recurrent_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), + v.stride(1), scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K = q.shape + V = v.shape[-1] + scale = ctx.scale + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_recurrent_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None, None + + +def fused_recurrent_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/fla2/ops/linear_attn/utils.py b/fla2/ops/linear_attn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b444376833f5d512af6fc2db387db75a43a92e5d --- /dev/null +++ b/fla2/ops/linear_attn/utils.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +import torch + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.cumsum(-2) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-10) diff --git a/fla2/ops/mask_delta_rule/README.md b/fla2/ops/mask_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d485a9552d70238c1f68288c72c62f9e0ef2 --- /dev/null +++ b/fla2/ops/mask_delta_rule/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/fla2/ops/mask_delta_rule/__init__.py b/fla2/ops/mask_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f2150c06c3304962a1534a95fa49037b300eaa --- /dev/null +++ b/fla2/ops/mask_delta_rule/__init__.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +from .chunk import mask_chunk_delta_rule +from .chunk_non import mask_chunk_delta_rule2 +# from .chunk_fuse import mask_fused_chunk_delta_rule +# from .recurrent_fuse import mask_fused_recurrent_delta_rule + +__all__ = [ + # 'mask_fused_chunk_delta_rule', + # 'mask_fused_recurrent_delta_rule', + 'mask_chunk_delta_rule', + 'mask_chunk_delta_rule2' + +] diff --git a/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e38239f74df42ddc4c9683610b0c0652d26ca406 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9f1806dd756fcb350ecf7a79a50de2b810d87b5 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1b00ff9d7724d80be01f2ba7dbd854969ebcae7 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a948f0ec52246f14b2ebf074fab6aca2eaac443c Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..886ef1586dfbff0fc951b6c1326240d65e5eea89 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652d3663607648fb7a4a98f9a9b14481ec87fdd5 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cff0a4c76ae4d2740b47cd72b0628dd5bc871f41 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f41739967c0262d0abcbdc71f48b649160f6fd Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0dc06389b089882124142772d8dc71d994a7b2f Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77372b1c74dd1d7cb91abe5a5ab3dabbbcda1b89 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-38.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c378431dcd84ae23fe1eb19b3a5d44d6928d95a4 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-38.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-39.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..686fba514ad2d79567e8dd2272d909e9960d8093 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_fuse.cpython-39.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f5b86cf6b349d3e2c826ca9567846d7519f669a Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a57afaa316c27bf40d756ab89aedd7c40c5ea2af Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/chunk_non.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb78a7b7d1e4500d81253c0bff4bff00aef760a5 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1fec66ac79083ab540a8cfba3224a23827ac5a Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/recurrent_fuse.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-310.pyc b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a61283aa5a52f545b4b7458d5af93d29d46369c Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-312.pyc b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f33cec7f62d0145b3829a6a6e9088b5f46bd3412 Binary files /dev/null and b/fla2/ops/mask_delta_rule/__pycache__/utils.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule/chunk.py b/fla2/ops/mask_delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0d79459d3a3606cd89f1c7fe3698a66010c931 --- /dev/null +++ b/fla2/ops/mask_delta_rule/chunk.py @@ -0,0 +1,742 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...ops.mask_delta_rule.wy_fast import (bwd_prepare_wy_repr, + fwd_prepare_wy_repr, fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + # b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + # b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + # b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + # b_v = tl.reshape(b_v,(BC,BV)) + # b_d = tl.reshape(b_d,(BC,BK)) + # b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + # tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + # bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + # b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + # b_v = b_v.to(tl.float32)#BC + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + r = mask.shape[-1] + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + #dv BHR T V + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v).to(v).float() + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32)#10s嘛 额 + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + print(k_grad) + print(k_grad0) + + diff --git a/fla2/ops/mask_delta_rule/chunk_fuse.py b/fla2/ops/mask_delta_rule/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6979fa906c6706bb07f6318b284920365db9eff --- /dev/null +++ b/fla2/ops/mask_delta_rule/chunk_fuse.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +import torch.nn.functional as F + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def mask_fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v = map(lambda x: pad(x), [q, k, v]) + beta = pad_b(beta) + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + return o, final_state + + +def mask_delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + # seq_len = 128 + # b = 2 + # h = 4 + # q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + # beta = torch.rand(b, h, seq_len).sigmoid() + # q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + # do = torch.rand_like(v) + # o2 = delta_rule_recurrence(q, k, v.clone(), beta) + # o2.backward(do, retain_graph=True) + # q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + # o.backward(do, retain_graph=True) + # q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # print((o - o2).abs().max()) + # print((q_grad - q_grad2).abs().max()) + # print((k_grad - k_grad2).abs().max()) + # print((v_grad - v_grad2).abs().max()) + # print((beta_grad - beta_grad2).abs().max()) \ No newline at end of file diff --git a/fla2/ops/mask_delta_rule/chunk_non.py b/fla2/ops/mask_delta_rule/chunk_non.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9fb407bc976836b887cecaf5b05d948135e807 --- /dev/null +++ b/fla2/ops/mask_delta_rule/chunk_non.py @@ -0,0 +1,836 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...ops.mask_delta_rule.wy_fast_non import (bwd_prepare_wy_repr, + fwd_prepare_wy_repr, fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd +import time + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + # b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + # b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + # b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + # b_v = tl.reshape(b_v,(BC,BV)) + # b_d = tl.reshape(b_d,(BC,BK)) + # b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + # tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + # bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + # b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + # b_v = b_v.to(tl.float32)#BC + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + # @staticmethod + # @contiguous + # @autocast_custom_fwd + # def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + # B,H,L,Q,V = *q.shape,v.shape[-1] + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + # r = mask.shape[-1] + # assert torch.isnan(w).sum() == 0, print('fwd_prepare_wy_repr,dq',w) + # assert torch.isnan(u).sum() == 0, print('fwd_prepare_wy_repr,dq',u) + # assert torch.isnan(A).sum() == 0, print('fwd_prepare_wy_repr,dq',A) + + # assert torch.isinf(w).sum() == 0, print('fwd_prepare_wy_repr,dq',w) + # assert torch.isinf(u).sum() == 0, print('fwd_prepare_wy_repr,dq',u) + # assert torch.isinf(A).sum() == 0, print('fwd_prepare_wy_repr,dq',A) + # # print('u0:,',u) + + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # assert torch.isnan(w).sum() == 0, print('recompute,w',w) + # assert torch.isinf(u).sum() == 0, print('recompute,u',u) + # assert torch.isinf(w).sum() == 0, print('recompute,w',w) + # assert torch.isnan(u).sum() == 0, print('recompute,u',u) + # # print(u) + + # final_state = None + # if output_final_state: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + + # assert torch.isnan(h).sum() == 0 + # assert torch.isnan(v_new).sum() == 0 + # #这里结果出现nan + # assert torch.isinf(h).sum() == 0, print('fwd_prepare_wy_repr,dq',h) + # assert torch.isinf(v_new).sum() == 0, print('fwd_prepare_wy_repr,dq',v_new) + # o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + # assert torch.isnan(o).sum() == 0, print('fwd_prepare_wy_repr,dq',o) + # assert torch.isinf(o).sum() == 0, print('fwd_prepare_wy_repr,dq',o) + + # if checkpoint_level == 1: + # h, v_new = None, None + # ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + # ctx.BT = BT + # return o.to(q.dtype), final_state + + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + #dv BHR T V + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + # @staticmethod + # @contiguous + # @autocast_custom_bwd + # def backward(ctx, do, d_ht=None): + # q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + # BT = ctx.BT + # r = mask.shape[-1] + + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # assert torch.isnan(w).sum() == 0, print('recompute,w',w) + # assert torch.isinf(u).sum() == 0, print('recompute,u',u) + # assert torch.isinf(w).sum() == 0, print('recompute,w',w) + # assert torch.isnan(u).sum() == 0, print('recompute,u',u) + + # # checkpont_level=1, recomputation. + # if h is None: + # # print("recompute") + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + # assert torch.isnan(v_new).sum() == 0, print('recompute,v_new',v_new) + # assert torch.isinf(v_new).sum() == 0, print('recompute,v_new',v_new) + + # assert torch.isnan(h).sum() == 0, print('recompute,h',h) + # assert torch.isinf(h).sum() == 0, print('recompute,h',h) + # #v_new b h r T V + # assert torch.isnan(do).sum() == 0, print('fwd_prepare_dv,dv',do) #这里出错嘛 + # assert torch.isinf(do).sum() == 0, print('fwd_prepare_dv,dv',do) #这里出错嘛 + + # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # assert torch.isnan(dv).sum() == 0, print('fwd_prepare_dv,dv',dv) #这里出错嘛 + # assert torch.isinf(dv).sum() == 0, print('fwd_prepare_dv,dv',dv) #这里出错嘛 + + # #dv BHR T V + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + # assert torch.isnan(dv).sum() == 0, print('chunk_bwd_dhu_fn,dv',dv) + # assert torch.isnan(dh).sum() == 0, print('chunk_bwd_dhu_fn,dh',dh) + + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + # assert torch.isnan(dw).sum() == 0, print('chunk_bwd_dqkw_fndw,dw',dw) + # assert torch.isnan(dq).sum() == 0, print('chunk_bwd_dqkw_fndw,dq',dq) + # assert torch.isnan(dk).sum() == 0, print('chunk_bwd_dqkw_fndw,dk',dk) + + # dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + + # assert torch.isnan(dk2).sum() == 0, print('bwd_prepare_wy_repr,dk2',dk2) + # assert torch.isnan(dv).sum() == 0, print('bwd_prepare_wy_repr,dv',dv) + # assert torch.isnan(dbeta).sum() == 0, print('bwd_prepare_wy_repr,dbeta',dbeta) + # dk.add_(dk2) + # return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule2( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32)#10s嘛 额 + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + print(k_grad) + print(k_grad0) + + diff --git a/fla2/ops/mask_delta_rule/naive.py b/fla2/ops/mask_delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab969e6e069c89e4da205ded25151baa2e8d111 --- /dev/null +++ b/fla2/ops/mask_delta_rule/naive.py @@ -0,0 +1,1480 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel(#需要解决这几个代码速度的问题,可以考虑分成3个部分,分别参与运算,类似fla3版本通过 拆分进行 + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) #get BT BT 16 16 + + ####在内部尝试一下进行分割16 BT + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +# def fwd_prepare_wy_repr_kernel(#需要解决这几个代码速度的问题,可以考虑分成3个部分,分别参与运算,类似fla3版本通过 拆分进行 +# k, +# v, +# beta, +# mask_ij, +# w, +# u, +# A, +# s_qk_h, +# s_qk_t, +# s_qk_d, +# s_vo_h, +# s_vo_t, +# s_vo_d, +# T, +# K, +# V, +# r: tl.constexpr, +# BT: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr +# ): +# i_t, i_bh = tl.program_id(0), tl.program_id(1) +# b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT +# dk = K//r +# p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) +# b_beta = tl.load(p_beta, boundary_check=(0,)) +# for i_r in range(r): +# r_mask = tl.arange(0, r) == i_r +# p_mask = mask_ij + tl.arange(0,r)* r + i_r +# b_mask = tl.load(p_mask) +# ij_mask = b_mask[:,None]*r_mask[None,:] +# for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) +# b_k = tl.load(p_k, boundary_check=(0, 1)) +# b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) +# dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) +# b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] +# b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + # for i in range(1, BT):#此时矩阵为 BT,r,BT,r + # mask = tl.arange(0, BT) == i + # b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + # q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + # b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + # b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + +# b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) +# b_A = tl.permute(b_A,(0,2,1,3)) +# b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r +# p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 +# tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) +# b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 +# for i_r in range(r): +# p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 +# b_mask = tl.load(p_mask) +# for i_k in range(tl.cdiv(dk, BK)): +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) +# b_k = tl.load(p_k, boundary_check=(0, 1)) +# b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d +# b_kb = tl.reshape(b_kb,(BT*r,BK)) +# b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK +# p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) +# tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + +# for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask +# p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# b_v = tl.load(p_v, boundary_check=(0, 1)) +# b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] +# b_vb = tl.reshape(b_vb,(BT*r,BV)) +# b_u = tl.dot(b_A, b_vb, allow_tf32=False) +# p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) +# tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + # b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + # b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + # b_ss = tl.sum(b_ss,0) + b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + # beta_y = (beta_kkt[:,None,:]*g) + # beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + # betas = tl.sum(beta_y,0) + betas = tl.sum(tl.sum(beta_kkt[:,None,:]*g,-1),0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +# def fwd_prepare_wy_repr(k, v, beta,mask, BT): +# # A, _ = chunk_scaled_dot_kkt_fwd( +# # k=k, +# # beta=beta, +# # g_cumsum=None, +# # cu_seqlens=cu_seqlens, +# # chunk_size=64, +# # output_dtype=torch.float32, +# # ) +# # A = solve_tril( +# # A=A, +# # cu_seqlens=cu_seqlens, +# # output_dtype=k.dtype +# # ) +# # w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) +# # return w, u, A +# B, H, T, K, V = *k.shape, v.shape[-1] +# r = mask.shape[-1] +# u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) +# w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) +# NT = triton.cdiv(T, BT) +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) +# fwd_prepare_wy_repr_kernel[(NT, B*H)]( +# k, v, beta, mask, w, u, A, +# T*K, K, 1, +# T*V, V, 1, +# T, K, V, r, BT, BK, BV +# ) +# return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + assert BK ==K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_o:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + # kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + for i in range(200): + B = 16 + H = 4 + L = 2048 + DK = 256 + DV = 256 + r = 4 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(True).contiguous() + + # start = time.time() + do = torch.randn(B, H, L, DV).cuda() + # o1 = delta_rule_recurrence(q,k,v,beta,mask) + # o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + mask_grad, mask.grad = mask.grad, None + beta_grad, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + mask_grad0, mask.grad = mask.grad, None + # # end = time.time() + # # print(end-start) + # print((o1-o).abs().max()) + # print((q_grad-q_grad0).abs().max()) + # print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + # print((v_grad-v_grad0).abs().max()) + # print((beta_grad-beta_grad0).abs().max()) + # print((mask_grad-mask_grad0).abs().max()) + # print('naive:',mask_grad) + # print('triton:',mask_grad0) + # print(k_grad) + # print(k_grad0) + + # BT = 16 + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + # print('finish0') + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, None, None)#need change' + # print('finish1') + # o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + # print('finish2') + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # print('finish3') + # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # print('finish4') + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + # print('finish5') + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + # print('finish6') + + # Ass = rearrange(A,'b h (n t) l->b h n t l',n = L//BT) + # dwss = rearrange(dw,'b h (n t) k->b h n t k',n = L//BT) + # dvss = rearrange(dv,'b h (n t) k->b h n t k',n = L//BT) + # dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + # print('triton:',dmask) #几乎完全相等 + + # vbeta = v*beta[...,None] + # vbeta = rearrange(vbeta,'b h (n T) d->b h n T d',T=BT) + # vbeta = vbeta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + # vbeta = rearrange(vbeta,'b h n t r d-> b h n (t r) d') + + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta = torch.einsum('b h n T r d,c r-> b h n T c r d',kbeta,mask) + # kbeta = rearrange(kbeta,'b h n t c r d-> b h n (t c) (r d)') + # dA = dvss@vbeta.transpose(-1,-2)+dwss@kbeta.transpose(-1,-2) + + + # dorg = Ass.transpose(-1,-2)@dwss#bhn bt*r k + # dorg = rearrange(dorg,'b h n (t r) (c k)->b h n t r c k',r=r,c=r) + # betan = rearrange(beta,'b h (n t)->b h n t',n=L//BT) + # kn = rearrange(k,'b h (n t) (r d)->b h n t r d ',n = L//BT,r=r) + + # dmask = torch.einsum('b h n t r c k,b h n t->b h n t r c k',dorg,betan) + # dmask = torch.einsum('b h n t r c k,b h n t c k->b h n t r c k',dmask,kn) + # dmask = rearrange(dmask,'b h n t r c k-> (b h n) (t k) r c') + # dmaskss = dmask.sum(0).sum(0) + + # i = torch.arange(0, BT * r)[:, None] + # j = torch.arange(0, BT * r)[None, :] + # iB = i // r + # jB = j // r + # da_mask = iB > jB + # da_mask = da_mask.cuda() + # b_dA = torch.where(da_mask, dA, 0) + + # b_dA = b_dA @ Ass.transpose(-1,-2) + # b_dA = Ass.transpose(-1,-2)@b_dA + + # b_dA = torch.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + # b_dA = rearrange(b_dA,'b h n (t r) (l c)-> b h n t r l c',c=r,r=r) + # # print((dAss-b_dA).abs())#到这里也完全相等 + + + # # betakkt = k*beta[...,None] + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta2 = rearrange(k,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # betakkt = torch.einsum('b h n T r d,b h n s r d->b h n r T s',kbeta,kbeta2)#r Bt bt + # betakkt = rearrange(betakkt,'b h n r T s->b h n T s r')#BT r BT###横向 + # # print((dAss-b_dA).abs()) + + # #证明是下面的计算出错了 + # dmask = torch.einsum('b h n t r l c,b h n t l c-> b h n t r l c',b_dA,betakkt) + # # print((dAss-dmask).abs().max())#意味着这个计算结果也是对的 + # # print((dAss-dmask)) + + # dmask = rearrange(dmask,'b h n t r l c->b h n (t l) r c') + # dmask = dmask.sum(-3) + # dmask = dmask.sum(0).sum(0).sum(0) + # print('matrix:',dmask) + + + + + + + + diff --git a/fla2/ops/mask_delta_rule/naive_rmbeta copy.py b/fla2/ops/mask_delta_rule/naive_rmbeta copy.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_delta_rule/naive_rmbeta copy.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_delta_rule/naive_rmbeta.py b/fla2/ops/mask_delta_rule/naive_rmbeta.py new file mode 100644 index 0000000000000000000000000000000000000000..33f29f3d3b93d378128a4dc0d3e8aba87ab67756 --- /dev/null +++ b/fla2/ops/mask_delta_rule/naive_rmbeta.py @@ -0,0 +1,1377 @@ +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + +# def fwd_prepare_wy_repr(k, v, beta,mask, BT): +# B, H, T, K, V = *k.shape, v.shape[-1] +# r = mask.shape[-1] +# u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) +# w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) +# NT = triton.cdiv(T, BT) +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) +# fwd_prepare_wy_repr_kernel[(NT, B*H)]( +# k, v, beta, mask, w, u, A, +# T*K, K, 1, +# T*V, V, 1, +# T, K, V, r, BT, BK, BV +# ) +# return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + # b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + # b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + # b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + # b_v = tl.reshape(b_v,(BC,BV)) + # b_d = tl.reshape(b_d,(BC,BK)) + # b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + # tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + # bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + # b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + # b_v = b_v.to(tl.float32)#BC + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + r = mask.shape[-1] + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + #dv BHR T V + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v).to(v).float() + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr.float(),S.clone()) + _k.unsqueeze(-1).float() * _v.unsqueeze(-2).float() + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q.float(), S).to(k.dtype) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + mask_grad, mask.grad = mask.grad, None + end = time.time() + print(end-start) + + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32)#10s嘛 额 + o.backward(do,retain_graph=True) + print((o-o1).abs().max()) + + print(o) + print(o1) + # q_grad0, q.grad = q.grad, None + # k_grad0, k.grad = k.grad, None + # v_grad0, v.grad = v.grad, None + # beta_grad0, beta.grad = beta.grad, None + # mask_grad0, mask.grad = mask.grad, None + # print((q_grad-q_grad0).abs().max()) + # print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + # print((v_grad-v_grad0).abs().max()) + # print((beta_grad-beta_grad0).abs().max()) + # print((mask_grad-mask_grad0).abs().max()) + + diff --git a/fla2/ops/mask_delta_rule/recurrent_fuse.py b/fla2/ops/mask_delta_rule/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..f21470ff11d7e75df52b0c81dcb66bd40a44a0e5 --- /dev/null +++ b/fla2/ops/mask_delta_rule/recurrent_fuse.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + beta, # beta [B, H, L] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, # batch size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_qk_h, # stride size: L * K + + s_vo_h, # stride size: L * V + + NK, # NK block size + scale, # K ** -0.5 + + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= b_k[:, None] * d_v[None, :] + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @contiguous + @staticmethod + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale + return o, final_state + + @contiguous + @staticmethod + def backward(ctx, do, dht=None): + q, k, v, beta, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None + + +def mask_fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/mask_delta_rule/utils.py b/fla2/ops/mask_delta_rule/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..173d6629c628bb6b5860a005cbc8ea85d7cf9b5e --- /dev/null +++ b/fla2/ops/mask_delta_rule/utils.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @contiguous + @autocast_custom_fwd + @staticmethod + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @contiguous + @autocast_custom_bwd + @staticmethod + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) diff --git a/fla2/ops/mask_delta_rule/wy_fast.py b/fla2/ops/mask_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c3538e86c9273b88c04fb719c0a543c4bd0ea6 --- /dev/null +++ b/fla2/ops/mask_delta_rule/wy_fast.py @@ -0,0 +1,784 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + +# def fwd_prepare_wy_repr(k, v, beta,mask, BT): +# B, H, T, K, V = *k.shape, v.shape[-1] +# r = mask.shape[-1] +# u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) +# w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) +# NT = triton.cdiv(T, BT) +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) +# fwd_prepare_wy_repr_kernel[(NT, B*H)]( +# k, v, beta, mask, w, u, A, +# T*K, K, 1, +# T*V, V, 1, +# T, K, V, r, BT, BK, BV +# ) +# return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 32 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 16 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.float32) + # s = rearrange(s,'b h n a c->(b h) (n a) c') + # print(Ad) + # print(s) + # print((Ad-s).abs().max()) + + + w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, 16) + As = rearrange(As,'b h (n t) l->(b h n) t l',t =BT*r) + # print((As-s).abs().max()) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/mask_delta_rule/wy_fast_non.py b/fla2/ops/mask_delta_rule/wy_fast_non.py new file mode 100644 index 0000000000000000000000000000000000000000..98b11f5743e8debffca59f9ce09c56ade7003d0d --- /dev/null +++ b/fla2/ops/mask_delta_rule/wy_fast_non.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + #here BT * r * BK + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, None, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +# def naive(k, v, beta,mask,chunk_size): +# l_org = k.shape[2] +# l_new = triton.next_power_of_2(l_org) +# k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) +# v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) +# beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + +# k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) +# beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) +# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) +# k_beta = k * beta[..., None] +# v = v * beta[..., None] +# attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) +# attn = attn * beta[..., None] +# x = attn @ v + +# o = torch.zeros_like(k) +# o2 = torch.zeros_like(v) + +# o[..., 0, :] = k_beta[..., 0, :].clone() +# o2[..., 0, :] = x[..., 0, :].clone() +# for i in range(1, chunk_size): +# o_i = (o[..., :i, :]).clone() +# o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] +# o2_i = (o2[..., :i, :]).clone() +# o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] +# return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + +#use this naive +#这个代码有问题 +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 32 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 16 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + # k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) + # b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) + # a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + # qq = torch.tril(a1,diagonal=-1) + # qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + # sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + # sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + # #长条对角线 + # i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + # s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + # s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + # s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + # s_copy = s + + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + do = torch.rand_like(o1) + do2 = torch.rand_like(o2)#b h n t t + print((o1-w).abs().max()) + print((o2-u).abs().max()) + if require_grad: + o1.backward(do, retain_graph=True) + o2.backward(do2, retain_graph=True) + k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # k.grad = v.grad = beta.grad = None + # # wc.backward(do, retain_graph=True) + # # uc.backward(do2, retain_graph=True) + # # k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad + # # k.grad = v.grad = beta.grad = None + w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # print((wc-w0).abs().max()) + # print((uc-u0).abs().max()) + # print((wc-o1).abs().max()) + # print((uc-o2).abs().max()) + k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + print((k_grad-k_grad2).abs().max()) + print((v_grad-v_grad2).abs().max()) + print((beta_grad-beta_grad2).abs().max()) + print((mask_grad-mask_grad2).abs().max()) + print(mask_grad) + print(mask_grad2) + + diff --git a/fla2/ops/mask_delta_rule/wy_fast_test.py b/fla2/ops/mask_delta_rule/wy_fast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e2a8be22392f019f48c280037b35a861e76a42 --- /dev/null +++ b/fla2/ops/mask_delta_rule/wy_fast_test.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + b_A21 = tl.permute(b_A21,(0,2,1,3)) + b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr2(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + torch.manual_seed(42) + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 128 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 32 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = BT,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = BT) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.bfloat16) + # s = rearrange(s,'b h n a c->(b h n) a c') + # print(Ad.shape) + # print(s.shape) + + w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, BT) + w2,u2,Ad2 = fwd_prepare_wy_repr(k, v, beta,mask, BT) + + print((w2-w).abs().max()) + print((u2-u).abs().max()) + print((As-Ad2).abs().max()) + + # print((Ad-s).abs().max()) + # print(Ad-s) + + # print((As-s).abs().max()) + # print(As-s) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/mask_delta_rule_t/README.md b/fla2/ops/mask_delta_rule_t/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d485a9552d70238c1f68288c72c62f9e0ef2 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/fla2/ops/mask_delta_rule_t/__init__.py b/fla2/ops/mask_delta_rule_t/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1087963f473d48ee4de9546b4699cd318d128fbb --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import mask_chunk_delta_rule +# from .chunk_fuse import mask_fused_chunk_delta_rule +# from .recurrent_fuse import mask_fused_recurrent_delta_rule + +__all__ = [ + 'mask_chunk_delta_rule', +] diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-310.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a17dd17528a64e826baf9485abf87cc3e90d879 Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-312.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a7cbed8871723275f2354542c84004cc1dbf2e Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-310.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14364232b2de88042251d14e290e4b6fabed5bd5 Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-312.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2232a1b7390030b81bed35be53832a7b0744c75b Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a7e2f114edc7c82230aafd9a1a35d633448531 Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-312.pyc b/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a00341c70550b0fa1ab3ee41d81324a15307f740 Binary files /dev/null and b/fla2/ops/mask_delta_rule_t/__pycache__/wy_fast.cpython-312.pyc differ diff --git a/fla2/ops/mask_delta_rule_t/chunk.py b/fla2/ops/mask_delta_rule_t/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..b8354a24d70e7d801ba4e6738c5c4f9c2034057b --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/chunk.py @@ -0,0 +1,770 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...ops.mask_delta_rule_t.wy_fast import (bwd_prepare_wy_repr, + fwd_prepare_wy_repr, fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd +#finish +import torch.nn.functional as F +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x,val, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=val) # 只在最后一个维度(l)进行填充 + return x + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + # b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + # b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + # b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + # b_v = tl.reshape(b_v,(BC,BV)) + # b_d = tl.reshape(b_d,(BC,BK)) + # b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + # tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + # bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + # b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + # b_v = b_v.to(tl.float32)#BC + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + r = mask.shape[-1] + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + #dv BHR T V + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + seq_len = v.shape[-2] + q, k, v = map(lambda x: pad(x,BT), [q, k, v]) + beta = pad_b(beta,0.0,BT) + q,k,v,beta = map(lambda x:x.contiguous(),[q,k,v,beta]) + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + o = o[..., :seq_len,:] + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v).to(v).float() + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32)#10s嘛 额 + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + print(k_grad) + print(k_grad0) + + diff --git a/fla2/ops/mask_delta_rule_t/chunk_fuse.py b/fla2/ops/mask_delta_rule_t/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6979fa906c6706bb07f6318b284920365db9eff --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/chunk_fuse.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +import torch.nn.functional as F + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def mask_fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v = map(lambda x: pad(x), [q, k, v]) + beta = pad_b(beta) + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + return o, final_state + + +def mask_delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + # seq_len = 128 + # b = 2 + # h = 4 + # q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + # beta = torch.rand(b, h, seq_len).sigmoid() + # q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + # do = torch.rand_like(v) + # o2 = delta_rule_recurrence(q, k, v.clone(), beta) + # o2.backward(do, retain_graph=True) + # q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + # o.backward(do, retain_graph=True) + # q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # print((o - o2).abs().max()) + # print((q_grad - q_grad2).abs().max()) + # print((k_grad - k_grad2).abs().max()) + # print((v_grad - v_grad2).abs().max()) + # print((beta_grad - beta_grad2).abs().max()) \ No newline at end of file diff --git a/fla2/ops/mask_delta_rule_t/naive.py b/fla2/ops/mask_delta_rule_t/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..fac9d616a63cb418c91e4bc1c49c3f95483d32a3 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/naive.py @@ -0,0 +1,1367 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([BT,r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask,1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + b_ss = (tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],-1)) # BT r + b_dmask += (b_ss[:,:,None]*rmask[None,None,:]).to(tl.float32)#BT r r + + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask,1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + betas = (tl.sum(beta_kkt[:,None,:]*g,-1))#BT r + b_dmask += (betas[:,:,None]*rmask[None,None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T) + i_t * BT)* r * r , (BT,r,r), (r*r,r,1), (0,0,0), (BT,r,r), (2,1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B,H,T,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta, dmask + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_o:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + # kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,b h r l->b h r d l v',kkt,mask[:,:,i,:,:].to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + for i in range(1): + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + r = 4 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.randn([r,r]) + mask = torch.randn(B,H,L,r,r).cuda().requires_grad_(True) + # mask = mask.cuda().requires_grad_(True).contiguous() + + # start = time.time() + do = torch.randn(B, H, L, DV).cuda() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + mask_grad, mask.grad = mask.grad, None + beta_grad, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + mask_grad0, mask.grad = mask.grad, None + # # end = time.time() + # # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + print((mask_grad-mask_grad0).abs().max()) + print('naive:',mask_grad) + print('triton:',mask_grad0) + # print(k_grad) + # print(k_grad0) + + # BT = 16 + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + # print('finish0') + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, None, None)#need change' + # print('finish1') + # o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + # print('finish2') + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # print('finish3') + # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # print('finish4') + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + # print('finish5') + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + # print('finish6') + + # Ass = rearrange(A,'b h (n t) l->b h n t l',n = L//BT) + # dwss = rearrange(dw,'b h (n t) k->b h n t k',n = L//BT) + # dvss = rearrange(dv,'b h (n t) k->b h n t k',n = L//BT) + # dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + # print('triton:',dmask) #几乎完全相等 + + # vbeta = v*beta[...,None] + # vbeta = rearrange(vbeta,'b h (n T) d->b h n T d',T=BT) + # vbeta = vbeta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + # vbeta = rearrange(vbeta,'b h n t r d-> b h n (t r) d') + + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta = torch.einsum('b h n T r d,c r-> b h n T c r d',kbeta,mask) + # kbeta = rearrange(kbeta,'b h n t c r d-> b h n (t c) (r d)') + # dA = dvss@vbeta.transpose(-1,-2)+dwss@kbeta.transpose(-1,-2) + + + # dorg = Ass.transpose(-1,-2)@dwss#bhn bt*r k + # dorg = rearrange(dorg,'b h n (t r) (c k)->b h n t r c k',r=r,c=r) + # betan = rearrange(beta,'b h (n t)->b h n t',n=L//BT) + # kn = rearrange(k,'b h (n t) (r d)->b h n t r d ',n = L//BT,r=r) + + # dmask = torch.einsum('b h n t r c k,b h n t->b h n t r c k',dorg,betan) + # dmask = torch.einsum('b h n t r c k,b h n t c k->b h n t r c k',dmask,kn) + # dmask = rearrange(dmask,'b h n t r c k-> (b h n) (t k) r c') + # dmaskss = dmask.sum(0).sum(0) + + # i = torch.arange(0, BT * r)[:, None] + # j = torch.arange(0, BT * r)[None, :] + # iB = i // r + # jB = j // r + # da_mask = iB > jB + # da_mask = da_mask.cuda() + # b_dA = torch.where(da_mask, dA, 0) + + # b_dA = b_dA @ Ass.transpose(-1,-2) + # b_dA = Ass.transpose(-1,-2)@b_dA + + # b_dA = torch.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + # b_dA = rearrange(b_dA,'b h n (t r) (l c)-> b h n t r l c',c=r,r=r) + # # print((dAss-b_dA).abs())#到这里也完全相等 + + + # # betakkt = k*beta[...,None] + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta2 = rearrange(k,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # betakkt = torch.einsum('b h n T r d,b h n s r d->b h n r T s',kbeta,kbeta2)#r Bt bt + # betakkt = rearrange(betakkt,'b h n r T s->b h n T s r')#BT r BT###横向 + # # print((dAss-b_dA).abs()) + + # #证明是下面的计算出错了 + # dmask = torch.einsum('b h n t r l c,b h n t l c-> b h n t r l c',b_dA,betakkt) + # # print((dAss-dmask).abs().max())#意味着这个计算结果也是对的 + # # print((dAss-dmask)) + + # dmask = rearrange(dmask,'b h n t r l c->b h n (t l) r c') + # dmask = dmask.sum(-3) + # dmask = dmask.sum(0).sum(0).sum(0) + # print('matrix:',dmask) + + + + + + + + diff --git a/fla2/ops/mask_delta_rule_t/naive_rmbeta copy.py b/fla2/ops/mask_delta_rule_t/naive_rmbeta copy.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/naive_rmbeta copy.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_delta_rule_t/naive_rmbeta.py b/fla2/ops/mask_delta_rule_t/naive_rmbeta.py new file mode 100644 index 0000000000000000000000000000000000000000..33f29f3d3b93d378128a4dc0d3e8aba87ab67756 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/naive_rmbeta.py @@ -0,0 +1,1377 @@ +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + +# def fwd_prepare_wy_repr(k, v, beta,mask, BT): +# B, H, T, K, V = *k.shape, v.shape[-1] +# r = mask.shape[-1] +# u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) +# w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) +# NT = triton.cdiv(T, BT) +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) +# fwd_prepare_wy_repr_kernel[(NT, B*H)]( +# k, v, beta, mask, w, u, A, +# T*K, K, 1, +# T*V, V, 1, +# T, K, V, r, BT, BK, BV +# ) +# return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + # b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + # b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + # b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + # b_v = tl.reshape(b_v,(BC,BV)) + # b_d = tl.reshape(b_d,(BC,BK)) + # b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + # tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + # bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + # b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + # b_v = b_v.to(tl.float32)#BC + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + T*K,K,1, + T*V, V, 1, + NT*K*V,V, + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + r = mask.shape[-1] + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + #dv BHR T V + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + dk.add_(dk2) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v).to(v).float() + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr.float(),S.clone()) + _k.unsqueeze(-1).float() * _v.unsqueeze(-2).float() + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q.float(), S).to(k.dtype) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + mask_grad, mask.grad = mask.grad, None + end = time.time() + print(end-start) + + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32)#10s嘛 额 + o.backward(do,retain_graph=True) + print((o-o1).abs().max()) + + print(o) + print(o1) + # q_grad0, q.grad = q.grad, None + # k_grad0, k.grad = k.grad, None + # v_grad0, v.grad = v.grad, None + # beta_grad0, beta.grad = beta.grad, None + # mask_grad0, mask.grad = mask.grad, None + # print((q_grad-q_grad0).abs().max()) + # print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + # print((v_grad-v_grad0).abs().max()) + # print((beta_grad-beta_grad0).abs().max()) + # print((mask_grad-mask_grad0).abs().max()) + + diff --git a/fla2/ops/mask_delta_rule_t/recurrent_fuse.py b/fla2/ops/mask_delta_rule_t/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..f21470ff11d7e75df52b0c81dcb66bd40a44a0e5 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/recurrent_fuse.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + beta, # beta [B, H, L] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, # batch size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_qk_h, # stride size: L * K + + s_vo_h, # stride size: L * V + + NK, # NK block size + scale, # K ** -0.5 + + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= b_k[:, None] * d_v[None, :] + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @contiguous + @staticmethod + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale + return o, final_state + + @contiguous + @staticmethod + def backward(ctx, do, dht=None): + q, k, v, beta, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None + + +def mask_fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/mask_delta_rule_t/utils.py b/fla2/ops/mask_delta_rule_t/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..173d6629c628bb6b5860a005cbc8ea85d7cf9b5e --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/utils.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @contiguous + @autocast_custom_fwd + @staticmethod + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @contiguous + @autocast_custom_bwd + @staticmethod + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) diff --git a/fla2/ops/mask_delta_rule_t/wy_fast.py b/fla2/ops/mask_delta_rule_t/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..389c7ce5a173fbf651390f11dc3fbe4a58735a22 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/wy_fast.py @@ -0,0 +1,758 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([BT,r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask,1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + b_ss = (tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],-1)) # BT r + b_dmask += (b_ss[:,:,None]*rmask[None,None,:]).to(tl.float32)#BT r r + + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask,1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + betas = (tl.sum(beta_kkt[:,None,:]*g,-1))#BT r + b_dmask += (betas[:,:,None]*rmask[None,None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T) + i_t * BT)* r * r , (BT,r,r), (r*r,r,1), (0,0,0), (BT,r,r), (2,1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + # p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + # b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + # b_A21 = tl.permute(b_A21,(0,2,1,3)) + # b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k = k ,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B,H,T,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 32 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 16 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.float32) + # s = rearrange(s,'b h n a c->(b h) (n a) c') + # print(Ad) + # print(s) + # print((Ad-s).abs().max()) + + + w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, 16) + As = rearrange(As,'b h (n t) l->(b h n) t l',t =BT*r) + # print((As-s).abs().max()) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/mask_delta_rule_t/wy_fast_non.py b/fla2/ops/mask_delta_rule_t/wy_fast_non.py new file mode 100644 index 0000000000000000000000000000000000000000..98b11f5743e8debffca59f9ce09c56ade7003d0d --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/wy_fast_non.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + #here BT * r * BK + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, None, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +# def naive(k, v, beta,mask,chunk_size): +# l_org = k.shape[2] +# l_new = triton.next_power_of_2(l_org) +# k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) +# v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) +# beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + +# k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) +# beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) +# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) +# k_beta = k * beta[..., None] +# v = v * beta[..., None] +# attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) +# attn = attn * beta[..., None] +# x = attn @ v + +# o = torch.zeros_like(k) +# o2 = torch.zeros_like(v) + +# o[..., 0, :] = k_beta[..., 0, :].clone() +# o2[..., 0, :] = x[..., 0, :].clone() +# for i in range(1, chunk_size): +# o_i = (o[..., :i, :]).clone() +# o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] +# o2_i = (o2[..., :i, :]).clone() +# o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] +# return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + +#use this naive +#这个代码有问题 +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 32 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 16 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + # k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) + # b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) + # a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + # qq = torch.tril(a1,diagonal=-1) + # qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + # sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + # sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + # #长条对角线 + # i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + # s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + # s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + # s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + # s_copy = s + + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + do = torch.rand_like(o1) + do2 = torch.rand_like(o2)#b h n t t + print((o1-w).abs().max()) + print((o2-u).abs().max()) + if require_grad: + o1.backward(do, retain_graph=True) + o2.backward(do2, retain_graph=True) + k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # k.grad = v.grad = beta.grad = None + # # wc.backward(do, retain_graph=True) + # # uc.backward(do2, retain_graph=True) + # # k_grad2, v_grad2, beta_grad2 = k.grad, v.grad, beta.grad + # # k.grad = v.grad = beta.grad = None + w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # print((wc-w0).abs().max()) + # print((uc-u0).abs().max()) + # print((wc-o1).abs().max()) + # print((uc-o2).abs().max()) + k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + print((k_grad-k_grad2).abs().max()) + print((v_grad-v_grad2).abs().max()) + print((beta_grad-beta_grad2).abs().max()) + print((mask_grad-mask_grad2).abs().max()) + print(mask_grad) + print(mask_grad2) + + diff --git a/fla2/ops/mask_delta_rule_t/wy_fast_test.py b/fla2/ops/mask_delta_rule_t/wy_fast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e2a8be22392f019f48c280037b35a861e76a42 --- /dev/null +++ b/fla2/ops/mask_delta_rule_t/wy_fast_test.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + b_A21 = tl.permute(b_A21,(0,2,1,3)) + b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr2(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + torch.manual_seed(42) + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 128 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 32 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = BT,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = BT) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.bfloat16) + # s = rearrange(s,'b h n a c->(b h n) a c') + # print(Ad.shape) + # print(s.shape) + + w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, BT) + w2,u2,Ad2 = fwd_prepare_wy_repr(k, v, beta,mask, BT) + + print((w2-w).abs().max()) + print((u2-u).abs().max()) + print((As-Ad2).abs().max()) + + # print((Ad-s).abs().max()) + # print(Ad-s) + + # print((As-s).abs().max()) + # print(As-s) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/mask_gated_delta_rule/README.md b/fla2/ops/mask_gated_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d485a9552d70238c1f68288c72c62f9e0ef2 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/fla2/ops/mask_gated_delta_rule/__init__.py b/fla2/ops/mask_gated_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c675b3da981726a2b4a9919545e4f569682d710a --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import mask_gated_chunk_delta_rule +# from .chunk_fuse import mask_fused_chunk_delta_rule +# from .recurrent_fuse import mask_fused_recurrent_delta_rule + +__all__ = [ + # 'mask_fused_chunk_delta_rule', + # 'mask_fused_recurrent_delta_rule', + 'mask_gated_chunk_delta_rule', +] diff --git a/fla2/ops/mask_gated_delta_rule/__pycache__/__init__.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2320bff858e5e662428a0af522848b48f9d7c559 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule/__pycache__/chunk.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule/__pycache__/chunk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d99b8023ff467879e00a226fe055c6de6aedac7 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule/__pycache__/chunk.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac2663a9e3ab32a47e175a2bdc70d396897a29e8 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule/__pycache__/wy_fast.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule/chunk.py b/fla2/ops/mask_gated_delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e3310f7af7447c5741ef5980a7880114a95160 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/chunk.py @@ -0,0 +1,1764 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from ...ops.mask_gated_delta_rule.wy_fast import (gated_chunk_scaled_dot_kkt_fwd,solve_tril, + gated_fwd_recompute_w_u) +from ...ops.utils import contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd +from fla.ops.utils import chunk_local_cumsum +#finish +import torch.nn.functional as F +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + + +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1), +# triton.Config({}, num_warps=2), +# triton.Config({}, num_warps=4), +# triton.Config({}, num_warps=8), +# triton.Config({}, num_warps=16) +# ], +# key=["BT", "BK", "BV"], +# ) +# @triton.jit +# def fwd_prepare_dv_kernel( +# q, +# k, +# do, +# dv, +# s_qk_h, +# s_qk_t, +# s_qk_d, +# s_vo_h, +# s_vo_t, +# s_vo_d, +# T, +# K, +# V, +# scale, +# BT: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# r: tl.constexpr, +# ): +# i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 +# i_bh = i_bhr//r +# i_r = i_bhr % r +# b_A = tl.zeros([BT, BT], dtype=tl.float32) +# block_r = K//r +# for i_k in range(tl.cdiv(block_r, BK)): +# p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) +# b_k = tl.load(p_k, boundary_check=(0, 1)) +# b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) +# b_q = (b_q * scale).to(b_k.dtype) +# b_A += tl.dot(b_k, b_q, allow_tf32=False) +# b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) +# for i_v in range(tl.cdiv(V, BV)): +# p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# b_do = tl.load(p_do, boundary_check=(0, 1)) +# p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# b_dv = tl.dot(b_A, b_do, allow_tf32=False) +# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +# #finish +# def fwd_prepare_dv(q, k, do, r,BT): +# B, H, T, K, V = *k.shape, do.shape[-1] +# dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like +# NT = triton.cdiv(T, BT) +# BK = min(triton.next_power_of_2(K//r),64) +# BV = min(triton.next_power_of_2(V), 64) +# fwd_prepare_dv_kernel[(NT, B*H*r)]( +# q, k, do, dv, +# T*K, K, 1, +# T*V, V, 1, +# T, K, V, K**-0.5, BT, BK, BV, r +# ) +# return dv + +# #finish +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1), +# triton.Config({}, num_warps=2), +# triton.Config({}, num_warps=4), +# triton.Config({}, num_warps=8), +# triton.Config({}, num_warps=16) +# ], +# key=["BT", "BK", "BV"], +# ) +# @triton.jit +# def gated_chunk_delta_rule_fwd_kernel_h( +# k, +# v,#u +# d,#w +# v_new, +# g, +# h, +# initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] +# final_state, # final state of the chunk [B, H, D_head_K, D_head_V] +# H: tl.constexpr, +# T: tl.constexpr, +# K: tl.constexpr, +# V: tl.constexpr, +# BT: tl.constexpr, +# BC: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# NT: tl.constexpr, +# r: tl.constexpr, +# USE_INITIAL_STATE: tl.constexpr, +# STORE_FINAL_STATE: tl.constexpr +# ): +# i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) +# b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 +# if USE_INITIAL_STATE: +# p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + +# for i_t in range(NT): +# p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) +# #这里save是对的 +# b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) +# for i_r in range(r): +# for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 +# r_mask = tl.arange(0,r) == i_r +# p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), +# (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 +# p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), +# (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) +# p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), +# (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) +# p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), +# (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) +# b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC +# b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK +# b_v = tl.load(p_v, boundary_check=(0, 1, 2)) +# b_v = tl.reshape(b_v,(BC,BV)) +# b_d = tl.reshape(b_d,(BC,BK)) +# b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC +# tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + +# last_idx = min((i_t + 1) * BT, T) - 1 +# b_g_last = tl.load(g + i_bh*T + last_idx) +# b_g_last = tl.exp(b_g_last) +# b_h = b_g_last * b_h + + +# bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) +# b_h_cumsum += bkv.to(b_h_cumsum.dtype) +# b_h += tl.reshape(b_h_cumsum,(BK,BV)) + +# if STORE_FINAL_STATE: +# p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +# #finish +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1), +# triton.Config({}, num_warps=2), +# triton.Config({}, num_warps=4), +# triton.Config({}, num_warps=8), +# triton.Config({}, num_warps=16) +# ], +# key=["BT", "BK", "BV"], +# ) +# @triton.jit +# def gated_chunk_linear_attn_fwd_kernel_o( +# q, +# k, +# v, +# h, +# g, +# o, +# s_qk_h, +# s_qk_t, +# s_qk_d, +# s_vo_h, +# s_vo_t, +# s_vo_d, +# s_h_h, +# s_h_t, +# scale, +# H: tl.constexpr, +# T: tl.constexpr, +# K: tl.constexpr, +# V: tl.constexpr, +# BT: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# r : tl.constexpr +# ): +# i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) +# i_bh = i_bhr//r +# i_r = i_bhr % r +# rk = K//r +# o_i = tl.arange(0, BT) +# m_s = o_i[:, None] >= o_i[None, :] +# b_o = tl.zeros([BT, BV], dtype=tl.float32) +# b_s = tl.zeros([BT, BT], dtype=tl.float32) +# for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r +# #问题是不同r_block读取了同一份qk,有影响吗 +# p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) +# p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) +# b_q = tl.load(p_q, boundary_check=(0, 1)) +# b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) +# b_h = tl.load(p_h, boundary_check=(0, 1)) +# b_o += tl.dot(b_q, b_h, allow_tf32=False) +# b_s += tl.dot(b_q, b_k, allow_tf32=False) + +# p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) +# b_g = tl.load(p_g, boundary_check=(0,)) +# b_g_diff = b_g[:, None] - b_g[None, :] +# b_s = b_s * safe_exp(b_g_diff)[:,:]#BT BT + +# b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 +# p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) +# b_v = tl.load(p_v, boundary_check=(0, 1)) +# b_o = b_o * scale + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale +# p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +# #finish +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1), +# triton.Config({}, num_warps=2), +# triton.Config({}, num_warps=4), +# triton.Config({}, num_warps=8), +# triton.Config({}, num_warps=16) +# ], +# key=["BT", "BK", "BV"], +# ) +# @triton.jit +# def chunk_delta_rule_bwd_kernel_dhu( +# q, +# k, +# d, +# do, +# dh, +# dv, +# dv2, +# s_qk_h, +# s_qk_t, +# s_qk_d, +# s_h_h, +# scale, +# H: tl.constexpr, +# T: tl.constexpr, +# K: tl.constexpr, +# V: tl.constexpr, +# BT: tl.constexpr, +# BC: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# NT: tl.constexpr, +# r: tl.constexpr, +# KR: tl.constexpr, +# ): +# i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) +# b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 +# for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 +# p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) +# tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) +# b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) +# #全列 +# for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): +# p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), +# (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 +# p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), +# (i_k * BK, i_t * BT * r + i_c * BC *r), (BK, BC * r), (0, 1)) +# p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), +# (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) +# b_q = (tl.load(p_q, boundary_check=(0, 1))) +# b_q = (b_q * scale).to(b_q.dtype) + +# b_do = tl.load(p_do, boundary_check=(0, 1)) +# b_d = (tl.load(p_d,boundary_check=(0, 1))) +# p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), +# (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0))#load r +# b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv +# b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) +# for i_r in range(r): +# rmask = tl.arange(0, r) == i_r #第ir列 +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), +# (i_t * BT + i_c * BC, i_r*KR + i_k * BK), (BC, KR), (1, 0))# +# b_k = tl.load(p_k, boundary_check=(0, 1)) #BC KR +# b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0)# KR BV +# dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False)#get BC*BV +# b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BC*r,BV)) + +# p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), +# (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) +# tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) +# b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) +# b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) +# b_dh += b_dh_tmp + + +# @triton.autotune( +# configs=[ +# triton.Config({}, num_warps=1), +# triton.Config({}, num_warps=2), +# triton.Config({}, num_warps=4), +# triton.Config({}, num_warps=8), +# triton.Config({}, num_warps=16) +# ], +# key=["BT", "BK", "BV"], +# ) +# @triton.jit +# def chunk_delta_rule_bwd_kernel_dqkw( +# q, +# k, +# v, +# w, +# h, +# do, +# dh, +# dq, +# dk, +# dv, +# dw, +# s_qk_h, +# s_qk_t, +# s_qk_d, +# s_vo_h, +# s_vo_t, +# s_vo_d, +# s_h_h, +# s_h_t, +# scale, +# H: tl.constexpr, +# T: tl.constexpr, +# K: tl.constexpr, +# V: tl.constexpr, +# BT: tl.constexpr, +# BK: tl.constexpr, +# BV: tl.constexpr, +# NT: tl.constexpr, +# r: tl.constexpr, +# ): +# i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) +# i_r = i_bhr%r +# i_bh = i_bhr//r +# o_i = tl.arange(0, BT) +# p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) +# p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) +# b_dq = tl.zeros([BT, BK], dtype=tl.float32) +# b_dk = tl.zeros([BT, BK], dtype=tl.float32) +# b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) +# b_ds = tl.zeros([BT, BT], dtype=tl.float32) +# for i_v in range(tl.cdiv(V, BV)): +# p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) +# p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) +# p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) +# p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + +# b_v = tl.load(p_v, boundary_check=(0, 1)) +# b_do = tl.load(p_do, boundary_check=(0, 1)) +# b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK +# b_dh =(tl.load(p_dh, boundary_check=(0, 1))) + +# b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok +# b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen +# b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 +# b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV +# b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + +# b_q = tl.load(p_q, boundary_check=(0, 1)) +# b_q = (b_q * scale).to(b_q.dtype) +# b_k = tl.load(p_k, boundary_check=(0, 1)) +# b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT +# b_dq += tl.dot(b_ds, b_k, allow_tf32=False) +# b_dq *= scale +# b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 +# p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) +# p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) +# p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) +# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) +# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) +# tl.store(p_dw, ((-b_dw.to(p_dw.dtype.element_ty))), boundary_check=(0, 1)) + + +# @triton.jit +# def preprocess_qkw(q, +# k, +# w, +# g, +# q_new, +# k_new, +# w_new, +# T, +# H, +# K, +# r, +# BT:tl.constexpr, +# BK:tl.constexpr, +# USE_Q:tl.constexpr, +# ): +# i_k,i_bh,i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + +# p_k = tl.make_block_ptr(k + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_w = tl.make_block_ptr(w +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) +# p_g = tl.make_block_ptr(g+i_bh*T,(T,),(i_t*BT,),(BT,),(0,)) +# p_k_new = tl.make_block_ptr(k_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_w_new = tl.make_block_ptr(w_new +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + +# last_idx = min((i_t + 1) * BT, T) - 1 +# b_g_last = tl.load(g + last_idx * 1).to(tl.float32) #read BT 位置 + +# b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) +# b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) +# b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) +# b_d_last = tl.exp(b_g_last - b_g) +# b_d_begin = tl.exp(b_g) +# b_k = b_k * b_d_last[:, None] +# b_w = b_w * b_d_begin[:, None] +# tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) +# tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + +# if USE_Q: +# p_q = tl.make_block_ptr(q + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# p_q_new = tl.make_block_ptr(q_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) +# b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) +# b_q = b_q * b_d_begin[:, None] +# tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + + +# #finish +# def gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state): +# B, H, T, K, V = *k.shape,u.shape[-1] +# _,_,rT,_ = w.shape +# r = rT//T +# BK = triton.next_power_of_2(K)#直接划分好 +# assert BK <= 256, "current kernel does not support head dimension larger than 256." +# BV = 16 if BK > 128 else 32 +# BV = 64 if BK <= 64 else BV +# BC = 16 if BK > 128 else 32 +# BC = 64 if BK <= 64 else BC +# BC = min(BT, BC) +# NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) +# assert NK == 1 +# h = k.new_empty(B, H, NT * K, V) + +# grid = (NK,B*H,NT) +# k_new = torch.empty_like(k) +# w_new = torch.empty_like(w) +# preprocess_qkw[grid]( +# q=None, +# k=k, +# w=w, +# g=g, +# q_new=None, +# k_new=k_new, +# w_new=w_new, +# T=T, +# H=H, +# K=K, +# r=r, +# BT=BT, +# BK=BK, +# USE_Q=False, +# ) +# grid = (NK, NV, B * H) +# v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first +# gated_chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 +# k_new,u,w_new, +# v_new,g,h, +# initial_state, +# final_state, +# H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, +# USE_INITIAL_STATE=initial_state is not None, +# STORE_FINAL_STATE=final_state is not None, +# ) +# return h, v_new + +# #finish +# def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): +# B,H,r,T,V,K = *dv.shape,q.shape[-1] +# BK = triton.next_power_of_2(K) +# assert BK <= 256, "current kernel does not support head dimension being larger than 256." +# BV = 16 if BK > 128 else 32 +# BV = 64 if BK <= 64 else BV +# BC = 16 if BK > 128 else 32 +# BC = 64 if BK <= 64 else BC +# BC = min(BT, BC) +# NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 +# assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + +# dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 +# grid = (NK, NV, B * H) +# dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() +# dv2 = torch.empty_like(dv)#一样的 #bhr T V +# chunk_delta_rule_bwd_kernel_dhu[grid]( +# q, k, w, do, dh, dv, dv2, +# T*K,K,1, +# NT*K*V, +# K**-0.5, +# H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, +# ) +# return dh, dv2 + +# #finish +# def gated_chunk_fwd_o_fn(q, k, v_new,h,g,BT): +# B,H,r,T,V,K = *v_new.shape,q.shape[-1] +# BK = triton.next_power_of_2(K//r) +# o = torch.empty_like(v_new)#there_fore,bhr nT,bv +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# NV = triton.cdiv(V, BV) +# NT = triton.cdiv(T, BT) +# grid = (NV, NT, B * H * r) +# #h shape b h nk v +# gated_chunk_linear_attn_fwd_kernel_o[grid]( +# q, k, v_new, h, g, o, +# T*K, K, 1 , +# r*T*V,T*V,V, +# NT*K*V,V, +# scale=K**-0.5, +# H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, +# ) +# o = o.sum(dim=2)#沿着r维度求和 +# return o + + +# def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): +# B, H, T, K, V = *q.shape, v_new.shape[-1] +# _,_,RT,_ = w.shape +# r = RT // T +# #最后一个函数,计算dw,dq,dk +# BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 +# BK = min(triton.next_power_of_2(K//r), 64) +# BV = min(triton.next_power_of_2(V), 64) +# NK = triton.cdiv(K//r, BK) +# NT = triton.cdiv(T, BT) +# grid = (NK, NT, B * H * r)#通过NK控制位置 +# dq = torch.empty_like(q) +# dk = torch.empty_like(k)#k_org +# dw = torch.empty_like(w)#bh nt k + +# chunk_delta_rule_bwd_kernel_dqkw[grid]( +# q, k, v_new, w, h, do, dh, dq, dk, du, dw, +# T*K,K,1, +# T*V, V, 1, +# NT*K*V,V, +# scale=K ** -0.5, +# H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r +# ) +# return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +# class gated_ChunkDeltaRuleFunction(torch.autograd.Function): +# @staticmethod +# @contiguous +# @autocast_custom_fwd +# def forward(ctx, q, k, v, beta,g,mask,BT, initial_state, output_final_state, checkpoint_level=1): + +# g = chunk_local_cumsum(g,BT) +# Aw,Au = gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) +# Aw = solve_tril(A=Aw,output_dtype=k.dtype) +# Au = solve_tril(A=Au,output_dtype=k.dtype) +# r = mask.shape[-1] +# w, u = gated_fwd_recompute_w_u(k, v, beta, mask,Aw,Au,BT) +# final_state = None +# if output_final_state: +# final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], +# dtype=torch.float32, requires_grad=False)#这部分不需要修正 +# h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state)#need change' +# o = gated_chunk_fwd_o_fn(q, k, v_new, h, g, BT)#need change +# if checkpoint_level == 1: +# h, v_new = None, None #这里重新计算了? +# ctx.save_for_backward(q, k, v, beta,g, mask, Aw, Au , h, v_new, initial_state) +# ctx.BT = BT +# return o.to(q.dtype), final_state + +# # @staticmethod +# # @contiguous +# # @autocast_custom_bwd +# # def backward(ctx, do, d_ht=None): +# # q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors +# # BT = ctx.BT +# # r = mask.shape[-1] +# # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 +# # # checkpont_level=1, recomputation. +# # if h is None: +# # h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) +# # #v_new b h r T V +# # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish +# # #dv BHR T V +# # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv +# # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) +# # dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) +# # dk.add_(dk2) +# # return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), dmask.to(mask.dtype), dg.to(mask.dtype), None, None, None + + +# def mask_gated_chunk_delta_rule( +# q: torch.Tensor, +# k: torch.Tensor, +# v: torch.Tensor, +# beta: torch.Tensor, +# g: torch.Tensor, +# mask: torch.Tensor,#use for mask org_tensor +# BT: int, +# initial_state: torch.Tensor = None, +# output_final_state: bool = False +# ): +# assert q.dtype == k.dtype == v.dtype +# assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." +# seq_len = v.shape[-2] +# q, k, v = map(lambda x: pad(x,BT), [q, k, v]) +# beta = pad_b(beta,BT) +# g = pad_b(g,BT) +# o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) +# return o[..., :seq_len,:], final_state + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + g, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx) + b_g_last = tl.exp(b_g_last) + b_h = b_g_last * b_h + + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h)#, allow_tf32=False) + b_s += tl.dot(b_q, b_k)#, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * tl.exp(b_g)[:,None] + + b_g_diff = b_g[:, None] - b_g[None, :] + b_s = b_s * safe_exp(b_g_diff)#BT BT + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o * scale + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK"], +) +@triton.jit +def preprocess_qkw(q, + k, + w, + g, + q_new, + k_new, + w_new, + T, + H, + K, + r:tl.constexpr, + BT:tl.constexpr, + BK:tl.constexpr, + USE_Q:tl.constexpr, + ): + i_k,i_bh,i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + p_g = tl.make_block_ptr(g+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + p_k_new = tl.make_block_ptr(k_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w_new = tl.make_block_ptr(w_new +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx).to(tl.float32) #read BT 位置 + + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_d_last = tl.exp((b_g_last - b_g)) + b_d_begin = tl.exp(b_g) + b_k = b_k * b_d_last[:, None] + b_w = b_w * b_d_begin[:, None] + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + + if USE_Q: + p_q = tl.make_block_ptr(q + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q_new = tl.make_block_ptr(q_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_q = b_q * b_d_begin[:, None] + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + +#finish +def gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state): + # k, w, u, g, BT, initial_state, final_state + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + + grid = (NK,B*H,NT) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + preprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=False, + ) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + + gated_chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k_new,u,w_new, + v_new,g,h, + initial_state, + final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +#finish +def gated_chunk_fwd_o_fn(q, k, v_new,h,g,BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + gated_chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, g, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_prepare_dv_kernel( + q, + k, + g, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A* safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def gated_fwd_prepare_dv(q, k, g, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, g , do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + g, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT), (BK, BT), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r), (BK, BT * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_glast = tl.load(g + i_bh * T + last_idx) + b_glast = tl.exp(b_glast) + + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT , i_r*KR + i_k * BK), (BT, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0) + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False) + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BT*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dh *= b_glast + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)-tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + + + +def gated_chunk_bwd_dhu_fn(q, k, w, g,h0, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K,V)#一样的#need 求和 得一起算 + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + # grid = (NK,) + grid = (NK,B*H,NT) + preprocess_qkw[grid]( + q=q, + k=k, + w=w, + g=g, + q_new=q_new, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=True, + ) + + + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + gated_chunk_delta_rule_bwd_kernel_dhu[grid]( + q_new, k_new, w_new, g, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + s_g_r, + s_g_k, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,],dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh = (tl.load(p_dh, boundary_check=(0, 1)))#需要额外添加r维度 + + b_dg_last += tl.sum(b_h * b_dh) #这里是存在r求和的 + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + b_dg = tl.zeros([BT,], dtype=tl.float32) + p_g = tl.make_block_ptr(g + i_bh * T ,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_glast = tl.load(g +i_bh*T + (min(i_t * BT + BT, T) - 1)) + b_dg_last *= tl.exp(b_glast) + + + p_w = tl.make_block_ptr(w + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + b_w = tl.load(p_w,boundary_check=(0,1))#BT * r ,BK + b_dw = b_dw * tl.reshape(tl.broadcast_to(tl.reshape(tl.exp(b_g),(BT,1)),(BT,r)),(BT*r))[:,None] + b_dg -= tl.sum(tl.reshape(b_w*b_dw,(BT,r*BK)),-1) + + b_dq = b_dq*scale*tl.exp(b_g)[:,None] + b_dg += tl.sum(b_dq*tl.trans(b_q),1)#BT*BK + + b_dk = b_dk * safe_exp(b_glast-b_g)[:,None] + b_dg -= tl.sum(b_dk*b_k,1)#BT*BK + b_dg_last += tl.sum(b_dk*b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds* safe_exp(b_g[:, None] - b_g[None, :]) * scale, 0) + b_ds2 = b_ds*(tl.dot(tl.trans(b_q),tl.trans(b_k))) + + b_dg += tl.sum(b_ds2,axis=1) + b_dg -= tl.sum(b_ds2,axis=0) + b_ds = b_ds.to(b_k.dtype) + + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_r * s_g_r + i_k * s_g_k + i_bh * T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_dg = tl.where(o_i jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + + p_A = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dA2 = tl.zeros([BT*r,BT*r], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(da_mask, b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA2 = tl.dot(tl.trans(b_A), b_dA2.to(b_A.dtype), allow_tf32=False) + b_dA2 = tl.where(da_mask, -b_dA2, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA2 = tl.reshape(b_dA2,(BT,r,BT,r)) + + + p_g = tl.make_block_ptr(g_cumsum + i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:,None]-b_g[None,:])[:,None,:,None] + b_dA += b_dA2 + b_dA2 = tl.permute(b_dA2,(0,2,1,3))#Bt bt r r + + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32) + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + b_A += beta_kkt[:,:,None,None] * ((rmask[None,:] * b_mask[:,None])[None,None,:,:])#这列全广播了不对 + + betas = tl.sum(tl.sum(beta_kkt[:,None,:]*g,-1),0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + b_dA2 *= b_A #BT BT r r + b_dA2 = tl.sum(tl.reshape(b_dA2,(BT,BT,r*r)),-1) + + b_dg = tl.sum(b_dA2,1)-tl.sum(b_dA2,0) + p_dg = tl.make_block_ptr(dg+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + + +def gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dg = torch.empty_like(g) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + assert BK <= K//r + gated_bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, g, Aw,Au, + dw, du, + dk, dv, dbeta,dmask,dg, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask,dg + + +class gated_ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,g,mask,BT, initial_state, output_final_state=False, checkpoint_level=1): + B,H,L,K = q.shape + g = chunk_local_cumsum(g,BT,head_first=True,output_dtype=torch.float) + Aw,Au = gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + + Aw = solve_tril(A=Aw,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + Au = solve_tril(A=Au,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + #到这里应该没啥问题 + r = mask.shape[-1] + w, u = gated_fwd_recompute_w_u(k, v, beta, mask,Aw,Au,BT)# + + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state)#need change' + #final_state almost 一致 + o = gated_chunk_fwd_o_fn(q, k, v_new, h, g, BT)#need change + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,g, mask, Aw, Au , h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, g, mask , Aw,Au, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + + w, u = gated_fwd_recompute_w_u(k, v, beta, mask, Aw,Au,BT)#跳过 + if h is None: + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, None) + + + #从这里开始重新书写计算代码 + dv = gated_fwd_prepare_dv(q, k, g, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + + + #dv BHR T V + + + dh, dv = gated_chunk_bwd_dhu_fn(q, k, w, g,initial_state,do, dv, BT)#new_dv dh #final for wyper dv + + + + dq, dk, dw , dg = gated_chunk_bwd_dqkw_fn(q, k, v_new, w, g, h, dv, do, dh, BT)#这一步也巨慢 + + + #仅仅两个dg位置可能出错,别的不会 + + + dk2, dv, dbeta,dmask,dg2 = gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, dv, BT)#只有这里带mask + dk.add_(dk2) + dg.add_(dg2) + + #仅仅两个dg位置可能出错,别的不会 + dg = chunk_local_cumsum(dg, BT, reverse=True,head_first=True,output_dtype=torch.float) + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype),dg,dmask.to(mask.dtype),None, None, None + + +def mask_gated_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + # assert q.dtype == k.dtype == v.dtype + # assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + # o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) + # return o, final_state + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + seq_len = v.shape[-2] + q, k, v = map(lambda x: pad(x,BT), [q, k, v]) + beta = pad_b(beta,BT) + g = pad_b(g,BT) + o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) + return o[..., :seq_len,:], final_state + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta,g, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + iplr = torch.einsum(' b h q k ,b h',iplr,g[:,:,i]) + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 4 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + g = torch.nn.functional.logsigmoid(torch.randn(B, H, L).cuda()).requires_grad_(True) + g_exp = torch.exp(g) + + + # start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,g_exp,mask) + # do = torch.randn(B, H, L, DV).cuda() + # o1.backward(do, retain_graph=True) + # q_grad, q.grad = q.grad, None + # k_grad, k.grad = k.grad, None + # v_grad, v.grad = v.grad, None + # beta_grad, beta.grad = beta.grad, None + # g_grad, g.grad = g.grad, None + # mask_grad, mask.grad = mask.grad, None + # end = time.time() + # print(end-start) + + o,f_state = mask_gated_chunk_delta_rule(q, k, v, beta, g,mask,BT=32)#10s嘛 额 + # o.backward(do,retain_graph=True) + # q_grad0, q.grad = q.grad, None + # k_grad0, k.grad = k.grad, None + # v_grad0, v.grad = v.grad, None + # beta_grad0, beta.grad = beta.grad, None + # g_grad0, g.grad = g.grad, None + # mask_grad0, mask.grad = beta.grad, None + + print((o-o1).abs().max()) + # print((k_grad-k_grad0).abs().max()) + # print((v_grad-v_grad0).abs().max()) + # print((beta_grad-beta_grad0).abs().max()) + # print((mask_grad-mask_grad0).abs().max()) + # print((g_grad-g_grad0).abs().max()) + + diff --git a/fla2/ops/mask_gated_delta_rule/chunk_fuse.py b/fla2/ops/mask_gated_delta_rule/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6979fa906c6706bb07f6318b284920365db9eff --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/chunk_fuse.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +import torch.nn.functional as F + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def mask_fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v = map(lambda x: pad(x), [q, k, v]) + beta = pad_b(beta) + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + return o, final_state + + +def mask_delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + # seq_len = 128 + # b = 2 + # h = 4 + # q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + # beta = torch.rand(b, h, seq_len).sigmoid() + # q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + # do = torch.rand_like(v) + # o2 = delta_rule_recurrence(q, k, v.clone(), beta) + # o2.backward(do, retain_graph=True) + # q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + # o.backward(do, retain_graph=True) + # q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # print((o - o2).abs().max()) + # print((q_grad - q_grad2).abs().max()) + # print((k_grad - k_grad2).abs().max()) + # print((v_grad - v_grad2).abs().max()) + # print((beta_grad - beta_grad2).abs().max()) \ No newline at end of file diff --git a/fla2/ops/mask_gated_delta_rule/naive.py b/fla2/ops/mask_gated_delta_rule/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..af2f7bfc1fed0b6ae3519d97399bef1cd80b9470 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/naive.py @@ -0,0 +1,1503 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange +from typing import Optional + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.ops.utils import chunk_local_cumsum + +from fla.ops import chunk_gated_delta_rule +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + Aw, + Au, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + b_Aw = None + p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK","r"], +) +@triton.jit +def gated_chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + mask_ij, + A, + Ag, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + + p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = safe_exp(b_g_diff) + + b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT + p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor, + g_cumsum:Optional[torch.Tensor] = None, + BT:int = 32, + output_dtype: torch.dtype=torch.float32): + # gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, g_cumsum, mask, A,Ag, + T*K, K, 1, + T, K, r, BT, BK + ) + return A,Ag + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, Aw,Au, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + g, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx) + b_g_last = tl.exp(b_g_last) + b_h = b_g_last * b_h + + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h)#, allow_tf32=False) + b_s += tl.dot(b_q, b_k)#, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * tl.exp(b_g)[:,None] + + b_g_diff = b_g[:, None] - b_g[None, :] + b_s = b_s * safe_exp(b_g_diff)#BT BT + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o * scale + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK"], +) +@triton.jit +def preprocess_qkw(q, + k, + w, + g, + q_new, + k_new, + w_new, + T, + H, + K, + r:tl.constexpr, + BT:tl.constexpr, + BK:tl.constexpr, + USE_Q:tl.constexpr, + ): + i_k,i_bh,i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + p_g = tl.make_block_ptr(g+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + p_k_new = tl.make_block_ptr(k_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w_new = tl.make_block_ptr(w_new +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx).to(tl.float32) #read BT 位置 + + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_d_last = tl.exp((b_g_last - b_g)) + b_d_begin = tl.exp(b_g) + b_k = b_k * b_d_last[:, None] + b_w = b_w * b_d_begin[:, None] + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + + if USE_Q: + p_q = tl.make_block_ptr(q + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q_new = tl.make_block_ptr(q_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_q = b_q * b_d_begin[:, None] + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + +#finish +def gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state): + # k, w, u, g, BT, initial_state, final_state + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + + grid = (NK,B*H,NT) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + preprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=False, + ) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + + gated_chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k_new,u,w_new, + v_new,g,h, + initial_state, + final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +#finish +def gated_chunk_fwd_o_fn(q, k, v_new,h,g,BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + gated_chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, g, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_prepare_dv_kernel( + q, + k, + g, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A* safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def gated_fwd_prepare_dv(q, k, g, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, g , do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + g, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT), (BK, BT), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r), (BK, BT * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_glast = tl.load(g + i_bh * T + last_idx) + b_glast = tl.exp(b_glast) + + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT , i_r*KR + i_k * BK), (BT, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0) + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False) + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BT*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dh *= b_glast + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)-tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + + + +def gated_chunk_bwd_dhu_fn(q, k, w, g,h0, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K,V)#一样的#need 求和 得一起算 + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + # grid = (NK,) + grid = (NK,B*H,NT) + preprocess_qkw[grid]( + q=q, + k=k, + w=w, + g=g, + q_new=q_new, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=True, + ) + + + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + gated_chunk_delta_rule_bwd_kernel_dhu[grid]( + q_new, k_new, w_new, g, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + s_g_r, + s_g_k, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,],dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh = (tl.load(p_dh, boundary_check=(0, 1)))#需要额外添加r维度 + + b_dg_last += tl.sum(b_h * b_dh) #这里是存在r求和的 + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + b_dg = tl.zeros([BT,], dtype=tl.float32) + p_g = tl.make_block_ptr(g + i_bh * T ,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_glast = tl.load(g +i_bh*T + (min(i_t * BT + BT, T) - 1)) + b_dg_last *= tl.exp(b_glast) + + + p_w = tl.make_block_ptr(w + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + b_w = tl.load(p_w,boundary_check=(0,1))#BT * r ,BK + b_dw = b_dw * tl.reshape(tl.broadcast_to(tl.reshape(tl.exp(b_g),(BT,1)),(BT,r)),(BT*r))[:,None] + b_dg -= tl.sum(tl.reshape(b_w*b_dw,(BT,r*BK)),-1) + + b_dq = b_dq*scale*tl.exp(b_g)[:,None] + b_dg += tl.sum(b_dq*tl.trans(b_q),1)#BT*BK + + b_dk = b_dk * safe_exp(b_glast-b_g)[:,None] + b_dg -= tl.sum(b_dk*b_k,1)#BT*BK + b_dg_last += tl.sum(b_dk*b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds* safe_exp(b_g[:, None] - b_g[None, :]) * scale, 0) + b_ds2 = b_ds*(tl.dot(tl.trans(b_q),tl.trans(b_k))) + + b_dg += tl.sum(b_ds2,axis=1) + b_dg -= tl.sum(b_ds2,axis=0) + b_ds = b_ds.to(b_k.dtype) + + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_r * s_g_r + i_k * s_g_k + i_bh * T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_dg = tl.where(o_i jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + + p_A = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dA2 = tl.zeros([BT*r,BT*r], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(da_mask, b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA2 = tl.dot(tl.trans(b_A), b_dA2.to(b_A.dtype), allow_tf32=False) + b_dA2 = tl.where(da_mask, -b_dA2, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA2 = tl.reshape(b_dA2,(BT,r,BT,r)) + + + p_g = tl.make_block_ptr(g_cumsum + i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:,None]-b_g[None,:])[:,None,:,None] + b_dA += b_dA2 + b_dA2 = tl.permute(b_dA2,(0,2,1,3))#Bt bt r r + + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32) + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + b_A += beta_kkt[:,:,None,None] * ((rmask[None,:] * b_mask[:,None])[None,None,:,:])#这列全广播了不对 + + betas = tl.sum(tl.sum(beta_kkt[:,None,:]*g,-1),0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + b_dA2 *= b_A #BT BT r r + b_dA2 = tl.sum(tl.reshape(b_dA2,(BT,BT,r*r)),-1) + + b_dg = tl.sum(b_dA2,1)-tl.sum(b_dA2,0) + p_dg = tl.make_block_ptr(dg+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + + +def gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dg = torch.empty_like(g) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + assert BK <= K//r + gated_bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, g, Aw,Au, + dw, du, + dk, dv, dbeta,dmask,dg, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask,dg + + +class gated_ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,g,mask,BT, initial_state, output_final_state=False, checkpoint_level=1): + B,H,L,K = q.shape + g = chunk_local_cumsum(g,BT,head_first=True,output_dtype=torch.float) + Aw,Au = gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + + Aw = solve_tril(A=Aw,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + Au = solve_tril(A=Au,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + #到这里应该没啥问题 + r = mask.shape[-1] + w, u = gated_fwd_recompute_w_u(k, v, beta, mask,Aw,Au,BT)# + + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state)#need change' + #final_state almost 一致 + o = gated_chunk_fwd_o_fn(q, k, v_new, h, g, BT)#need change + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,g, mask, Aw, Au , h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, g, mask , Aw,Au, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = gated_fwd_recompute_w_u(k, v, beta, mask, Aw,Au,BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + if h is None: + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, None) + start = time.time() + + #从这里开始重新书写计算代码 + dv = gated_fwd_prepare_dv(q, k, g, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = gated_chunk_bwd_dhu_fn(q, k, w, g,initial_state,do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw , dg = gated_chunk_bwd_dqkw_fn(q, k, v_new, w, g, h, dv, do, dh, BT)#这一步也巨慢 + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + #仅仅两个dg位置可能出错,别的不会 + + start = time.time() + dk2, dv, dbeta,dmask,dg2 = gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, dv, BT)#只有这里带mask + dk.add_(dk2) + dg.add_(dg2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + #仅仅两个dg位置可能出错,别的不会 + dg = chunk_local_cumsum(dg, BT, reverse=True,head_first=True,output_dtype=torch.float) + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype),dg,dmask.to(mask.dtype),None, None, None + + +def mask_gated_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) + return o, final_state + + +def delta_rule_recurrence(q, k, v, beta,g, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v,device=k.device,dtype=torch.float32) + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + iplr = torch.einsum(' b h q k ,b h->b h q k',iplr,g[:,:,i]) + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr.float(),S.clone()) + _k.unsqueeze(-1).float() * _v.unsqueeze(-2).float() + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q.float(), S).to(k.dtype) + return o,S + + +if __name__ =="__main__": + import sys + import time + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + # for i in range(200): + B = 16 + H = 4 + L = 128 + DK = 256 + DV = 256 + r = 4 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(True).contiguous() + + # mask = torch.ones([2,2]) + # mask = mask.cuda().requires_grad_(True).contiguous() + + g = torch.nn.functional.logsigmoid(torch.randn(B, H, L).cuda()).requires_grad_(True) + g_exp = (torch.exp(g)) + + do = torch.randn(B, H, L, DV).cuda() + o1,ss = delta_rule_recurrence(q,k,v,beta,g_exp,mask) + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + mask_grad, mask.grad = mask.grad, None + beta_grad, beta.grad = beta.grad, None + g_grad, g.grad = g.grad, None + # end = time.time() + # print(end-start) + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + # o,f_state = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + # o2,f_state = mask_chunk_delta_rule(q, k, v,beta,mask,BT=32) + + # qh,kh,vh,betah,gh = map(lambda x: rearrange(x, 'b h l ... -> b l h ...'), (q, k, v, beta, g)) + # o,f_state = chunk_gated_delta_rule(qh,kh,vh,gh,(betah*rearrange(mask,'c r-> (c r)')).contiguous(),use_qk_l2norm_in_kernel=False,output_final_state=True) + # o = rearrange(o,'b l h d->b h l d') + o,f_state = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + mask_grad0, mask.grad = mask.grad, None + g_grad0, g.grad = g.grad, None + print((o1-o).abs().max()) + print((f_state-ss).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + print((mask_grad-mask_grad0).abs().max()) + + print((g_grad-g_grad0).abs().max()) + print(g_grad) + print(g_grad0) + + + # o2,f_state2 = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + # o2.backward(do,retain_graph=True) + # q_grad2, q.grad = q.grad, None + # k_grad2, k.grad = k.grad, None + # v_grad2, v.grad = v.grad, None + # beta_grad2, beta.grad = beta.grad, None + # mask_grad2, mask.grad = mask.grad, None + + # print((o-o2).abs().max()) + # print((f_state-f_state2).abs().max()) + + # print((q_grad2-q_grad0).abs().max()) + # print((k_grad2-k_grad0).abs().max())#计算结果差距大 差距到1 + # print((v_grad2-v_grad0).abs().max()) + # print((beta_grad2-beta_grad0).abs().max()) + # print((mask_grad2-mask_grad0).abs().max()) + # print('naive:',mask_grad2) + # print('triton:',mask_grad0) + # print(k_grad2) + # print(k_grad0) + + + # BT = 16 + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + # print('finish0') + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, None, None)#need change' + # print('finish1') + # o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + # print('finish2') + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # print('finish3') + # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # print('finish4') + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + # print('finish5') + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + # print('finish6') + + # Ass = rearrange(A,'b h (n t) l->b h n t l',n = L//BT) + # dwss = rearrange(dw,'b h (n t) k->b h n t k',n = L//BT) + # dvss = rearrange(dv,'b h (n t) k->b h n t k',n = L//BT) + # dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + # print('triton:',dmask) #几乎完全相等 + + # vbeta = v*beta[...,None] + # vbeta = rearrange(vbeta,'b h (n T) d->b h n T d',T=BT) + # vbeta = vbeta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + # vbeta = rearrange(vbeta,'b h n t r d-> b h n (t r) d') + + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta = torch.einsum('b h n T r d,c r-> b h n T c r d',kbeta,mask) + # kbeta = rearrange(kbeta,'b h n t c r d-> b h n (t c) (r d)') + # dA = dvss@vbeta.transpose(-1,-2)+dwss@kbeta.transpose(-1,-2) + + + # dorg = Ass.transpose(-1,-2)@dwss#bhn bt*r k + # dorg = rearrange(dorg,'b h n (t r) (c k)->b h n t r c k',r=r,c=r) + # betan = rearrange(beta,'b h (n t)->b h n t',n=L//BT) + # kn = rearrange(k,'b h (n t) (r d)->b h n t r d ',n = L//BT,r=r) + + # dmask = torch.einsum('b h n t r c k,b h n t->b h n t r c k',dorg,betan) + # dmask = torch.einsum('b h n t r c k,b h n t c k->b h n t r c k',dmask,kn) + # dmask = rearrange(dmask,'b h n t r c k-> (b h n) (t k) r c') + # dmaskss = dmask.sum(0).sum(0) + + # i = torch.arange(0, BT * r)[:, None] + # j = torch.arange(0, BT * r)[None, :] + # iB = i // r + # jB = j // r + # da_mask = iB > jB + # da_mask = da_mask.cuda() + # b_dA = torch.where(da_mask, dA, 0) + + # b_dA = b_dA @ Ass.transpose(-1,-2) + # b_dA = Ass.transpose(-1,-2)@b_dA + + # b_dA = torch.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + # b_dA = rearrange(b_dA,'b h n (t r) (l c)-> b h n t r l c',c=r,r=r) + # # print((dAss-b_dA).abs())#到这里也完全相等 + + + # # betakkt = k*beta[...,None] + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta2 = rearrange(k,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # betakkt = torch.einsum('b h n T r d,b h n s r d->b h n r T s',kbeta,kbeta2)#r Bt bt + # betakkt = rearrange(betakkt,'b h n r T s->b h n T s r')#BT r BT###横向 + # # print((dAss-b_dA).abs()) + + # #证明是下面的计算出错了 + # dmask = torch.einsum('b h n t r l c,b h n t l c-> b h n t r l c',b_dA,betakkt) + # # print((dAss-dmask).abs().max())#意味着这个计算结果也是对的 + # # print((dAss-dmask)) + + # dmask = rearrange(dmask,'b h n t r l c->b h n (t l) r c') + # dmask = dmask.sum(-3) + # dmask = dmask.sum(0).sum(0).sum(0) + # print('matrix:',dmask) + + + + + + + + diff --git a/fla2/ops/mask_gated_delta_rule/naive_rmbeta copy.py b/fla2/ops/mask_gated_delta_rule/naive_rmbeta copy.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/naive_rmbeta copy.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_gated_delta_rule/naive_rmbeta.py b/fla2/ops/mask_gated_delta_rule/naive_rmbeta.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/naive_rmbeta.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_gated_delta_rule/recurrent_fuse.py b/fla2/ops/mask_gated_delta_rule/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..f21470ff11d7e75df52b0c81dcb66bd40a44a0e5 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/recurrent_fuse.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + beta, # beta [B, H, L] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, # batch size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_qk_h, # stride size: L * K + + s_vo_h, # stride size: L * V + + NK, # NK block size + scale, # K ** -0.5 + + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= b_k[:, None] * d_v[None, :] + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @contiguous + @staticmethod + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale + return o, final_state + + @contiguous + @staticmethod + def backward(ctx, do, dht=None): + q, k, v, beta, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None + + +def mask_fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/mask_gated_delta_rule/utils.py b/fla2/ops/mask_gated_delta_rule/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..173d6629c628bb6b5860a005cbc8ea85d7cf9b5e --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/utils.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @contiguous + @autocast_custom_fwd + @staticmethod + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @contiguous + @autocast_custom_bwd + @staticmethod + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) diff --git a/fla2/ops/mask_gated_delta_rule/wy_fast.py b/fla2/ops/mask_gated_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2e6b2e79e2a70f6ab19f6fd432225369d70857 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/wy_fast.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +from typing import Optional +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + Aw, + Au, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + b_Aw = None + p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK","r"], +) +@triton.jit +def gated_chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + mask_ij, + A, + Ag, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + + p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = safe_exp(b_g_diff) + + b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT + p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor, + g_cumsum:Optional[torch.Tensor] = None, + BT:int = 32, + output_dtype: torch.dtype=torch.float32): + # gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, g_cumsum, mask, A,Ag, + T*K, K, 1, + T, K, r, BT, BK + ) + return A,Ag + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, Aw,Au, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + + +# class WYRepresentationPrepration(torch.autograd.Function): +# @staticmethod +# @contiguous +# @autocast_custom_fwd +# def forward(ctx, k, v, beta,mask,chunk_size=64): +# ctx.BT = chunk_size +# w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) +# ctx.save_for_backward(k, v, beta,mask,A) +# return w, u +# @staticmethod +# @contiguous +# @autocast_custom_bwd +# def backward(ctx, dw, du): +# k, v, beta,mask, A = ctx.saved_tensors +# BT = ctx.BT +# dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) +# return dk, dv, dbeta, dmask, None + +# prepare_wy_repr = WYRepresentationPrepration.apply + + +# def naive(k, v, beta,maskij,chunk_size): +# l_org = k.shape[2] +# l_new = triton.next_power_of_2(l_org) +# k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) +# v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) +# beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) +# k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) +# beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + +# b,h,nt,BT,dk = k.shape +# dv = v.shape[-1] +# r = maskij.shape[-1] +# k_beta = k * beta[..., None] +# k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) +# k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) +# k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org +# v_beta = v * beta[..., None] +# v_beta = v_beta +# v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) +# ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + +# attn = (ki @ ki.transpose(-1, -2)) +# attn = torch.tril(attn, diagonal=-1)#bhnr cc +# attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc +# attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + +# o = torch.zeros_like(k_beta) +# o2 = torch.zeros_like(v_beta) + +# o[..., 0, :,:] = k_beta[..., 0,:,:].clone() +# o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() +# for i in range(1, chunk_size): +# o_i = (o[..., :i,:,:]).clone()#bhn :t cc +# o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) +# o2_i = (o2[..., :i,:,:]).clone()#少一个维度 +# o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) +# return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +# if __name__ == "__main__": +# #all compute here +# import sys +# sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') +# torch.set_default_dtype(torch.bfloat16) +# seq_len = 32 +# b = 2 +# h = 2 +# k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 +# v = torch.randn(b, h, seq_len, 128) +# beta = torch.rand(b, h, seq_len).sigmoid() +# require_grad = True +# BT = 16 +# k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) +# r = 4 +# # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() +# mask = torch.randn([r,r]) +# mask = mask.cuda().requires_grad_(require_grad).contiguous() +# # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) +# # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) +# # from einops import rearrange + +# k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = 16,r=r) +# b2 = rearrange(beta,'b h (n t)-> b h n t',t = 16) +# a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt +# qq = torch.tril(a1,diagonal=-1) +# qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) +# sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') +# sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + +# # #长条对角线 +# i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) +# s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() +# s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') +# s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + +# # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r +# # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.float32) +# # s = rearrange(s,'b h n a c->(b h) (n a) c') +# # print(Ad) +# # print(s) +# # print((Ad-s).abs().max()) + +# w,u,As = fwd_prepare_wy_repr(k, v, beta,mask, 16) +# As = rearrange(As,'b h (n t) l->(b h n) t l',t =BT*r) +# # print((As-s).abs().max()) +# # B*H*NT,BT*r,16*r +# # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) +# # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) +# # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') +# # wc = s_copy@k_exp + +# # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) +# # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) +# # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) +# # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') +# # uc = s_copy@v_exp +# # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) +# # do = torch.rand_like(wc) +# # do2 = torch.rand_like(uc)#b h n t t +# # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 +# # do = torch.rand_like(o1) +# # do2 = torch.rand_like(o2)#b h n t t +# # if require_grad: +# # o1.backward(do, retain_graph=True) +# # o2.backward(do2, retain_graph=True) +# # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + +# # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) +# # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + +# # print((o1-w0).abs().max()) +# # print((o2-u0).abs().max()) +# # print((k_grad-k_grad2).abs().max()) +# # print((v_grad-v_grad2).abs().max()) +# # print((beta_grad-beta_grad2).abs().max()) +# # print((mask_grad-mask_grad2).abs().max()) +# # print(mask_grad) +# # print(mask_grad2) + + diff --git a/fla2/ops/mask_gated_delta_rule/wy_fast_test.py b/fla2/ops/mask_gated_delta_rule/wy_fast_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22aba7278db186f6b7139b33d446813078728861 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule/wy_fast_test.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +import pdb +import torch +import triton +import triton.language as tl +from einops import rearrange +# from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, BT)[:, None, None, None] == tl.arange(0, BT)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + # r_mask = tl.arange(0, r) == i_r # + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta,dmask, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_dmask = tl.zeros([r,r],dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r + i_r#读取第ir列 + b_mask = tl.load(p_mask)#第r列 + rmask = tl.arange(0, r) == i_r #第r列 + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False) + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + b_dk = sum_dk* b_beta[:, None] + b_dbeta += tl.sum(sum_dk * b_k, 1) + + + b_ss = b_dk_beta * b_beta[:,None,None] * b_k[:,None,:] + b_ss = tl.reshape(tl.permute(b_ss,(2,0,1)),(BT*BK,r)) + b_ss = tl.sum(b_ss,0) + # b_ss = (tl.sum(tl.sum(b_dk_beta * b_beta[:,None,None] * b_k[:,None,:],0),-1)) + b_dmask += (b_ss[:,None]*rmask[None,:]).to(tl.float32) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + + + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + #bt r bt r + + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + #对应的c部分 + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + + beta_y = (beta_kkt[:,None,:]*g) + beta_y = tl.reshape(tl.permute(beta_y,(2,0,1)),(BT*BT,r)) + betas = tl.sum(beta_y,0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "r"], +) +@triton.jit +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + mask_ij, + A, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,32,r,r),(32*r*r,r*r,r,1) ,(i_t * 32 + 16, 0, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1,2,3)).to(tl.float32) + b_A21 = tl.permute(b_A21,(0,2,1,3)) + b_A21 = tl.reshape(b_A21,(16*r,16*r))#BT*r BT*r + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + +def chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, mask, A, + T*K, K, 1, + T, K, r, BT, BK + ) + return A + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def fwd_prepare_wy_repr2(k, v, beta,mask, BT): + A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,torch.float32) + A = solve_tril(A=A,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + w, u = fwd_recompute_w_u(k, v, beta,mask, A, BT) + return w, u, A + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=k.dtype) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + return w, u, A + + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A, + dw, du, + dk, dv, dbeta,dmask, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask + + +class WYRepresentationPrepration(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, k, v, beta,mask,chunk_size=64): + ctx.BT = chunk_size + w, u, A = fwd_prepare_wy_repr(k, v,beta,mask, ctx.BT) + ctx.save_for_backward(k, v, beta,mask,A) + return w, u + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, dw, du): + k, v, beta,mask, A = ctx.saved_tensors + BT = ctx.BT + dk, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta,mask, A, dw, du, BT) + return dk, dv, dbeta, dmask, None + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta,maskij,chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + + b,h,nt,BT,dk = k.shape + dv = v.shape[-1] + r = maskij.shape[-1] + k_beta = k * beta[..., None] + k_beta = rearrange(k_beta,'b h n t (r k)->b h n t r k', r=r) + k_beta = torch.einsum('b h n t r k,l r-> b h n t l r k',k_beta,maskij) + k_beta = rearrange(k_beta,'b h n t l r k->b h n t l (r k)')#l=1 rk=org + v_beta = v * beta[..., None] + v_beta = v_beta + v_beta = v_beta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + ki = rearrange(k,'b h n c (r k)-> b h n r c k',r=r) + + attn = (ki @ ki.transpose(-1, -2)) + attn = torch.tril(attn, diagonal=-1)#bhnr cc + attn = torch.einsum('b h n r t l,c r->b h n t l c r',attn,maskij)#bhn rr cc + attn = torch.einsum('b h n t l c r,b h n t->b h n t l c r',attn,beta) + + o = torch.zeros_like(k_beta) + o2 = torch.zeros_like(v_beta) + + o[..., 0, :,:] = k_beta[..., 0,:,:].clone() + o2[..., 0,:, :] = v_beta[..., 0,:,:].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i,:,:]).clone()#bhn :t cc + o[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o_i).sum(3) + k_beta[..., i,:,:]) + o2_i = (o2[..., :i,:,:]).clone()#少一个维度 + o2[..., i,:,:] = (-(attn[:,:,:,i, :i,:,:]@o2_i).sum(3) + v_beta[..., i,:,:]) + return map(lambda x: rearrange(x, 'b h n c r k -> b h (n c r) k'), (o, o2)) + + +if __name__ == "__main__": + #all compute here + import sys + torch.manual_seed(42) + sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + seq_len = 128 + b = 2 + h = 2 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + v = torch.randn(b, h, seq_len, 128) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + BT = 32 + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v, beta)) + r = 4 + # mask = torch.tensor([[1,1,0,0],[0.5,1,0.5,0],[0,0.5,1,0.5],[0,0,1,1]]).cuda().contiguous() + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(require_grad).contiguous() + # w,u,a0 = fwd_prepare_wy_repr(k,v,beta,mask, 16) + # w2,u2 = fwd_recompute_w_u(k,v,beta,mask,a0,16) + # from einops import rearrange + + k2 = rearrange(k,'b h (n t) (r k)-> b h n r t k',t = BT,r=r) + b2 = rearrange(beta,'b h (n t)-> b h n t',t = BT) + a1 = (k2*b2.unsqueeze(-2).unsqueeze(-1))@k2.transpose(-1,-2)#bhnrtt + qq = torch.tril(a1,diagonal=-1) + qq = torch.einsum('b h n r t l,c r-> b h n t c l r',qq,mask) + sf = rearrange(qq,'b h n t c l r->b h n (t c) (l r)') + sf = rearrange(sf,'b h n (t c) (l r)->b h n t l c r',c=r ,r =r)#这个 + + # #长条对角线 + i_mask = ((torch.arange(0, BT)[:, None, None, None] == torch.arange(0, BT)[None, :, None, None]) & (torch.arange(0, r)[None, None, :, None] == torch.arange(0, r)[None, None, None, :])) + s = sf+i_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).cuda() + s = rearrange(s,'b h n a d c r->b h n (a c) (d r)') + s = torch.linalg.inv(s.float()).to(k)#矩阵逆#bhn tr tr + + + # A = chunk_scaled_dot_kkt_fwd(k,beta,mask,BT,output_dtype=torch.float32)#bh nt BT bt r r + # Ad = solve_tril(A,mask,k,BT,output_dtype=torch.bfloat16) + # s = rearrange(s,'b h n a c->(b h n) a c') + # print(Ad.shape) + # print(s.shape) + + w,u,As = fwd_prepare_wy_repr2(k, v, beta,mask, BT) + # w2,u2,Ad2 = fwd_prepare_wy_repr(k, v, beta,mask, BT) + + # print((w2-w).abs().max()) + # print((u2-u).abs().max()) + # print((As-Ad2).abs().max()) + + # print((Ad-s).abs().max()) + # print(Ad-s) + + # print((As-s).abs().max()) + # print(As-s) + # B*H*NT,BT*r,16*r + # k_exp = torch.einsum('b h n r t k,b h n t-> b h n r t k',k2,b2) + # k_exp = torch.einsum('b h n r t k,c r-> b h n r t k c',k_exp,mask) + # k_exp = rearrange(k_exp,'b h n r t k c->b h n (t c) (r k)') + # wc = s_copy@k_exp + + # v_exp = rearrange(v,'b h (n t) v-> b h n t v',t = BT) + # v_exp = torch.einsum('b h n t v,b h n t-> b h n t v',v_exp,b2) + # v_exp = v_exp.unsqueeze(4).expand(-1,-1,-1,-1,r,-1) + # v_exp = rearrange(v_exp, ' b h n t r v-> b h n (t r) v') + # uc = s_copy@v_exp + # wc,uc = map(lambda x: rearrange(x,"b h n t r->b h (n t) r"), (wc,uc)) + # do = torch.rand_like(wc) + # do2 = torch.rand_like(uc)#b h n t t + # o1, o2 = naive(k.clone(), v.clone(), beta.clone(),mask.clone(), BT)#这个代码有问题 + # do = torch.rand_like(o1) + # do2 = torch.rand_like(o2)#b h n t t + # if require_grad: + # o1.backward(do, retain_graph=True) + # o2.backward(do2, retain_graph=True) + # k_grad2, v_grad2, beta_grad2,mask_grad2 = k.grad, v.grad, beta.grad, mask.grad + + # w0,u0,s0 = fwd_prepare_wy_repr(k, v, beta,mask, 16) + # k_grad, v_grad, beta_grad,mask_grad = bwd_prepare_wy_repr(k,v,beta,mask,s0,do,do2,BT) + + # print((o1-w0).abs().max()) + # print((o2-u0).abs().max()) + # print((k_grad-k_grad2).abs().max()) + # print((v_grad-v_grad2).abs().max()) + # print((beta_grad-beta_grad2).abs().max()) + # print((mask_grad-mask_grad2).abs().max()) + # print(mask_grad) + # print(mask_grad2) + + diff --git a/fla2/ops/mask_gated_delta_rule_t/README.md b/fla2/ops/mask_gated_delta_rule_t/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ab2d485a9552d70238c1f68288c72c62f9e0ef2 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/README.md @@ -0,0 +1,4 @@ +- Delta Rule + +The implementation of delta rule described in https://arxiv.org/abs/2102.11174 + diff --git a/fla2/ops/mask_gated_delta_rule_t/__init__.py b/fla2/ops/mask_gated_delta_rule_t/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c675b3da981726a2b4a9919545e4f569682d710a --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import mask_gated_chunk_delta_rule +# from .chunk_fuse import mask_fused_chunk_delta_rule +# from .recurrent_fuse import mask_fused_recurrent_delta_rule + +__all__ = [ + # 'mask_fused_chunk_delta_rule', + # 'mask_fused_recurrent_delta_rule', + 'mask_gated_chunk_delta_rule', +] diff --git a/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-310.pyc b/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35d513f4271be90fc2a35eea5dafbbf915a2b0e3 Binary files /dev/null and b/fla2/ops/mask_gated_delta_rule_t/__pycache__/__init__.cpython-310.pyc differ diff --git a/fla2/ops/mask_gated_delta_rule_t/chunk.py b/fla2/ops/mask_gated_delta_rule_t/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..039a1a6700d2aa89be8578f922f41c1de27c9813 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/chunk.py @@ -0,0 +1,1521 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd,contiguous +from fla.ops.utils import chunk_local_cumsum +import torch.nn.functional as F +from typing import Optional +from fla.ops.gated_delta_rule import chunk_gated_delta_rule + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + Aw, + Au, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask.to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + b_Aw = None + p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK","r"], +) +@triton.jit +def gated_chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + mask_ij, + A, + Ag, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + ij_mask = b_mask*r_mask[None,None,:]#行数 #BT [r,r] + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False)#BT BT + b_A += dot[:,:,None,None]*ij_mask[:,None,:,:]#BT r r + + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + + p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = safe_exp(b_g_diff) + + b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT + p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor, + g_cumsum:Optional[torch.Tensor] = None, + BT:int = 32, + output_dtype: torch.dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] #B H T r r + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, g_cumsum, mask, A,Ag, + T*K, K, 1, + T, K, r, BT, BK + ) + return A,Ag + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, Aw,Au, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + return x + +def pad_b(x,val, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=val) # 只在最后一个维度(l)进行填充 + return x + +def pad_m(x,val, chunk_size=16): + seq_len = x.shape[-3] # 获取序列长度 b h l r r + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0,0,0,0,0,padded_seq_len - seq_len),value=val) # 只在最后一个维度(l)进行填充 + return x + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + g, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + # B,H,NV,NT + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (K, T), (1, K), + (i_k * BK + i_r * BK//r, i_t * BT), (BK//r,BT), (0, 1))#读取对应 + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT , i_v * BV), (BT , BV), (1, 0)) + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r * K ),(r * K, 1), + (i_t * BT, i_r * K + i_k * BK), (BT,BK),(1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r * V ),(r * V, 1), + (i_t * BT, i_r * K + i_v * BV), (BT,BV),(1,0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v -= tl.dot(b_d, b_h.to(b_d.dtype)).to(b_v.dtype) + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + kv = tl.dot((b_k),b_v)####小数乘以大数的精度问题 + b_h_cumsum = tl.where(r_mask[:,None,None],b_h_cumsum + kv[None,:,:] ,b_h_cumsum) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx) + b_g_last = tl.exp(b_g_last) + b_h = b_g_last * b_h + + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h.to(b_q.dtype)) + b_s += tl.dot(b_q, b_k) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * tl.exp(b_g)[:,None] + + b_g_diff = b_g[:, None] - b_g[None, :] + b_s = b_s * safe_exp(b_g_diff)#BT BT + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o * scale + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK"], +) +@triton.jit +def preprocess_qkw(q, + k, + w, + g, + q_new, + k_new, + w_new, + T, + H, + K, + r:tl.constexpr, + BT:tl.constexpr, + BK:tl.constexpr, + USE_Q:tl.constexpr, + ): + i_k,i_bh,i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + p_g = tl.make_block_ptr(g+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + p_k_new = tl.make_block_ptr(k_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w_new = tl.make_block_ptr(w_new +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx).to(tl.float32) #read BT 位置 + + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_d_last = tl.exp((b_g_last - b_g)) + b_d_begin = tl.exp(b_g) + b_k = b_k * b_d_last[:, None] + b_w = b_w * b_d_begin[:, None] + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + + if USE_Q: + p_q = tl.make_block_ptr(q + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q_new = tl.make_block_ptr(q_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_q = b_q * b_d_begin[:, None] + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state): + # k, w, u, g, BT, initial_state, final_state + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = torch.empty(B, H, NT * K, V,device=k.device,dtype=k.dtype) + grid = (NK,B*H,NT) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + preprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=False, + ) + + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + gated_chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k_new,u,w_new, + v_new, + g,h, + initial_state, + final_state, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def gated_chunk_fwd_o_fn(q, k, v_new,h,g,BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + gated_chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, g, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_prepare_dv_kernel( + q, + k, + g, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A* safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def gated_fwd_prepare_dv(q, k, g, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, g , do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + g, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT), (BK, BT), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r), (BK, BT * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_glast = tl.load(g + i_bh * T + last_idx) + b_glast = tl.exp(b_glast) + + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT , i_r*KR + i_k * BK), (BT, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0) + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False) + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BT*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dh *= b_glast + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)-tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + + + +def gated_chunk_bwd_dhu_fn(q, k, w, g,h0, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K,V)#一样的#need 求和 得一起算 + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + # grid = (NK,) + grid = (NK,B*H,NT) + preprocess_qkw[grid]( + q=q, + k=k, + w=w, + g=g, + q_new=q_new, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=True, + ) + + + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + gated_chunk_delta_rule_bwd_kernel_dhu[grid]( + q_new, k_new, w_new, g, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + s_g_r, + s_g_k, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,],dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh = (tl.load(p_dh, boundary_check=(0, 1)))#需要额外添加r维度 + + b_dg_last += tl.sum(b_h * b_dh) #这里是存在r求和的 + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + b_dg = tl.zeros([BT,], dtype=tl.float32) + p_g = tl.make_block_ptr(g + i_bh * T ,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_glast = tl.load(g +i_bh*T + (min(i_t * BT + BT, T) - 1)) + b_dg_last *= tl.exp(b_glast) + + + p_w = tl.make_block_ptr(w + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + b_w = tl.load(p_w,boundary_check=(0,1))#BT * r ,BK + b_dw = b_dw * tl.reshape(tl.broadcast_to(tl.reshape(tl.exp(b_g),(BT,1)),(BT,r)),(BT*r))[:,None] + b_dg -= tl.sum(tl.reshape(b_w*b_dw,(BT,r*BK)),-1) + + b_dq = b_dq*scale*tl.exp(b_g)[:,None] + b_dg += tl.sum(b_dq*tl.trans(b_q),1)#BT*BK + + b_dk = b_dk * safe_exp(b_glast-b_g)[:,None] + b_dg -= tl.sum(b_dk*b_k,1)#BT*BK + b_dg_last += tl.sum(b_dk*b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds* safe_exp(b_g[:, None] - b_g[None, :]) * scale, 0) + b_ds2 = b_ds*(tl.dot(tl.trans(b_q),tl.trans(b_k))) + + b_dg += tl.sum(b_ds2,axis=1) + b_dg -= tl.sum(b_ds2,axis=0) + b_ds = b_ds.to(b_k.dtype) + + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_r * s_g_r + i_k * s_g_k + i_bh * T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_dg = tl.where(o_i jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + + p_A = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dA2 = tl.zeros([BT*r,BT*r], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(da_mask, b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA2 = tl.dot(tl.trans(b_A), b_dA2.to(b_A.dtype), allow_tf32=False) + b_dA2 = tl.where(da_mask, -b_dA2, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA2 = tl.reshape(b_dA2,(BT,r,BT,r)) + + + p_g = tl.make_block_ptr(g_cumsum + i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:,None]-b_g[None,:])[:,None,:,None] + b_dA += b_dA2 + b_dA2 = tl.permute(b_dA2,(0,2,1,3))#Bt bt r r + + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32) + + for i_r in range(r):#只取ir项 + p_mask = tl.make_block_ptr(mask_ij + i_bh * T*r*r,(T,r,r),(r*r,r,1),(i_t*BT,0,i_r),(BT,r,1),(2,1,0)) + b_mask = tl.load(p_mask)#BT r 1 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask,1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + b_A += beta_kkt[:,:,None,None] * ((rmask[None,None,:] * b_mask)[:,None,:,:])#这列全广播了不对 + + betas = (tl.sum(beta_kkt[:,None,:]*g,-1))#BT r + b_dmask += (betas[:,:,None]*rmask[None,None,:]).to(tl.float32) + + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T) + i_t * BT)* r * r , (BT,r,r), (r*r,r,1), (0,0,0), (BT,r,r), (2,1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + b_dA2 *= b_A #BT BT r r + b_dA2 = tl.sum(tl.reshape(b_dA2,(BT,BT,r*r)),-1) + + b_dg = tl.sum(b_dA2,1)-tl.sum(b_dA2,0) + p_dg = tl.make_block_ptr(dg+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + +def gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dg = torch.empty_like(g) + dmask = torch.zeros([B,H,T,r,r],device=k.device,dtype=k.dtype).contiguous() + assert BK <= K//r + gated_bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, g, Aw,Au, + dw, du, + dk, dv, dbeta,dmask,dg, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + # dmask = dmask.sum(0) + return dk, dv, dbeta, dmask,dg + + + +class gated_ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,g,mask,BT, initial_state, output_final_state=False, checkpoint_level=1): + B,H,L,K = q.shape + r = mask.shape[-1] + g = chunk_local_cumsum(g,BT,head_first=True,output_dtype=torch.float) #无需变化 + #注意 mask 变成 B H T r d + Aw,Au = gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + Aw = solve_tril(A=Aw,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + Au = solve_tril(A=Au,mask=mask,k=k,BT=BT,output_dtype=k.dtype)#bh + w, u = gated_fwd_recompute_w_u(k, v, beta, mask,Aw,Au,BT)# + + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state)#need change' + o = gated_chunk_fwd_o_fn(q, k, v_new, h, g, BT) + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,g, mask, Aw, Au , h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + # return o.to(q.dtype), h_s, final_state + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, g, mask , Aw,Au, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + + w, u = gated_fwd_recompute_w_u(k, v, beta, mask, Aw,Au,BT)#跳过 + if h is None: + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, None) + + #从这里开始重新书写计算代码 + dv = gated_fwd_prepare_dv(q, k, g, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + dh, dv = gated_chunk_bwd_dhu_fn(q, k, w, g,initial_state,do, dv, BT)#new_dv dh #final for wyper dv + dq, dk, dw , dg = gated_chunk_bwd_dqkw_fn(q, k, v_new, w, g, h, dv, do, dh, BT)#这一步也巨慢 + + dk2, dv, dbeta,dmask,dg2 = gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, dv, BT)#只有这里带mask + dk.add_(dk2) + dg.add_(dg2) + dg = chunk_local_cumsum(dg, BT, reverse=True,head_first=True,output_dtype=torch.float) + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype),dg,dmask.to(mask.dtype),None, None, None + + + +def delta_rule_recurrence(q, k, v, beta, g, mask,initial_state=None,output_final_state=True): + g_exp = torch.exp(g).float() + BT = 32 + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if l%BT==0: + S_t = torch.zeros(b, h, l//BT, d_k, d_v,device=k.device,dtype=torch.float32) + else: + S_t = torch.zeros(b, h, l//BT + 1, d_k, d_v,device=k.device,dtype=torch.float32) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v,device=k.device,dtype=torch.float32) + else: + S = initial_state + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + if i%BT==0: + S_t[:,:,i//BT,:,:] = S + _k = k[:, :, i].float() + _q = q[:, :, i].float()*(d_k ** -0.5) + _v = v[:, :, i].float() + beta_i = beta[:, :, i].float() + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,b h r l->b h r d l v',kkt,mask[:,:,i,:,:].float())#16d参数,几乎可以忽略 + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + iplr = torch.einsum(' b h q k ,b h->b h q k',iplr,g_exp[:,:,i]) + S = torch.einsum('b h q k ,b h k v->b h q v',iplr.float(),S) + _k.unsqueeze(-1).float() * _v.unsqueeze(-2).float() + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q.float(), S).to(k.dtype) + return o,S_t,S + + +def mask_gated_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + seq_len = v.shape[-2] + q, k, v = map(lambda x: pad(x,BT), [q, k, v]) + dim = v.shape[-1] + r = mask.shape[-1] + if dim < r*16: + q,k,v = map(lambda x:rearrange(x,'b h l (r d)->b h l r d',r=r),[q,k,v]) + q,k,v = map(lambda x:F.pad(x, (0, 16 - dim//r),value=0),[q,k,v])#基本只有32存在意义 + q,k,v = map(lambda x:rearrange(x,'b h l r d->b h l (r d)',r=r),[q,k,v]) + beta = pad_b(beta,0,BT)#bhl + g = pad_b(g,0,BT)#bhl + mask = pad_m(mask,0,BT) + q,k,v,g,beta,mask = map(lambda x:x.contiguous(),[q,k,v,g,beta,mask]) + o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) + o = o[..., :seq_len,:] + if dim < r*16: + o = rearrange(o,'b h l (r d)->b h l r d',r=r) + o = o[...,:dim//r]#保留dim + o = rearrange(o,'b h l r d->b h l (r d)') + return o, final_state + + +if __name__ =="__main__": + import sys + import time + from fla.modules.l2norm import l2_norm as l2_norm_fn + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + print('test') + B = 8 + H = 8 + L = 227 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + + r = 4 + mask = torch.randn(1,r).cuda()#if 全1 直接return0 + mask = mask.expand(r,r).requires_grad_(True) + + target_matrix = torch.softmax(mask,dim=-1)#h r c + eye_mask = torch.eye(r, dtype=torch.bool, device=target_matrix.device).unsqueeze(0) + target_matrix = torch.where(eye_mask, torch.tensor(1.0, device=target_matrix.device), target_matrix) + target_matrix = (target_matrix.unsqueeze(1).unsqueeze(0).expand(B,H,L,r,r)) + + + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + g = torch.nn.functional.logsigmoid(torch.randn(B, H, L).cuda()).requires_grad_(True) + # # dict = {"q":q,"k":k,"v":v,'beta':beta,"g":g,"mask":target_matrix} + # # torch.save(dict,'/mnt/jfzn/msj/log.pth') + + # dicts= torch.load('/mnt/jfzn/msj/log.pth') + # q = dicts["q"] + # k = dicts["k"] + # v = dicts["v"] + # beta = dicts["beta"] + # g = dicts["g"] + # mask = target_matrix = dicts["mask"] + # B,H,L,DV = v.shape + # g_exp = torch.exp(g) + q_trans,k_trans,v_trans = map(lambda x:rearrange(x,' b h l d->b l h d'),[q,k,v]) + beta_trans,g_trans = map(lambda x:rearrange(x,' b h l->b l h'),[beta,g]) + + o11,ss = chunk_gated_delta_rule(q_trans,k_trans,v_trans,g_trans,beta_trans) + o11 = rearrange(o11,'b h l d-> b l h d') + do = torch.randn(B, H, L, DV).cuda() + o11.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + g_grad, g.grad = g.grad, None + print('done') + o22,f_state = mask_gated_chunk_delta_rule(q, k, v, beta, g,target_matrix,BT=32,output_final_state=True)#10s嘛 额 + o22.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + g_grad0, g.grad = g.grad, None + mask_grad0, mask.grad = mask.grad, None + + print(mask_grad0) + # print((o11-o22).abs().max()) + # print(o11-o22) + # # print(o22) + # print((k_grad-k_grad0).abs().max()) + # # print(k_grad-k_grad0) + # print((v_grad-v_grad0).abs().max()) + # print(v_grad-v_grad0) + # print(q_grad-q_grad0) + # print((beta_grad-beta_grad0).abs().max()) + # # print((mask_grad-mask_grad0).abs().max()) + # print((g_grad-g_grad0).abs().max()) + + + diff --git a/fla2/ops/mask_gated_delta_rule_t/chunk_fuse.py b/fla2/ops/mask_gated_delta_rule_t/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6979fa906c6706bb07f6318b284920365db9eff --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/chunk_fuse.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...ops.delta_rule.utils import bwd_prepare_wy_repr, fwd_prepare_wy_repr +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +import torch.nn.functional as F + +def ceildiv(a, b): + return -(a // -b) + +def pad(x, chunk_size=16): + seq_len = x.shape[-2] + #b n l d + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) + if x.shape[-1] % 32 != 0: + x = F.pad(x, (0, 32 - x.shape[-1] % 32)) + return x + +def pad_b(x, chunk_size=16): + seq_len = x.shape[-1] # 获取序列长度 l + padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size # 计算填充后的长度 + # 如果序列长度不是 chunk_size 的倍数,则进行填充 + if seq_len % chunk_size != 0: + x = F.pad(x, (0, padded_seq_len - seq_len),value=1.0) # 只在最后一个维度(l)进行填充 + return x + +# on-the-fly computation without materializing hidden statets into HBMs +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8) + ], + key=["BT", "BK"], +) +@triton.jit +def fused_chunk_delta_rule_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_K] + v, # value [B, H, L, D_head_V] + v_new, + d, # decay [B, H, L, D_head_K] + o, # output [B, H, L, D_head_V] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_v_prime = tl.dot(b_d, b_h.to(b_q.dtype), allow_tf32=False) + b_v = b_v - b_v_prime + tl.store(p_v_new, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + b_o = tl.dot(b_s.to(b_q.dtype), b_v.to(b_q.dtype), allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v.to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_v_new = tl.advance(p_v_new, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_d = tl.advance(p_d, (BT, 0)) + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fused_chunk_delta_rule_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + d, # decay [B, H, L, D_head_K] + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + dd, # gradient of decay [NV, B, H, L, D_head_K] + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + # first reverse + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_d = tl.make_block_ptr(d + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [DK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, DV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, DK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, DV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + b_d = tl.load(p_d, boundary_check=(0, 1)) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + b_dh -= tl.dot(b_d, b_dv.to(b_d.dtype), allow_tf32=False) + + tl.store(p_dk, (b_dk).to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + NT = tl.cdiv(T, BT) + for i in range(0, NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, DK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [DV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, DV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, DK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [DV, DK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + if i < (NT - 1): + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), ((i + 1) * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dd = tl.dot(b_dv.to(b_k.dtype), b_h.to(b_k.dtype), allow_tf32=False) + p_dd = tl.make_block_ptr(dd + (i_bh + i_v*B*H) * s_qk_h, (T, DK), (s_qk_t, s_qk_d), + ((i+1) * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dd, -b_dd.to(p_dd.dtype.element_ty), boundary_check=(0, 1)) + + +def fused_chunk_delta_rule_fwd(q, k, v, d, BT, initial_state, output_final_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BT = BT + # ctx.BT = BT + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1, 'NK should be 1' + o = q.new_empty(batch_size, n_heads, seq_len, d_head_v) + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float32, requires_grad=False) + else: + final_state = None + CHECK = True + # if version.parse(triton.__version__) < version.parse('2.2.0'): + # import warnings + # warnings.warn( + # "Triton<2.2.0 detected for running this kernel, " + # "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + # "that lead to significant precision loss. " + # "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + # "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + # ) + # CHECK = True + grid = (NV, NK, batch_size * n_heads) + v_new = torch.empty_like(v) + fused_chunk_delta_rule_fwd_kernel[grid]( + q, k, v, v_new, d, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + ) + return o, v_new, CHECK, final_state + + +def fused_chunk_delta_rule_bwd(q, k, v, d, do, BT, CHECK, initial_state): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + BK, BV = triton.next_power_of_2(d_head_qk), min(triton.next_power_of_2(d_head_v), 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + assert NK == 1 + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dd = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + fused_chunk_delta_rule_bwd_kernel[grid]( + q, k, v, d, do, dq, dk, dv, dd, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=CHECK, + # num_warps=num_warps, + # num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dd = dd.sum(0) + dd[:, :, 0:BT] = 0 + return dq, dk, dv, dd + + +class FusedChunkDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta, BT, initial_state, output_final_state, checkpoint_level=0): + # lvl=1 will recompute ``fwd_prepare_wy_repr`` for saving memory. + assert checkpoint_level in [0, 1] + k_origin = k + # k = _l2_norm_fwd(k_origin) + k = k + d, v_new = fwd_prepare_wy_repr(k, v, beta, BT) + o, v_new2, CHECK, final_state = fused_chunk_delta_rule_fwd(q, k, v_new, d, BT, initial_state, output_final_state) + if checkpoint_level == 1: + d, v_new = None, None + ctx.save_for_backward(q, k_origin, v, v_new, v_new2, d, beta, initial_state) + ctx.CHECK = CHECK + ctx.chunk_size = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_final_state=None): + q, k_origin, v, v_new, v_new2, d, beta, initial_state = ctx.saved_tensors + chunk_size = ctx.chunk_size + k = k_origin + # k = _l2_norm_fwd(k_origin) + if d is None: + d, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + dq, dk, dv, dd = fused_chunk_delta_rule_bwd(q, k, v_new2, d, do, chunk_size, ctx.CHECK, initial_state) + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, d, v_new, dd, dv, chunk_size) + dk.add_(dk2) + # dk = _l2_norm_bwd(k_origin, dk) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(d.dtype), None, None, None + + +def mask_fused_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if initial_state is not None: + initial_state = initial_state.detach() + seq_len = v.shape[-2] + d_head_v = v.shape[-1] + q, k, v = map(lambda x: pad(x), [q, k, v]) + beta = pad_b(beta) + o, final_state = FusedChunkDeltaRuleFunction.apply(q, k, v, beta, BT, initial_state, output_final_state) + o = o[..., :seq_len, :d_head_v] + return o, final_state + + +def mask_delta_rule_recurrence(q, k, v, beta): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + k = torch.nn.functional.normalize(k, p=2, dim=-1) + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v - (S.clone() * _k[..., None]).sum(-2) + _v = _v * beta_i[..., None] + S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ == "__main__": + import torch.nn.functional as F + # seq_len = 128 + # b = 2 + # h = 4 + # q = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # k = F.normalize(torch.randn(b, h, seq_len, 64), 2, -1) + # v = F.normalize(torch.randn(b, h, seq_len, 128), 2, -1) + # beta = torch.rand(b, h, seq_len).sigmoid() + # q, k, v, beta = map(lambda x: x.cuda().to(torch.float32).requires_grad_(True), (q, k, v, beta)) + # do = torch.rand_like(v) + # o2 = delta_rule_recurrence(q, k, v.clone(), beta) + # o2.backward(do, retain_graph=True) + # q_grad2, k_grad2, v_grad2, beta_grad2 = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # o, _ = fused_chunk_delta_rule(q, k, v, beta, 32) + # o.backward(do, retain_graph=True) + # q_grad, k_grad, v_grad, beta_grad = q.grad, k.grad, v.grad, beta.grad + # q.grad = k.grad = v.grad = beta.grad = None + # print((o - o2).abs().max()) + # print((q_grad - q_grad2).abs().max()) + # print((k_grad - k_grad2).abs().max()) + # print((v_grad - v_grad2).abs().max()) + # print((beta_grad - beta_grad2).abs().max()) \ No newline at end of file diff --git a/fla2/ops/mask_gated_delta_rule_t/naive.py b/fla2/ops/mask_gated_delta_rule_t/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3aa76636f031a3ef132850fcd7851795399e1d --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/naive.py @@ -0,0 +1,1503 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange +from typing import Optional + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +from fla.ops.utils import chunk_local_cumsum + +from fla.ops import chunk_gated_delta_rule +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float('-inf'))) + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + Aw, + Au, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_Aw = tl.make_block_ptr(Aw + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_Aw, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + tl.debug_barrier() + b_Aw = None + p_Au = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)).to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK","r"], +) +@triton.jit +def gated_chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + mask_ij, + A, + Ag, + s_qk_h, + s_qk_t, + s_qk_d, + T, + K, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r#列读,因而是行数目 + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:]#行数 + + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + b_A = tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + p_A = tl.make_block_ptr(A + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0,1,2,3)) + + p_g = tl.make_block_ptr(g_cumsum + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_g_diff = safe_exp(b_g_diff) + + b_Ag = b_A * ((b_g_diff)[:,:,None,None])#BT BT + p_Ag = tl.make_block_ptr(Ag + (i_bh*T//BT+i_t)*BT*BT*r*r ,(BT,BT,r,r), (BT*r*r,r*r,r,1), (0,0,0,0), (BT,BT,r,r),(3,2,1,0)) + tl.store(p_Ag, (b_Ag).to(p_Ag.dtype.element_ty),boundary_check=(0,1,2,3)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "r"], +) +@triton.jit +def solve_tril_16x16_kernel( + A, + Ad, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + offset = (i_t * 16) % BT + + p_A = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T,BT,r,r),(BT*r*r,r*r,r,1) ,(i_t * 16, offset, 0, 0), (16, 16,r,r), (3,2,1,0)) + b_A = tl.load(p_A, boundary_check=(0,1,2,3)).to(tl.float32) + b_A = -tl.where((tl.arange(0, 16)[:,None] > tl.arange(0, 16)[None,:])[:,:,None,None], b_A, 0) + + for i in range(1, 16): + mask = tl.arange(0, 16) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0) + q = (tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)) + b_a = b_a + tl.sum(q,0)*((tl.arange(0, 16) < i)[:,None,None]) + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A += ((tl.arange(0, 16)[:, None, None, None] == tl.arange(0, 16)[None, :, None, None])&(tl.arange(0, r)[None, None, :, None] == tl.arange(0, r)[None, None, None, :])) + + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(16*r,16*r))#BT*r BT*r + p_Ad = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 16 * r, 0), (16*r,16*r), (1,0)) + tl.store(p_Ad, (b_A).to(p_Ad.dtype.element_ty),boundary_check=(0,1)) + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,32*r),(32*r,1) ,((i_t * 32 + 16) *r, 0), (16*r, 16*r), (1,0)) + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 32 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t *32 +16) * r, 0), (16*r,16*r), (1,0)) + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), (i_t * 32 * r , 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r , 16*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,32*r), (32*r, 1), ((i_t * 32 + 16) * r, 0), (16*r, 16*r), (1, 0)) + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["r"], +) +@triton.jit +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + s_A_bh, + s_Ad_bh, + T, + r: tl.constexpr, + BT: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_A21 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1,0)) + p_A31 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1,0)) + p_A32 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1,0)) + p_A41 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 0), (16*r, 16*r), (1,0)) + p_A42 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1,0)) + p_A43 = tl.make_block_ptr(A + (i_bh)*s_A_bh, (T*r,64*r),(64*r,1) ,((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1,0)) + + b_A21 = tl.load(p_A21, boundary_check=(0,1)).to(tl.float32) + b_A31 = tl.load(p_A31, boundary_check=(0,1)).to(tl.float32) + b_A32 = tl.load(p_A32, boundary_check=(0,1)).to(tl.float32) + b_A41 = tl.load(p_A41, boundary_check=(0,1)).to(tl.float32) + b_A42 = tl.load(p_A42, boundary_check=(0,1)).to(tl.float32) + b_A43 = tl.load(p_A43, boundary_check=(0,1)).to(tl.float32) + + + p_Ad11 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), (i_t * 64 * r, 0), (16*r,16*r), (1,0)) + p_Ad22 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 16) * r, 0), (16*r,16*r), (1,0)) + p_Ad33 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 32) * r, 0), (16*r,16*r), (1,0)) + p_Ad44 = tl.make_block_ptr(Ad + (i_bh)*s_Ad_bh,(T*r,16*r),(16*r,1), ((i_t * 64 + 48) * r, 0), (16*r,16*r), (1,0)) + + + p_Ai11 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 ) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai22 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai33 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 32*r), (16*r, 16*r), (1, 0)) + p_Ai44 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 48*r), (16*r, 16*r), (1, 0)) + p_Ai21 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 16) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai31 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 0), (16*r, 16*r), (1, 0)) + p_Ai32 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 32) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai41 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r ,0), (16*r, 16*r), (1, 0)) + p_Ai42 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 16*r), (16*r, 16*r), (1, 0)) + p_Ai43 = tl.make_block_ptr(Ai+ (i_bh)*s_A_bh, (T*r,64*r), (64*r, 1), ((i_t * 64 + 48) *r, 32*r), (16*r, 16*r), (1, 0)) + + + Ai11 = tl.load(p_Ad11, boundary_check=(0, 1)).to(tl.float32) + Ai22 = tl.load(p_Ad22, boundary_check=(0, 1)).to(tl.float32) + Ai33 = tl.load(p_Ad33, boundary_check=(0, 1)).to(tl.float32) + Ai44 = tl.load(p_Ad44, boundary_check=(0, 1)).to(tl.float32) + + Ai21 = -tl.dot(tl.dot(Ai22,b_A21, input_precision='ieee'),Ai11,input_precision='ieee') + Ai32 = -tl.dot(tl.dot(Ai33,b_A32, input_precision='ieee'),Ai11,input_precision='ieee') + Ai43 = -tl.dot(tl.dot(Ai44,b_A43, input_precision='ieee'),Ai11,input_precision='ieee') + + Ai31 = -tl.dot( + Ai33, + tl.dot(b_A31,Ai11, input_precision='ieee')+ + tl.dot(b_A32,Ai21, input_precision='ieee'), + input_precision='ieee') + + Ai42 = -tl.dot( + Ai44, + tl.dot(b_A42,Ai22, input_precision='ieee')+ + tl.dot(b_A43,Ai32, input_precision='ieee'), + input_precision='ieee') + + Ai41 = -tl.dot( + Ai44, + tl.dot(b_A41, Ai11, input_precision='ieee') + + tl.dot(b_A42, Ai21, input_precision='ieee') + + tl.dot(b_A43, Ai31, input_precision='ieee'), + input_precision='ieee' + ) + + tl.store(p_Ai11,Ai11.to(p_Ai11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai22,Ai22.to(p_Ai22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai33,Ai33.to(p_Ai33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai44,Ai44.to(p_Ai44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai21,Ai21.to(p_Ai21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai31,Ai31.to(p_Ai31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai32,Ai32.to(p_Ai32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai41,Ai41.to(p_Ai41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai42,Ai42.to(p_Ai42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai43,Ai43.to(p_Ai43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + + +def gated_chunk_scaled_dot_kkt_fwd(k: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor, + g_cumsum:Optional[torch.Tensor] = None, + BT:int = 32, + output_dtype: torch.dtype=torch.float32): + # gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + A = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + Ag = torch.empty(B*H*NT,BT*BT,r*r,device=k.device, dtype=output_dtype).contiguous() + gated_chunk_scaled_dot_kkt_fwd_kernel[(NT, B*H)]( + k, beta, g_cumsum, mask, A,Ag, + T*K, K, 1, + T, K, r, BT, BK + ) + return A,Ag + +def solve_tril(A,mask,k,BT,output_dtype=torch.float32): + B, H, T, K = k.shape + r = mask.shape[-1] + NT = triton.cdiv(T, 16) + Ad = torch.empty(B,H,NT*16*r,16*r,device=A.device, dtype=torch.float if BT != 16 else output_dtype) + solve_tril_16x16_kernel[(NT, B*H)]( + A,Ad, + T*BT*r*r,#s_abh + T*16*r*r,#s_adbh + T, + r, BT + ) + if BT == 16: + return Ad + + A = rearrange(A,'b (t l) (c r)->b (t c) (l r)',t=BT,c=r).contiguous()#BT*r BT*r + if BT == 32: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_32x32_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + if BT == 64: + NT = triton.cdiv(T, BT) + Ai = torch.zeros(B,H,NT*BT*r,BT*r,device=A.device, dtype=output_dtype) + merge_16x16_to_64x64_inverse_kernel[(NT, B*H)]( + A,Ad,Ai, + T*BT*r*r,#s_a_bh and s_ai_bh + T*16*r*r,#s_ad_bh + T,r,BT + ) + return Ai + + +def gated_fwd_recompute_w_u(k, v, beta,mask, Aw,Au,BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, Aw,Au, + T*K, K, 1, + T*V, V, 1, + T, K, V, r,BT, BK, BV + ) + return w, u + + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + g, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2)) + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(tl.bfloat16), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx) + b_g_last = tl.exp(b_g_last) + b_h = b_g_last * b_h + + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (V, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h)#, allow_tf32=False) + b_s += tl.dot(b_q, b_k)#, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * tl.exp(b_g)[:,None] + + b_g_diff = b_g[:, None] - b_g[None, :] + b_s = b_s * safe_exp(b_g_diff)#BT BT + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o * scale + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK"], +) +@triton.jit +def preprocess_qkw(q, + k, + w, + g, + q_new, + k_new, + w_new, + T, + H, + K, + r:tl.constexpr, + BT:tl.constexpr, + BK:tl.constexpr, + USE_Q:tl.constexpr, + ): + i_k,i_bh,i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + p_g = tl.make_block_ptr(g+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + p_k_new = tl.make_block_ptr(k_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w_new = tl.make_block_ptr(w_new +i_bh*T*K*r,(T,r*K),(r * K, 1),(i_t * BT, i_k * r * BK) ,(BT,r*BK),(1,0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + i_bh*T + last_idx).to(tl.float32) #read BT 位置 + + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_d_last = tl.exp((b_g_last - b_g)) + b_d_begin = tl.exp(b_g) + b_k = b_k * b_d_last[:, None] + b_w = b_w * b_d_begin[:, None] + tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) + + + if USE_Q: + p_q = tl.make_block_ptr(q + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q_new = tl.make_block_ptr(q_new + i_bh*T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_q = b_q * b_d_begin[:, None] + tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) + + +#finish +def gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state): + # k, w, u, g, BT, initial_state, final_state + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + + grid = (NK,B*H,NT) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + preprocess_qkw[grid]( + q=None, + k=k, + w=w, + g=g, + q_new=None, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=False, + ) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + + gated_chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k_new,u,w_new, + v_new,g,h, + initial_state, + final_state, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + + +#finish +def gated_chunk_fwd_o_fn(q, k, v_new,h,g,BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + gated_chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, g, o, + T*K, K, 1 , + r*T*V,T*V,V, + NT*K*V,V, + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_fwd_prepare_dv_kernel( + q, + k, + g, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A* safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def gated_fwd_prepare_dv(q, k, g, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + gated_fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, g , do, dv, + T*K, K, 1, + T*V, V, 1, + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + + + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + g, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_h_h, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (V, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT), (BK, BT), (0, 1))#全读取 + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (K,T*r), (1, K), + (i_k * BK, i_t * BT * r), (BK, BT * r), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + last_idx = min((i_t + 1) * BT, T) - 1 + b_glast = tl.load(g + i_bh * T + last_idx) + b_glast = tl.exp(b_glast) + + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_d = (tl.load(p_d,boundary_check=(0, 1))) + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0))#load r + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + for i_r in range(r): + rmask = tl.arange(0, r) == i_r #第ir列 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT , i_r*KR + i_k * BK), (BT, KR), (1, 0))# + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dhr = tl.sum(tl.where(rmask[:,None,None],b_dhtrans,0), 0) + dv_sum = tl.dot(b_k,b_dhr.to(b_k.dtype),allow_tf32=False) + b_dv += tl.reshape((dv_sum[:,None,:]*rmask[None,:,None]).to(b_dv.dtype),(BT*r,BV)) + + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dh *= b_glast + b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)-tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + + + +def gated_chunk_bwd_dhu_fn(q, k, w, g,h0, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B, H, NT * K,V)#一样的#need 求和 得一起算 + q_new = torch.empty_like(q) + k_new = torch.empty_like(k) + w_new = torch.empty_like(w) + # grid = (NK,) + grid = (NK,B*H,NT) + preprocess_qkw[grid]( + q=q, + k=k, + w=w, + g=g, + q_new=q_new, + k_new=k_new, + w_new=w_new, + T=T, + H=H, + K=K, + r=r, + BT=BT, + BK=BK, + USE_Q=True, + ) + + + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + gated_chunk_delta_rule_bwd_kernel_dhu[grid]( + q_new, k_new, w_new, g, do, dh, dv, dv2, + T*K,K,1, + NT*K*V, + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def gated_chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + s_g_r, + s_g_k, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (1, K), (i_r*K//r + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT*r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,],dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V,NT * K), (1,s_h_t), (i_v * BV,i_t * K + i_r * K // r + i_k * BK), (BV, BK), (0, 1)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_h = (tl.load(p_h, boundary_check=(0, 1)))#BV BK + b_dh = (tl.load(p_dh, boundary_check=(0, 1)))#需要额外添加r维度 + + b_dg_last += tl.sum(b_h * b_dh) #这里是存在r求和的 + + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, b_dh, allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = (tl.load(p_dv, boundary_check=(0, 1)))#BT*r BV + b_dw += (tl.dot(b_dv.to(b_v.dtype),b_h.to(b_v.dtype))) #get BT*r BK + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + b_dg = tl.zeros([BT,], dtype=tl.float32) + p_g = tl.make_block_ptr(g + i_bh * T ,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_glast = tl.load(g +i_bh*T + (min(i_t * BT + BT, T) - 1)) + b_dg_last *= tl.exp(b_glast) + + + p_w = tl.make_block_ptr(w + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + b_w = tl.load(p_w,boundary_check=(0,1))#BT * r ,BK + b_dw = b_dw * tl.reshape(tl.broadcast_to(tl.reshape(tl.exp(b_g),(BT,1)),(BT,r)),(BT*r))[:,None] + b_dg -= tl.sum(tl.reshape(b_w*b_dw,(BT,r*BK)),-1) + + b_dq = b_dq*scale*tl.exp(b_g)[:,None] + b_dg += tl.sum(b_dq*tl.trans(b_q),1)#BT*BK + + b_dk = b_dk * safe_exp(b_glast-b_g)[:,None] + b_dg -= tl.sum(b_dk*b_k,1)#BT*BK + b_dg_last += tl.sum(b_dk*b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds* safe_exp(b_g[:, None] - b_g[None, :]) * scale, 0) + b_ds2 = b_ds*(tl.dot(tl.trans(b_q),tl.trans(b_k))) + + b_dg += tl.sum(b_ds2,axis=1) + b_dg -= tl.sum(b_ds2,axis=0) + b_ds = b_ds.to(b_k.dtype) + + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) #这些应该没啥问题 + + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T*r, K), (K,1), (i_t * BT * r,i_r*K//r + i_k * BK), (BT*r ,BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_r * s_g_r + i_k * s_g_k + i_bh * T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_dg = tl.where(o_i jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA = tl.reshape(b_dA,(BT,r,BT,r)) + + + p_A = tl.make_block_ptr(Au + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dA2 = tl.zeros([BT*r,BT*r], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(da_mask, b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA2 = tl.dot(tl.trans(b_A), b_dA2.to(b_A.dtype), allow_tf32=False) + b_dA2 = tl.where(da_mask, -b_dA2, 0) #等价于 kkt的 dA 很多0,对角处 + b_dA2 = tl.reshape(b_dA2,(BT,r,BT,r)) + + + p_g = tl.make_block_ptr(g_cumsum + i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + b_g = tl.load(p_g,boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:,None]-b_g[None,:])[:,None,:,None] + b_dA += b_dA2 + b_dA2 = tl.permute(b_dA2,(0,2,1,3))#Bt bt r r + + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32) + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + rmask = tl.arange(0, r) == i_r #第ir列 + g = tl.sum(tl.where(rmask[None,None,None,:], b_dA, 0), -1)#BT r BT #取出第ir列 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + + for i_k in range(tl.cdiv(block_k, BK)):#ik = 1 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)#BT*BK + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + beta_kkt = (tl.dot(b_k_beta,tl.trans(b_k), allow_tf32=False))#BT BT + b_A += beta_kkt[:,:,None,None] * ((rmask[None,:] * b_mask[:,None])[None,None,:,:])#这列全广播了不对 + + betas = tl.sum(tl.sum(beta_kkt[:,None,:]*g,-1),0) + b_dmask += (betas[:,None]*rmask[None,:]).to(tl.float32) + + + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + p_dmask = tl.make_block_ptr(dmask + (i_bh * (T//BT) + i_t)* r * r , (r,r), (r,1), (0,0), (r,r), (1,0)) + tl.store(p_dmask, b_dmask.to(p_dmask.dtype.element_ty), boundary_check=(0,1)) + + b_dA2 *= b_A #BT BT r r + b_dA2 = tl.sum(tl.reshape(b_dA2,(BT,BT,r*r)),-1) + + b_dg = tl.sum(b_dA2,1)-tl.sum(b_dA2,0) + p_dg = tl.make_block_ptr(dg+i_bh*T,(T,),(1,),(i_t*BT,),(BT,),(0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + + +def gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + dg = torch.empty_like(g) + dmask = torch.zeros([B*H*NT,r,r],device=k.device,dtype=k.dtype).contiguous() + assert BK <= K//r + gated_bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, g, Aw,Au, + dw, du, + dk, dv, dbeta,dmask,dg, + T*K, K, 1, + T*V, V, 1, + T, K, V, r, BT, BK, BV + ) + dmask = dmask.sum(0) + return dk, dv, dbeta, dmask,dg + + +class gated_ChunkDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,g,mask,BT, initial_state, output_final_state=False, checkpoint_level=1): + B,H,L,K = q.shape + g = chunk_local_cumsum(g,BT,head_first=True,output_dtype=torch.float) + Aw,Au = gated_chunk_scaled_dot_kkt_fwd(k=k,beta=beta,g_cumsum=g,mask=mask,BT=BT,output_dtype=torch.float32) + + Aw = solve_tril(A=Aw,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + Au = solve_tril(A=Au,mask=mask,k=k,BT=BT,output_dtype=k.dtype) + #到这里应该没啥问题 + r = mask.shape[-1] + w, u = gated_fwd_recompute_w_u(k, v, beta, mask,Aw,Au,BT)# + + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, final_state)#need change' + #final_state almost 一致 + o = gated_chunk_fwd_o_fn(q, k, v_new, h, g, BT)#need change + if checkpoint_level == 1: + h, v_new = None, None #这里重新计算了? + ctx.save_for_backward(q, k, v, beta,g, mask, Aw, Au , h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta, g, mask , Aw,Au, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = gated_fwd_recompute_w_u(k, v, beta, mask, Aw,Au,BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + if h is None: + h, v_new = gated_chunk_fwd_h_fn(k, w, u, g, BT, initial_state, None) + start = time.time() + + #从这里开始重新书写计算代码 + dv = gated_fwd_prepare_dv(q, k, g, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = gated_chunk_bwd_dhu_fn(q, k, w, g,initial_state,do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw , dg = gated_chunk_bwd_dqkw_fn(q, k, v_new, w, g, h, dv, do, dh, BT)#这一步也巨慢 + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + #仅仅两个dg位置可能出错,别的不会 + + start = time.time() + dk2, dv, dbeta,dmask,dg2 = gated_bwd_prepare_wy_repr(k, v, beta, mask,g, Aw,Au, dw, dv, BT)#只有这里带mask + dk.add_(dk2) + dg.add_(dg2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + #仅仅两个dg位置可能出错,别的不会 + dg = chunk_local_cumsum(dg, BT, reverse=True,head_first=True,output_dtype=torch.float) + + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype),dg,dmask.to(mask.dtype),None, None, None + + +def mask_gated_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = gated_ChunkDeltaRuleFunction.apply(q, k, v, beta,g,mask, BT, initial_state, output_final_state) + return o, final_state + + +def delta_rule_recurrence(q, k, v, beta,g, mask,initial_state=None): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + if initial_state == None: + S = torch.zeros(b, h, d_k, d_v,device=k.device,dtype=torch.float32) + else: + S = initial_state + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + iplr = torch.einsum(' b h q k ,b h->b h q k',iplr,g[:,:,i]) + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr.float(),S.clone()) + _k.unsqueeze(-1).float() * _v.unsqueeze(-2).float() + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q.float(), S).to(k.dtype) + return o,S + + +if __name__ =="__main__": + import sys + import time + torch.set_default_dtype(torch.bfloat16) + torch.manual_seed(42) + + # for i in range(200): + B = 16 + H = 4 + L = 128 + DK = 256 + DV = 256 + r = 4 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + mask = torch.randn([r,r]) + mask = mask.cuda().requires_grad_(True).contiguous() + + # mask = torch.ones([2,2]) + # mask = mask.cuda().requires_grad_(True).contiguous() + + g = torch.nn.functional.logsigmoid(torch.randn(B, H, L).cuda()).requires_grad_(True) + g_exp = (torch.exp(g)) + + do = torch.randn(B, H, L, DV).cuda() + o1,ss = delta_rule_recurrence(q,k,v,beta,g_exp,mask) + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + mask_grad, mask.grad = mask.grad, None + beta_grad, beta.grad = beta.grad, None + g_grad, g.grad = g.grad, None + # end = time.time() + # print(end-start) + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + # o,f_state = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + # o2,f_state = mask_chunk_delta_rule(q, k, v,beta,mask,BT=32) + + # qh,kh,vh,betah,gh = map(lambda x: rearrange(x, 'b h l ... -> b l h ...'), (q, k, v, beta, g)) + # o,f_state = chunk_gated_delta_rule(qh,kh,vh,gh,(betah*rearrange(mask,'c r-> (c r)')).contiguous(),use_qk_l2norm_in_kernel=False,output_final_state=True) + # o = rearrange(o,'b l h d->b h l d') + o,f_state = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + mask_grad0, mask.grad = mask.grad, None + g_grad0, g.grad = g.grad, None + print((o1-o).abs().max()) + print((f_state-ss).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + print((mask_grad-mask_grad0).abs().max()) + + print((g_grad-g_grad0).abs().max()) + print(mask_grad) + print(mask_grad0) + + + # o2,f_state2 = mask_gated_chunk_delta_rule(q, k, v,beta,g,mask,BT=32,output_final_state=True) + # o2.backward(do,retain_graph=True) + # q_grad2, q.grad = q.grad, None + # k_grad2, k.grad = k.grad, None + # v_grad2, v.grad = v.grad, None + # beta_grad2, beta.grad = beta.grad, None + # mask_grad2, mask.grad = mask.grad, None + + # print((o-o2).abs().max()) + # print((f_state-f_state2).abs().max()) + + # print((q_grad2-q_grad0).abs().max()) + # print((k_grad2-k_grad0).abs().max())#计算结果差距大 差距到1 + # print((v_grad2-v_grad0).abs().max()) + # print((beta_grad2-beta_grad0).abs().max()) + # print((mask_grad2-mask_grad0).abs().max()) + # print('naive:',mask_grad2) + # print('triton:',mask_grad0) + # print(k_grad2) + # print(k_grad0) + + + # BT = 16 + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + # print('finish0') + # h, v_new = chunk_fwd_h_fn(k, w, u, BT, None, None)#need change' + # print('finish1') + # o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + # print('finish2') + # w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + # print('finish3') + # dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # print('finish4') + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + # print('finish5') + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT)#这一步也巨慢 + # print('finish6') + + # Ass = rearrange(A,'b h (n t) l->b h n t l',n = L//BT) + # dwss = rearrange(dw,'b h (n t) k->b h n t k',n = L//BT) + # dvss = rearrange(dv,'b h (n t) k->b h n t k',n = L//BT) + # dk2, dv, dbeta,dmask = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT) + # print('triton:',dmask) #几乎完全相等 + + # vbeta = v*beta[...,None] + # vbeta = rearrange(vbeta,'b h (n T) d->b h n T d',T=BT) + # vbeta = vbeta.unsqueeze(-2).expand(-1,-1,-1,-1,r,-1) + # vbeta = rearrange(vbeta,'b h n t r d-> b h n (t r) d') + + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta = torch.einsum('b h n T r d,c r-> b h n T c r d',kbeta,mask) + # kbeta = rearrange(kbeta,'b h n t c r d-> b h n (t c) (r d)') + # dA = dvss@vbeta.transpose(-1,-2)+dwss@kbeta.transpose(-1,-2) + + + # dorg = Ass.transpose(-1,-2)@dwss#bhn bt*r k + # dorg = rearrange(dorg,'b h n (t r) (c k)->b h n t r c k',r=r,c=r) + # betan = rearrange(beta,'b h (n t)->b h n t',n=L//BT) + # kn = rearrange(k,'b h (n t) (r d)->b h n t r d ',n = L//BT,r=r) + + # dmask = torch.einsum('b h n t r c k,b h n t->b h n t r c k',dorg,betan) + # dmask = torch.einsum('b h n t r c k,b h n t c k->b h n t r c k',dmask,kn) + # dmask = rearrange(dmask,'b h n t r c k-> (b h n) (t k) r c') + # dmaskss = dmask.sum(0).sum(0) + + # i = torch.arange(0, BT * r)[:, None] + # j = torch.arange(0, BT * r)[None, :] + # iB = i // r + # jB = j // r + # da_mask = iB > jB + # da_mask = da_mask.cuda() + # b_dA = torch.where(da_mask, dA, 0) + + # b_dA = b_dA @ Ass.transpose(-1,-2) + # b_dA = Ass.transpose(-1,-2)@b_dA + + # b_dA = torch.where(da_mask, -b_dA, 0) #等价于 kkt的 dA 很多0,对角处 + # b_dA = rearrange(b_dA,'b h n (t r) (l c)-> b h n t r l c',c=r,r=r) + # # print((dAss-b_dA).abs())#到这里也完全相等 + + + # # betakkt = k*beta[...,None] + # kbeta = k*beta[...,None] + # kbeta = rearrange(kbeta,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # kbeta2 = rearrange(k,'b h (n T) (r d)->b h n T r d',T=BT,r=r) + # betakkt = torch.einsum('b h n T r d,b h n s r d->b h n r T s',kbeta,kbeta2)#r Bt bt + # betakkt = rearrange(betakkt,'b h n r T s->b h n T s r')#BT r BT###横向 + # # print((dAss-b_dA).abs()) + + # #证明是下面的计算出错了 + # dmask = torch.einsum('b h n t r l c,b h n t l c-> b h n t r l c',b_dA,betakkt) + # # print((dAss-dmask).abs().max())#意味着这个计算结果也是对的 + # # print((dAss-dmask)) + + # dmask = rearrange(dmask,'b h n t r l c->b h n (t l) r c') + # dmask = dmask.sum(-3) + # dmask = dmask.sum(0).sum(0).sum(0) + # print('matrix:',dmask) + + + + + + + + diff --git a/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta copy.py b/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta copy.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta copy.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta.py b/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac72fd6ab3c1c7928194c488cb608129bf6fc0 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/naive_rmbeta.py @@ -0,0 +1,1102 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +import time +import torch +import triton +import triton.language as tl +from einops import rearrange +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + b_A = tl.zeros([BT,BT,r,r], dtype=tl.float32)#r*BT r*BT + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_r in range(r): + r_mask = tl.arange(0, r) == i_r + p_mask = mask_ij + tl.arange(0,r)* r + i_r + b_mask = tl.load(p_mask) + ij_mask = b_mask[:,None]*r_mask[None,:] + for i_k in range(tl.cdiv(dk, BK)):#分块k读取计算 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_kb = (b_k).to(b_k.dtype) + dot = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A += dot[:,:,None,None]*ij_mask[None,None,:,:] + + b_A = -tl.where((tl.arange(0, BT)[:,None] > tl.arange(0, BT)[None,:])[:,:,None,None], b_A, 0) + #先save这个看看 + for i in range(1, BT):#此时矩阵为 BT,r,BT,r + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:,None,None,None], b_A, 0), 0)#get ba BT*r*r + q = tl.sum(b_a[:,None,:,:,None]*b_A[:,:,None,:,:],-2)#矩阵乘法解决,get BT,BT*r*r + b_a = b_a + tl.sum(q,0)*((tl.arange(0, BT) < i)[:,None,None])#BT*r*r + b_A = tl.where(mask[:,None,None,None],b_a,b_A)#按行计算 ,逐步交换结果 + b_A = tl.permute(b_A,(0,2,1,3)) + b_A = tl.reshape(b_A,(BT*r,BT*r))#BT*r BT*r + b_A += tl.arange(0, BT*r)[:,None] == tl.arange(0, BT*r)[None,:] + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0))#旧版本实现需要很多乘法 + tl.store(p_A, (b_A).to(p_A.dtype.element_ty),boundary_check=(0, 1)) + #解决矩阵求逆 + b_A = b_A.to(k.dtype.element_ty)#ok 解决求逆了 #下一步计算结果 + + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_recompute_w_u_kernel( + k, + v, + beta, + mask_ij, + w, + u, + A, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + dk = K//r + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t*BT*r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(dk, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*dk + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + # b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)[:,None,:]*b_mask[None,:,None].to(b_k.dtype)#BT*r*d + b_kb = tl.reshape(b_kb,(BT*r,BK)) + b_w = tl.dot(b_A, b_kb, allow_tf32=False)#get BT*r *BK + p_w = tl.make_block_ptr(w + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*dk + i_k * BK), (BT*r, BK), (1, 0)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)):#no need for 任意mask不使用 #无需for 循环 ,这里也不存在mask + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None] + b_vb = tl.reshape(b_vb,(BT*r,BV)) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + p_u = tl.make_block_ptr(u + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT*r, i_v * BV), (BT*r, BV), (1, 0)) + tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) + +#compute this +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta,mask_ij,A, + dw, du, + dk, dv, dbeta, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + r: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_A = tl.make_block_ptr(A + i_bh*T*BT*r*r ,(T*r,BT*r), (BT*r,1), (i_t * BT * r,0), (BT*r,BT*r),(1,0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(k.dtype.element_ty) + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT*r,BT*r], dtype=tl.float32) + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + for i_v in range(tl.cdiv(V, BV)):#分块r + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * s_vo_h * r, (T * r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0))#r*BT BV + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = ((b_v * b_beta[:, None])[:,None,:]*tl.full([r],1, dtype=b_v.dtype)[None,:,None]).to(b_v.dtype)##BT*r*BV + b_v_beta = tl.reshape(b_v_beta,(BT*r,BV)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)#BT*r,BT*r + b_dv_beta = tl.dot(tl.trans(b_A), b_du, allow_tf32=False)#BT*r,BV + b_dv_beta = tl.reshape(b_dv_beta,(BT,r,BV))# + sum_dv = tl.sum(b_dv_beta,-2)#这里不一样,结果 + b_dv = (sum_dv * b_beta[:, None])#?哪一步结果不一样呢 + b_dbeta += tl.sum(sum_dv * b_v, 1) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + block_k = K//r + for i_r in range(r): + p_mask = mask_ij + tl.arange(0,r)*r+i_r + b_mask = tl.load(p_mask) + for i_k in range(tl.cdiv(block_k, BK)):#assert block_k = BK + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_dw = tl.make_block_ptr(dw + i_bh * s_qk_h*r, (T*r, K), (s_qk_t, s_qk_d), (i_t * BT * r, i_r*block_k + i_k * BK), (BT * r, BK), (1, 0)) + # b_k_beta = ((b_k * b_beta[:, None])[:,None,:]*b_mask[None,:,None]).to(b_k.dtype)#BT*r*d + b_k_beta = ((b_k)[:,None,:]*b_mask[None,:,None]).to(b_k.dtype) + b_k_beta = tl.reshape(b_k_beta,(BT*r,BK)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(tl.trans(b_A), b_dw, allow_tf32=False)#get BT*r*BT*r + b_dk_beta = tl.reshape(b_dk_beta,(BT,r,BK)) + sum_dk = tl.sum(b_dk_beta * b_mask[None,:,None],1) + # b_dk = sum_dk* b_beta[:, None] + b_dk = sum_dk + # b_dbeta += tl.sum(sum_dk * b_k, 1) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + i = tl.arange(0, BT * r)[:, None] + j = tl.arange(0, BT * r)[None, :] + iB = i // r + jB = j // r + da_mask = iB > jB + b_dA = tl.where(da_mask, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), tl.trans(b_A), allow_tf32=False) + b_dA = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype), allow_tf32=False) + b_dA = tl.where(da_mask, -b_dA, 0) + b_dA = tl.reshape(b_dA,(BT,r,BT,r)).to(k.dtype.element_ty)#到这应该都是对的 + + for i_r in range(r):#只取ir项 + p_mask = mask_ij + tl.arange(0,r)*r+i_r#读取第ir列 + b_mask = tl.load(p_mask)#第ir列 + mask = tl.arange(0, r) == i_r + g = tl.sum(tl.where(mask[None,None,None,:], b_dA, 0), -1)#BT r BT 取最后一列, + #这里对应 kr 部分 + ir_A = tl.sum(g * b_mask[None,:,None],1).to(k.dtype.element_ty)#BT BT + for i_k in range(tl.cdiv(block_k, BK)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*block_k + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + # b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k_beta = (b_k).to(b_k.dtype) + + b_dk_beta = tl.dot(ir_A, b_k, allow_tf32=False) + # b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(ir_A), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta #* b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))#这里也没问题吧 + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + +def fwd_prepare_wy_repr(k, v, beta,mask, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + assert BK == K//r + BV = min(triton.next_power_of_2(V), 64) + A = torch.empty(B,H,NT*BT*r,BT*r,device=k.device, dtype=torch.float32) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return w, u, A + +def fwd_recompute_w_u(k, v, beta,mask, A, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + u = torch.empty(B,H,r*T,V,device=k.device, dtype=k.dtype) + w = torch.empty(B,H,r*T,K,device=k.device, dtype=k.dtype) + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64)#32 + BV = min(triton.next_power_of_2(V), 64) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k, v, beta,mask, w, u, A, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r,BT, BK, BV + ) + return w, u + +def bwd_prepare_wy_repr(k, v, beta, mask, A, dw, du, BT): + B, H, T, K, V = *k.shape, v.shape[-1] + r = mask.shape[-1] + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NT = triton.cdiv(T, BT) + dk = torch.empty_like(k) + dv = torch.empty_like(v).contiguous() + dbeta = torch.zeros_like(beta) + assert BK == K//r + bwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, mask, A,#da, + dw, du, + dk, dv, dbeta, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + T, K, V, r, BT, BK, BV + ) + return dk, dv, dbeta#,da + + +# from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_dv_kernel( + q, + k, + do, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + T, + K, + V, + scale, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r: tl.constexpr, +): + i_t, i_bhr = tl.program_id(0), tl.program_id(1)#或许也可以r并行 + i_bh = i_bhr//r + i_r = i_bhr % r + b_A = tl.zeros([BT, BT], dtype=tl.float32) + block_r = K//r + for i_k in range(tl.cdiv(block_r, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * block_r + i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_q = tl.trans(tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_k.dtype) + b_A += tl.dot(b_k, b_q, allow_tf32=False) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0).to(do.dtype.element_ty) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + p_dv = tl.make_block_ptr(dv + i_bhr * s_vo_h , (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.dot(b_A, b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + +#finish +def fwd_prepare_dv(q, k, do, r,BT): + B, H, T, K, V = *k.shape, do.shape[-1] + dv = torch.empty(B,H,r,T,V,device = do.device, dtype= do.dtype)#没法like + NT = triton.cdiv(T, BT) + BK = min(triton.next_power_of_2(K//r),64) + BV = min(triton.next_power_of_2(V), 64) + fwd_prepare_dv_kernel[(NT, B*H*r)]( + q, k, do, dv, + k.stride(1), k.stride(2), k.stride(3), + do.stride(1), do.stride(2), do.stride(3), + T, K, V, K**-0.5, BT, BK, BV, r + ) + return dv + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_fwd_kernel_h( + k, + v,#u + d,#w + v_new, + h, + initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V] + final_state, # final state of the chunk [B, H, D_head_K, D_head_V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)#assert ik=1 all use + b_h = tl.zeros([BK, BV], dtype=tl.float32)#读取一横行 + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h = tl.make_block_ptr(h + i_bh * NT * K * V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + #这里save是对的 + b_h_cumsum = tl.zeros([r, BK//r, BV], dtype=tl.float32) + for i_r in range(r): + for i_c in range(tl.cdiv(BT, BC)):#BK 大,通过BC 分块 + r_mask = tl.arange(0,r) == i_r + p_k = tl.make_block_ptr(k + i_bh * K * T, (T, K), (K, 1), + (i_t * BT + i_c * BC, i_k * BK + i_r * BK//r), (BC,BK//r), (1, 0))#读取对应 + p_d = tl.make_block_ptr((d + i_bh * T * r * K),(T, r, K ),(r * K, K, 1), + (i_t * BT + i_c * BC, i_r, i_k * BK), (BC,1,BK),(2,1,0)) + p_v = tl.make_block_ptr((v + i_bh * T * r * V),(T, r, V ),(r * V, V, 1), + (i_t * BT + i_c * BC, i_r, i_v * BV), (BC,1,BV),(2,1,0)) + p_v_new = tl.make_block_ptr(v_new + (i_bh * r + i_r)* T * V, (T , V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC , BV), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1))#BK//r,BC + b_d = tl.load(p_d, boundary_check=(0, 1, 2))#BK + b_v = tl.load(p_v, boundary_check=(0, 1, 2))#BC + b_v = tl.reshape(b_v,(BC,BV)) + b_d = tl.reshape(b_d,(BC,BK)) + b_v -= tl.dot(b_d, b_h.to(b_k.dtype), allow_tf32=False)#ok #到这相等的 这里BC + tl.store(p_v_new, b_v.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))#至少到这里第一步结果相同 + bkv = tl.where(r_mask[:,None,None],tl.dot(tl.trans(b_k),b_v.to(b_k.dtype),allow_tf32=False)[None,:,:],0) + b_h_cumsum += bkv.to(b_h_cumsum.dtype) + b_h += tl.reshape(b_h_cumsum,(BK,BV)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_linear_attn_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + r : tl.constexpr +): + i_v, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bh = i_bhr//r + i_r = i_bhr % r + rk = K//r + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K//r, BK)):#这里需要注意拆分#这里K//BK = r + #问题是不同r_block读取了同一份qk,有影响吗 + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r * rk + i_k * BK), (BT, BK), (1, 0)) + # p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_r * rk + i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_r * rk + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.trans(tl.load(p_k, boundary_check=(0, 1))) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_q, b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s = tl.where(m_s, b_s, 0)#置为0 Bs = 0 + p_v = tl.make_block_ptr(v + i_bhr * T * V, (T, V), (V, 1), (i_t * BT , i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = b_o + (tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) + p_o = tl.make_block_ptr(o + i_bhr * T * V, (T, V), (V,1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + +#finish +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dhu( + q, + k, + d, + do, + dh, + dv, + dv2, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, + KR: tl.constexpr, +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + b_dh = tl.zeros([BK, BV], dtype=tl.float32)#这个不变 读取所有 + for i_t in range(NT - 1, -1, -1):# 向前偏移了一位,计算流程是对的 + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V , (K, V), (s_h_t, 1), (i_k * BK , i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + #全列 + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), + (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))#全读取 + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))# + p_d = tl.make_block_ptr(d + i_bh * (T * K * r), (T*r,K), (K, 1), + (i_t * BT * r + i_c * BC *r,i_k * BK), (BC * r,BK), (1, 0))#读取 BC r BK的内容 + p_dv = tl.make_block_ptr(dv + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T * V, (T, V), (V, 1), + (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_q = (tl.load(p_q, boundary_check=(0, 1))) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1))#BT*r Bv + b_d = tl.trans(tl.load(p_d,boundary_check=(0, 1))) + b_k = tl.permute(tl.reshape(b_k,(BC,r,KR)),(1,0,2))#r BC KR + b_dhtrans = tl.reshape(b_dh,(r,KR,BV)) + dv_sum = tl.sum(b_k[:,:,:,None]*b_dhtrans.to(b_k.dtype)[:,None,:,:],-2) #get r BC BV + b_dv += tl.reshape(tl.permute(dv_sum,(1,0,2)),(BC*r,BV)) + #bhtrv + p_dv2 = tl.make_block_ptr(dv2 + i_bh * r * T * V, (T*r, V), (V , 1), + (i_t * BT * r + i_c * BC * r, i_v * BV), (BC*r, BV), (1, 0)) + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d,b_dv.to(b_q.dtype),allow_tf32=False) + b_dh += b_dh_tmp + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_delta_rule_bwd_kernel_dqkw( + q, + k, + v, + w, + h, + do, + dh, + dq, + dk, + dv, + dw, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + r: tl.constexpr, +): + i_k, i_t, i_bhr = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_r = i_bhr%r + i_bh = i_bhr//r + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT,r,BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bhr * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r * K // r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_r* K// r + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * s_vo_h*r, (T*r, V), (s_vo_t, s_vo_d), (i_t * BT * r, i_v * BV), (BT * r, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.trans(tl.load(p_h, boundary_check=(0, 1)))#BV BK + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)#ok + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False)#d_do 全, bh应该包含 i_Kbufen + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)#用来计算dk,yes 行独立没问题 + b_dv = tl.reshape(tl.load(p_dv, boundary_check=(0, 1)),(BT,r,BV))#BT*r BV + b_dw += tl.sum(b_dv.to(b_v.dtype)[:,:,:,None]*b_h.to(b_v.dtype)[None,None,:,:],-2)#get BT r BK + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0).to(b_q.dtype)#BT*BT + b_dq += tl.dot(b_ds, b_k, allow_tf32=False) + b_dq *= scale + b_dk += tl.trans(tl.dot(tl.trans(b_q), b_ds, allow_tf32=False)) #这些应该没啥问题 + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_r*K//r + i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT, 0 ,i_r*K//r + i_k * BK), (BT, r ,BK), (2, 1, 0)) + # p_dw = tl.make_block_ptr(dw + i_bh * T*r*K, (T, r, K), (r*K,K,1), (i_t * BT ,i_r, i_k * BK), (BT, 1, BK), (2, 1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, (tl.reshape(-b_dw.to(p_dw.dtype.element_ty),(BT,r,BK))), boundary_check=(0, 1)) + +#finish +def chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state): + B, H, T, K, V = *k.shape,u.shape[-1] + _,_,rT,_ = w.shape + r = rT//T + BK = triton.next_power_of_2(K)#直接划分好 + assert BK <= 256, "current kernel does not support head dimension larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1 + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + v_new = torch.empty(B,H,r,T,V,dtype=u.dtype,device=u.device)#做了v_new的r_first + chunk_delta_rule_fwd_kernel_h[grid](#r没有for循环 + k, u, w, v_new, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + u.stride(1), u.stride(2), u.stride(3), #rt*v,v,1 + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + ) + return h, v_new + +#finish +def chunk_bwd_dhu_fn(q, k, w, do, dv, BT): + B,H,r,T,V,K = *dv.shape,q.shape[-1] + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + BV = 16 if BK > 128 else 32 + BV = 64 if BK <= 64 else BV + BC = 16 if BK > 128 else 32 + BC = 64 if BK <= 64 else BC + BC = min(BT, BC) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)#感觉可以放并行度 + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + dh = q.new_empty(B , H, NT * K,V)#一样的#need 求和 得一起算 + grid = (NK, NV, B * H) + dv = rearrange(dv,'b h r t v-> b h (t r) v').contiguous() + dv2 = torch.empty_like(dv)#一样的 #bhr T V + chunk_delta_rule_bwd_kernel_dhu[grid]( + q, k, w, do, dh, dv, dv2, + q.stride(1), q.stride(2), q.stride(3), + do.stride(1), do.stride(2), do.stride(3), + dh.stride(1), dh.stride(2), + K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,r=r,KR = K//r, + ) + return dh, dv2 + +#finish +def chunk_fwd_o_fn(q, k, v_new, h, BT): + B,H,r,T,V,K = *v_new.shape,q.shape[-1] + BK = triton.next_power_of_2(K//r) + o = torch.empty_like(v_new)#there_fore,bhr nT,bv + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H * r) + #h shape b h nk v + chunk_linear_attn_fwd_kernel_o[grid]( + q, k, v_new, h, o, + q.stride(1), q.stride(2), q.stride(3), + v_new.stride(1), v_new.stride(2), v_new.stride(3), + h.stride(1), h.stride(2), + scale=K**-0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,r = r, + ) + o = o.sum(dim=2)#沿着r维度求和 + return o + + +def chunk_bwd_dqkw_fn(q, k, v_new, w, h, du, do, dh, BT): + B, H, T, K, V = *q.shape, v_new.shape[-1] + _,_,RT,_ = w.shape + r = RT // T + #最后一个函数,计算dw,dq,dk + BK = triton.next_power_of_2(K//r)#需要更细粒度的划分,确保不会使得 不同位置的划到一起 + BK = min(triton.next_power_of_2(K//r), 64) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K//r, BK) + NT = triton.cdiv(T, BT) + grid = (NK, NT, B * H * r)#通过NK控制位置 + dq = torch.empty_like(q) + dk = torch.empty_like(k)#k_org + dw = torch.empty_like(w)#bh nt k + chunk_delta_rule_bwd_kernel_dqkw[grid]( + q, k, v_new, w, h, do, dh, dq, dk, du, dw, + q.stride(1), q.stride(2), q.stride(3), + T*V, V, 1, + dh.stride(1), dh.stride(2), + scale=K ** -0.5, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,r = r + ) + return dq.to(q.dtype), dk.to(k.dtype), dw.to(w.dtype) + + +class ChunkDeltaRuleFunction(torch.autograd.Function): + #前向写完了 + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, beta,mask,BT, initial_state, output_final_state, checkpoint_level=1): + start = time.time() + w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, BT)#compute for A matrix #compute all + final_state = None + if output_final_state: + final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + dtype=torch.float32, requires_grad=False)#这部分不需要修正 + end = time.time() + print('compute_A:',end-start) + start = time.time() + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change' + end = time.time() + print('compute_h_s:',end-start) + + start = time.time() + o = chunk_fwd_o_fn(q, k, v_new, h, BT)#need change + end = time.time() + print('compute_h_s:',end-start) + if checkpoint_level == 1: + h, v_new = None, None + ctx.save_for_backward(q, k, v, beta,mask, A, h, v_new, initial_state) + ctx.BT = BT + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + q, k, v, beta,mask , A, h, v_new, initial_state = ctx.saved_tensors + BT = ctx.BT + r = mask.shape[-1] + start = time.time() + w, u = fwd_recompute_w_u(k, v, beta, mask, A, BT)#跳过 + end = time.time() + print('recompute_wu:',end-start) + # checkpont_level=1, recomputation. + if h is None: + h, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, None) + #v_new b h r T V + start = time.time() + dv = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + end = time.time() + print('pre:',end-start) + #dv BHR T V + + start = time.time() + dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv, BT)#new_dv dh #final for wyper dv + end = time.time() + print('chunk_bwd_dhu_fn:',end-start) + + start = time.time() + dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h, dv, do, dh, BT) + end = time.time() + print('chunk_bwd_dqkw_fn:',end-start) + + start = time.time() + dk2, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, mask, A, dw, dv, BT)#这一步误差较大 + dk.add_(dk2) + end = time.time() + print('bwd_prepare_wy_repr:',end-start) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dbeta.to(beta.dtype), None, None, None, None + + +def mask_chunk_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + mask: torch.Tensor,#use for mask org_tensor + BT: int, + initial_state: torch.Tensor = None, + output_final_state: bool = False +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "FusedChunkDeltaRuleFunction does not support float32. Please use bfloat16." + o, final_state = ChunkDeltaRuleFunction.apply(q, k, v, beta,mask, BT, initial_state, output_final_state) + return o, final_state + + +def naive(q,k,w,u,initial_state,BT,r): + B,H,seq_len,dk = q.shape + dv = u.shape[-1] + NT = seq_len//BT + state = torch.empty(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + g = torch.zeros(B,H,NT,dk,dv,device=q.device,dtype=q.dtype) + v_new = torch.empty_like(u) + from einops import rearrange + q,k,w,u,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + q = q*(dk**-0.5) + v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + if initial_state is not None: + state[:,:,0,:,:] = initial_state + else: + state[:,:,0,:,:] = 0 + for i in range(NT): + ki = rearrange(k[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + ui = rearrange(u[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + wi = rearrange(w[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:]) + v_new[:,:,i,:,:,:] = v_newi#这里保存的结果是相等 + kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + g[:,:,i,:,:] = rearrange(kui,'b h r k v-> b h (r k) v') + if i+1 < seq_len//BT: + state[:,:,i+1,:,:] = state[:,:,i,:,:] + g[:,:,i,:,:] + q_r = rearrange(q,'b h n t (r d)->b h n t r d',r = r) + k_r = rearrange(k,'b h n t (r d)->b h n t r d',r = r) + s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + o1 = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_new) + o1 = o1.sum(dim=-2)#bhntv + o2 = torch.einsum('b h n t q,b h n q v-> b h n t v',q,state)#只看state 算的对不对 + o = o1 + o2 + return state,v_new,o,g + + +def delta_rule_recurrence(q, k, v, beta, mask): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + r = mask.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + if beta.ndim < v.ndim: + beta = beta[..., None] + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i].clone() + beta_i = beta[:, :, i] + _v = _v * beta_i + # kkt = torch.einsum('b h d,b h v->b h d v',_k*beta_i,_k) + kkt = torch.einsum('b h d,b h v->b h d v',_k,_k) + kkt = rearrange(kkt,' b h (r d) (l v)-> b h r d l v',r= r,l=r) + kkt = torch.einsum('b h r d l v,r l->b h r d l v',kkt,mask.to(kkt)) + kkt = rearrange(kkt,'b h r d l v-> b h (r d) (l v)') + iplr = torch.eye(d_k).to(q)-kkt + S = torch.einsum(' b h q k ,b h k v->b h q v',iplr,S.clone()) + _k.unsqueeze(-1) * _v.unsqueeze(-2) + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + return o + + +if __name__ =="__main__": + import sys + import time + # from einops import rearrange + # sys.path.append('/mnt/jfzn/msj/flash-linear-attention-main/legacy/training/fla2-copy') + torch.set_default_dtype(torch.bfloat16) + # seq_len = 128 + # b = 2 + # h = 2 + # k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # q = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 128), dim=-1, p=2)#d=128 + # v = torch.randn(b, h, seq_len, 128) + # beta = torch.rand(b, h, seq_len).sigmoid() + # require_grad = True + # BT = 16 + # r = 4 + # scale = 128**-0.5 + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + # w = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda()) + # u = torch.nn.functional.normalize(torch.randn(b,h,seq_len*r,128).cuda())#bhn tr d + # initial_state = torch.randn(b,h,128,128).cuda().contiguous() + # k, v, q, beta,w, u= map(lambda x: x.cuda().requires_grad_(require_grad).contiguous(), (k, v,q, beta,w,u)) + + # final_state = None + # if False: + # final_state = q.new_empty(q.shape[0], q.shape[1], q.shape[-1], v.shape[-1], + # dtype=torch.float32, requires_grad=False)#这部分不需要修正 + + # h_state, v_new = chunk_fwd_h_fn(k, w, u, BT, initial_state, final_state)#need change + # o2 = chunk_fwd_o_fn(q, k, v_new, h_state, BT)#need change + # o2 = rearrange(o2,'b h (n t) v-> b h n t v',n=seq_len//BT) + # do = torch.rand_like(o2) + # do_naive = do + # do = rearrange(do,'b h n t v-> b h (n t) v') + # dv0 = fwd_prepare_dv(q, k, do, r, BT)#qk do v_new#因此这个dv应该是一个w的shape finish + # #bhrtv + # #到这里计算结果相同 + # dh, dv = chunk_bwd_dhu_fn(q, k, w, do, dv0, BT)#new_dv dh + # ###到这里算的一样了 + # dq, dk, dw = chunk_bwd_dqkw_fn(q, k, v_new, w, h_state, dv, do, dh, BT)#需要dh和dv + + # #bhtrv + # # dv0 = rearrange(dv0,'b h r (n t) v-> b h n t r v',n=seq_len//BT) + # # dv = rearrange(dv,'b h (n t) r v->b h n t r v',n=seq_len//BT)#应该有二者在BT=1维度相等,yes + # # dh = rearrange(dh,'b h (n k) v->b h n k v',n=seq_len//BT) + # # 应该有 + # # NT = seq_len//BT + # # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # # v_new = torch.zeros_like(u) + # # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # wr = rearrange(w_na,'b h n (t r) k->b h n t r k',r = r) + # # dh0 = -torch.einsum('b h t r v,b h t r k-> b h k v ',dv[:,:,1,:,:,:],wr[:,:,1,:,:,:]) + # # dh0 += torch.einsum('b h t v,b h t q->b h q v',do_naive[:,:,1,:,:],q_na[:,:,1,:,:])*scale + # # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # # dh0 = rearrange(dh0,'b h (r k) v->b h r k v',r=r)#need b h t r v + # # dvs = dv0[:,:,0,:,:,:] + torch.einsum('b h r k v,b h t r k->b h t r v',dh0,k_r[:,:,0,:,:,:]) + # # dh0 = rearrange(dh0,'b h r k v->b h (r k) v') + # # #这样计算流程是对的 + # # print((dv[:,:,0,:,:,:]-dvs).abs().max()) + # # print((dh0-dh[:,:,0,:,:]).abs().max()) + + # #####here is naive + # NT = seq_len//BT + # state = torch.zeros(b,h,NT,128,128,device=q.device,dtype=q.dtype) + # v_new = torch.zeros_like(u) + # from einops import rearrange + # q_na,k_na,w_na,u_na,v_new = map(lambda x:rearrange(x,'b h (n t) d->b h n t d',n = NT),(q,k,w,u,v_new)) + # # u_na = u_na.detach().requires_grad_(True)#b h n (t r) d + # v_new = rearrange(v_new,'b h n (t r) d->b h n t r d',r = r) + # if initial_state is not None: + # state[:,:,0,:,:] = initial_state + # else: + # state[:,:,0,:,:] = 0 + # for i in range(NT): + # ki = rearrange(k_na[:,:,i,:,:],'b h t (r d)->b h t r d',r = r) + # ui = rearrange(u_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # wi = rearrange(w_na[:,:,i,:,:],'b h (t r) d->b h t r d',r = r) + # v_newi = ui - torch.einsum('b h t r d, b h d v-> b h t r v',wi,state[:,:,i,:,:].clone()) + # v_new[:,:,i,:,:,:] = v_newi.clone()#这里保存的结果是相等 + # kui = torch.einsum('b h t r k,b h t r v-> b h r k v',ki,v_newi) + # if i+1 < seq_len//BT: + # state[:,:,i+1,:,:] = state[:,:,i,:,:].clone() + rearrange(kui,'b h r k v-> b h (r k ) v') + # q_r = rearrange(q_na,'b h n t (r d)->b h n t r d',r = r)*scale + # k_r = rearrange(k_na,'b h n t (r d)->b h n t r d',r = r) + # s_r = torch.einsum('b h n t r d,b h n l r d->b h n r t l',q_r,k_r) + # s_r = torch.tril(s_r,diagonal=0)#mask get bhnrtl + # v_newnew = v_new#.detach().requires_grad_(True) + # os = torch.einsum('b h n r t l, b h n l r v-> b h n t r v',s_r,v_newnew) + # os = os.sum(dim=-2)#bhntv + # oss = torch.einsum('b h n t q,b h n q v-> b h n t v',q_na,state)*scale#只看state 算的对不对 + # o_naive = os + oss + # # o_naive = rearrange(o_naive,'b h n t v->b h (n t) v') + # o_naive.backward(do_naive,retain_graph=True) + # una_grad = u.grad + # w_grad = w.grad + # k_grad = k.grad + # q_grad = q.grad + # print((una_grad-dv).abs().max())#基本相等 + # print((k_grad-dk).abs().max()) + # print((w_grad-dw).abs().max()) + # print((q_grad-dq).abs().max()) + # print(k_grad) + # print(dk) + + B = 2 + H = 1 + L = 128 + DK = 256 + DV = 256 + q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) + k = (torch.randn(B, H, L, DK)).cuda() + k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) + v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) + beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) + # mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + mask = torch.tensor([[1,1,0,0],[1,1,1,0],[0,1,1,1],[0,0,1,1]],requires_grad=False).cuda().contiguous() + + start = time.time() + o1 = delta_rule_recurrence(q,k,v,beta,mask) + do = torch.randn(B, H, L, DV).cuda() + o1.backward(do, retain_graph=True) + q_grad, q.grad = q.grad, None + k_grad, k.grad = k.grad, None + v_grad, v.grad = v.grad, None + beta_grad, beta.grad = beta.grad, None + end = time.time() + print(end-start) + + # start = time.time() + # w, u, A = fwd_prepare_wy_repr(k, v,beta, mask, 64) + o,f_state = mask_chunk_delta_rule(q, k, v, beta,mask,BT=32) + o.backward(do,retain_graph=True) + q_grad0, q.grad = q.grad, None + k_grad0, k.grad = k.grad, None + v_grad0, v.grad = v.grad, None + beta_grad0, beta.grad = beta.grad, None + # end = time.time() + # print(end-start) + print((o1-o).abs().max()) + print((q_grad-q_grad0).abs().max()) + print((k_grad-k_grad0).abs().max())#计算结果差距大 差距到1 + print((v_grad-v_grad0).abs().max()) + print((beta_grad-beta_grad0).abs().max()) + # print(beta_grad) + # print(beta_grad0) + print(k_grad) + print(k_grad0) + + + + diff --git a/fla2/ops/mask_gated_delta_rule_t/recurrent_fuse.py b/fla2/ops/mask_gated_delta_rule_t/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..f21470ff11d7e75df52b0c81dcb66bd40a44a0e5 --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/recurrent_fuse.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from ...utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + beta, # beta [B, H, L] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + s_qk_h, # stride size: L * K + s_vo_h, # stride size: L * V + scale, # K ** -0.5 + B, # batch size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _v_minus = tl.sum(h * b_k[None, :], axis=1) + b_v -= _v_minus + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # in-place overwrite + tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv) + b_v *= b_beta + h += b_k[None, :] * b_v[:, None] + _o = h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + p_beta += V if IS_HEADWISE_BETA else 1 + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + beta, # beta [B, H, L, (V)] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + dbeta, # gradient of beta [NV, (NK), B, H, L] + + # initial hidden state initialization [B, H, K, V] + h0, + + s_qk_h, # stride size: L * K + + s_vo_h, # stride size: L * V + + NK, # NK block size + scale, # K ** -0.5 + + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + IS_HEADWISE_BETA: tl.constexpr, # whether beta is headwise vector or scalar +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + else: + p_beta = beta + i_bh * T + T - 1 + + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + if IS_HEADWISE_BETA: + p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V + else: + p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1 + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v) + d_v = d_v * b_beta + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + if IS_HEADWISE_BETA: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv) + else: + tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty)) + + d_h -= b_k[:, None] * d_v[None, :] + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + p_dbeta -= V if IS_HEADWISE_BETA else 1 + p_beta -= V if IS_HEADWISE_BETA else 1 + + tl.debug_barrier() + + h = tl.zeros([BK, BV], dtype=tl.float32) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + if IS_HEADWISE_BETA: + p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + else: + p_beta = beta + i_bh * T + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + V + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + K + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + if IS_HEADWISE_BETA: + b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + + h += b_k[:, None] * b_v[None, :] + _d_q = h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + if i < T - 1: + d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32) + d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32) + d_k -= tl.sum(d_v[None, :] * h, axis=1) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dk += K + p_dv += V + p_dq += K + p_beta += V if IS_HEADWISE_BETA else 1 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @contiguous + @staticmethod + def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False): + B, H, T, K, V = *q.shape, v.shape[-1] + + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 1 + assert NK == 1, "NK > 1 is not supported yet" + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V) + else: + final_state = None + + grid = (NV, NK, B * H) + fused_recurrent_fwd_kernel[grid]( + q, k, v, beta, o, initial_state, final_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + IS_HEADWISE_BETA=beta.ndim == v.ndim, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.sum(0) + ctx.save_for_backward(q, k, v, beta, initial_state) + ctx.scale = scale + return o, final_state + + @contiguous + @staticmethod + def backward(ctx, do, dht=None): + q, k, v, beta, initial_state = ctx.saved_tensors + B, H, T, K, V = *q.shape, v.shape[-1] + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 1 + num_warps = 2 + + beta_vector = beta.ndim == v.ndim + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + if beta_vector: + dbeta = q.new_empty(NV, NK, B, H, T, V) + else: + dbeta = q.new_empty(NV, B, H, T) + grid = (NV, NK, B * H) + + fused_recurrent_bwd_kernel[grid]( + q, k, v, beta, do, dq, dk, dv, dbeta, initial_state, + q.stride(1), + v.stride(1), + NK, scale, + B=B, H=H, T=T, K=K, V=V, + BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + IS_HEADWISE_BETA=beta_vector, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0) + return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, None, None + + +def mask_fused_recurrent_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor = None, + scale: float = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if initial_state is not None: + initial_state = initial_state.detach() + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/mask_gated_delta_rule_t/utils.py b/fla2/ops/mask_gated_delta_rule_t/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..173d6629c628bb6b5860a005cbc8ea85d7cf9b5e --- /dev/null +++ b/fla2/ops/mask_gated_delta_rule_t/utils.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from ...ops.delta_rule.wy_fast import prepare_wy_repr as prepare_wy_repr2 +from ...utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +# Inspired by "THE WY REPRESENTATION FOR PRODUCTS OF HOUSEHOLDER MATRICES" https://epubs.siam.org/doi/pdf/10.1137/0908009 +# o: cumprod +# o2: cumprodsum +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def fwd_prepare_wy_repr_kernel( + k, + v, + beta, + o, + o2, + T, + K, + V, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = tl.arange(0, BK) < K + mask_bv = tl.arange(0, BV) < V + mask_bk = mask_bk[None, :] & mask_bt[:, None] + mask_bv = mask_bv[None, :] & mask_bt[:, None] + # [BT, BK] + b_k = tl.load(p_k, mask=mask_bk, other=0) + # [BT,] + b_beta = tl.load(p_beta, mask=mask_bt, other=0).to(tl.float32) + # [BT, BV] + b_v = tl.load(p_v, mask=mask_bv, other=0) + b_v = (b_v * b_beta[:, None]).to(b_v.dtype) + # [BT, BK] + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + # [BT, BT] + b_A = tl.dot(b_kb, tl.trans(b_k), allow_tf32=False) + b_A = -tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + + for i in range(BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + b_A = b_A.to(b_k.dtype) + b_w = tl.dot(b_A, b_kb, allow_tf32=False) + b_u = tl.dot(b_A, b_v, allow_tf32=False) + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_o, b_w.to(p_o.dtype.element_ty), mask=mask_bk) + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_o2, b_u.to(p_o2.dtype.element_ty), mask=mask_bv) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def bwd_prepare_wy_repr_kernel( + k, v, beta, + o, o2, do, do2, + dk, dv, dbeta, + NT, K, V, T, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = k + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do = do + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_do2 = do2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + + p_beta = beta + i_bh * T + i_t * BT + tl.arange(0, BT) + mask_bt = (tl.arange(0, BT) + i_t * BT) < T + mask_bk = (tl.arange(0, BK) < K)[None, :] & mask_bt[:, None] + mask_bv = (tl.arange(0, BV) < V)[None, :] & mask_bt[:, None] + b_k, b_beta = tl.load(p_k, mask=mask_bk), tl.load(p_beta, mask=mask_bt) + + b_beta = b_beta.to(tl.float32) + A = tl.dot(b_k, tl.trans(b_k), allow_tf32=False) * b_beta[:, None] + A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], A, 0) + b_do = tl.load(p_do, mask=mask_bk).to(tl.float32) + b_dv = tl.load(p_do2, mask=mask_bv).to(tl.float32) + dA = tl.zeros([BT, BT], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i in range(BT-1, -1, -1): + mask = tl.arange(0, BT) == i + attn = tl.sum(tl.where(mask[:, None], A, 0), axis=0) + do_ = tl.sum(tl.where(mask[:, None], b_do, 0), axis=0) + dv_ = tl.sum(tl.where(mask[:, None], b_dv, 0), axis=0) + b_do = b_do - attn[:, None] * do_[None, :] + b_dv = b_dv - attn[:, None] * dv_[None, :] + tl.debug_barrier() + p_v = v + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_v = tl.load(p_v, mask=mask_bv) + b_dk += b_do * b_beta[:, None] + b_dbeta = tl.sum(b_do * b_k, axis=1) + b_dbeta += tl.sum(b_dv * b_v, axis=1) + b_v = None + + p_o = o + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + p_o2 = o2 + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + b_o = tl.load(p_o, mask=mask_bk) + b_o2 = tl.load(p_o2, mask=mask_bv) + + dA = -tl.dot(b_do.to(b_o.dtype), tl.trans(b_o), allow_tf32=False) + dA -= tl.dot(b_dv.to(b_o2.dtype), tl.trans(b_o2).to(b_o.dtype), + allow_tf32=False) + dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], dA, 0) + b_dv *= b_beta[:, None] + p_dv = dv + i_bh * T * V + (i_t * BT + tl.arange(0, BT)[:, None]) * V + tl.arange(0, BV)[None, :] + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_bv) + + b_dbeta += tl.sum(dA * tl.dot(b_k, tl.trans(b_k), allow_tf32=False), axis=1) + dA = dA * b_beta[:, None] + b_dk += tl.dot(tl.trans(dA.to(b_k.dtype)), b_k, allow_tf32=False) + b_dk += tl.dot(dA.to(b_k.dtype), b_k, allow_tf32=False) + p_dk = dk + i_bh * T * K + (i_t * BT + tl.arange(0, BT)[:, None]) * K + tl.arange(0, BK)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_bk) + p_dbeta = dbeta + i_bh * T + i_t * BT + tl.arange(0, BT) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), mask=mask_bt) + + +def fwd_prepare_wy_repr(k, v, beta, chunk_size): + B, H, T, K, V = *k.shape, v.shape[-1] + v_new = torch.empty_like(v) + o_cumdecay = torch.empty_like(k) + BT = chunk_size + NT = triton.cdiv(T, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + fwd_prepare_wy_repr_kernel[(NT, B*H)]( + k, v, beta, o_cumdecay, v_new, + T, K, V, BT, BK, BV + ) + return o_cumdecay, v_new + + +def bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, chunk_size): + b, h, l, d_k = do.shape + d_v = v.shape[-1] + BK = triton.next_power_of_2(d_k) + BV = triton.next_power_of_2(d_v) + c = chunk_size + BK = d_k + NT = triton.cdiv(l, c) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.zeros_like(beta) + bwd_prepare_wy_repr_kernel[(NT, b*h)]( + k, v, beta, + o_cumdecay, v_new, do, do2, + dk, dv, dbeta, + NT, d_k, d_v, l, chunk_size, BK, BV + ) + return dk, dv, dbeta + + +class WYRepresentationPrepration(torch.autograd.Function): + @contiguous + @autocast_custom_fwd + @staticmethod + def forward(ctx, k, v, beta, chunk_size): + o_cumdecay, v_new = fwd_prepare_wy_repr(k, v, beta, chunk_size) + ctx.chunk_size = chunk_size + ctx.save_for_backward(k.to(v), v, beta, o_cumdecay, v_new) + return o_cumdecay, v_new + + @contiguous + @autocast_custom_bwd + @staticmethod + def backward(ctx, do, do2): + k, v, beta, o_cumdecay, v_new = ctx.saved_tensors + dk, dv, dbeta = bwd_prepare_wy_repr(k, v, beta, o_cumdecay, v_new, do, do2, ctx.chunk_size) + return dk, dv, dbeta, None + + +prepare_wy_repr = WYRepresentationPrepration.apply + + +def naive(k, v, beta, chunk_size): + l_org = k.shape[2] + l_new = triton.next_power_of_2(l_org) + # pad k, v, beta + k = torch.cat([k, torch.zeros_like(k)[:, :, :l_new-l_org, :]], dim=2) + v = torch.cat([v, torch.zeros_like(v)[:, :, :l_new-l_org, :]], dim=2) + beta = torch.cat([beta, torch.zeros_like(beta)[:, :, :l_new-l_org]], dim=2) + + k, v = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), (k, v)) + # k = torch.nn.functional.normalize(k, dim=-1, p=2) + beta = rearrange(beta, 'b h (n c) -> b h n c', c=chunk_size) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=0) + k_beta = k * beta[..., None] + v = v * beta[..., None] + attn = (k @ k.transpose(-1, -2)).masked_fill_(mask, 0) + attn = attn * beta[..., None] + x = attn @ v + + o = torch.zeros_like(k) + o2 = torch.zeros_like(v) + + o[..., 0, :] = k_beta[..., 0, :].clone() + o2[..., 0, :] = x[..., 0, :].clone() + for i in range(1, chunk_size): + o_i = (o[..., :i, :]).clone() + o[..., i, :] = -(attn[..., i, :i, None] * o_i).sum(3) + k_beta[..., i, :] + o2_i = (o2[..., :i, :]).clone() + o2[..., i, :] = -(attn[..., i, :i, None] * o2_i).sum(3) + x[..., i, :] + return map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d')[:, :, :l_org], (o, v-o2)) + + +if __name__ == "__main__": + torch.set_default_dtype(torch.bfloat16) + seq_len = 2048 + b = 4 + h = 8 + k = torch.nn.functional.normalize(torch.randn(b, h, seq_len, 256), dim=-1, p=2) + v = torch.randn(b, h, seq_len, 256) + beta = torch.rand(b, h, seq_len).sigmoid() + require_grad = True + k, v, beta = map(lambda x: x.cuda().requires_grad_(require_grad), (k, v, beta)) + do = torch.rand_like(k) + do2 = torch.rand_like(v) + + print("Start warmup.") + o1, o2 = prepare_wy_repr(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + o3, o4 = prepare_wy_repr2(k, v, beta, 32) + # (o1 * do + o2 * do2).sum().backward() + print((o1 - o3).abs().max()) + print((o2 - o4).abs().max()) + + for i in range(30): + o1, o2 = prepare_wy_repr(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + o1, o2 = prepare_wy_repr2(k, v, beta, 32) + (o1 * do + o2 * do2).sum().backward() + + print("Done warmup.") + + import time + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + for i in range(200): + o1, o2 = prepare_wy_repr2(k, v, beta, 64) + (o1 * do + o2 * do2).sum().backward() + + torch.cuda.synchronize() + print(time.time() - start) diff --git a/fla2/ops/rebased/__init__.py b/fla2/ops/rebased/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec6a0cb31f7f635aa528cad753d5e19196a2028 --- /dev/null +++ b/fla2/ops/rebased/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_rebased + +__all__ = [ + 'parallel_rebased' +] diff --git a/fla2/ops/rebased/__pycache__/__init__.cpython-312.pyc b/fla2/ops/rebased/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db79130931b4134ead3f3bbe3120c0606ebe6f3 Binary files /dev/null and b/fla2/ops/rebased/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/rebased/__pycache__/__init__.cpython-38.pyc b/fla2/ops/rebased/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..651306136350d13f1d33d25e03b9359972bf0d54 Binary files /dev/null and b/fla2/ops/rebased/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/rebased/__pycache__/__init__.cpython-39.pyc b/fla2/ops/rebased/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39614e3b318ff3bd9395e979b4fab35d70661ef3 Binary files /dev/null and b/fla2/ops/rebased/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/rebased/__pycache__/parallel.cpython-312.pyc b/fla2/ops/rebased/__pycache__/parallel.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2bbfdb2b17fe78683f579218d26ad0142fc8b7a Binary files /dev/null and b/fla2/ops/rebased/__pycache__/parallel.cpython-312.pyc differ diff --git a/fla2/ops/rebased/__pycache__/parallel.cpython-38.pyc b/fla2/ops/rebased/__pycache__/parallel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edfdec554c79cb15531e5f52dedd7e30e81a935f Binary files /dev/null and b/fla2/ops/rebased/__pycache__/parallel.cpython-38.pyc differ diff --git a/fla2/ops/rebased/__pycache__/parallel.cpython-39.pyc b/fla2/ops/rebased/__pycache__/parallel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d06dee2f858a5c83fef4921556d441f1303e8e7 Binary files /dev/null and b/fla2/ops/rebased/__pycache__/parallel.cpython-39.pyc differ diff --git a/fla2/ops/rebased/naive.py b/fla2/ops/rebased/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..e9436a0802c964485354082dcc9fbcd437e5d7f7 --- /dev/null +++ b/fla2/ops/rebased/naive.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +import torch + +from fla.ops.rebased.parallel import parallel_rebased + + +def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): + if use_scale: + q = q * (q.shape[-1] ** -0.5) + attn = q @ k.transpose(-2, -1) + attn = (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +if __name__ == "__main__": + B = 4 + H = 4 + L = 128 + # D = 15 + dtype = torch.float32 + q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) + v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) + + do = torch.randn_like(v).cuda() + ref = naive_parallel_rebased(q, k, v, True, True) + ref.backward(do, retain_graph=True) + ref_dq, q.grad = q.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dv, v.grad = v.grad.clone(), None + + tri = parallel_rebased(q, k, v, 1e-6, True, True) + tri.backward(do, retain_graph=True) + tri_dq, q.grad = q.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dv, v.grad = v.grad.clone(), None + print((ref-tri).abs().max()) + print((ref_dq-tri_dq).abs().max()) + print((ref_dk-tri_dk).abs().max()) + print((ref_dv-tri_dv).abs().max()) diff --git a/fla2/ops/rebased/parallel.py b/fla2/ops/rebased/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb8541f20761cc91f54974e9b92755350ba8aca --- /dev/null +++ b/fla2/ops/rebased/parallel.py @@ -0,0 +1,428 @@ + +# -*- coding: utf-8 -*- + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +# Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models +# https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py + + +@triton.jit +def parallel_rebased_fwd_kernel( + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + z, # normalizer [B, H, L] + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + scale, # D_head_K ** -0.5 + B, # batch size + H, # H + T, # T + K: tl.constexpr, # D_head_K + V: tl.constexpr, # D_head_V + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), + mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_rebased_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + i_h, + q, + k, + v, + do, + dz, + dq, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_q = (b_q * scale).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), + (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), + (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype), + b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), + (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), + (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load( + p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros( + [BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \ + scale # [BTL, BTS] + b_s2 = b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr( + q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr( + do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), + tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, + (T, K), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, + (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_rebased_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_rebased_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, scale, + B=B, H=H, T=T, K=K, V=V, BTL=BTL, BTS=BTS, BK=BK, BV=BV + ) + tl.debug_barrier() + _parallel_rebased_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, + scale, + B=B, H=H, T=T, K=K, V=V, BTL=BTL, BTS=BTS, BK=BK, BV=BV + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_rebased_fwd_kernel[grid]( + q, k, v, o, z, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_rebased_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if use_scale: + scale = q.shape[-1] ** -0.5 + else: + scale = 1 + o, z = triton_parallel_based(q, k, v, scale) + if return_both: + return o, z + if use_normalize: + o = o / (z[..., None] + eps) + else: + o = o + return o.to(q.dtype) diff --git a/fla2/ops/retention/__init__.py b/fla2/ops/retention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f29d7fbf5f36c7a2ba6a3b8c6bfa9f7ea19096 --- /dev/null +++ b/fla2/ops/retention/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_retention +from .chunk_fuse import fused_chunk_retention +from .parallel import parallel_retention +from .recurrent_fuse import fused_recurrent_retention + +__all__ = [ + 'chunk_retention', + 'fused_chunk_retention', + 'parallel_retention', + 'fused_recurrent_retention' +] diff --git a/fla2/ops/retention/__pycache__/__init__.cpython-312.pyc b/fla2/ops/retention/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454882497ad4f48720b3fa3096109af0c60a6aa7 Binary files /dev/null and b/fla2/ops/retention/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla2/ops/retention/__pycache__/__init__.cpython-38.pyc b/fla2/ops/retention/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47494dd8b581d96e74d59703468d25990a36056b Binary files /dev/null and b/fla2/ops/retention/__pycache__/__init__.cpython-38.pyc differ diff --git a/fla2/ops/retention/__pycache__/__init__.cpython-39.pyc b/fla2/ops/retention/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f68f499f0109e08c7d2099adf711a7bdd33e4f10 Binary files /dev/null and b/fla2/ops/retention/__pycache__/__init__.cpython-39.pyc differ diff --git a/fla2/ops/retention/__pycache__/chunk.cpython-312.pyc b/fla2/ops/retention/__pycache__/chunk.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996d3989c6e9a808f49809cabcc11c537441463e Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk.cpython-312.pyc differ diff --git a/fla2/ops/retention/__pycache__/chunk.cpython-38.pyc b/fla2/ops/retention/__pycache__/chunk.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3ada5d16b7bf86fcf6f970a18d8034f7d62524 Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk.cpython-38.pyc differ diff --git a/fla2/ops/retention/__pycache__/chunk.cpython-39.pyc b/fla2/ops/retention/__pycache__/chunk.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2024cf7c45bf94112741d9b7c70726d4a52f3dcb Binary files /dev/null and b/fla2/ops/retention/__pycache__/chunk.cpython-39.pyc differ diff --git a/fla2/ops/retention/chunk.py b/fla2/ops/retention/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5e6d42e88787e4d3ee882f22d8eb229443ba41 --- /dev/null +++ b/fla2/ops/retention/chunk.py @@ -0,0 +1,438 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_retention_fwd_kernel_h( + k, + v, + h, + h0, # initial state of the chunk [B, H, K, V] + ht, # final state of the chunk [B, H, K, V] + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BK, BV] + if i_t == NT - 1 and (T % BT) != 0: + d_b = tl.math.exp2((T % BT) * b_b) + d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b) + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_retention_fwd_kernel_o( + q, + k, + v, + h, + o, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_i = tl.math.exp2((o_i + 1) * b_b) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_s = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False) + b_s += tl.dot(b_q, b_k, allow_tf32=False) + + b_s *= d_s + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale + p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_retention_bwd_kernel_dh( + q, + do, + dh, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=["BT", "BK", "BV"], +) +@triton.jit +def chunk_retention_bwd_kernel_dqkv( + q, + k, + v, + h, + do, + dh, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + s_h_h, + s_h_t, + scale, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + n_bh = tl.num_programs(2) + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + o_i = tl.arange(0, BT) + d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + d_q = (d_q * scale).to(d_q.dtype) + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_ds = (b_ds * d_s).to(b_q.dtype) + # [BT, BK] + b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False) + b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False)) + + p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + final_state = None + if output_final_state: + final_state = k.new_empty(B, H, K, V, dtype=torch.float32) + + BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V)) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + h = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_retention_fwd_kernel_h[grid]( + k, v, h, initial_state, final_state, + k.stride(1), k.stride(2), k.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state + ) + return h, final_state + + +def chunk_fwd_o_fn(h, q, k, v, BT, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + o = torch.empty_like(v) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) + grid = (NV, NT, B * H) + chunk_retention_fwd_kernel_o[grid]( + q, k, v, h, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV + ) + return o + + +def chunk_bwd_dh_fn(do, q, k, v, BT, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = k.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_retention_bwd_kernel_dh[grid]( + q, do, dh, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + dh.stride(1), dh.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT + ) + return dh + + +def chunk_bwd_dqkv_fn(do, q, k, v, h, dh, scale): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + NT, NK = triton.cdiv(T, BT), triton.cdiv(K, BK) + grid = (NK, NT, B * H) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = v.new_empty(NK, *v.shape) + chunk_retention_bwd_kernel_dqkv[grid]( + q, k, v, h, do, dh, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + h.stride(1), h.stride(2), + scale, + H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT + ) + dv = dv.sum(0) + return dq, dk, dv + + +class ChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_level): + BT = 64 + h, final_state = chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state) + o = chunk_fwd_o_fn(h, q, k, v, BT, scale) + if checkpoint_level == 1: + h = None + ctx.save_for_backward(q, k, v, h, initial_state) + ctx.BT, ctx.scale = BT, scale + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, d_ht=None): + BT, scale = ctx.BT, ctx.scale + q, k, v, h, initial_state = ctx.saved_tensors + if h is None: + h, _ = chunk_fwd_h_fn(k, v, BT, initial_state, False) + dh = chunk_bwd_dh_fn(do, q, k, v, BT, scale) + dq, dk, dv = chunk_bwd_dqkv_fn(do, q, k, v, h, dh, scale) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None, None + + +def chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + scale: float = None, + checkpoint_level: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + checkpoint_level (Optional[int]): + Checkpointing level; higher values will save more memories and do more recomputations during backward. + Default: `1` (recommended): + - Level `0`: no memory saved, no recomputation. + - Level `1`: recompute the chunk-level hidden state `h` during backward pass. + """ + assert checkpoint_level in [0, 1], "checkpoint_level must be 0, 1" + assert q.dim() == k.dim() == v.dim() == 4, "q, k, v must have 4 dimensions (b, h, l, d)" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + if scale is None: + scale = q.size(-1) ** -0.5 + o, final_state = ChunkRetentionFunction.apply( + q, k, v, initial_state, output_final_state, scale, checkpoint_level) + return o, final_state diff --git a/fla2/ops/retention/chunk_fuse.py b/fla2/ops/retention/chunk_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..ca98bfe97bf46407da9756d8eb8a91db114a44be --- /dev/null +++ b/fla2/ops/retention/chunk_fuse.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_chunk_retention_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + h0, # initial state of the chunk [B, H, K, V] + ht, # final state of the chunk [B, H, K, V] + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # batch size + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + + # d_b: overall decay for the entire chunk + # d_o: cumulative decay from the start of the chunk + # d_h: cumulative decay from the end of the chunk + d_b, d_o, d_h = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + NT = tl.cdiv(T, BT) + for i in range(0, NT): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) * d_o[:, None] + if i == NT - 1 and (T % BT) != 0: + d_b = tl.math.exp2((T % BT) * b_b) + d_h = tl.math.exp2(((T % BT) - o_i - 1) * b_b) + b_h = d_b * b_h + tl.dot(b_k, (b_v * d_h[:, None]).to(b_k.dtype), allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_retention_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + + h0, # initial state of the chunk [B, H, K, V] + + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + o_i = tl.arange(0, BT) + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + d_q, d_k = tl.math.exp2((o_i+1) * b_b) * scale, tl.math.exp2((BT - o_i - 1) * b_b) + d_b = tl.math.exp2(BT * b_b) + + m_s = o_i[:, None] >= o_i[None, :] + d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + # [BT, K] + b_dq = tl.dot(b_ds, b_k, allow_tf32=False) + # [V, K] + if CHECK and i == 0: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_dd, b_h.to(b_k.dtype), allow_tf32=False) + b_h = d_b * b_h + tl.dot((b_v * d_k[None, :]).to(b_k.dtype), b_k, allow_tf32=False) + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + d_s = tl.trans(d_s) + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * s_qk_h, (T, K), (s_qk_t, s_qk_d), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [K, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dd = (b_do * d_q[:, None]).to(b_do.dtype) + + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = (b_ds * d_s).to(b_k.dtype) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) * d_k[:, None] + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) * d_k[:, None] + b_dh = d_b * b_dh + tl.dot(b_q, b_dd, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = K ** -0.5 + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + o = q.new_empty(NK, B, H, T, V) + + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = K ** -0.5 + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None + + +def fused_chunk_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + o, final_state = FusedChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/retention/naive.py b/fla2/ops/retention/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..15611bf649779d2d956d2ab390b7d72dbb12201d --- /dev/null +++ b/fla2/ops/retention/naive.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- + +import torch + + +def naive_retention(q, k, v): + orig_type = q.dtype + q, k, v = q.float(), k.float(), v.float() + _, n_heads, seq_len, d_head = q.shape + s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() + n = q.new_tensor(range(seq_len), dtype=torch.float) + n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) + s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) + o = torch.einsum('bhqk,bhkd->bhqd', s, v) + return o.to(orig_type) diff --git a/fla2/ops/retention/parallel.py b/fla2/ops/retention/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a62a8d9c785d855bc772895adf11b71903baf6 --- /dev/null +++ b/fla2/ops/retention/parallel.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous + + +@triton.jit +def parallel_retention_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + s_qk_h, # stride size: L * K + s_qk_t, # stride size: K + s_qk_d, # stride size: 1 + s_vo_h, # stride size: L * V + s_vo_t, # stride size: V + s_vo_d, # stride size: 1 + scale, # K ** -0.5 + B: tl.constexpr, # batch size + H: tl.constexpr, # H + T: tl.constexpr, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q + BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # cumulative decay from the end of the chunk + o_k = tl.arange(0, BTS) + d_h = tl.math.exp2((BTS - o_k) * b_b) + + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_o = b_o * tl.math.exp2(b_b * BTS) + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + d_q = tl.math.exp2(tl.arange(0, BTL) * b_b) + b_o *= d_q[:, None] + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1)) + # decay rate given the head index + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # cumulative decay from the end of the chunk + d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b) + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :] + # [BQ, BD] + b_dq *= d_b + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (o_q[:, None] - o_k[None, :]) * b_b), 0) + b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale + # [BTL, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, + s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # no overlap. no need for mask. + b_b = tl.math.log2(1 - tl.math.pow(2, -5 - i_h * 1.0)) + # overall decay rate for an entire block + d_b = tl.math.exp2(b_b * BTS) + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32) + d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b) + b_kd = (b_k * d_h[:, None]).to(b_k.dtype) + d_q = tl.math.exp2(tl.arange(0, BTS) * b_b) + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS] + b_do = (b_do * d_q[None, :]).to(b_do.dtype) + + b_dv *= d_b + b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS] + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + + b_dk *= d_b + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dk *= d_h[:, None] * scale + b_dv *= scale + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (V, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + d_s = tl.where(m_s, tl.math.exp2( + (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale + b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s + # [BK, BD] + b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + o_q += BTS + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h, (T, K), + (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h, (T, V), + (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def parallel_retention_bwd_kernel( + q, + k, + v, + do, + dq, + dk, + dv, + s_qk_h, + s_qk_t, + s_qk_d, + s_vo_h, + s_vo_t, + s_vo_d, + scale, + B: tl.constexpr, + H: tl.constexpr, + T: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + i_h = i_bh % H + _parallel_retention_bwd_dq( + i_bh, i_c, i_k, i_v, i_h, + k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV + ) + tl.debug_barrier() + _parallel_retention_bwd_dkv( + i_bh, i_c, i_k, i_v, i_h, + q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, + s_vo_t, s_vo_d, scale, + B, H, T, K, V, + BTL, BTS, BK, BV + ) + + +class ParallelRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + scale = K ** -0.5 + o = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + parallel_retention_fwd_kernel[grid]( + q, k, v, o, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + return o.sum(0).to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v = ctx.saved_tensors + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 3 if K <= 64 else 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + scale = K ** -0.5 + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + scale, + B=B, H=H, T=T, K=K, V=V, + BTL=BTL, BTS=BTS, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype) + + +parallel_retention = ParallelRetentionFunction.apply diff --git a/fla2/ops/retention/recurrent_fuse.py b/fla2/ops/retention/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..d529ea6e51ff98ec112ab12a8d7ad9bb2d77cb60 --- /dev/null +++ b/fla2/ops/retention/recurrent_fuse.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023, Yu Zhang, Songlin Yang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import contiguous + +# on-the-fly computation without materializing hidden statets into HBMs + + +@triton.jit +def fused_recurrent_retention_fwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + o, # output [B, H, L, D_head_V] + initial_state, + final_state, # final hidden state [B, H, D_head_K, D_head_V] + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + # decay rate given the head index + b_b = (1 - tl.math.pow(2, -5 - i_h * 1.0)) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < DK + mask_bv = (i_v * BV + tl.arange(0, BV)) < DV + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + h = b_b * h + _k[None, :] * _v[:, None] + _o = h * _q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += DK + p_k += DK + p_o += DV + p_v += DV + + if STORE_FINAL_STATE: + p_final_s = final_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[None, :]) * \ + DV + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_retention_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: d_head + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, D_head_K] + k, # key [B, H, L, D_head_V] + v, # value [B, H, L, D_head_V] + + do, # gradient of output [B, H, L, D_head_V] + dq, # gradient of query [NV, B, H, L, D_head_K] + dk, # gradient of key [NV, B, H, L, D_head_K] + dv, # gradient of value [NK, B, H, L, D_head_V] + + # initial hidden state initialization [B, H, D_head_K, D_head_V] + initial_state, + + s_qk_h, # stride size: L * D_head_K + s_qk_t, # stride size: D_head_K + s_qk_d, # stride size: 1 + + s_vo_h, # stride size: L * D_head_V + s_vo_t, # stride size: D_head_V + s_vo_d, # stride size: 1 + + B, # batch_size + H, # n_heads + T, # seq_len + scale, # D_head_K ** -0.5 + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + DK: tl.constexpr, # D_head_K + DV: tl.constexpr, # D_head_V + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h = i_bh % H + + b_b = 1 - tl.math.pow(2, -5 - i_h * 1.0) + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < DK + mask_bv = i_v * BV + tl.arange(0, BV) < DV + + h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_init_s = initial_state + i_bh * DK * DV + \ + (i_k * BK + tl.arange(0, BK)[:, None]) * \ + DV + (i_v * BV + tl.arange(0, BV)[None, :]) + h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32) + + for i in range(0, T): + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + h = b_b * h + _k[:, None] * _v[None, :] + _d_q = h * _do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += DK + p_do += DV + p_v += DV + p_dq += DK + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK + p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV + p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \ + BK + tl.arange(0, BK) + (T - 1) * DK + p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \ + BV + tl.arange(0, BV) + (T - 1) * DV + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += _q[:, None] * _do[None, :] + d_k = tl.sum(d_h * _v[None, :], axis=1) + d_v = tl.sum(d_h * _k[:, None], axis=0) + + d_h *= b_b + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= DV + p_q -= DK + p_k -= DK + p_v -= DV + p_dk -= DK + p_dv -= DV + + +class FusedRecurrentRetentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + def forward(ctx, q, k, v, initial_state=None, output_final_state=False): + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + + scale = d_head_qk ** -0.5 + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + + if output_final_state: + final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v) + else: + final_state = None + + grid = (NV, NK, batch_size * n_heads) + fused_recurrent_retention_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + return o, final_state + + @staticmethod + @contiguous + def backward(ctx, do, d_final_state=None): + q, k, v, initial_state = ctx.saved_tensors + batch_size, n_heads, seq_len, d_head_qk = q.shape + d_head_v = v.shape[-1] + scale = d_head_qk ** -0.5 + + BK, BV = min(d_head_qk, 32), min(d_head_v, 32) + NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) + num_stages = 1 + num_warps = 1 + + dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) + dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) + grid = (NV, NK, batch_size * n_heads) + + fused_recurrent_retention_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), q.stride(2), q.stride(3), + v.stride(1), v.stride(2), v.stride(3), + batch_size, n_heads, seq_len, scale, + DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages, + USE_INITIAL_STATE=initial_state is not None + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None + + +# fused_recurrent_retention = FusedRecurrentRetentionFunction.apply + +def fused_recurrent_retention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if initial_state is not None: + initial_state = initial_state.detach() + o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state) + return o, final_state diff --git a/fla2/ops/rwkv4/__init__.py b/fla2/ops/rwkv4/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae23a00c1673d1b3f60611d781c66dc8c0e83095 --- /dev/null +++ b/fla2/ops/rwkv4/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .recurrent_fuse import fused_recurrent_rwkv4 + +__all__ = [ + 'fused_recurrent_rwkv4' +] diff --git a/fla2/ops/rwkv4/recurrent_fuse.py b/fla2/ops/rwkv4/recurrent_fuse.py new file mode 100644 index 0000000000000000000000000000000000000000..3232087af98dd9dd84957afdd709ec292956a809 --- /dev/null +++ b/fla2/ops/rwkv4/recurrent_fuse.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- +# adopted from https://github.com/codekansas/rwkv + +from typing import Any, cast + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.autograd.function import Function, FunctionCtx, once_differentiable + + +def get_block_size_c(chans: int) -> int: + if chans < 32: + return 32 + if chans < 64: + return 64 + return 128 + + +@triton.jit +def fused_recurrent_rwkv4_forward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_c, + # WKV + wkv_ptr, + wkv_s_b, + wkv_s_t, + wkv_s_c, + # Output state + state_out_ptr, + state_out_s_b, + state_out_s_abe, + state_out_s_t, + state_out_s_c, + # Params + chans, + tsz, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + wkv_ptr = wkv_ptr + b_idx * wkv_s_b + alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b + beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe + eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe + + # Loads parameters. + alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) + + for t in range(tsz): + kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps) + e1a = tl.exp(eps - tau) + e2a = tl.exp(ukt - tau) + wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) + tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) + + w_eps = w + eps + eps = tl.maximum(w_eps, kt) + e1b = tl.exp(w_eps - eps) + e2b = tl.exp(kt - eps) + alpha = e1b * alpha + e2b * vt + beta = e1b * beta + e2b + tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) + tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) + tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) + + +def fused_recurrent_rwkv4_forward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, +) -> tuple[Tensor, Tensor]: + (bsz, tsz, chans) = k.shape + + # New tensors to output. + wkvs = k.new_empty(bsz, tsz, chans) + state_out = k.new_empty(bsz, 3, tsz, chans) + + # Constants. + block_size_c = get_block_size_c(chans) + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_forward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(3), + # WKV + wkvs, + wkvs.stride(0), + wkvs.stride(1), + wkvs.stride(2), + # Output state + state_out, + state_out.stride(0), + state_out.stride(1), + state_out.stride(2), + state_out.stride(3), + # Params + chans, + tsz, + BLOCK_SIZE_C=block_size_c, + ) + + state_out = torch.cat((state, state_out), dim=2) + + return wkvs, state_out + + +@triton.jit +def fused_recurrent_rwkv4_backward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_t, + state_s_c, + # WKV grad + gwkv_ptr, + gwkv_s_b, + gwkv_s_t, + gwkv_s_c, + # Output state grad + gstate_out_ptr, + gstate_out_s_b, + gstate_out_s_abe, + gstate_out_s_c, + # W grad + gw_ptr, + gw_s_c, + # U grad + gu_ptr, + gu_s_c, + # K grad + gk_ptr, + gk_s_b, + gk_s_t, + gk_s_c, + # V grad + gv_ptr, + gv_s_b, + gv_s_t, + gv_s_c, + # State grad + gstate_ptr, + gstate_s_b, + gstate_s_abe, + gstate_s_c, + # Params + tsz, + chans, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + gk_ptr = gk_ptr + b_idx * gk_s_b + gv_ptr = gv_ptr + b_idx * gv_s_b + + # Pointers to gradients which were recieved by the function. + gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b + galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe + geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe + + # Loads parameters. + galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) + + # Gradient accumulators. + gw = tl.zeros_like(w) + gu = tl.zeros_like(u) + + alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + for t in range(tsz): + tc = tsz - t - 1 + + kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) + + alpha_curr = alpha_prev + beta_curr = beta_prev + eps_curr = eps_prev + + alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps_prev) + e1 = tl.exp(eps_prev - tau) + e2 = tl.exp(ukt - tau) + + euke = tl.exp(ukt + eps_prev - 2 * tau) + + denom = e1 * beta_prev + e2 + denom_sq = denom * denom + + gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) + + # Backpropagates wkv gradients. + guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq + gu += guk + gk = guk + gv = gwkvt * e2 / denom + + galpha_wkv = gwkvt * e1 / denom + gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq + geps_wkv_denom = e1 * beta_prev + e2 + geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) + + e1 = tl.exp(w + eps_prev - eps_curr) + e2 = tl.exp(kt - eps_curr) + + # Backpropagates alpha gradients. + galpha_we = galpha * e1 * alpha_prev + gw += galpha_we + gk += galpha * e2 * vt + gv += galpha * e2 + geps += galpha * -alpha_curr + + # Backpropagates beta gradients. + gbeta_we = gbeta * e1 * beta_prev + gw += gbeta_we + gk += gbeta * e2 + geps += gbeta * -beta_curr + + # Backpropagates epsilon gradients. + geps_mask = w + eps_prev > kt + geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) + gw += geps_we + gk += tl.where(geps_mask, tl.zeros_like(geps), geps) + + # Stores the gradients for k and v. + tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) + tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) + + # Computes new gradients for alpha and beta. + galpha = galpha * e1 + galpha_wkv + gbeta = gbeta * e1 + gbeta_wkv + geps = galpha_we + gbeta_we + geps_we + geps_wkv + + # Stores final gradients for alpha and beta. + galpha_ptr = gstate_ptr + b_idx * gstate_s_b + gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe + geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe + tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) + tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) + tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) + + # Stores final gradients for w and u. + gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32) + gw_temp += gw + tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask) + gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32) + gu_temp += gu + tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask) + + +def fused_recurrent_rwkv4_backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + grad_wkv: Tensor, + grad_state: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + bsz, tsz, chans = k.shape + + gw = torch.zeros_like(w) # New tensors to output. + gu = torch.zeros_like(u) + gk = torch.empty_like(k) + gv = torch.empty_like(v) + gstate = k.new_empty(bsz, 3, 1, chans) + + block_size_c = get_block_size_c(chans) # Constants. + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_backward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # WKV grad + grad_wkv, + grad_wkv.stride(0), + grad_wkv.stride(1), + grad_wkv.stride(2), + # Output state grad + grad_state, + grad_state.stride(0), + grad_state.stride(1), + grad_state.stride(3), + # W grad + gw, + gw.stride(0), + # U grad + gu, + gu.stride(0), + # K grad + gk, + gk.stride(0), + gk.stride(1), + gk.stride(2), + # V grad + gv, + gv.stride(0), + gv.stride(1), + gv.stride(2), + # State grad + gstate, + gstate.stride(0), + gstate.stride(1), + gstate.stride(3), + # Params + tsz, + chans, + BLOCK_SIZE_C=block_size_c, + ) + + return gw, gu, gk, gv, gstate + + +class FusedRecurrentRWKV4Function(Function): + @staticmethod + def forward( + ctx: FunctionCtx, + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + ) -> tuple[Tensor, Tensor]: + ctx.input_dtype = k.dtype + + if ( + w.device.type != "cuda" + or u.device.type != "cuda" + or k.device.type != "cuda" + or v.device.type != "cuda" + ): + raise ValueError( + "Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices." + ) + + w = -torch.exp(w.float().contiguous()) + if k.dtype == torch.float16: + u = u.float() + k = k.float() + v = v.float() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) + ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) + return wkv, state_out[:, :, -1:] + + @staticmethod + @once_differentiable + def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) + gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) + return gw, gu, gk, gv, gstate + + +def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: + return FusedRecurrentRWKV4Function.apply(w, u, k, v, state) diff --git a/fla2/ops/rwkv6/__init__.py b/fla2/ops/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52f9fe7ea317f30e1bd78f3a13914e9c8774bfff --- /dev/null +++ b/fla2/ops/rwkv6/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv6 +from .recurrent_fuse import fused_recurrent_rwkv6 + +__all__ = [ + 'chunk_rwkv6', + 'fused_recurrent_rwkv6' +] diff --git a/fla2/ops/rwkv6/__pycache__/__init__.cpython-312.pyc b/fla2/ops/rwkv6/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb83db2ee036518e42bd34981ba155b6e5b326af Binary files /dev/null and b/fla2/ops/rwkv6/__pycache__/__init__.cpython-312.pyc differ diff --git a/fla3/__init__.py b/fla3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b282dde8d9a76df4c565788e7c36469e379a5d0b --- /dev/null +++ b/fla3/__init__.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- + +from fla.layers import ( + ABCAttention, + Attention, + BasedLinearAttention, + BitAttention, + DeltaNet, + GatedDeltaNet, + GatedDeltaProduct, + GatedLinearAttention, + GatedSlotAttention, + HGRN2Attention, + HGRNAttention, + LightNetAttention, + LinearAttention, + MultiScaleRetention, + NativeSparseAttention, + PaTHAttention, + ReBasedLinearAttention, + RWKV6Attention, + RWKV7Attention +) +from fla.models import ( + ABCForCausalLM, + ABCModel, + BitNetForCausalLM, + BitNetModel, + DeltaNetForCausalLM, + DeltaNetModel, + GatedDeltaNetForCausalLM, + GatedDeltaNetModel, + GatedDeltaProductForCausalLM, + GatedDeltaProductModel, + GLAForCausalLM, + GLAModel, + GSAForCausalLM, + GSAModel, + HGRN2ForCausalLM, + HGRN2Model, + HGRNForCausalLM, + HGRNModel, + LightNetForCausalLM, + LightNetModel, + LinearAttentionForCausalLM, + LinearAttentionModel, + NSAForCausalLM, + NSAModel, + PaTHAttentionForCausalLM, + PaTHAttentionModel, + RetNetForCausalLM, + RetNetModel, + RWKV6ForCausalLM, + RWKV6Model, + RWKV7ForCausalLM, + RWKV7Model, + TransformerForCausalLM, + TransformerModel +) + +__all__ = [ + 'ABCAttention', + 'Attention', + 'BasedLinearAttention', + 'BitAttention', + 'DeltaNet', + 'GatedDeltaNet', + 'GatedDeltaProduct', + 'GatedLinearAttention', + 'GatedSlotAttention', + 'HGRNAttention', + 'HGRN2Attention', + 'LightNetAttention', + 'LinearAttention', + 'MultiScaleRetention', + 'NativeSparseAttention', + 'PaTHAttention', + 'ReBasedLinearAttention', + 'RWKV6Attention', + 'RWKV7Attention', + 'ABCForCausalLM', + 'ABCModel', + 'BitNetForCausalLM', + 'BitNetModel', + 'DeltaNetForCausalLM', + 'DeltaNetModel', + 'GatedDeltaNetForCausalLM', + 'GatedDeltaNetModel', + 'GatedDeltaProductForCausalLM', + 'GatedDeltaProductModel', + 'GLAForCausalLM', + 'GLAModel', + 'GSAForCausalLM', + 'GSAModel', + 'HGRNForCausalLM', + 'HGRNModel', + 'HGRN2ForCausalLM', + 'HGRN2Model', + 'LightNetForCausalLM', + 'LightNetModel', + 'LinearAttentionForCausalLM', + 'LinearAttentionModel', + 'NSAForCausalLM', + 'NSAModel', + 'PaTHAttentionForCausalLM', + 'PaTHAttentionModel', + 'RetNetForCausalLM', + 'RetNetModel', + 'RWKV6ForCausalLM', + 'RWKV6Model', + 'RWKV7ForCausalLM', + 'RWKV7Model', + 'TransformerForCausalLM', + 'TransformerModel', +] + +__version__ = '0.2.2'