File size: 8,252 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | from typing import Any, Optional
import pytest
import torch
from sgl_kernel import (
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
def _ref_torch_impl(
score: torch.Tensor,
seq_len: int,
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert score.dim() == 2
if row_starts is None:
return torch.topk(score[:, :seq_len], topk, dim=-1, sorted=False).indices
else:
ks = row_starts.cpu().tolist()
ke = (row_starts + seq_len).tolist()
scores = []
for i, (start, end) in enumerate(zip(ks, ke)):
scores.append(score[i, start:end].unsqueeze(0))
score = torch.cat(scores, dim=0)
return torch.topk(score, topk, dim=-1, sorted=False).indices
def _ref_torch_transform_decode_impl(
score: torch.Tensor,
seq_len: int,
src_page_table: torch.Tensor,
topk: int,
row_starts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, _ = score.shape
assert score.shape[0] == src_page_table.shape[0]
assert seq_len >= topk
indices = _ref_torch_impl(score, seq_len, topk, row_starts=row_starts)
topk_indices = torch.empty(
(batch_size, topk), dtype=torch.int32, device=score.device
)
for i in range(batch_size):
topk_indices[i] = src_page_table[i, indices[i]]
return topk_indices
def _ref_torch_transform_ragged_impl(
score: torch.Tensor,
seq_len: int,
topk_indices_offset: torch.Tensor,
topk: int,
row_starts: torch.Tensor,
) -> torch.Tensor:
assert score.shape[0] == topk_indices_offset.shape[0]
assert seq_len >= topk
indices = _ref_torch_impl(score, seq_len, topk, row_starts=row_starts)
mask = indices != -1
topk_indices_offset = topk_indices_offset.unsqueeze(1)
return torch.where(mask, indices + topk_indices_offset, indices)
MAX_SEQ_LEN = 131072
def assert_equal(
score: torch.Tensor,
indices_ref: torch.Tensor,
indices_our: torch.Tensor,
bs: int,
k: int,
seq_len: int,
topk_indices_offset: Optional[torch.Tensor] = None,
max_permit_error: int = 0,
):
indices_our_cpu = indices_our.cpu().tolist()
indices_ref_cpu = indices_ref.cpu().tolist()
wrong_values = 0
for i in range(bs):
indices_ref_set_i = set(indices_ref_cpu[i])
indices_our_set_i = set(indices_our_cpu[i])
more = indices_our_set_i - indices_ref_set_i
less = indices_ref_set_i - indices_our_set_i
offset = topk_indices_offset[i].item() if topk_indices_offset is not None else 0
if len(more) > 0 or len(less) > 0:
# check whether more values are the same with less values
# if so, either one is acceptable, since their values are the same
more_values = sorted(score[i, idx - offset].item() for idx in more)
less_values = sorted(score[i, idx - offset].item() for idx in less)
if more_values != less_values:
wrong_values += len(more)
print(
f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
)
assert wrong_values <= max_permit_error, f"{wrong_values=}, {max_permit_error=}"
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@pytest.mark.parametrize("has_row_starts", [True, False])
@torch.inference_mode()
def test_topk_kernel(bs: int, k: int, seq_len: int, has_row_starts: bool) -> None:
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
if has_row_starts:
row_starts = torch.randint(0, 2048, (bs,), dtype=torch.int32, device="cuda")
else:
row_starts = None
indices_ref = _ref_torch_impl(score, seq_len, k, row_starts=row_starts)
indices_our = fast_topk_v2(score, lengths, k, row_starts=row_starts)
# sort and compare
indices_ref = torch.sort(indices_ref, dim=-1).values
indices_our = torch.sort(indices_our, dim=-1).values
# Tests can pass with max_permit_error=3, set to 5 for safety
assert_equal(score, indices_ref, indices_our, bs, k, seq_len, max_permit_error=5)
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@pytest.mark.parametrize("mode", ["extend", "decode", "target_verify"])
@torch.inference_mode()
def test_topk_transform_kernel(bs: int, k: int, seq_len: int, mode: str) -> None:
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
# NOTE: for decode, cumulative seqlens_q is just 0..=bs
# NOTE: since page table is arange, they equal topk indices
if mode == "decode":
step = 1
else:
step = 4 if bs % 4 == 0 else 1
num_tokens = bs
bs = bs // step
if mode == "extend":
row_starts = torch.randint(0, 2048, (bs,), dtype=torch.int32, device="cuda")
else:
row_starts = None
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
cu_seqlens_q = torch.arange(
0, num_tokens + 1, step=step, dtype=torch.int32, device="cuda"
)
src_page_table = torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
src_page_table = src_page_table.unsqueeze(0).expand(bs, -1)
dst_page_table_ref = _ref_torch_transform_decode_impl(
score=score,
seq_len=seq_len,
src_page_table=src_page_table,
topk=k,
row_starts=row_starts,
)
dst_page_table_our = fast_topk_transform_fused(
score=score,
lengths=lengths,
page_table_size_1=src_page_table,
cu_seqlens_q=cu_seqlens_q,
topk=k,
row_starts=row_starts,
)
# sort and compare
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
assert_equal(
score,
dst_page_table_ref,
dst_page_table_our,
bs,
k,
seq_len,
max_permit_error=5,
)
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
@pytest.mark.parametrize("has_row_starts", [True, False])
@torch.inference_mode()
def test_topk_transform_ragged_kernel(
bs: int, k: int, seq_len: int, has_row_starts: bool
) -> None:
# Used in prefill only
torch.manual_seed(42)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
# bs: # of q tokens
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
# kv_len
if has_row_starts:
row_starts = torch.randint(0, 2048, (bs,), dtype=torch.int32, device="cuda")
else:
row_starts = None
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
topk_indices_offset = torch.randint(
0, 1024, (bs,), dtype=torch.int32, device="cuda"
)
dst_page_table_ref = _ref_torch_transform_ragged_impl(
score=score,
seq_len=seq_len,
topk_indices_offset=topk_indices_offset,
topk=k,
row_starts=row_starts,
)
dst_page_table_our = fast_topk_transform_ragged_fused(
score=score,
lengths=lengths,
topk_indices_offset=topk_indices_offset,
topk=k,
row_starts=row_starts,
)
# sort and compare
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
assert_equal(
score,
dst_page_table_ref,
dst_page_table_our,
bs,
k,
seq_len,
topk_indices_offset,
max_permit_error=5,
)
if __name__ == "__main__":
pytest.main([__file__])
|