| |
| |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| from ...ops.utils import prepare_chunk_indices, prepare_chunk_offsets |
| from ...ops.utils.op import exp |
| from ...utils import is_nvidia_hopper, use_cuda_graph |
|
|
| NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] |
|
|
|
|
| @triton.heuristics({ |
| 'USE_G': lambda args: args['g'] is not None, |
| 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, |
| 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, |
| 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) |
| for num_warps in [2, 4] |
| for num_stages in [2, 3, 4] |
| for BV in [16, 32, 64] |
| ], |
| key=['H', 'K', 'V', 'BT', 'USE_G'], |
| use_cuda_graph=use_cuda_graph, |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( |
| k, |
| v, |
| d, |
| v_new, |
| g, |
| h, |
| h0, |
| ht, |
| cu_seqlens, |
| chunk_offsets, |
| T, |
| H: tl.constexpr, |
| K: tl.constexpr, |
| V: tl.constexpr, |
| BT: tl.constexpr, |
| BV: tl.constexpr, |
| USE_G: tl.constexpr, |
| USE_INITIAL_STATE: tl.constexpr, |
| STORE_FINAL_STATE: tl.constexpr, |
| SAVE_NEW_VALUE: tl.constexpr, |
| IS_VARLEN: tl.constexpr, |
| ): |
| i_v, i_nh = tl.program_id(0), tl.program_id(1) |
| i_n, i_h = i_nh // H, i_nh % H |
| if IS_VARLEN: |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
| T = eos - bos |
| NT = tl.cdiv(T, BT) |
| boh = tl.load(chunk_offsets + i_n).to(tl.int32) |
| else: |
| bos, eos = i_n * T, i_n * T + T |
| NT = tl.cdiv(T, BT) |
| boh = i_n * NT |
|
|
| |
| b_h1 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 64: |
| b_h2 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 128: |
| b_h3 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 192: |
| b_h4 = tl.zeros([64, BV], dtype=tl.float32) |
|
|
| |
| h += (boh * H + i_h) * K*V |
| v += (bos * H + i_h) * V |
| k += (bos * H + i_h) * K |
| d += (bos * H + i_h) * K |
| if SAVE_NEW_VALUE: |
| v_new += (bos * H + i_h) * V |
| stride_v = H*V |
| stride_h = H*K*V |
| stride_k = H*K |
| if USE_INITIAL_STATE: |
| h0 = h0 + i_nh * K*V |
| if STORE_FINAL_STATE: |
| ht = ht + i_nh * K*V |
|
|
| |
| if USE_INITIAL_STATE: |
| p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) |
| if K > 64: |
| p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) |
| if K > 128: |
| p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) |
| if K > 192: |
| p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) |
|
|
| |
| for i_t in range(NT): |
| p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 64: |
| p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 128: |
| p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 192: |
| p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), |
| (BT, BV), (1, 0)) if SAVE_NEW_VALUE else None |
| b_intermediate = tl.zeros([BT, BV], dtype=tl.float32) |
| p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| b_intermediate += tl.dot(b_d, b_h1.to(b_d.dtype)) |
| if K > 64: |
| p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| b_intermediate += tl.dot(b_d, b_h2.to(b_d.dtype)) |
| if K > 128: |
| p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| b_intermediate += tl.dot(b_d, b_h3.to(b_d.dtype)) |
| if K > 192: |
| p_d = tl.make_block_ptr(d, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| b_intermediate += tl.dot(b_d, b_h4.to(b_d.dtype)) |
| b_intermediate = -b_intermediate + tl.load(p_v, boundary_check=(0, 1)) |
| b_intermediate = b_intermediate.to(k.dtype.element_ty) |
| if SAVE_NEW_VALUE: |
| p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
| tl.store(p_v_new, b_intermediate, boundary_check=(0, 1)) |
|
|
| if USE_G: |
| last_idx = min((i_t + 1) * BT, T) - 1 |
| b_g_last = tl.load(g + bos * H + last_idx * H + i_h) |
| b_g_last = exp(b_g_last) |
| b_h1 = b_h1 * b_g_last |
| if K > 64: |
| b_h2 = b_h2 * b_g_last |
| if K > 128: |
| b_h3 = b_h3 * b_g_last |
| if K > 192: |
| b_h4 = b_h4 * b_g_last |
|
|
| p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_h1 += tl.dot(b_k, b_intermediate) |
| if K > 64: |
| p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_h2 += tl.dot(b_k, b_intermediate) |
| if K > 128: |
| p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_h3 += tl.dot(b_k, b_intermediate) |
| if K > 192: |
| p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_h4 += tl.dot(b_k, b_intermediate) |
|
|
| |
| if STORE_FINAL_STATE: |
| p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 64: |
| p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 128: |
| p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 192: |
| p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| @triton.heuristics({ |
| 'USE_G': lambda args: args['g'] is not None, |
| 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, |
| 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) |
| for num_warps in [2, 4] |
| for num_stages in [4, 3, 2, 1] |
| for BV in [64, 32, 16] |
| ], |
| key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'], |
| use_cuda_graph=use_cuda_graph, |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( |
| q, |
| k, |
| d, |
| g, |
| dht, |
| dh0, |
| do, |
| dh, |
| dv, |
| dv2, |
| cu_seqlens, |
| chunk_offsets, |
| scale, |
| T, |
| H: tl.constexpr, |
| K: tl.constexpr, |
| V: tl.constexpr, |
| BT: tl.constexpr, |
| BV: tl.constexpr, |
| USE_G: tl.constexpr, |
| USE_INITIAL_STATE: tl.constexpr, |
| USE_FINAL_STATE_GRADIENT: tl.constexpr, |
| IS_VARLEN: tl.constexpr |
| ): |
| i_v, i_nh = tl.program_id(0), tl.program_id(1) |
| i_n, i_h = i_nh // H, i_nh % H |
| if IS_VARLEN: |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
| T = eos - bos |
| NT = tl.cdiv(T, BT) |
| boh = tl.load(chunk_offsets + i_n).to(tl.int32) |
| else: |
| bos, eos = i_n * T, i_n * T + T |
| NT = tl.cdiv(T, BT) |
| boh = i_n * NT |
|
|
| |
| b_dh1 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 64: |
| b_dh2 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 128: |
| b_dh3 = tl.zeros([64, BV], dtype=tl.float32) |
| if K > 192: |
| b_dh4 = tl.zeros([64, BV], dtype=tl.float32) |
|
|
| |
| dh += (boh * H + i_h) * K*V |
| dv += (bos * H + i_h) * V |
| dv2 += (bos * H + i_h) * V |
| q += (bos * H + i_h) * K |
| k += (bos * H + i_h) * K |
| d += (bos * H + i_h) * K |
| do += (bos * H + i_h) * V |
| stride_v = H*V |
| stride_h = H*K*V |
| stride_k = H*K |
| if USE_INITIAL_STATE: |
| dh0 += i_nh * K*V |
| if USE_FINAL_STATE_GRADIENT: |
| dht += i_nh * K*V |
|
|
| if USE_FINAL_STATE_GRADIENT: |
| p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) |
| if K > 64: |
| p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) |
| if K > 128: |
| p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) |
| if K > 192: |
| p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) |
|
|
| for i_t in range(NT - 1, -1, -1): |
| p_dh1 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 64: |
| p_dh2 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 128: |
| p_dh3 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 192: |
| p_dh4 = tl.make_block_ptr(dh + i_t*stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| |
| if USE_G: |
| last_idx = min((i_t + 1) * BT, T) - 1 |
| bg_last = tl.load(g + (bos + last_idx) * H + i_h) |
| bg_last = exp(bg_last) |
| else: |
| bg_last = None |
| last_idx = None |
|
|
| p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
| b_do = tl.load(p_do, boundary_check=(0, 1)) |
| b_dv = tl.zeros([BT, BV], dtype=tl.float32) |
| b_dv += tl.load(p_dv, boundary_check=(0, 1)) |
|
|
| |
| p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype)) |
|
|
| if K > 64: |
| p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) |
|
|
| if K > 128: |
| p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) |
|
|
| if K > 192: |
| p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) |
|
|
| tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| |
| p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) |
| p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| b_q = (b_q * scale).to(b_q.dtype) |
| if USE_G: |
| b_dh1 *= bg_last |
| b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) |
| if K > 64: |
| p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) |
| p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| b_q = (b_q * scale).to(b_q.dtype) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| if USE_G: |
| b_dh2 *= bg_last |
| b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) |
| if K > 128: |
| p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) |
| p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| b_q = (b_q * scale).to(b_q.dtype) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| if USE_G: |
| b_dh3 *= bg_last |
| b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) |
| if K > 192: |
| p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) |
| p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| b_q = (b_q * scale).to(b_q.dtype) |
| b_d = tl.load(p_d, boundary_check=(0, 1)) |
| if USE_G: |
| b_dh4 *= bg_last |
| b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype))-tl.dot(b_d, b_dv.to(b_d.dtype)) |
|
|
| if USE_INITIAL_STATE: |
| p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 64: |
| p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 128: |
| p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) |
| if K > 192: |
| p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) |
| tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| @triton.heuristics({ |
| 'USE_Q': lambda args: args['q'] is not None, |
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, |
| }) |
| @triton.autotune( |
| configs=[ |
| triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) |
| for num_warps in [2, 4, 8] |
| for num_stages in [2, 3, 4] |
| for BK in [16, 32, 64, 128] |
| ], |
| key=['H', 'K', 'BT', 'BK'], |
| use_cuda_graph=use_cuda_graph, |
| ) |
| @triton.jit(do_not_specialize=['T']) |
| def preprocess_qkw( |
| q, |
| k, |
| w, |
| g, |
| q_new, |
| k_new, |
| w_new, |
| cu_seqlens, |
| T, |
| H: tl.constexpr, |
| K: tl.constexpr, |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| USE_Q: tl.constexpr, |
| IS_VARLEN: tl.constexpr, |
| ): |
| i_k, i_nh, i_t = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| i_n, i_h = i_nh // H, i_nh % H |
|
|
| if IS_VARLEN: |
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) |
| T = eos - bos |
| else: |
| bos, eos = i_n * T, i_n * T + T |
|
|
| |
| k += (bos * H + i_h) * K |
| w += (bos * H + i_h) * K |
| k_new += (bos * H + i_h) * K |
| w_new += (bos * H + i_h) * K |
| if USE_Q: |
| q += (bos * H + i_h) * K |
| q_new += (bos * H + i_h) * K |
| g += bos * H + i_h |
| stride_k = H * K |
| stride_g = H |
|
|
| |
| last_idx = min((i_t + 1) * BT, T) - 1 |
| b_g_last = tl.load(g + last_idx * stride_g).to(tl.float32) |
|
|
| p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) |
| p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_k_new = tl.make_block_ptr(k_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_w_new = tl.make_block_ptr(w_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) |
| b_w = tl.load(p_w, boundary_check=(0, 1)).to(tl.float32) |
| b_g = tl.load(p_g, boundary_check=(0,)).to(tl.float32) |
| b_d_last = exp(b_g_last - b_g) |
| b_d_begin = exp(b_g) |
|
|
| b_k = b_k * b_d_last[:, None] |
| b_w = b_w * b_d_begin[:, None] |
| tl.store(p_k_new, b_k.to(p_k_new.dtype.element_ty), boundary_check=(0, 1)) |
| tl.store(p_w_new, b_w.to(p_w_new.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| if USE_Q: |
| p_q = tl.make_block_ptr(q, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_q_new = tl.make_block_ptr(q_new, (T, K), (stride_k, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) |
| b_q = b_q * b_d_begin[:, None] |
| tl.store(p_q_new, b_q.to(p_q_new.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| def chunk_gated_delta_rule_fwd_h( |
| k: torch.Tensor, |
| w: torch.Tensor, |
| u: torch.Tensor, |
| g: Optional[torch.Tensor] = None, |
| initial_state: Optional[torch.Tensor] = None, |
| output_final_state: bool = False, |
| chunk_size: int = 64, |
| save_new_value: bool = True, |
| cu_seqlens: Optional[torch.LongTensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, T, H, K, V = *k.shape, u.shape[-1] |
| BT = chunk_size |
|
|
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None |
| |
| if cu_seqlens is None: |
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None |
| else: |
| N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) |
| assert K <= 256, "current kernel does not support head dimension larger than 256." |
|
|
| h = k.new_empty(B, NT, H, K, V) |
| final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None |
|
|
| if g is not None: |
| k_new = torch.empty_like(k) |
| w_new = torch.empty_like(w) |
| def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) |
| preprocess_qkw[grid]( |
| q=None, |
| k=k, |
| w=w, |
| g=g, |
| q_new=None, |
| k_new=k_new, |
| w_new=w_new, |
| cu_seqlens=cu_seqlens, |
| T=T, |
| H=H, |
| K=K, |
| BT=BT, |
| ) |
|
|
| v_new = torch.empty_like(u) if save_new_value else None |
| def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) |
| chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( |
| k=k if g is None else k_new, |
| v=u, |
| d=w if g is None else w_new, |
| v_new=v_new, |
| g=g, |
| h=h, |
| h0=initial_state, |
| ht=final_state, |
| cu_seqlens=cu_seqlens, |
| chunk_offsets=chunk_offsets, |
| T=T, |
| H=H, |
| K=K, |
| V=V, |
| BT=BT |
| ) |
| return h, v_new, final_state |
|
|
|
|
| def chunk_gated_delta_rule_bwd_dhu( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| w: torch.Tensor, |
| g: torch.Tensor, |
| h0: torch.Tensor, |
| dht: Optional[torch.Tensor], |
| do: torch.Tensor, |
| dv: torch.Tensor, |
| scale: float, |
| cu_seqlens: Optional[torch.LongTensor] = None, |
| chunk_size: int = 64, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| B, T, H, K, V = *q.shape, do.shape[-1] |
| |
| BT = 64 |
| assert K <= 256, "current kernel does not support head dimension being larger than 256." |
|
|
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None |
| if cu_seqlens is None: |
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None |
| else: |
| N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) |
|
|
| dh = q.new_empty(B, NT, H, K, V) |
| dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None |
| dv2 = torch.empty_like(dv) |
|
|
| if g is not None: |
| q_new = torch.empty_like(q) |
| k_new = torch.empty_like(k) |
| w_new = torch.empty_like(w) |
| def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) |
| preprocess_qkw[grid]( |
| q=q, |
| k=k, |
| w=w, |
| g=g, |
| q_new=q_new, |
| k_new=k_new, |
| w_new=w_new, |
| cu_seqlens=cu_seqlens, |
| T=T, |
| H=H, |
| K=K, |
| BT=BT, |
| ) |
|
|
| def grid(meta): return (triton.cdiv(V, meta['BV']), N*H) |
| chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( |
| q=q if g is None else q_new, |
| k=k if g is None else k_new, |
| d=w if g is None else w_new, |
| g=g, |
| dht=dht, |
| dh0=dh0, |
| do=do, |
| dh=dh, |
| dv=dv, |
| dv2=dv2, |
| cu_seqlens=cu_seqlens, |
| chunk_offsets=chunk_offsets, |
| scale=scale, |
| T=T, |
| H=H, |
| K=K, |
| V=V, |
| BT=BT, |
| ) |
| return dh, dh0, dv2 |
|
|