| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def maybe_contiguous(x): |
| return x.contiguous() if x is not None and x.stride(-1) != 1 else x |
|
|
|
|
| |
| def convert_vertical_slash_indexes( |
| q_seqlens: torch.Tensor, |
| kv_seqlens: torch.Tensor, |
| vertical_indexes: torch.Tensor, |
| slash_indexes: torch.Tensor, |
| context_size: int, |
| block_size_M: int, |
| block_size_N: int, |
| causal: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| batch_size = slash_indexes.size(0) |
| num_heads = slash_indexes.size(1) |
| nnz_slash = slash_indexes.size(2) |
| nnz_vertical = vertical_indexes.size(2) |
| num_rows = (context_size + block_size_M - 1) // block_size_M |
|
|
| block_count = torch.zeros( |
| batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device |
| ) |
| block_offset = torch.zeros( |
| batch_size, |
| num_heads, |
| num_rows, |
| nnz_slash, |
| dtype=q_seqlens.dtype, |
| device=q_seqlens.device, |
| ) |
| column_count = torch.zeros( |
| batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device |
| ) |
| column_index = torch.zeros( |
| batch_size, |
| num_heads, |
| num_rows, |
| nnz_vertical, |
| dtype=q_seqlens.dtype, |
| device=q_seqlens.device, |
| ) |
|
|
| torch.ops.sgl_kernel.convert_vertical_slash_indexes.default( |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| q_seqlens, |
| kv_seqlens, |
| vertical_indexes, |
| slash_indexes, |
| context_size, |
| block_size_M, |
| block_size_N, |
| causal, |
| ) |
| return block_count, block_offset, column_count, column_index |
|
|
|
|
| def convert_vertical_slash_indexes_mergehead( |
| q_seqlens: torch.Tensor, |
| kv_seqlens: torch.Tensor, |
| vertical_indexes: torch.Tensor, |
| slash_indexes: torch.Tensor, |
| |
| vertical_indices_count: torch.Tensor, |
| slash_indices_count: torch.Tensor, |
| context_size: int, |
| block_size_M: int, |
| block_size_N: int, |
| causal: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| batch_size = slash_indexes.size(0) |
| num_heads = slash_indexes.size(1) |
| nnz_slash = slash_indexes.size(2) |
| nnz_vertical = vertical_indexes.size(2) |
| num_rows = (context_size + block_size_M - 1) // block_size_M |
|
|
| block_count = torch.empty( |
| batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device |
| ) |
| block_offset = torch.empty( |
| batch_size, |
| num_heads, |
| num_rows, |
| nnz_slash, |
| dtype=q_seqlens.dtype, |
| device=q_seqlens.device, |
| ) |
| column_count = torch.empty( |
| batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device |
| ) |
| column_index = torch.empty( |
| batch_size, |
| num_heads, |
| num_rows, |
| nnz_vertical, |
| dtype=q_seqlens.dtype, |
| device=q_seqlens.device, |
| ) |
|
|
| torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default( |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| q_seqlens, |
| kv_seqlens, |
| vertical_indexes, |
| slash_indexes, |
| vertical_indices_count, |
| slash_indices_count, |
| context_size, |
| block_size_M, |
| block_size_N, |
| causal, |
| ) |
| return block_count, block_offset, column_count, column_index |
|
|
|
|
| def sparse_attn_func( |
| q, |
| k, |
| v, |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| return_attn_probs=False, |
| *, |
| return_softmax_lse=False, |
| out=None, |
| ): |
| """Compute attention with vertical and slash sparsity patterns. |
| Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: |
| block_count and block_offset for slash sparsity patterns, and |
| column_count and column_index for vertical sparsity patterns. |
| For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. |
| |
| Arguments: |
| q: (batch_size, seqlen, nheads, headdim) |
| k: (batch_size, seqlen, nheads_k, headdim) |
| v: (batch_size, seqlen, nheads_k, headdim) |
| block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) |
| column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) |
| dropout_p: float. Dropout probability. |
| softmax_scale: float. The scaling of QK^T before applying softmax. |
| Default to 1 / sqrt(headdim). |
| causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| is added to the attention score of query i and key j. |
| deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| which is slightly slower and uses more memory. The forward pass is always deterministic. |
| return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| testing only. The returned probabilities are not guaranteed to be correct |
| (they might not have the right scaling). |
| Return: |
| out: (batch_size, seqlen, nheads, headdim). |
| softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The |
| logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| normalization factor). |
| """ |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| q, k, v = [maybe_contiguous(x) for x in (q, k, v)] |
| out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default( |
| q, |
| k, |
| v, |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| out, |
| alibi_slopes, |
| dropout_p, |
| softmax_scale, |
| causal, |
| softcap, |
| return_attn_probs and dropout_p > 0, |
| None, |
| ) |
| return (out, softmax_lse) if return_softmax_lse else out |
|
|
|
|
| def sparse_attn_varlen_func( |
| q, |
| k, |
| v, |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| max_seqlen_q, |
| max_seqlen_k, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| return_attn_probs=False, |
| *, |
| return_softmax_lse=False, |
| out=None, |
| ): |
| """Compute attention with vertical and slash sparsity patterns. |
| Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: |
| block_count and block_offset for slash sparsity patterns, and |
| column_count and column_index for vertical sparsity patterns. |
| For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. |
| |
| Arguments: |
| q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. |
| k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) |
| column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) |
| cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| of the sequences in the batch, used to index into q. |
| cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| of the sequences in the batch, used to index into kv. |
| max_seqlen_q: int. Maximum query sequence length in the batch. |
| max_seqlen_k: int. Maximum key sequence length in the batch. |
| dropout_p: float. Dropout probability. |
| softmax_scale: float. The scaling of QK^T before applying softmax. |
| Default to 1 / sqrt(headdim). |
| causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| softcap: float. Anything > 0 activates softcapping attention. |
| alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| is added to the attention score of query i and key j. |
| deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| which is slightly slower and uses more memory. The forward pass is always deterministic. |
| return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| testing only. The returned probabilities are not guaranteed to be correct |
| (they might not have the right scaling). |
| Return: |
| out: (total, nheads, headdim). |
| softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The |
| logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| normalization factor). |
| """ |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| q, k, v = [maybe_contiguous(x) for x in (q, k, v)] |
| out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default( |
| q, |
| k, |
| v, |
| block_count, |
| block_offset, |
| column_count, |
| column_index, |
| out, |
| cu_seqlens_q, |
| cu_seqlens_k, |
| None, |
| alibi_slopes, |
| max_seqlen_q, |
| max_seqlen_k, |
| dropout_p, |
| softmax_scale, |
| False, |
| causal, |
| softcap, |
| return_attn_probs and dropout_p > 0, |
| None, |
| ) |
| return (out, softmax_lse) if return_softmax_lse else out |
|
|