| import torch |
| import triton |
| import triton.language as tl |
| |
| from typing import Optional, Sequence, Tuple, Union |
| from HybridTensor.triton.triton_flashattn_decode import maybe_contiguous |
|
|
| def get_padded_headsize(size): |
| |
| padded_d_model = 1 << (size - 1).bit_length() |
| |
| |
| padded_d_model = max(padded_d_model, 16) |
| return padded_d_model |
|
|
|
|
| def _strides(x: torch.Tensor, *stride_names: str): |
| if x is None: |
| return {f"stride_{s}": 0 for i, s in enumerate(stride_names)} |
|
|
| assert x.ndim == len(stride_names) |
| return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} |
|
|
|
|
| @triton.jit |
| def _fwd_kernel_splitK( |
| Q, |
| K, |
| V, |
| sm_scale, |
| Out_splitK, |
| Metadata, |
| batch_group_index, |
| K_new, |
| V_new, |
| Cache_seqlens, |
| Cache_batch_idx, |
| Alibi_slopes, |
| stride_qz, stride_qm, stride_qg, stride_qh, stride_qd, |
| stride_kz, stride_kn, stride_kg, stride_kh, stride_kd, |
| stride_vz, stride_vn, stride_vg, stride_vh, stride_vd, |
| stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d, |
| stride_mzhg, stride_m2, stride_ms, stride_mm, |
| |
| stride_bgz, stride_bgk, |
| stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d, |
| stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d, |
| stride_az, stride_ah, |
| Z, |
| N_CTX_Q, |
| N_CTX_K, |
| N_CTX_NEW, |
| BLOCK_N_PER_SPLIT, |
| H_q: tl.constexpr, |
| H_kv: tl.constexpr, |
| G_q: tl.constexpr, |
| TOP_K_GROUPS: tl.constexpr, |
| BLOCK_M: tl.constexpr, |
| BLOCK_DMODEL: tl.constexpr, |
| ACTUAL_BLOCK_DMODEL: tl.constexpr, |
| BLOCK_N: tl.constexpr, |
| BOUNDS_CHECKS_N: tl.constexpr, |
| USE_CACHE_SEQLENs: tl.constexpr, |
| USE_CACHE_BATCH_IDX: tl.constexpr, |
| NEW_KV: tl.constexpr, |
| IS_GQA: tl.constexpr, |
| IS_CAUSAL: tl.constexpr, |
| USE_ALIBI: tl.constexpr, |
| ): |
| |
| PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) |
| if PADDED_HEAD: |
| d_mask = tl.arange(0, BLOCK_DMODEL) < ACTUAL_BLOCK_DMODEL |
|
|
| start_m = tl.program_id(0).to(tl.int32) |
| off_zhg = tl.program_id(1).to(tl.int64) |
| splitk_idx = tl.program_id(2).to(tl.int64) |
| HEAD_RATIO: tl.constexpr = H_q // H_kv |
| |
| |
| |
| |
| |
| |
| |
| |
| off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) |
| off_k = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS |
| off_g_q = off_zhg % G_q |
| |
| |
| batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_k * stride_bgk |
| group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64) |
| |
| |
| off_h_q = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO |
| |
| |
| |
| |
|
|
| |
| if USE_CACHE_BATCH_IDX: |
| cache_batch_idx = tl.load(Cache_batch_idx + off_z) |
| else: |
| cache_batch_idx = off_z |
|
|
| |
| if USE_ALIBI: |
| a_offset = off_z * stride_az + off_h_q * stride_ah |
| alibi_slope = tl.load(Alibi_slopes + a_offset) |
| else: |
| alibi_slope = None |
|
|
| lo = splitk_idx * BLOCK_N_PER_SPLIT |
| if USE_CACHE_SEQLENs: |
| cache_seqlen_last_idx = tl.load(Cache_seqlens + off_z) |
| if NEW_KV: |
| kv_len = cache_seqlen_last_idx + N_CTX_NEW |
| else: |
| kv_len = cache_seqlen_last_idx |
| else: |
| kv_len = N_CTX_K |
| hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) |
| |
| if IS_GQA: |
| k_head_idx = off_h_q // HEAD_RATIO |
| v_head_idx = k_head_idx |
| else: |
| k_head_idx = off_h_q |
| v_head_idx = off_h_q |
|
|
| |
| k_base = K + k_head_idx * stride_kh + cache_batch_idx * stride_kz + off_g_q * stride_kg |
| v_base = V + v_head_idx * stride_vh + cache_batch_idx * stride_vz + off_g_q * stride_vg |
|
|
| |
| if NEW_KV: |
| knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g |
| |
| |
| if USE_CACHE_SEQLENs: |
| start_idx = tl.load(Cache_seqlens + off_z) |
| else: |
| start_idx = N_CTX_K - N_CTX_NEW |
|
|
| |
| for i in range(0, N_CTX_NEW, BLOCK_N): |
| |
| k_new_block = tl.load( |
| knew_base + |
| tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + |
| (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, |
| mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & |
| (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), |
| other=0 |
| ) |
| |
| |
| tl.store( |
| k_base + |
| tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + |
| (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, |
| k_new_block, |
| mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & |
| (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), |
| ) |
|
|
| |
| vnew_base = V_new + v_head_idx * stride_vn_h + off_z * stride_vn_z + off_g_q * stride_vn_g |
| for i in range(0, N_CTX_NEW, BLOCK_N): |
| |
| v_new_block = tl.load( |
| vnew_base + |
| (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + |
| tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, |
| mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & |
| (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), |
| other=0 |
| ) |
| |
| |
| tl.store( |
| v_base + |
| (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + |
| tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, |
| v_new_block, |
| mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & |
| (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| Q_block_ptr = tl.make_block_ptr( |
| base=Q + off_h_q.to(tl.int64) * stride_qh + off_z.to(tl.int64) * stride_qz + off_g_q.to(tl.int64) * stride_qg, |
| shape=(N_CTX_Q, ACTUAL_BLOCK_DMODEL), |
| strides=(stride_qm, stride_qd), |
| offsets=((start_m * BLOCK_M).to(tl.int32), 0), |
| block_shape=(BLOCK_M, BLOCK_DMODEL), |
| order=(1, 0), |
| ) |
|
|
| |
| K_block_ptr = tl.make_block_ptr( |
| base=k_base, |
| shape=(ACTUAL_BLOCK_DMODEL, hi.to(tl.int32)), |
| strides=(stride_kd, stride_kn), |
| offsets=(0, lo.to(tl.int32)), |
| block_shape=(BLOCK_DMODEL, BLOCK_N), |
| order=(0, 1), |
| ) |
| V_block_ptr = tl.make_block_ptr( |
| base=v_base, |
| shape=(hi.to(tl.int32), ACTUAL_BLOCK_DMODEL), |
| strides=(stride_vn, stride_vd), |
| offsets=(lo.to(tl.int32), 0), |
| block_shape=(BLOCK_N, BLOCK_DMODEL), |
| order=(1, 0), |
| ) |
|
|
|
|
| K_scale_shift_block_ptr = None |
| V_scale_shift_block_ptr = None |
|
|
| |
| m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) |
| l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
|
| acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
| |
| |
| |
| qk_scale = sm_scale * 1.44269504 |
| |
| q = tl.load( |
| tl.advance(Q_block_ptr, (0, 0)), boundary_check=(0, )) |
| q = (q * qk_scale).to(q.dtype) |
| if PADDED_HEAD: |
| q = tl.where(d_mask[None, :], q, 0.0) |
|
|
| |
| for start_n in range(lo, hi, BLOCK_N): |
| k, v = load_k_v_group( |
| K_block_ptr, |
| V_block_ptr, |
| K_scale_shift_block_ptr, |
| V_scale_shift_block_ptr, |
| BOUNDS_CHECKS_N, |
| 1, |
| BLOCK_DMODEL, |
| ACTUAL_BLOCK_DMODEL, |
| Q.dtype.element_ty, |
| 0, |
| ) |
| if PADDED_HEAD: |
| k = tl.where(d_mask[:, None], k, 0.0) |
| v = tl.where(d_mask[None, :], v, 0.0) |
|
|
| |
| qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
| qk += tl.dot(q, k) |
|
|
| if USE_ALIBI: |
| row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| col_idx = start_n + tl.arange(0, BLOCK_N) |
| |
| |
| relative_pos = row_idx[:, None] + kv_len - (N_CTX_Q + col_idx[None, :]) |
| relative_pos = tl.abs(relative_pos) |
| |
| |
| alibi_bias = -1 * alibi_slope * relative_pos |
| qk += (alibi_bias * 1.44269504) |
|
|
| |
| if IS_CAUSAL: |
| row_idx = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| col_idx = start_n + tl.arange(0, BLOCK_N) |
| |
| |
| col_offset = N_CTX_Q - kv_len |
| causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) |
|
|
| |
| qk = tl.where(causal_mask, qk, float("-inf")) |
|
|
| |
| |
| if BOUNDS_CHECKS_N: |
| qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) |
|
|
| |
| m_i_new = tl.maximum(m_i, tl.max(qk, 1)) |
| if IS_CAUSAL: |
| alpha = tl.math.exp2(tl.where(m_i > float("-inf"), m_i - m_i_new, float("-inf"))) |
| else: |
| alpha = tl.math.exp2(m_i - m_i_new) |
| |
| if IS_CAUSAL: |
| qk = tl.where(qk > float("-inf"), qk - m_i_new[:, None], float("-inf")) |
| else: |
| qk = qk - m_i_new[:, None] |
| |
| p = tl.math.exp2(qk) |
|
|
| |
| l_i = l_i * alpha + tl.sum(p, 1) |
| m_i = m_i_new |
| p = p.to(Q.dtype.element_ty) |
|
|
| |
| acc *= alpha[:, None] |
| acc += tl.dot(p.to(v.dtype), v) |
| |
| |
| K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) |
| V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) |
|
|
| |
| O_block_ptr = tl.make_block_ptr( |
| |
| base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, |
| shape=(N_CTX_Q, BLOCK_DMODEL), |
| strides=(stride_osk_m, 1), |
| offsets=(start_m * BLOCK_M, 0), |
| block_shape=(BLOCK_M, BLOCK_DMODEL), |
| order=(1, 0), |
| ) |
| tl.store( |
| tl.advance(O_block_ptr, (0, 0)), |
| acc, |
| boundary_check=(0, ), |
| ) |
| |
| |
| Metadata_ptr = (Metadata + off_zhg * stride_mzhg + splitk_idx * stride_ms + start_m * BLOCK_M + tl.arange(0, BLOCK_M)) |
| |
| tl.store(Metadata_ptr, m_i) |
| tl.store(Metadata_ptr + stride_m2, l_i) |
|
|
|
|
| @triton.jit |
| def load_k_v_group( |
| K_block_ptr, |
| V_block_ptr, |
| K_scale_shift_block_ptr, |
| V_scale_shift_block_ptr, |
| BOUNDS_CHECKS_N: tl.constexpr, |
| PACKED_PER_VAL: tl.constexpr, |
| BLOCK_DMODEL: tl.constexpr, |
| ACTUAL_BLOCK_DMODEL: tl.constexpr, |
| dtype: tl.constexpr, |
| group_id: tl.constexpr, |
| ): |
| |
|
|
| |
| K_block_ptr = tl.advance(K_block_ptr, (ACTUAL_BLOCK_DMODEL * group_id, 0)) |
| V_block_ptr = tl.advance(V_block_ptr, (0, ACTUAL_BLOCK_DMODEL * group_id)) |
|
|
| |
| k = tl.load(K_block_ptr, boundary_check=(1, ) if BOUNDS_CHECKS_N else ()) |
| v = tl.load(V_block_ptr, boundary_check=(0, ) if BOUNDS_CHECKS_N else ()) |
|
|
| return k, v |
|
|
|
|
| @triton.jit |
| def cast_uint32_to_half2(scale_shift): |
| |
| scale = scale_shift & 0xFFFF |
| shift = scale_shift >> 16 |
| scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) |
| shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) |
| return scale, shift |
|
|
|
|
| @triton.jit |
| def dequantize( |
| x_, |
| scale, |
| shift, |
| PACKED_PER_VAL: tl.constexpr = 8, |
| ): |
| |
| |
| |
|
|
| BLOCK_N: tl.constexpr = x_.shape[0] |
| BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] |
| offsets = tl.arange(0, PACKED_PER_VAL) * 4 |
| quant_offset = (x_[:, None, :] >> offsets[None, :, None]) |
|
|
| quant_offset = tl.view(quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)) |
| |
| |
| quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) |
| quant_offset = (quant_offset * 32768.0).to(tl.float16) |
| scale_512 = scale * 512 |
|
|
| dequant = quant_offset * scale_512 + shift |
| return dequant |
|
|
|
|
| @triton.jit |
| def _splitK_reduce( |
| Out_splitK, |
| Metadata, |
| Out, |
| LSE, |
| batch_group_index, |
| stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_k, |
| stride_mzhg, stride_m2, stride_ms, stride_mm, |
| stride_bgz, stride_bgk, |
| stride_oz, stride_oh, stride_og, stride_om, stride_ok, |
| stride_lse_zhg, stride_lse_m, |
| M_ceil: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| H: tl.constexpr, |
| G: tl.constexpr, |
| H_kv: tl.constexpr, |
| TOP_K_GROUPS: tl.constexpr, |
| split_k: tl.constexpr, |
| splitK_pow2: tl.constexpr, |
| use_mask: tl.constexpr, |
| IS_CAUSAL: tl.constexpr, |
| ): |
| off_zhg = tl.program_id(0).to(tl.int64) |
| off_m = tl.program_id(1).to(tl.int64) |
| off_k = tl.program_id(2).to(tl.int64) |
| HEAD_RATIO: tl.constexpr = H // H_kv |
| |
| |
| |
| off_g = off_zhg % G |
| |
| off_z = off_zhg // (HEAD_RATIO * TOP_K_GROUPS) |
| off_bk = (off_zhg //HEAD_RATIO) % TOP_K_GROUPS |
| |
| |
| batch_group_index_ptr = batch_group_index + off_z * stride_bgz + off_bk * stride_bgk |
| group_idx = tl.load(batch_group_index_ptr, mask=True, other=0).to(tl.int64) |
| |
| |
| off_h = group_idx * HEAD_RATIO + off_zhg % HEAD_RATIO |
| |
| |
| |
| |
| spk_idx = tl.arange(0, splitK_pow2) |
| kidx = tl.arange(0, BLOCK_SIZE) |
|
|
| |
| Metadata_ptr = (Metadata + stride_mzhg * off_zhg + spk_idx * stride_ms + off_m * stride_mm) |
|
|
| |
| |
| |
| o_ptr = (Out_splitK + off_zhg * stride_osk_zhg + stride_osk_m * off_m + off_k * BLOCK_SIZE + |
| stride_osk_s * spk_idx[:, None] + kidx[None, :] * stride_osk_k) |
|
|
| |
| if use_mask: |
| spk_mask = spk_idx < split_k |
| l_m = tl.load(Metadata_ptr, mask=spk_mask, other=float("-inf")) |
| l_sum = tl.load(Metadata_ptr + stride_m2, mask=spk_mask, other=0.0) |
| acc = tl.load(o_ptr, mask=spk_mask[:, None], other=0.0) |
| else: |
| l_m = tl.load(Metadata_ptr) |
| l_sum = tl.load(Metadata_ptr + stride_m2) |
| acc = tl.load(o_ptr) |
|
|
| g_m = tl.max(l_m, axis=0) |
| |
| if IS_CAUSAL: |
| l_m_offset = l_m - g_m |
| alpha = tl.where(l_m_offset > float("-inf"), tl.math.exp2(l_m_offset), 0.0) |
| else: |
| alpha = tl.math.exp2(l_m - g_m) |
|
|
| |
| l_sum *= alpha |
| g_sum = tl.sum(l_sum, axis=0) |
| acc = acc * alpha[:, None] |
|
|
| if IS_CAUSAL: |
| |
| g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) |
| acc_out = tl.sum(acc, axis=0) / g_sum_safe |
| else: |
| acc_out = tl.sum(acc, axis=0) / g_sum |
|
|
| |
| Out_ptr = (Out + stride_oz * off_z + stride_oh * off_h + stride_og * off_g + stride_om * off_m + |
| off_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) |
| tl.store(Out_ptr, acc_out) |
|
|
| |
| |
| l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m |
| |
| if IS_CAUSAL: |
| lse = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) |
| tl.store(l_ptrs, lse) |
| else: |
| tl.store(l_ptrs, (g_m + tl.math.log2(g_sum)) / 1.44269504) |
|
|
|
|
| def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: |
| |
| |
| |
| k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) |
| max_vals = torch.max(k, dim=-1, keepdim=True).values |
| min_vals = torch.min(k, dim=-1, keepdim=True).values |
| scale_k: torch.Tensor = (max_vals - min_vals) / 15 |
|
|
| shift_k = torch.min(k, dim=-1, keepdim=True).values |
| scale_k = scale_k.to(torch.float16) |
| shift_k = shift_k.to(torch.float16) |
|
|
| in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 |
| in_bytes = in_bytes.to(torch.uint8) |
| in_int4 = in_bytes & 0xF |
| in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) |
| scale_shift = torch.concat([scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1) |
| k_quant = torch.concat( |
| [ |
| scale_shift.flatten(start_dim=-2), |
| in_int4_packed.flatten(start_dim=-2), |
| ], |
| dim=-1, |
| ).view(torch.int16) |
| return k_quant |
|
|
| def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: |
| k_i16 = quant_k.view(torch.int16) |
| k_ui8 = k_i16.view(torch.uint8) |
|
|
| ss_size = num_groups * 4 |
| scale_shift_ui8 = k_ui8[..., 0:ss_size] |
| scale_shift_ui8 = scale_shift_ui8.reshape(*scale_shift_ui8.shape[:-1], num_groups, 4) |
| scale = scale_shift_ui8[..., 0:2].view(torch.float16) |
| shift = scale_shift_ui8[..., 2:4].view(torch.float16) |
|
|
| kv_ui8 = k_ui8[..., ss_size:] |
| k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) |
| k1_i4 = k_ui8 & 0xF |
| k2_i4 = (k_ui8 & 0xF0) >> 4 |
| k_shape = k1_i4.shape |
| k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) |
| k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) |
|
|
| out = torch.empty((*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), dtype=torch.float16, device=quant_k.device) |
| out[..., ::2] = k1_f16 |
| out[..., 1::2] = k2_f16 |
| out = out.reshape(*k_shape[:-2], -1) |
|
|
| return out |
|
|
|
|
| def get_split_k(B: int, G: int, H: int, Mk: int) -> int: |
| """Heuristic for the number of splits""" |
| bh = max(B * H, 1) |
| split_k = max(Mk, 1024) // bh |
| max_chunk_size = 64 |
| while split_k > 0 and Mk / split_k < max_chunk_size: |
| split_k = split_k // 2 |
| while B * H * G * split_k >= 1024: |
| split_k = split_k // 2 |
| split_k = min(split_k, 512) |
| split_k = max(split_k, 1) |
| return split_k |
|
|
| def select_gqa(q, k, v, |
| sm_scale, causal, |
| alibi_slopes, layout, |
| cache_seqlens, cache_batch_idx, |
| new_kv, k_new, v_new, |
| batch_group_index): |
| |
| |
| |
| |
| BLOCK_M = 16 |
| BLOCK_N = 64 |
| SPLIT_K = None |
| NUM_QUANT_GROUPS = 1 |
|
|
| |
| original_layout = layout |
| if layout == "bshd": |
| q=q.unsqueeze(2) |
| k=k.unsqueeze(2) |
| v=v.unsqueeze(2) |
| if new_kv: |
| k_new = k_new.unsqueeze(2) |
| v_new = v_new.unsqueeze(2) |
| layout = "bsghd" |
| elif layout == "bhsd": |
| q=q.permute(0, 2, 1, 3).unsqueeze(2) |
| k=k.permute(0, 2, 1, 3).unsqueeze(2) |
| v=v.permute(0, 2, 1, 3).unsqueeze(2) |
| if new_kv: |
| k_new = k_new.permute(0, 2, 1, 3).unsqueeze(2) |
| v_new = v_new.permute(0, 2, 1, 3).unsqueeze(2) |
| layout = "bsghd" |
| elif layout == "bsghd": |
| pass |
| elif layout is None: |
| raise ValueError("Layout not given") |
| assert layout == "bsghd" |
|
|
| |
| batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape |
| _, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape |
| _, seqlen_v, n_group_v, heads_per_group_v, dim_v = v.shape |
| num_actual_groups = heads_per_group_k |
| heads_per_actual_group = heads_per_group_q // num_actual_groups |
| |
| top_k_groups = batch_group_index.shape[1] |
| |
| num_actual_groups = heads_per_group_k |
| heads_per_actual_group = heads_per_group_q // num_actual_groups |
|
|
| assert dim_q == dim_k == dim_v, f"Dimensions must match: {dim_q}, {dim_k}, {dim_v}" |
|
|
| |
| dim_padded = get_padded_headsize(dim_k) |
|
|
| |
| if heads_per_group_q > heads_per_group_k: |
| is_gqa = True |
| elif heads_per_group_q < heads_per_group_k: |
| raise ValueError("heads_per_group_q < heads_per_group_k") |
| else: |
| is_gqa = False |
|
|
| assert dim_k == dim_q, f"Keys have head dim {dim_k} but queries have head dim {dim_q}" |
|
|
| if SPLIT_K is not None: |
| split_k = SPLIT_K |
| else: |
| |
| split_k = get_split_k(batch_size, num_actual_groups, heads_per_actual_group, seqlen_k) |
|
|
| seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M |
| out_splitk = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, split_k, seqlen_q_ceil, dim_padded], dtype=torch.float32, device=q.device) |
| metadata = torch.empty([batch_size * num_actual_groups * heads_per_actual_group, 2, split_k, seqlen_q_ceil], dtype=torch.float32, device=q.device) |
|
|
| |
| lse = torch.empty((batch_size * num_actual_groups * heads_per_actual_group, seqlen_q), device=q.device, dtype=torch.float32) |
|
|
| |
| grid_fwd = (triton.cdiv(seqlen_q, BLOCK_M), batch_size * top_k_groups * heads_per_actual_group, split_k) |
| |
| |
| num_warps = 1 |
| split_size = (seqlen_k + split_k - 1) // split_k |
| use_cache_seqlens = cache_seqlens is not None |
| |
| assert batch_group_index.is_contiguous() |
| stride_bgz = batch_group_index.stride(0) |
| stride_bgk = 1 |
| |
| |
| _fwd_kernel_splitK[grid_fwd]( |
| Q=q, |
| K=k, |
| V=v, |
| sm_scale=sm_scale, |
| Out_splitK=out_splitk, |
| Metadata=metadata, |
| batch_group_index=batch_group_index, |
| K_new = k_new, |
| V_new = v_new, |
| Cache_seqlens=cache_seqlens, |
| Cache_batch_idx=cache_batch_idx, |
| Alibi_slopes=alibi_slopes, |
| **_strides(q, "qz", "qm", "qg", "qh", "qd"), |
| **_strides(k, "kz", "kn", "kg", "kh", "kd"), |
| **_strides(v, "vz", "vn", "vg", "vh", "vd"), |
| **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_d"), |
| **_strides(metadata, "mzhg", "m2", "ms", "mm"), |
| |
| |
| stride_bgz=stride_bgz, stride_bgk=stride_bgk, |
| **_strides(k_new, "kn_z", "kn_n", "kn_g", "kn_h", "kn_d"), |
| **_strides(v_new, "vn_z", "vn_n", "vn_g", "vn_h", "vn_d"), |
| **_strides(alibi_slopes, "az", "ah"), |
| Z=batch_size, |
| H_q=heads_per_group_q, |
| H_kv=heads_per_group_k, |
| G_q=n_group_q, |
| TOP_K_GROUPS = int(top_k_groups), |
| N_CTX_Q=seqlen_q, |
| N_CTX_K=seqlen_k, |
| N_CTX_NEW=k_new.shape[1] if new_kv else None, |
| BLOCK_N_PER_SPLIT=split_size, |
| BLOCK_M=BLOCK_M, |
| BLOCK_N=BLOCK_N, |
| BLOCK_DMODEL=dim_padded, |
| ACTUAL_BLOCK_DMODEL=dim_k, |
| BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, |
| USE_CACHE_SEQLENs=use_cache_seqlens, |
| USE_CACHE_BATCH_IDX=cache_batch_idx is not None, |
| NEW_KV=new_kv, |
| IS_GQA=is_gqa, |
| IS_CAUSAL=causal, |
| USE_ALIBI=False if alibi_slopes is None else True, |
| num_warps=num_warps, |
| num_stages=1, |
| ) |
|
|
| |
| out = torch.zeros((batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_padded), device=q.device, dtype=q.dtype) |
| |
| |
| splitK_pow2 = triton.next_power_of_2(split_k) |
| use_mask = splitK_pow2 > split_k |
| if batch_size * num_actual_groups * heads_per_actual_group * seqlen_q >= 512: |
| k_block_num = 1 |
| else: |
| k_block_num = 2 |
| assert dim_padded % k_block_num == 0 |
| k_block_size = dim_padded // k_block_num |
| grid_reduce = (batch_size * top_k_groups * heads_per_actual_group, seqlen_q, k_block_num) |
|
|
| _splitK_reduce[grid_reduce]( |
| out_splitk, |
| metadata, |
| out, |
| lse, |
| batch_group_index, |
| **_strides(out_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), |
| **_strides(metadata, "mzhg", "m2", "ms", "mm"), |
| |
| stride_bgz=stride_bgz, stride_bgk=stride_bgk, |
| **_strides(out, "oz", "om", "og", "oh", "ok"), |
| **_strides(lse, "lse_zhg", "lse_m"), |
| M_ceil=seqlen_q_ceil, |
| BLOCK_SIZE=k_block_size, |
| G= n_group_q, |
| TOP_K_GROUPS = int(top_k_groups), |
| H=heads_per_group_q, |
| H_kv=heads_per_group_k, |
| |
| split_k=split_k, |
| splitK_pow2=splitK_pow2, |
| use_mask=use_mask, |
| IS_CAUSAL=causal, |
| num_warps=4) |
|
|
| lse = lse.reshape([batch_size, n_group_q, heads_per_group_q, seqlen_q]) |
| if q.ndim == 4: |
| |
| assert n_group_q == 1 |
| out = out[:, :, 0] |
| lse = lse[:, 0] |
| if seqlen_k == 0: |
| out.zero_() |
| out = out.reshape(batch_size, heads_per_group_q * n_group_q, -1, dim_padded).contiguous() |
|
|
| |
| if original_layout == "bshd": |
| |
| |
| out = out.reshape(batch_size, seqlen_q, -1, dim_padded) |
|
|
| return out.narrow(-1, 0, dim_k), lse |