File size: 4,992 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | from typing import Optional
import torch
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
else:
# Use topk for efficiency with larger k values
# TODO: implement faster cuda kernels for large vocab sizes
return torch.topk(values, topk, dim=dim)
def fast_topk_v2(
score: torch.Tensor,
lengths: torch.Tensor,
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Get the topk indices of the score tensor.
Args:
score: The score tensor of shape (B, L). The score tensor is the logits
between the query and the key whose layout is either ragged or paged.
row_starts is only required when the key is ragged.
lengths: The lengths tensor of shape (B)
topk: The number of topk indices to get
row_starts: The start index of each row in the score tensor of shape (B).
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
of the score tensor.
Returns:
The topk indices tensor of shape (B, topk)
"""
assert (
topk == 2048
), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths, row_starts)
return topk_indices
def fast_topk_transform_fused(
score: torch.Tensor,
lengths: torch.Tensor,
page_table_size_1: torch.Tensor, # NOTE: page size should be 1
cu_seqlens_q: torch.Tensor,
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Get the topk indices of the score tensor and then transform the topk indices
to indices to the page table (page_size = 1)
Args:
score: The score tensor of shape (B, L). The score tensor is the logits
between the query and the key whose layout is either ragged or paged.
row_starts is only required when the key is ragged.
lengths: The lengths tensor of shape (B)
page_table_size_1: The page table tensor of shape (Batch, topk)
cu_seqlens_q: The cumulative sequence lengths tensor of shape (Batch + 1)
topk: The number of topk indices to get
row_starts: The start index of each row in the score tensor of shape (B).
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
of the score tensor. It's only used for cases where the key is
ragged, i.e. during extend and draft extend.
Returns:
The topk indices tensor of shape (B, topk)
"""
assert (
topk == 2048
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
src_page_table = page_table_size_1
dst_page_table = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_fused(
score, lengths, dst_page_table, src_page_table, cu_seqlens_q, row_starts
)
return dst_page_table
def fast_topk_transform_ragged_fused(
score: torch.Tensor,
lengths: torch.Tensor,
topk_indices_offset: torch.Tensor, # ragged kv
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Get the topk indices of the score tensor and then transform the topk indices to
indices to ragged kv (non-paged). This function is only used for extend,
not including draft extend.
Args:
score: The score tensor of shape (B, L). The score tensor is the logits
between the query and the key which can be ragged or paged.
row_starts is only required when the key is ragged.
lengths: The lengths tensor of shape (B)
topk_indices_offset: The offset of topk indices in ragged kv of shape (B)
topk: The number of topk indices to get
row_starts: The start index of each row in the score tensor of shape (B).
For each row i, topk only applies to section [row_starts[i], row_starts[i] + lengths[i]]
of the score tensor. It can be None if only the fast path is triggered,
in the case of all values in lengths <= topk (not checked in the kernel,
guaranteed by the caller).
Returns:
The topk indices tensor of shape (B, topk)
"""
assert (
topk == 2048
), "fast_topk_transform_ragged_fused is only optimized for deepseek v3.2 model, where topk=2048"
assert score.dim() == 2
topk_indices_ragged = score.new_empty((score.shape[0], topk), dtype=torch.int32)
torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
score, lengths, topk_indices_ragged, topk_indices_offset, row_starts
)
return topk_indices_ragged
|