| from typing import Optional, Tuple |
|
|
| import torch |
|
|
| try: |
| from sgl_kernel import flashmla_ops |
| except Exception as _e: |
| _flashmla_import_error = _e |
| else: |
| _flashmla_import_error = None |
|
|
| _IMPORT_ERROR = ImportError( |
| "Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4" |
| ) |
|
|
|
|
| def get_mla_metadata( |
| cache_seqlens: torch.Tensor, |
| num_q_tokens_per_head_k: int, |
| num_heads_k: int, |
| num_heads_q: Optional[int] = None, |
| is_fp8_kvcache: bool = False, |
| topk: Optional[int] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Arguments: |
| cache_seqlens: (batch_size), dtype torch.int32. |
| num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. |
| num_heads_k: The number of k heads. |
| num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled |
| is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. |
| topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. |
| |
| Returns: |
| tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. |
| num_splits: (batch_size + 1), dtype torch.int32. |
| """ |
| if is_fp8_kvcache and topk is None: |
| return torch.ops.sgl_kernel.get_mla_decoding_metadata_dense_fp8.default( |
| cache_seqlens, |
| num_q_tokens_per_head_k, |
| num_heads_k, |
| ) |
| return torch.ops.sgl_kernel.get_mla_decoding_metadata.default( |
| cache_seqlens, |
| num_q_tokens_per_head_k, |
| num_heads_k, |
| num_heads_q, |
| is_fp8_kvcache, |
| topk, |
| ) |
|
|
|
|
| def flash_mla_with_kvcache( |
| q: torch.Tensor, |
| k_cache: torch.Tensor, |
| block_table: torch.Tensor, |
| cache_seqlens: torch.Tensor, |
| head_dim_v: int, |
| tile_scheduler_metadata: torch.Tensor, |
| num_splits: torch.Tensor, |
| softmax_scale: Optional[float] = None, |
| causal: bool = False, |
| descale_q: torch.Tensor | None = None, |
| descale_k: torch.Tensor | None = None, |
| is_fp8_kvcache: bool = False, |
| indices: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Arguments: |
| q: (batch_size, seq_len_q, num_heads_q, head_dim). |
| k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). |
| block_table: (batch_size, max_num_blocks_per_seq), torch.int32. |
| cache_seqlens: (batch_size), torch.int32. |
| head_dim_v: Head dimension of v. |
| tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. |
| num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. |
| softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). |
| causal: bool. Whether to apply causal attention mask. |
| descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. |
| descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. |
| is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md |
| indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. |
| |
| Returns: |
| out: (batch_size, seq_len_q, num_heads_q, head_dim_v). |
| softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. |
| """ |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
| if indices is not None: |
| assert causal == False, "causal must be `false` if sparse attention is enabled." |
| assert (descale_q is None) == ( |
| descale_k is None |
| ), "descale_q and descale_k should be both None or both not None" |
|
|
| if indices is None and q.element_size() == 1: |
| out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla_fp8.default( |
| q, |
| k_cache, |
| head_dim_v, |
| cache_seqlens, |
| block_table, |
| softmax_scale, |
| causal, |
| tile_scheduler_metadata, |
| num_splits, |
| descale_q, |
| descale_k, |
| ) |
| else: |
| out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default( |
| q, |
| k_cache, |
| head_dim_v, |
| cache_seqlens, |
| block_table, |
| softmax_scale, |
| causal, |
| tile_scheduler_metadata, |
| num_splits, |
| is_fp8_kvcache, |
| indices, |
| ) |
| return out, softmax_lse |
|
|
|
|
| def flash_mla_sparse_fwd( |
| q: torch.Tensor, |
| kv: torch.Tensor, |
| indices: torch.Tensor, |
| sm_scale: float, |
| d_v: int = 512, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Sparse attention prefill kernel |
| |
| Args: |
| q: [s_q, h_q, d_qk], bfloat16 |
| kv: [s_kv, h_kv, d_qk], bfloat16 |
| indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv |
| sm_scale: float |
| d_v: The dimension of value vectors. Can only be 512 |
| |
| Returns: |
| (output, max_logits, lse) |
| About the definition of output, max_logits and lse, please refer to README.md |
| - output: [s_q, h_q, d_v], bfloat16 |
| - max_logits: [s_q, h_q], float |
| - lse: [s_q, h_q], float, 2-based log-sum-exp |
| """ |
| results = torch.ops.sgl_kernel.sparse_prefill_fwd.default( |
| q, kv, indices, sm_scale, d_v |
| ) |
| return results |
|
|