Hanrui / sglang /sgl-kernel /tests /test_rotary_embedding.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
SglKernelRotaryEmbedding,
create_inputs,
)
@pytest.mark.parametrize(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
[
# GPT-OSS cases
*[
(
64,
64,
4096,
8000,
True,
torch.bfloat16,
"cuda",
batch_size,
seq_len,
64,
8,
save_kv_cache,
)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
],
# Other cases
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
(64, 64, 32, 8000, True, torch.float32, "cuda", 32, 32, 1, 1, False),
(256, 128, 4096, 10000, True, torch.float32, "cuda", 2, 512, 4, 2, False),
(512, 128, 311, 10000, True, torch.float32, "cuda", 3, 39, 4, 2, False),
(128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 32, 8, False),
(128, 128, 2048, 10000, False, torch.float32, "cuda", 2, 512, 16, 4, False),
(512, 128, 311, 10000, False, torch.float32, "cuda", 3, 39, 4, 2, False),
],
)
def test_correctness(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
batch_size: int,
seq_len: int,
num_q_heads: int,
num_kv_heads: int,
save_kv_cache: bool,
):
config = dict(
head_size=head_size,
rotary_dim=rotary_dim,
max_position_embeddings=max_position_embeddings,
base=base,
is_neox_style=is_neox_style,
dtype=dtype,
)
rope_ref = RotaryEmbedding(**config).to(device)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
rope_sglkernel = SglKernelRotaryEmbedding(**config).to(device)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
if save_kv_cache:
pool_ref_for_flashinfer = MHATokenToKVPool(
head_num=num_kv_heads, head_dim=head_size
)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
query_sglkernel, key_sglkernel = inputs["query"].clone(), inputs["key"].clone()
# This is to align with the flashinfer implementation, flashinfer uses float32 cos/sin cache
query_ref_for_flashinfer_out, key_ref_for_flashinfer_out = rope_ref.forward_native(
inputs["pos_ids"], query_ref.to(torch.float32), key_ref.to(torch.float32)
)
query_ref_for_sglkernel_out, key_ref_for_sglkernel_out = rope_ref.forward_native(
inputs["pos_ids"], query_ref, key_ref
)
if save_kv_cache:
pool_ref_for_flashinfer.set_kv_buffer(
loc=inputs["out_cache_loc"],
cache_k=key_ref_for_flashinfer_out.view(-1, num_kv_heads, head_size),
cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
)
query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
query_sglkernel_out, key_sglkernel_out = rope_sglkernel.forward_cuda(
inputs["pos_ids"],
query_sglkernel,
key_sglkernel,
)
torch.testing.assert_close(
query_ref_for_flashinfer_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(
key_ref_for_flashinfer_out, key_flashinfer_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(
query_ref_for_sglkernel_out, query_sglkernel_out, atol=1e-2, rtol=1e-2
)
torch.testing.assert_close(
key_ref_for_sglkernel_out, key_sglkernel_out, atol=1e-2, rtol=1e-2
)
if save_kv_cache:
for field in ["k_buffer", "v_buffer"]:
x_ref = getattr(pool_ref_for_flashinfer, field)[0]
x_flashinfer = getattr(pool_flashinfer, field)[0]
torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
nonzero_ref = x_ref != 0
nonzero_flashinfer = x_ref != 0
assert torch.all(nonzero_ref == nonzero_flashinfer)
if __name__ == "__main__":
pytest.main([__file__])