| from typing import Optional, Tuple |
|
|
| import torch |
|
|
|
|
| def merge_state( |
| v_a: torch.Tensor, |
| s_a: torch.Tensor, |
| v_b: torch.Tensor, |
| s_b: torch.Tensor, |
| v_merged: Optional[torch.Tensor] = None, |
| s_merged: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| s_a = s_a.to(torch.float32) |
| s_b = s_b.to(torch.float32) |
| |
| if v_merged is None: |
| v_merged = torch.empty_like(v_a) |
| if s_merged is None: |
| s_merged = torch.empty_like(s_a) |
| torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) |
| return v_merged, s_merged |
|
|
|
|
| def merge_state_v2( |
| v_a: torch.Tensor, |
| s_a: torch.Tensor, |
| v_b: torch.Tensor, |
| s_b: torch.Tensor, |
| v_merged: Optional[torch.Tensor] = None, |
| s_merged: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| s_a = s_a.to(torch.float32) |
| s_b = s_b.to(torch.float32) |
| |
| |
| |
|
|
| |
| if v_merged is None: |
| v_merged = torch.empty_like(v_a) |
| if s_merged is None: |
| s_merged = torch.empty_like(s_a) |
| torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged) |
| return v_merged, s_merged |
|
|
|
|
| def cutlass_mla_decode( |
| q_nope: torch.Tensor, |
| q_pe: torch.Tensor, |
| kv_c_and_k_pe_cache: torch.Tensor, |
| seq_lens: torch.Tensor, |
| page_table: torch.Tensor, |
| workspace: torch.Tensor, |
| sm_scale: float, |
| num_kv_splits: int = 1, |
| ) -> torch.Tensor: |
| assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" |
| assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" |
| assert ( |
| kv_c_and_k_pe_cache.ndim == 3 |
| ), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}" |
|
|
| B_q, H, D_q_nope = q_nope.shape |
| B_q_2, H_2, D_q_pe = q_pe.shape |
| assert (B_q == B_q_2) and (H == H_2) |
|
|
| _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape |
|
|
| D_latent = 512 |
| D_rope = 64 |
| assert D_q_nope == D_latent |
| assert D_q_pe == D_rope |
| assert D_ckv == D_latent + D_rope |
|
|
| MAX_HEADS = 128 |
| assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" |
| if H < MAX_HEADS: |
| q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) |
| q_nope_padded[:, :H] = q_nope |
| q_nope = q_nope_padded |
|
|
| q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) |
| q_pe_padded[:, :H] = q_pe |
| q_pe = q_pe_padded |
|
|
| assert len(page_table.shape) == 2 |
| B_block_table, block_num = page_table.shape |
| assert B_block_table == B_q |
| assert block_num > 0, f"block num must be greater than 0, got {block_num}" |
| assert block_num % (128 / PAGE_SIZE) == 0 |
|
|
| |
| assert q_nope.dtype in ( |
| torch.float16, |
| torch.bfloat16, |
| ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." |
| assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype |
| assert ( |
| seq_lens.dtype == torch.int32 |
| ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." |
| assert ( |
| page_table.dtype == torch.int32 |
| ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." |
|
|
| out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) |
|
|
| torch.ops.sgl_kernel.cutlass_mla_decode.default( |
| out, |
| q_nope, |
| q_pe, |
| kv_c_and_k_pe_cache, |
| seq_lens, |
| page_table, |
| workspace, |
| sm_scale, |
| num_kv_splits, |
| ) |
| return out[:, :H].contiguous() |
|
|
|
|
| def cutlass_mla_get_workspace_size( |
| max_seq_len: int, |
| num_batches: int, |
| sm_count: int = 0, |
| num_kv_splits: int = 1, |
| ) -> int: |
| assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}" |
| assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}" |
| return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( |
| max_seq_len, num_batches, sm_count, num_kv_splits |
| ) |
|
|