Hanrui / sglang /sgl-kernel /tests /test_kvcacheio.py
Lekr0's picture
Add files using upload-large-folder tool
a402b9b verified
import pytest
import torch
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_lf_ph,
transfer_kv_all_layer_mla,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_direct_pf_lf,
transfer_kv_per_layer_mla,
)
from sglang.srt.utils import is_hip
def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
dst_pool[dst_indices] = src_pool[src_indices].to(dst_pool.device)
def ref_copy_with_indices_pf_direct(
src_pool, dst_pool, src_indices, dst_indices, page_size, layer_id, lf_to_pf=False
):
if lf_to_pf:
for i in range(0, len(src_indices), page_size):
dst_pool[dst_indices[i] // page_size][layer_id] = src_pool[layer_id][
src_indices[i : i + page_size]
].to(dst_pool.device)
else:
for i in range(0, len(src_indices), page_size):
dst_pool[layer_id][dst_indices[i : i + page_size]] = src_pool[
src_indices[i] // page_size
][layer_id].to(dst_pool.device)
def ref_copy_with_indices_page_head(
src_pool,
dst_pool,
src_indices,
dst_indices,
page_size,
layer_id,
head_num,
lf_to_ph=False,
):
if lf_to_ph:
for head_id in range(head_num):
for i in range(0, len(src_indices)):
dst_pool[dst_indices[i] // page_size][head_id][
dst_indices[i] % page_size
][layer_id] = src_pool[layer_id][src_indices[i]][head_id].to(
dst_pool.device
)
else:
for head_id in range(head_num):
for i in range(0, len(src_indices)):
dst_pool[layer_id][dst_indices[i]][head_id] = src_pool[
src_indices[i] // page_size
][head_id][src_indices[i] % page_size][layer_id].to(dst_pool.device)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [1, 128, 1024])
@pytest.mark.parametrize("page_size", [1, 16, 64])
@pytest.mark.parametrize("item_size", [256])
@pytest.mark.parametrize("total_items_in_pool", [10240])
@pytest.mark.parametrize("is_mla", [False, True])
@pytest.mark.parametrize("all_layers", [False, True])
def test_transfer_kv(
dtype: torch.dtype,
num_items_to_transfer: int,
item_size: int,
page_size: int,
total_items_in_pool: int,
is_mla: bool,
all_layers: bool,
):
"""
Tests the per-layer transfer functions, treating tensors as memory pools.
"""
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4 # A small number of layers for pool creation
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# Prepare memory pools based on whether it's an MLA case.
if is_mla:
src_pool_host = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_pool_ref = torch.zeros_like(src_pool_host).to(device)
dst_pool_kernel = torch.zeros_like(dst_pool_ref)
dst_pool_direct = torch.zeros_like(dst_pool_ref)
else:
src_k_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
src_v_pool = torch.randn(
num_layers, total_items_in_pool, item_size
).pin_memory()
dst_k_pool_ref = torch.zeros_like(src_k_pool).to(device)
dst_v_pool_ref = torch.zeros_like(src_v_pool).to(device)
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
torch.cuda.synchronize()
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if is_mla:
if not all_layers:
ref_copy_with_indices(
src_pool_host[layer_idx_to_test],
dst_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer_mla(
src_pool_host[layer_idx_to_test],
dst_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
)
transfer_kv_direct(
[src_pool_host[layer_idx_to_test]],
[dst_pool_direct[layer_idx_to_test]],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_pool_host[layer_id],
dst_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
src_layers_device = torch.tensor(
[src_pool_host[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_layers_device = torch.tensor(
[
dst_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer_mla(
src_layers_device,
dst_layers_device,
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
)
transfer_kv_direct(
[src_pool_host[layer_id] for layer_id in range(num_layers)],
[dst_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_kernel, dst_pool_ref)
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
if not all_layers:
ref_copy_with_indices(
src_k_pool[layer_idx_to_test],
dst_k_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_idx_to_test],
dst_v_pool_ref[layer_idx_to_test],
src_indices_host,
dst_indices_device,
)
transfer_kv_per_layer(
src_k_pool[layer_idx_to_test],
dst_k_pool_kernel[layer_idx_to_test],
src_v_pool[layer_idx_to_test],
dst_v_pool_kernel[layer_idx_to_test],
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
)
transfer_kv_direct(
[src_k_pool[layer_idx_to_test], src_v_pool[layer_idx_to_test]],
[
dst_k_pool_direct[layer_idx_to_test],
dst_v_pool_direct[layer_idx_to_test],
],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
else:
for layer_id in range(num_layers):
ref_copy_with_indices(
src_k_pool[layer_id],
dst_k_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
ref_copy_with_indices(
src_v_pool[layer_id],
dst_v_pool_ref[layer_id],
src_indices_host,
dst_indices_device,
)
src_k_layers_device = torch.tensor(
[src_k_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
src_v_layers_device = torch.tensor(
[src_v_pool[layer_id].data_ptr() for layer_id in range(num_layers)],
dtype=torch.uint64,
device=device,
)
dst_k_layers_device = torch.tensor(
[
dst_k_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
dst_v_layers_device = torch.tensor(
[
dst_v_pool_kernel[layer_id].data_ptr()
for layer_id in range(num_layers)
],
dtype=torch.uint64,
device=device,
)
transfer_kv_all_layer(
src_k_layers_device,
dst_k_layers_device,
src_v_layers_device,
dst_v_layers_device,
src_indices_device,
dst_indices_device,
item_size=item_size * dtype.itemsize,
num_layers=num_layers,
)
transfer_kv_direct(
[src_k_pool[layer_id] for layer_id in range(num_layers)]
+ [src_v_pool[layer_id] for layer_id in range(num_layers)],
[dst_k_pool_direct[layer_id] for layer_id in range(num_layers)]
+ [dst_v_pool_direct[layer_id] for layer_id in range(num_layers)],
src_indices_host,
dst_indices_device,
page_size=page_size,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [128, 1024, 8192])
@pytest.mark.parametrize("page_size", [16, 64, 128])
@pytest.mark.parametrize("item_size", [256])
@pytest.mark.parametrize("total_items_in_pool", [20480])
@pytest.mark.parametrize("is_mla", [False, True])
@pytest.mark.parametrize("lf_to_pf", [False, True])
def test_transfer_kv_pf_direct(
dtype: torch.dtype,
num_items_to_transfer: int,
item_size: int,
page_size: int,
total_items_in_pool: int,
is_mla: bool,
lf_to_pf: bool,
):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
test_stream = torch.cuda.Stream()
num_layers = 4
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if lf_to_pf:
if is_mla:
src_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_pool_ptrs = [src_pool[i] for i in range(num_layers)]
dst_pool_ref = torch.zeros(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_pool_direct = torch.zeros_like(dst_pool_ref)
torch.cuda.synchronize()
with torch.cuda.stream(test_stream):
transfer_kv_all_layer_direct_lf_pf(
src_pool_ptrs,
[dst_pool_direct],
src_indices_host,
dst_indices_host,
page_size,
)
test_stream.synchronize()
for i in range(num_layers):
ref_copy_with_indices_pf_direct(
src_pool,
dst_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
src_k_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
src_v_pool = torch.randn(num_layers, total_items_in_pool, item_size).to(
device
)
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
dst_k_pool_ref = torch.zeros(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
torch.cuda.synchronize()
with torch.cuda.stream(test_stream):
transfer_kv_all_layer_direct_lf_pf(
src_k_pool_ptrs + src_v_pool_ptrs,
[dst_k_pool_direct, dst_v_pool_direct],
src_indices_host,
dst_indices_host,
page_size,
)
test_stream.synchronize()
for i in range(num_layers):
ref_copy_with_indices_pf_direct(
src_k_pool,
dst_k_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
ref_copy_with_indices_pf_direct(
src_v_pool,
dst_v_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
lf_to_pf=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
else:
if is_mla:
src_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
device
)
dst_pool_direct = torch.zeros_like(dst_pool_ref)
dst_pool_direct_ptrs = [dst_pool_direct[i] for i in range(num_layers)]
torch.cuda.synchronize()
with torch.cuda.stream(test_stream):
transfer_kv_per_layer_direct_pf_lf(
[src_pool],
[dst_pool_direct_ptrs[layer_idx_to_test]],
src_indices_host,
dst_indices_host,
layer_idx_to_test,
page_size,
)
test_stream.synchronize()
ref_copy_with_indices_pf_direct(
src_pool,
dst_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_pool_direct, dst_pool_ref)
else:
src_k_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
src_v_pool = torch.randn(
total_pages_in_pool, num_layers, page_size, item_size
).pin_memory()
dst_k_pool_ref = torch.zeros(num_layers, total_items_in_pool, item_size).to(
device
)
dst_k_pool_direct = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_direct_ptrs = [dst_k_pool_direct[i] for i in range(num_layers)]
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_direct = torch.zeros_like(dst_v_pool_ref)
dst_v_pool_direct_ptrs = [dst_v_pool_direct[i] for i in range(num_layers)]
torch.cuda.synchronize()
with torch.cuda.stream(test_stream):
transfer_kv_per_layer_direct_pf_lf(
[src_k_pool, src_v_pool],
[
dst_k_pool_direct_ptrs[layer_idx_to_test],
dst_v_pool_direct_ptrs[layer_idx_to_test],
],
src_indices_host,
dst_indices_host,
layer_idx_to_test,
page_size,
)
test_stream.synchronize()
ref_copy_with_indices_pf_direct(
src_k_pool,
dst_k_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
ref_copy_with_indices_pf_direct(
src_v_pool,
dst_v_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
lf_to_pf=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_direct, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_direct, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
@pytest.mark.skipif(is_hip(), reason="HIP is not supported for this test")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_items_to_transfer", [256, 1024])
@pytest.mark.parametrize("page_size", [16, 64, 128])
@pytest.mark.parametrize("item_size", [1024])
@pytest.mark.parametrize("head_num", [8, 16])
@pytest.mark.parametrize("total_items_in_pool", [4096])
@pytest.mark.parametrize("lf_to_ph", [False, True])
def test_transfer_kv_page_head(
dtype: torch.dtype,
num_items_to_transfer: int,
page_size: int,
item_size: int,
head_num: int,
total_items_in_pool: int,
lf_to_ph: bool,
):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
device = "cuda"
torch.cuda.manual_seed(42)
num_layers = 4
total_pages_in_pool = total_items_in_pool // page_size
num_pages_to_transfer = num_items_to_transfer // page_size
if num_pages_to_transfer == 0:
torch.set_default_dtype(original_dtype)
return
assert item_size % head_num == 0
head_dim = item_size // head_num
page_indices = torch.randperm(total_pages_in_pool, dtype=torch.int64)
src_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[:num_pages_to_transfer]
]
)
src_indices_device = src_indices_host.to(device)
dst_indices_host = torch.cat(
[
torch.arange(p * page_size, (p + 1) * page_size)
for p in page_indices[num_pages_to_transfer : 2 * num_pages_to_transfer]
]
)
dst_indices_device = dst_indices_host.to(device)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test = 0
if lf_to_ph:
src_k_pool = torch.randn(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
src_v_pool = torch.randn(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
src_k_pool_ptrs = [src_k_pool[i] for i in range(num_layers)]
src_k_pool_ptrs = torch.tensor(
[x.data_ptr() for x in src_k_pool_ptrs],
dtype=torch.uint64,
device=device,
)
src_v_pool_ptrs = [src_v_pool[i] for i in range(num_layers)]
src_v_pool_ptrs = torch.tensor(
[x.data_ptr() for x in src_v_pool_ptrs],
dtype=torch.uint64,
device=device,
)
dst_k_pool_ref = torch.zeros(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref).pin_memory()
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref).pin_memory()
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref).pin_memory()
torch.cuda.synchronize()
transfer_kv_all_layer_lf_ph(
src_k_pool_ptrs,
dst_k_pool_kernel,
src_v_pool_ptrs,
dst_v_pool_kernel,
src_indices_device,
dst_indices_device,
item_size * dtype.itemsize,
item_size * num_layers * dtype.itemsize,
num_layers,
page_size,
head_num,
)
torch.cuda.synchronize()
for i in range(num_layers):
ref_copy_with_indices_page_head(
src_k_pool,
dst_k_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
head_num,
lf_to_ph=True,
)
ref_copy_with_indices_page_head(
src_v_pool,
dst_v_pool_ref,
src_indices_device,
dst_indices_host,
page_size,
i,
head_num,
lf_to_ph=True,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
else:
from sgl_kernel.kvcacheio import transfer_kv_per_layer_ph_lf
src_k_pool = torch.randn(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
src_v_pool = torch.randn(
total_pages_in_pool, head_num, page_size, num_layers, head_dim
).pin_memory()
dst_k_pool_ref = torch.zeros(
num_layers, total_items_in_pool, head_num, head_dim
).to(device)
dst_v_pool_ref = torch.zeros_like(dst_k_pool_ref)
dst_k_pool_kernel = torch.zeros_like(dst_k_pool_ref)
dst_v_pool_kernel = torch.zeros_like(dst_v_pool_ref)
dst_k_pool_kernel_ptrs = [dst_k_pool_kernel[i] for i in range(num_layers)]
dst_v_pool_kernel_ptrs = [dst_v_pool_kernel[i] for i in range(num_layers)]
torch.cuda.synchronize()
transfer_kv_per_layer_ph_lf(
src_k_pool,
dst_k_pool_kernel_ptrs[layer_idx_to_test],
src_v_pool,
dst_v_pool_kernel_ptrs[layer_idx_to_test],
src_indices_device,
dst_indices_device,
layer_idx_to_test,
item_size * dtype.itemsize,
item_size * num_layers * dtype.itemsize,
page_size,
head_num,
)
ref_copy_with_indices_page_head(
src_k_pool,
dst_k_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
head_num,
lf_to_ph=False,
)
ref_copy_with_indices_page_head(
src_v_pool,
dst_v_pool_ref,
src_indices_host,
dst_indices_device,
page_size,
layer_idx_to_test,
head_num,
lf_to_ph=False,
)
torch.cuda.synchronize()
torch.testing.assert_close(dst_k_pool_kernel, dst_k_pool_ref)
torch.testing.assert_close(dst_v_pool_kernel, dst_v_pool_ref)
torch.set_default_dtype(original_dtype)
if __name__ == "__main__":
pytest.main([__file__])