| |
|
|
| |
| |
| |
|
|
| from typing import Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| from einops import rearrange |
| from packaging import version |
| from torch.cuda.amp import custom_bwd, custom_fwd |
|
|
| from fla.ops.gla.chunk_util import (bwd_decay_global_cumsum, fwd_decay_cumsum, |
| prepare_qg_kg) |
| from fla.utils import contiguous |
|
|
| inv_ln2 = 1.44269504 |
|
|
| @triton.jit |
| def fused_chunk_gla_fwd_kernel( |
| |
| q, |
| k, |
| v, |
| g, |
| o, |
| |
| initial_state, |
| final_state, |
| |
| s_qk_h, |
| s_qk_t, |
| s_qk_d, |
| |
| s_vo_h, |
| s_vo_t, |
| s_vo_d, |
| |
| B, |
| H, |
| T, |
| scale, |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| BV: tl.constexpr, |
| DK: tl.constexpr, |
| DV: tl.constexpr, |
| USE_INITIAL_STATE: tl.constexpr, |
| STORE_FINAL_STATE: tl.constexpr, |
| CHECK: tl.constexpr |
| ): |
| |
| i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
| b_h = tl.zeros([BK, BV], dtype=tl.float32) |
|
|
| |
| p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (0, i_k * BK), (BT, BK), (1, 0)) |
| p_db = g + i_bh * s_qk_h + (BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) |
| p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BT), (0, 1)) |
| p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) |
| p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (0, i_v * BV), (BT, BV), (1, 0)) |
|
|
| if USE_INITIAL_STATE: |
| p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
| b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) |
| |
| mask = (i_k * BK + tl.arange(0, BK)) < DK |
|
|
| for i in range(0, tl.cdiv(T, BT)): |
| |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| |
| b_o = tl.zeros([BT, BV], dtype=tl.float32) |
| b_v = tl.load(p_v, boundary_check=(0, 1)) |
| |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) |
| if CHECK and i == 0: |
| b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) |
| b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) |
| else: |
| b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) |
| b_h = b_h * tl.math.exp2(d_b)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) |
|
|
| tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
| p_q = tl.advance(p_q, (BT, 0)) |
| p_k = tl.advance(p_k, (0, BT)) |
| p_v = tl.advance(p_v, (BT, 0)) |
| p_o = tl.advance(p_o, (BT, 0)) |
| p_db += BT * DK |
|
|
| if STORE_FINAL_STATE: |
| p_final = tl.make_block_ptr(final_state + i_bh * DK * DV, (DK, DV), (DV, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
| tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| |
| @triton.jit |
| def fused_chunk_gla_bwd_kernel( |
| q, k, v, g, |
| do, |
| dq, |
| dk, |
| dv, |
| |
| initial_state, |
| |
| s_qk_h, |
| s_qk_t, |
| s_qk_d, |
| |
| s_vo_h, |
| s_vo_t, |
| s_vo_d, |
| |
| B, |
| H, |
| T, |
| scale, |
| |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| BV: tl.constexpr, |
| DK: tl.constexpr, |
| DV: tl.constexpr, |
| USE_INITIAL_STATE: tl.constexpr, |
| CHECK: tl.constexpr |
| ): |
| i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| |
| b_h = tl.zeros([BV, BK], dtype=tl.float32) |
|
|
| if USE_INITIAL_STATE: |
| p_h = tl.make_block_ptr(initial_state + i_bh * DK * DV, (DV, DK), (1, DV), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) |
| b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) |
| |
| mask = (i_k * BK + tl.arange(0, BK)) < DK |
| for i in range(0, tl.cdiv(T, BT)): |
| p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_db = g + i_bh * s_qk_h + ((i+1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) |
| p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i * BT), (BV, BT), (0, 1)) |
| p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (i * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) |
| |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| |
| d_b = tl.load(p_db, mask=mask, other=0).to(tl.float32) |
|
|
| |
| b_v = tl.load(p_v, boundary_check=(0, 1)) |
| |
| b_do = tl.load(p_do, boundary_check=(0, 1)) |
| |
| if CHECK and i == 0: |
| b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) |
| b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) |
| else: |
| b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) |
| b_h = b_h * tl.math.exp2(d_b)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) |
| b_dq *= scale |
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) |
|
|
| |
| b_h = None |
| tl.debug_barrier() |
| |
| b_dh = tl.zeros([BK, BV], dtype=tl.float32) |
|
|
| |
| for i in range(1, tl.cdiv(T, BT) + 1): |
| p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) |
| p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_db = g + i_bh * s_qk_h + (T - (i-1) * BT - 1) * s_qk_t + i_k * BK + tl.arange(0, BK) |
| p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) |
| p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * s_qk_h, (T, DK), |
| (s_qk_t, s_qk_d), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) |
| p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * s_vo_h, (T, DV), |
| (s_vo_t, s_vo_d), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) |
| |
| b_q = tl.load(p_q, boundary_check=(0, 1)) |
| |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| |
| b_v = tl.load(p_v, boundary_check=(0, 1)) |
| b_do = tl.load(p_do, boundary_check=(0, 1)) |
| b_db = tl.load(p_db, mask=mask, other=0).to(tl.float32) |
|
|
| |
| |
| if CHECK and i == 1: |
| b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) |
| b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) |
| b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) |
| else: |
| b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) |
| b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) |
| b_dh = b_dh * tl.math.exp2(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) |
|
|
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) |
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| @triton.jit |
| def fwd_inner_chunk( |
| q, k, g, A, |
| s_qk_h, |
| s_qk_t, |
| s_qk_d, |
| B, |
| H, |
| T, |
| scale, |
| |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| DK: tl.constexpr, |
| ): |
|
|
| i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
|
| p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
|
| p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
|
| b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) |
|
|
| mask = (i_k * BK + tl.arange(0, BK)) < DK |
| o_i = tl.arange(0, BT) |
|
|
| p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) |
| p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) |
| p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) |
|
|
| for i in range(BT): |
| _q = tl.load(p_q, mask=mask, other=0) * scale |
| gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) |
| s = _q[None, :] * b_k * tl.math.exp2(gq[None, :] - b_g) |
| score = tl.sum(s, axis=1) |
| score = tl.where(o_i <= i, score, 0) |
| tl.store(p_A, score.to(p_A.dtype.element_ty)) |
| p_q += DK |
| p_gq += DK |
| p_A += BT |
|
|
|
|
| @triton.jit |
| def bwd_inner_chunk( |
| q, |
| k, |
| g, |
| dA, |
| dq, |
| dk, |
| s_qk_h, |
| s_qk_t, |
| s_qk_d, |
| B, |
| H, |
| T, |
| scale, |
| |
| BT: tl.constexpr, |
| BK: tl.constexpr, |
| DK: tl.constexpr, |
| ): |
| i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_k = tl.load(p_k, boundary_check=(0, 1)) |
| p_g = tl.make_block_ptr(g + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) |
|
|
| mask = (i_k * BK + tl.arange(0, BK)) < DK |
| o_i = tl.arange(0, BT) |
|
|
| p_q = q + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) |
| p_dq = dq + (i_bh) * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) |
| p_gq = g + i_bh * s_qk_h + i_k * BK + i_t * BT * DK + tl.arange(0, BK) |
| p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) |
|
|
| b_dk = tl.zeros([BT, BK], dtype=tl.float32) |
|
|
| for i in range(BT): |
| _q = tl.load(p_q, mask=mask, other=0) |
| gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) |
| score = tl.math.exp2(gq[None, :] - b_g) |
| score = tl.where(o_i[:, None] <= i, score, 0) |
| _dA = tl.load(p_dA) |
| _dA = tl.where(o_i <= i, _dA, 0) |
| b_dk += (_dA[:, None] * score * _q[None, :]) |
| b_dq = tl.sum(_dA[:, None] * score * b_k, axis=0) |
| tl.store(p_dq, b_dq, mask=mask) |
| p_q += DK |
| p_dq += DK |
| p_gq += DK |
| p_dA += BT |
|
|
| p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
| tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
| class FusedChunkGLAFunction(torch.autograd.Function): |
|
|
| @staticmethod |
| @contiguous |
| @custom_fwd |
| def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): |
| ctx.g_dtype = g.dtype |
| g_original = g |
| |
| g = torch.empty_like(g, dtype=torch.float32) |
| batch_size, n_heads, seq_len, d_head_qk = q.shape |
| d_head_v = v.shape[-1] |
| ctx.scale = scale |
|
|
| |
| BT = 16 |
| BK, BV = min(d_head_qk, 64), min(d_head_v, 64) |
| num_stages = 1 |
| num_warps = 2 |
|
|
| NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) |
| o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) |
| q_g = torch.empty_like(q) |
| k_g = torch.empty_like(k) |
| grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) |
| fwd_decay_cumsum[grid]( |
| g_original, |
| g, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, BK=BK, DK=d_head_qk, num_warps=1 |
| ) |
| prepare_qg_kg[grid]( |
| q, k, g, q_g, k_g, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, BK=BK, DK=d_head_qk, num_warps=1 |
| ) |
|
|
| if output_final_state: |
| final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v, dtype=torch.float, requires_grad=False) |
| else: |
| final_state = None |
| |
| |
| CHECK = True |
| if version.parse(triton.__version__) < version.parse('2.2.0'): |
| import warnings |
| warnings.warn( |
| "Triton<2.2.0 detected for running this kernel, " |
| "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " |
| "that lead to significant precision loss. " |
| "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " |
| "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." |
| ) |
| CHECK = True |
|
|
| grid = (NV, NK, batch_size * n_heads) |
| fused_chunk_gla_fwd_kernel[grid]( |
| q_g, k_g, v, g, o, initial_state, final_state, |
| q.stride(1), q.stride(2), q.stride(3), |
| v.stride(1), v.stride(2), v.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, |
| USE_INITIAL_STATE=initial_state is not None, |
| STORE_FINAL_STATE=output_final_state, |
| CHECK=CHECK, |
| num_warps=num_warps, |
| num_stages=num_stages |
| ) |
|
|
| o = o.sum(0) |
|
|
| |
| chunk_size = 16 |
| num_chunk = seq_len // chunk_size |
| v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) |
| BK = min(d_head_qk, 64) |
| NK = triton.cdiv(d_head_qk, BK) |
| A = q.new_empty(NK, batch_size, n_heads, triton.cdiv(seq_len, BT), BT, BT) |
| grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) |
| fwd_inner_chunk[grid]( |
| q, k, g, A, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, BT=BT, BK=BK, DK=d_head_qk, num_stages=3, |
| num_warps=4 |
| ) |
| A = A.sum(0) |
| o2 = A @ v2 |
| o2 = rearrange(o2, 'b h n c d -> b h (n c) d') |
| |
| o.add_(o2) |
| ctx.save_for_backward(q, k, v, g_original, A, initial_state) |
| ctx.CHECK = CHECK |
| return o.to(v), final_state |
|
|
| @staticmethod |
| @contiguous |
| @custom_bwd |
| def backward(ctx, do, d_final_state=None): |
| q, k, v, g_origin, A, initial_state = ctx.saved_tensors |
| batch_size, n_heads, seq_len, d_head_qk = q.shape |
| d_head_v = v.shape[-1] |
| scale = ctx.scale |
|
|
| |
| |
| BT = 16 |
| g = torch.empty_like(g_origin, dtype=torch.float32) |
| BK, BV = min(d_head_qk, 64), min(d_head_v, 64) |
| NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) |
| q_g = torch.empty_like(q) |
| k_g = torch.empty_like(k) |
| grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) |
| fwd_decay_cumsum[grid]( |
| g_origin, |
| g, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, BK=BK, DK=d_head_qk, num_warps=1 |
| ) |
| prepare_qg_kg[grid]( |
| q, k, g, q_g, k_g, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, BK=BK, DK=d_head_qk, num_warps=1 |
| ) |
|
|
| |
| BT = 16 |
| BK, BV = min(triton.next_power_of_2(d_head_qk), 64), min(triton.next_power_of_2(d_head_v), 64) |
| NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV) |
| num_stages = 1 |
| num_warps = 2 |
| dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) |
| dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk) |
| dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v) |
|
|
| grid = (NV, NK, batch_size * n_heads) |
|
|
| fused_chunk_gla_bwd_kernel[grid]( |
| q_g, k_g, v, g, do, dq, dk, dv, initial_state, |
| q.stride(1), q.stride(2), q.stride(3), |
| v.stride(1), v.stride(2), v.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| |
| BT=BT, DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV, |
| USE_INITIAL_STATE=initial_state is not None, |
| CHECK=ctx.CHECK, |
| num_warps=num_warps, |
| num_stages=num_stages, |
| ) |
| dq = dq.sum(0) |
| dk = dk.sum(0) |
| dv = dv.sum(0) |
|
|
| |
| num_chunk = seq_len // BT |
| v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) |
| do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=num_chunk) |
| dA2 = (do2 @ v2.transpose(-2, -1)) * scale |
| dv2 = A.transpose(-1, -2) @ do2 |
| dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=num_chunk) |
|
|
| BK = min(triton.next_power_of_2(d_head_qk), 16) |
| NK = triton.cdiv(d_head_qk, BK) |
| dk2 = torch.empty_like(k) |
| dq2 = torch.empty_like(q) |
|
|
| grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) |
| bwd_inner_chunk[grid]( |
| q, k, g, |
| dA2, dq2, dk2, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, DK=d_head_qk, BK=BK, |
| num_warps=1, |
| num_stages=3 |
| ) |
|
|
| BK = min(triton.next_power_of_2(d_head_qk), 32) |
| NK = triton.cdiv(d_head_qk, BK) |
| dg = torch.empty_like(g, dtype=torch.float32) |
| grid = (NK, triton.cdiv(seq_len, BT), batch_size * n_heads) |
| bwd_decay_global_cumsum[grid]( |
| dq2, dq, dk2, dk, q, k, g, dg, |
| q.stride(1), q.stride(2), q.stride(3), |
| batch_size, n_heads, seq_len, scale, |
| BT=BT, DK=d_head_qk, BK=BK, |
| num_warps=1, |
| num_stages=1 |
| ) |
| dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT) |
|
|
| def rev_cumsum_exclusive(x): |
| cumsum_x = x.cumsum(-2) |
| rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x |
| return rev_cumsum_x |
|
|
| rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :]) |
| dg.add_(rev_cumsum_dg.unsqueeze(-2)) |
| dv.add_(dv2) |
| dg = rearrange(dg, 'b h n c d -> b h (n c) d') |
|
|
| return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None |
|
|
|
|
| def pad(x, chunk_size=16): |
| seq_len = x.shape[-2] |
| padded_seq_len = ceildiv(seq_len, chunk_size) * chunk_size |
| if x.shape[-2] % chunk_size != 0: |
| x = F.pad(x, (0, 0, 0, padded_seq_len - seq_len)) |
| |
| return x |
|
|
|
|
| def ceildiv(a, b): |
| return -(a // -b) |
|
|
|
|
| def fused_chunk_gla( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| g: torch.Tensor, |
| scale: int = -1, |
| initial_state: torch.Tensor = None, |
| output_final_state: bool = False |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if scale == -1: |
| scale = q.shape[-1] ** -0.5 |
| if initial_state is not None: |
| initial_state = initial_state.detach() |
| seq_len = q.shape[-2] |
| q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) |
| o, final_state = FusedChunkGLAFunction.apply( |
| q, k, v, g, scale, initial_state, output_final_state) |
| o = o[..., :seq_len, :] |
| return o, final_state |
|
|