| import logging | |
| import os | |
| import torch | |
| import torch.distributed | |
| from aibrix_kvcache.common.absl_logging import log_every_n_seconds | |
| from aibrix_kvcache_storage import AibrixKVCacheStorage | |
| from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig | |
| from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool | |
| from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def setup(): | |
| os.environ["RANK"] = "0" | |
| os.environ["WORLD_SIZE"] = "1" | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "63886" | |
| class AIBrixKVCacheStorageTest: | |
| def test_with_page_size(self): | |
| config = HiCacheStorageConfig( | |
| tp_rank=0, | |
| tp_size=1, | |
| is_mla_model=False, | |
| is_page_first_layout=True, | |
| model_name="test", | |
| ) | |
| for page_size in range(1, 3): | |
| logger.info(f"page_size: {page_size}") | |
| batch_size = 2 | |
| head_num = 1 | |
| layer_num = 64 | |
| head_dim = 128 | |
| kv_cache = MHATokenToKVPool( | |
| 1024, | |
| page_size, | |
| torch.float16, | |
| head_num, | |
| head_dim, | |
| layer_num, | |
| "cpu", | |
| False, | |
| 0, | |
| layer_num, | |
| ) | |
| mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first") | |
| query_length = batch_size * 2 | |
| partial = batch_size | |
| self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool) | |
| target_shape = (2, layer_num, page_size, head_num, head_dim) | |
| rand_tensor = [ | |
| torch.rand(target_shape, dtype=torch.float16) | |
| for _ in range(query_length) | |
| ] | |
| keys = ["hash" + str(i) for i in range(query_length)] | |
| partial_keys = keys[batch_size:query_length] | |
| assert self.aibrix_kvcache.batch_exists(keys) == 0 | |
| assert self.aibrix_kvcache.batch_set(keys, rand_tensor) | |
| get_tensor = [ | |
| torch.rand(target_shape, dtype=torch.float16).flatten() | |
| for _ in range(query_length) | |
| ] | |
| self.aibrix_kvcache.batch_get(keys, get_tensor) | |
| for i in range(query_length): | |
| assert torch.equal(get_tensor[i], rand_tensor[i].flatten()) | |
| ret = self.aibrix_kvcache.batch_exists(keys) | |
| assert self.aibrix_kvcache.batch_exists(keys) == query_length | |
| assert self.aibrix_kvcache.batch_exists(partial_keys) == partial | |
| partial_get_tensor = [ | |
| torch.rand(target_shape, dtype=torch.float16).flatten() | |
| for _ in range(partial) | |
| ] | |
| self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor) | |
| for i in range(partial): | |
| assert torch.equal( | |
| partial_get_tensor[i], rand_tensor[i + partial].flatten() | |
| ) | |
| log_every_n_seconds( | |
| logger, | |
| logging.INFO, | |
| self.aibrix_kvcache.kv_cache_manager.metrics.summary(), | |
| 1, | |
| ) | |
| if __name__ == "__main__": | |
| setup() | |
| test = AIBrixKVCacheStorageTest() | |
| test.test_with_page_size() | |
Xet Storage Details
- Size:
- 3.42 kB
- Xet hash:
- dd5cd8bcc8dbdfcccfded65541a5c68d5806d91a8a07ee5a94767dfbc129117a
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.