| """ | |
| Copyright 2023-2024 SGLang Team | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from sglang.srt.configs.mamba_utils import Mamba2CacheParams | |
| from sglang.srt.layers.attention.nsa import index_buf_accessor | |
| from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache | |
| from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter | |
| """ | |
| Memory pool. | |
| SGLang has two levels of memory pool. | |
| ReqToTokenPool maps a request to its token locations. | |
| TokenToKVPoolAllocator manages the indices to kv cache data. | |
| KVCache actually holds the physical kv cache. | |
| """ | |
| import abc | |
| import logging | |
| from contextlib import nullcontext | |
| from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE | |
| from sglang.srt.layers.radix_attention import RadixAttention | |
| from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.cache_controller import LayerDoneCounter | |
| from sglang.srt.managers.schedule_batch import Req | |
| logger = logging.getLogger(__name__) | |
| GB = 1024 * 1024 * 1024 | |
| _is_cuda = is_cuda() | |
| _is_npu = is_npu() | |
| if _is_npu: | |
| import torch_npu | |
| def get_tensor_size_bytes(t: torch.Tensor): | |
| return np.prod(t.shape) * t.dtype.itemsize | |
| class ReqToTokenPool: | |
| """A memory pool that maps a request to its token locations.""" | |
| def __init__( | |
| self, | |
| size: int, | |
| max_context_len: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| ): | |
| memory_saver_adapter = TorchMemorySaverAdapter.create( | |
| enable=enable_memory_saver | |
| ) | |
| self.size = size | |
| self.max_context_len = max_context_len | |
| self.device = device | |
| with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| self.req_to_token = torch.zeros( | |
| (size, max_context_len), dtype=torch.int32, device=device | |
| ) | |
| self.free_slots = list(range(size)) | |
| def write(self, indices, values): | |
| self.req_to_token[indices] = values | |
| def available_size(self): | |
| return len(self.free_slots) | |
| def alloc(self, need_size: int) -> List[int]: | |
| if need_size > len(self.free_slots): | |
| return None | |
| select_index = self.free_slots[:need_size] | |
| self.free_slots = self.free_slots[need_size:] | |
| return select_index | |
| def free(self, free_index: Union[int, List[int]]): | |
| if isinstance(free_index, (int,)): | |
| self.free_slots.append(free_index) | |
| else: | |
| self.free_slots.extend(free_index) | |
| def clear(self): | |
| self.free_slots = list(range(self.size)) | |
| class MambaPool: | |
| class State: | |
| conv: torch.Tensor | |
| temporal: torch.Tensor | |
| def at_layer_idx(self, layer: int): | |
| return type(self)(**{k: v[layer] for k, v in vars(self).items()}) | |
| def mem_usage_bytes(self): | |
| return sum(get_tensor_size_bytes(t) for t in vars(self).values()) | |
| class SpeculativeState(State): | |
| intermediate_ssm: torch.Tensor | |
| intermediate_conv_window: torch.Tensor | |
| def __init__( | |
| self, | |
| *, | |
| size: int, | |
| cache_params: "Mamba2CacheParams", | |
| device: str, | |
| speculative_num_draft_tokens: Optional[int] = None, | |
| ): | |
| conv_state_shape = cache_params.shape.conv | |
| temporal_state_shape = cache_params.shape.temporal | |
| conv_dtype = cache_params.dtype.conv | |
| ssm_dtype = cache_params.dtype.temporal | |
| num_mamba_layers = len(cache_params.layers) | |
| # for disagg with nvlink | |
| self.enable_custom_mem_pool = get_bool_env_var( | |
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" | |
| ) | |
| if self.enable_custom_mem_pool: | |
| # TODO(shangming): abstract custom allocator class for more backends | |
| from mooncake.allocator import NVLinkAllocator | |
| allocator = NVLinkAllocator.get_allocator(self.device) | |
| self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) | |
| else: | |
| self.custom_mem_pool = None | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.enable_custom_mem_pool | |
| else nullcontext() | |
| ): | |
| # assume conv_state = (dim, state_len) | |
| assert conv_state_shape[0] > conv_state_shape[1] | |
| conv_state = torch.zeros( | |
| size=(num_mamba_layers, size + 1) + conv_state_shape, | |
| dtype=conv_dtype, | |
| device=device, | |
| ) | |
| temporal_state = torch.zeros( | |
| size=(num_mamba_layers, size + 1) + temporal_state_shape, | |
| dtype=ssm_dtype, | |
| device=device, | |
| ) | |
| if speculative_num_draft_tokens is not None: | |
| # Cache intermediate SSM states per draft token during target verify | |
| # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] | |
| intermediate_ssm_state_cache = torch.zeros( | |
| size=( | |
| num_mamba_layers, | |
| size + 1, | |
| speculative_num_draft_tokens, | |
| temporal_state_shape[0], | |
| temporal_state_shape[1], | |
| temporal_state_shape[2], | |
| ), | |
| dtype=ssm_dtype, | |
| device="cuda", | |
| ) | |
| # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify | |
| # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] | |
| intermediate_conv_window_cache = torch.zeros( | |
| size=( | |
| num_mamba_layers, | |
| size + 1, | |
| speculative_num_draft_tokens, | |
| conv_state_shape[0], | |
| conv_state_shape[1], | |
| ), | |
| dtype=conv_dtype, | |
| device="cuda", | |
| ) | |
| self.mamba_cache = self.SpeculativeState( | |
| conv=conv_state, | |
| temporal=temporal_state, | |
| intermediate_ssm=intermediate_ssm_state_cache, | |
| intermediate_conv_window=intermediate_conv_window_cache, | |
| ) | |
| logger.info( | |
| f"Mamba Cache is allocated. " | |
| f"max_mamba_cache_size: {size}, " | |
| f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " | |
| f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " | |
| f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " | |
| f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " | |
| ) | |
| else: | |
| self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) | |
| logger.info( | |
| f"Mamba Cache is allocated. " | |
| f"max_mamba_cache_size: {size}, " | |
| f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " | |
| f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " | |
| ) | |
| self.size = size | |
| self.device = device | |
| self.free_slots = torch.arange( | |
| self.size, dtype=torch.int64, device=self.device | |
| ) | |
| self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB | |
| self.num_mamba_layers = num_mamba_layers | |
| def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: | |
| assert isinstance(self.mamba_cache, self.SpeculativeState) | |
| return self.mamba_cache | |
| def mamba2_layer_cache(self, layer_id: int): | |
| return self.mamba_cache.at_layer_idx(layer_id) | |
| def available_size(self): | |
| return len(self.free_slots) | |
| def alloc(self, need_size: int) -> Optional[torch.Tensor]: | |
| if need_size > len(self.free_slots): | |
| return None | |
| select_index = self.free_slots[:need_size] | |
| self.free_slots = self.free_slots[need_size:] | |
| return select_index | |
| def free(self, free_index: torch.Tensor): | |
| if free_index.numel() == 0: | |
| return | |
| self.free_slots = torch.cat((self.free_slots, free_index)) | |
| self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ | |
| :, free_index | |
| ] = 0 | |
| def clear(self): | |
| self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) | |
| def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): | |
| self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] | |
| self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ | |
| :, src_index | |
| ] | |
| return | |
| def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]: | |
| dst_index = self.alloc(1) | |
| if dst_index == None: | |
| return None | |
| self.copy_from(src_index, dst_index) | |
| return dst_index | |
| def get_contiguous_buf_infos(self): | |
| state_tensors = [ | |
| getattr(self.mamba_cache, field) for field in vars(self.mamba_cache) | |
| ] | |
| data_ptrs, data_lens, item_lens = [], [], [] | |
| for _, state_tensor in enumerate(state_tensors): | |
| data_ptrs += [ | |
| state_tensor[i].data_ptr() for i in range(self.num_mamba_layers) | |
| ] | |
| data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)] | |
| item_lens += [ | |
| state_tensor[i][0].nbytes for i in range(self.num_mamba_layers) | |
| ] | |
| return data_ptrs, data_lens, item_lens | |
| class HybridReqToTokenPool(ReqToTokenPool): | |
| """A memory pool that maps a request to its token locations.""" | |
| def __init__( | |
| self, | |
| *, | |
| size: int, | |
| mamba_size: int, | |
| max_context_len: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| cache_params: "Mamba2CacheParams", | |
| speculative_num_draft_tokens: int = None, | |
| ): | |
| super().__init__( | |
| size=size, | |
| max_context_len=max_context_len, | |
| device=device, | |
| enable_memory_saver=enable_memory_saver, | |
| ) | |
| self._init_mamba_pool( | |
| size=mamba_size, | |
| cache_params=cache_params, | |
| device=device, | |
| speculative_num_draft_tokens=speculative_num_draft_tokens, | |
| ) | |
| def _init_mamba_pool( | |
| self, | |
| size: int, | |
| cache_params: "Mamba2CacheParams", | |
| device: str, | |
| speculative_num_draft_tokens: int = None, | |
| ): | |
| self.mamba_pool = MambaPool( | |
| size=size, | |
| cache_params=cache_params, | |
| device=device, | |
| speculative_num_draft_tokens=speculative_num_draft_tokens, | |
| ) | |
| self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} | |
| self.device = device | |
| self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( | |
| size, dtype=torch.int32, device=self.device | |
| ) | |
| # For chunk prefill req, we do not need to allocate mamba cache, | |
| # We could use allocated mamba cache instead. | |
| def alloc( | |
| self, need_size: int, reqs: Optional[List[Req]] = None | |
| ) -> Optional[List[int]]: | |
| select_index = super().alloc(need_size) | |
| if select_index == None: | |
| return None | |
| mamba_index = [] | |
| for req in reqs: | |
| mid = None | |
| if req.mamba_pool_idx is not None: # for radix cache | |
| mid = req.mamba_pool_idx | |
| else: | |
| mid = self.mamba_pool.alloc(1)[0] | |
| req.mamba_pool_idx = mid | |
| if mid is not None: | |
| mamba_index.append(mid) | |
| assert len(select_index) == len( | |
| mamba_index | |
| ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." | |
| self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( | |
| mamba_index, dtype=torch.int32, device=self.device | |
| ) | |
| return select_index | |
| def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor: | |
| return self.req_index_to_mamba_index_mapping[req_indices] | |
| def mamba2_layer_cache(self, layer_id: int): | |
| assert layer_id in self.mamba_map | |
| return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id]) | |
| def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState: | |
| return self.mamba_pool.get_speculative_mamba2_params_all_layers() | |
| # For chunk prefill, we can not free mamba cache, we need use it in the future | |
| def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): | |
| if isinstance(free_index, (int,)): | |
| free_index = [free_index] | |
| super().free(free_index) | |
| if free_mamba_cache: | |
| mamba_index = self.req_index_to_mamba_index_mapping[free_index] | |
| self.mamba_pool.free(mamba_index) | |
| def clear(self): | |
| super().clear() | |
| self.mamba_pool.clear() | |
| class KVCache(abc.ABC): | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| dtype: torch.dtype, | |
| layer_num: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| ): | |
| self.size = size | |
| self.page_size = page_size | |
| self.dtype = dtype | |
| self.device = device | |
| if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): | |
| # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 | |
| self.store_dtype = torch.uint8 | |
| else: | |
| self.store_dtype = dtype | |
| self.layer_num = layer_num | |
| self.start_layer = start_layer or 0 | |
| self.end_layer = end_layer or layer_num - 1 | |
| self.memory_saver_adapter = TorchMemorySaverAdapter.create( | |
| enable=enable_memory_saver | |
| ) | |
| self.mem_usage = 0 | |
| # used for chunked cpu-offloading | |
| self.cpu_offloading_chunk_size = 8192 | |
| # default state for optional layer-wise transfer control | |
| self.layer_transfer_counter = None | |
| # for disagg with nvlink | |
| self.enable_custom_mem_pool = get_bool_env_var( | |
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" | |
| ) | |
| if self.enable_custom_mem_pool: | |
| # TODO(shangming): abstract custom allocator class for more backends | |
| from mooncake.allocator import NVLinkAllocator | |
| allocator = NVLinkAllocator.get_allocator(self.device) | |
| self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) | |
| else: | |
| self.custom_mem_pool = None | |
| def _finalize_allocation_log(self, num_tokens: int): | |
| """Common logging and mem_usage computation for KV cache allocation. | |
| Supports both tuple (K, V) size returns and single KV size returns. | |
| """ | |
| kv_size_bytes = self.get_kv_size_bytes() | |
| if isinstance(kv_size_bytes, tuple): | |
| k_size, v_size = kv_size_bytes | |
| k_size_GB = k_size / GB | |
| v_size_GB = v_size / GB | |
| logger.info( | |
| f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB" | |
| ) | |
| self.mem_usage = k_size_GB + v_size_GB | |
| else: | |
| kv_size_GB = kv_size_bytes / GB | |
| logger.info( | |
| f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB" | |
| ) | |
| self.mem_usage = kv_size_GB | |
| def get_key_buffer(self, layer_id: int) -> torch.Tensor: | |
| raise NotImplementedError() | |
| def get_value_buffer(self, layer_id: int) -> torch.Tensor: | |
| raise NotImplementedError() | |
| def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| raise NotImplementedError() | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| ) -> None: | |
| raise NotImplementedError() | |
| def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter): | |
| self.layer_transfer_counter = layer_transfer_counter | |
| def get_cpu_copy(self, indices): | |
| raise NotImplementedError() | |
| def load_cpu_copy(self, kv_cache_cpu, indices): | |
| raise NotImplementedError() | |
| def maybe_get_custom_mem_pool(self): | |
| return self.custom_mem_pool | |
| class MHATokenToKVPool(KVCache): | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| dtype: torch.dtype, | |
| head_num: int, | |
| head_dim: int, | |
| layer_num: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| enable_kv_cache_copy: bool = False, | |
| ): | |
| super().__init__( | |
| size, | |
| page_size, | |
| dtype, | |
| layer_num, | |
| device, | |
| enable_memory_saver, | |
| start_layer, | |
| end_layer, | |
| ) | |
| self.head_num = head_num | |
| self.head_dim = head_dim | |
| self._create_buffers() | |
| self.device_module = torch.get_device_module(self.device) | |
| self.alt_stream = self.device_module.Stream() if _is_cuda else None | |
| if enable_kv_cache_copy: | |
| self._init_kv_copy_and_warmup() | |
| else: | |
| self._kv_copy_config = None | |
| self._finalize_allocation_log(size) | |
| def _init_kv_copy_and_warmup(self): | |
| # Heuristics for KV copy tiling | |
| _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192 | |
| _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096 | |
| _KV_COPY_TILE_SIZE_LARGE = 512 | |
| _KV_COPY_TILE_SIZE_MEDIUM = 256 | |
| _KV_COPY_TILE_SIZE_SMALL = 128 | |
| _KV_COPY_NUM_WARPS_LARGE_TILE = 8 | |
| _KV_COPY_NUM_WARPS_SMALL_TILE = 4 | |
| stride_bytes = int(self.data_strides[0].item()) | |
| if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE: | |
| bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE | |
| elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM: | |
| bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM | |
| else: | |
| bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL | |
| self._kv_copy_config = { | |
| "bytes_per_tile": bytes_per_tile, | |
| "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile, | |
| "num_warps": ( | |
| _KV_COPY_NUM_WARPS_SMALL_TILE | |
| if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM | |
| else _KV_COPY_NUM_WARPS_LARGE_TILE | |
| ), | |
| } | |
| dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device) | |
| grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"]) | |
| copy_all_layer_kv_cache_tiled[grid]( | |
| self.data_ptrs, | |
| self.data_strides, | |
| dummy_loc, | |
| dummy_loc, | |
| 1, | |
| 1, | |
| BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"], | |
| num_warps=self._kv_copy_config["num_warps"], | |
| num_stages=2, | |
| ) | |
| def _create_buffers(self): | |
| with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.enable_custom_mem_pool | |
| else nullcontext() | |
| ): | |
| # [size, head_num, head_dim] for each layer | |
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | |
| self.k_buffer = [ | |
| torch.zeros( | |
| (self.size + self.page_size, self.head_num, self.head_dim), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| for _ in range(self.layer_num) | |
| ] | |
| self.v_buffer = [ | |
| torch.zeros( | |
| (self.size + self.page_size, self.head_num, self.head_dim), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| for _ in range(self.layer_num) | |
| ] | |
| self.k_data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.k_buffer], | |
| dtype=torch.uint64, | |
| device=self.device, | |
| ) | |
| self.v_data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.v_buffer], | |
| dtype=torch.uint64, | |
| device=self.device, | |
| ) | |
| self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0) | |
| self.data_strides = torch.tensor( | |
| [ | |
| np.prod(x.shape[1:]) * x.dtype.itemsize | |
| for x in self.k_buffer + self.v_buffer | |
| ], | |
| device=self.device, | |
| ) | |
| def _clear_buffers(self): | |
| del self.k_buffer | |
| del self.v_buffer | |
| def get_kv_size_bytes(self): | |
| assert hasattr(self, "k_buffer") | |
| assert hasattr(self, "v_buffer") | |
| k_size_bytes = 0 | |
| for k_cache in self.k_buffer: | |
| k_size_bytes += get_tensor_size_bytes(k_cache) | |
| v_size_bytes = 0 | |
| for v_cache in self.v_buffer: | |
| v_size_bytes += get_tensor_size_bytes(v_cache) | |
| return k_size_bytes, v_size_bytes | |
| # for disagg | |
| def get_contiguous_buf_infos(self): | |
| # layer_num x [seq_len, head_num, head_dim] | |
| # layer_num x [page_num, page_size, head_num, head_dim] | |
| kv_data_ptrs = [ | |
| self._get_key_buffer(i).data_ptr() | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self._get_value_buffer(i).data_ptr() | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| kv_data_lens = [ | |
| self._get_key_buffer(i).nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self._get_value_buffer(i).nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| kv_item_lens = [ | |
| self._get_key_buffer(i)[0].nbytes * self.page_size | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self._get_value_buffer(i)[0].nbytes * self.page_size | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| return kv_data_ptrs, kv_data_lens, kv_item_lens | |
| def get_cpu_copy(self, indices): | |
| torch.cuda.synchronize() | |
| kv_cache_cpu = [] | |
| chunk_size = self.cpu_offloading_chunk_size | |
| for layer_id in range(self.layer_num): | |
| kv_cache_cpu.append([]) | |
| for i in range(0, len(indices), chunk_size): | |
| chunk_indices = indices[i : i + chunk_size] | |
| k_cpu = self.k_buffer[layer_id][chunk_indices].to( | |
| "cpu", non_blocking=True | |
| ) | |
| v_cpu = self.v_buffer[layer_id][chunk_indices].to( | |
| "cpu", non_blocking=True | |
| ) | |
| kv_cache_cpu[-1].append([k_cpu, v_cpu]) | |
| torch.cuda.synchronize() | |
| return kv_cache_cpu | |
| def load_cpu_copy(self, kv_cache_cpu, indices): | |
| torch.cuda.synchronize() | |
| chunk_size = self.cpu_offloading_chunk_size | |
| for layer_id in range(self.layer_num): | |
| for i in range(0, len(indices), chunk_size): | |
| chunk_indices = indices[i : i + chunk_size] | |
| k_cpu, v_cpu = ( | |
| kv_cache_cpu[layer_id][i // chunk_size][0], | |
| kv_cache_cpu[layer_id][i // chunk_size][1], | |
| ) | |
| assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices) | |
| k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True) | |
| v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True) | |
| self.k_buffer[layer_id][chunk_indices] = k_chunk | |
| self.v_buffer[layer_id][chunk_indices] = v_chunk | |
| torch.cuda.synchronize() | |
| def _get_key_buffer(self, layer_id: int): | |
| # for internal use of referencing | |
| if self.store_dtype != self.dtype: | |
| return self.k_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.k_buffer[layer_id - self.start_layer] | |
| def get_key_buffer(self, layer_id: int): | |
| # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading | |
| # it is supposed to be used only by attention backend not for information purpose | |
| # same applies to get_value_buffer and get_kv_buffer | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| return self._get_key_buffer(layer_id) | |
| def _get_value_buffer(self, layer_id: int): | |
| # for internal use of referencing | |
| if self.store_dtype != self.dtype: | |
| return self.v_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.v_buffer[layer_id - self.start_layer] | |
| def get_value_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| return self._get_value_buffer(layer_id) | |
| def get_kv_buffer(self, layer_id: int): | |
| return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| k_scale: Optional[float] = None, | |
| v_scale: Optional[float] = None, | |
| layer_id_override: Optional[int] = None, | |
| ): | |
| from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode | |
| if layer_id_override is not None: | |
| layer_id = layer_id_override | |
| else: | |
| layer_id = layer.layer_id | |
| if cache_k.dtype != self.dtype: | |
| if k_scale is not None: | |
| cache_k.div_(k_scale) | |
| if v_scale is not None: | |
| cache_v.div_(v_scale) | |
| cache_k = cache_k.to(self.dtype) | |
| cache_v = cache_v.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| cache_k = cache_k.view(self.store_dtype) | |
| cache_v = cache_v.view(self.store_dtype) | |
| if get_is_capture_mode() and self.alt_stream is not None: | |
| # Overlap the copy of K and V cache for small batch size | |
| current_stream = self.device_module.current_stream() | |
| self.alt_stream.wait_stream(current_stream) | |
| self.k_buffer[layer_id - self.start_layer][loc] = cache_k | |
| with self.device_module.stream(self.alt_stream): | |
| self.v_buffer[layer_id - self.start_layer][loc] = cache_v | |
| current_stream.wait_stream(self.alt_stream) | |
| else: | |
| self.k_buffer[layer_id - self.start_layer][loc] = cache_k | |
| self.v_buffer[layer_id - self.start_layer][loc] = cache_v | |
| def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): | |
| N = tgt_loc.numel() | |
| if N == 0: | |
| return | |
| assert ( | |
| self._kv_copy_config is not None | |
| ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__" | |
| cfg = self._kv_copy_config | |
| N_upper = next_power_of_2(N) | |
| grid = (self.data_ptrs.numel(), cfg["byte_tiles"]) | |
| copy_all_layer_kv_cache_tiled[grid]( | |
| self.data_ptrs, | |
| self.data_strides, | |
| tgt_loc, | |
| src_loc, | |
| N, | |
| N_upper, | |
| BYTES_PER_TILE=cfg["bytes_per_tile"], | |
| num_warps=cfg["num_warps"], | |
| num_stages=2, | |
| ) | |
| class HybridLinearKVPool(KVCache): | |
| """KV cache with separate pools for full and linear attention layers.""" | |
| def __init__( | |
| self, | |
| size: int, | |
| dtype: torch.dtype, | |
| page_size: int, | |
| head_num: int, | |
| head_dim: int, | |
| full_attention_layer_ids: List[int], | |
| enable_kvcache_transpose: bool, | |
| device: str, | |
| mamba_pool: MambaPool, | |
| ): | |
| self.size = size | |
| self.dtype = dtype | |
| self.device = device | |
| self.full_layer_nums = len(full_attention_layer_ids) | |
| self.page_size = page_size | |
| # TODO support pp? | |
| self.start_layer = 0 | |
| self.head_num = head_num | |
| self.head_dim = head_dim | |
| self.mamba_pool = mamba_pool | |
| # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True | |
| assert not enable_kvcache_transpose | |
| if _is_npu: | |
| TokenToKVPoolClass = AscendTokenToKVPool | |
| else: | |
| TokenToKVPoolClass = MHATokenToKVPool | |
| self.full_kv_pool = TokenToKVPoolClass( | |
| size=size, | |
| page_size=self.page_size, | |
| dtype=dtype, | |
| head_num=head_num, | |
| head_dim=head_dim, | |
| layer_num=self.full_layer_nums, | |
| device=device, | |
| enable_memory_saver=False, | |
| ) | |
| self.full_attention_layer_id_mapping = { | |
| id: i for i, id in enumerate(full_attention_layer_ids) | |
| } | |
| k_size, v_size = self.get_kv_size_bytes() | |
| self.mem_usage = (k_size + v_size) / GB | |
| def get_kv_size_bytes(self): | |
| return self.full_kv_pool.get_kv_size_bytes() | |
| def get_contiguous_buf_infos(self): | |
| return self.full_kv_pool.get_contiguous_buf_infos() | |
| def get_state_buf_infos(self): | |
| mamba_data_ptrs, mamba_data_lens, mamba_item_lens = ( | |
| self.mamba_pool.get_contiguous_buf_infos() | |
| ) | |
| return mamba_data_ptrs, mamba_data_lens, mamba_item_lens | |
| def maybe_get_custom_mem_pool(self): | |
| return self.full_kv_pool.maybe_get_custom_mem_pool() | |
| def _transfer_full_attention_id(self, layer_id: int): | |
| if layer_id not in self.full_attention_layer_id_mapping: | |
| raise ValueError( | |
| f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}" | |
| ) | |
| return self.full_attention_layer_id_mapping[layer_id] | |
| def get_key_buffer(self, layer_id: int): | |
| layer_id = self._transfer_full_attention_id(layer_id) | |
| return self.full_kv_pool.get_key_buffer(layer_id) | |
| def get_value_buffer(self, layer_id: int): | |
| layer_id = self._transfer_full_attention_id(layer_id) | |
| return self.full_kv_pool.get_value_buffer(layer_id) | |
| def get_kv_buffer(self, layer_id: int): | |
| layer_id = self._transfer_full_attention_id(layer_id) | |
| return self.full_kv_pool.get_kv_buffer(layer_id) | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| k_scale: float = 1.0, | |
| v_scale: float = 1.0, | |
| ): | |
| layer_id = self._transfer_full_attention_id(layer.layer_id) | |
| self.full_kv_pool.set_kv_buffer( | |
| None, | |
| loc, | |
| cache_k, | |
| cache_v, | |
| k_scale, | |
| v_scale, | |
| layer_id_override=layer_id, | |
| ) | |
| def get_v_head_dim(self): | |
| return self.full_kv_pool.get_value_buffer(0).shape[-1] | |
| class SWAKVPool(KVCache): | |
| """KV cache with separate pools for full and SWA attention layers.""" | |
| def __init__( | |
| self, | |
| size: int, | |
| size_swa: int, | |
| dtype: torch.dtype, | |
| head_num: int, | |
| head_dim: int, | |
| swa_attention_layer_ids: List[int], | |
| full_attention_layer_ids: List[int], | |
| enable_kvcache_transpose: bool, | |
| device: str, | |
| token_to_kv_pool_class: KVCache = MHATokenToKVPool, | |
| **kwargs, | |
| ): | |
| self.size = size | |
| self.size_swa = size_swa | |
| self.dtype = dtype | |
| self.head_num = head_num | |
| self.head_dim = head_dim | |
| self.device = device | |
| self.swa_layer_nums = len(swa_attention_layer_ids) | |
| self.full_layer_nums = len(full_attention_layer_ids) | |
| self.start_layer = 0 | |
| self.page_size = 1 | |
| kwargs["page_size"] = 1 | |
| kwargs["enable_memory_saver"] = False | |
| kwargs["head_num"] = head_num | |
| kwargs["head_dim"] = head_dim | |
| kwargs["device"] = device | |
| # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True | |
| assert not enable_kvcache_transpose | |
| # for disagg with nvlink | |
| self.enable_custom_mem_pool = get_bool_env_var( | |
| "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" | |
| ) | |
| if self.enable_custom_mem_pool: | |
| # TODO(shangming): abstract custom allocator class for more backends | |
| from mooncake.allocator import NVLinkAllocator | |
| allocator = NVLinkAllocator.get_allocator(self.device) | |
| self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) | |
| else: | |
| self.custom_mem_pool = None | |
| self.swa_kv_pool = token_to_kv_pool_class( | |
| size=size_swa, | |
| dtype=dtype, | |
| layer_num=self.swa_layer_nums, | |
| **kwargs, | |
| ) | |
| self.full_kv_pool = token_to_kv_pool_class( | |
| size=size, | |
| dtype=dtype, | |
| layer_num=self.full_layer_nums, | |
| **kwargs, | |
| ) | |
| self.layers_mapping: Dict[int, Tuple[int, bool]] = {} | |
| for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids): | |
| self.layers_mapping[global_layer_id] = (full_attn_layer_id, False) | |
| for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids): | |
| self.layers_mapping[global_layer_id] = (swa_layer_id, True) | |
| self.full_to_swa_index_mapping: Optional[torch.Tensor] = None | |
| k_size, v_size = self.get_kv_size_bytes() | |
| self.mem_usage = (k_size + v_size) / GB | |
| logger.info( | |
| f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}" | |
| ) | |
| def get_kv_size_bytes(self): | |
| k_size, v_size = self.full_kv_pool.get_kv_size_bytes() | |
| k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes() | |
| return k_size + k_size_swa, v_size + v_size_swa | |
| def get_contiguous_buf_infos(self): | |
| full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( | |
| self.full_kv_pool.get_contiguous_buf_infos() | |
| ) | |
| kv_data_ptrs = full_kv_data_ptrs | |
| kv_data_lens = full_kv_data_lens | |
| kv_item_lens = full_kv_item_lens | |
| return kv_data_ptrs, kv_data_lens, kv_item_lens | |
| def get_state_buf_infos(self): | |
| swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = ( | |
| self.swa_kv_pool.get_contiguous_buf_infos() | |
| ) | |
| return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens | |
| def get_key_buffer(self, layer_id: int): | |
| layer_id_pool, is_swa = self.layers_mapping[layer_id] | |
| if is_swa: | |
| return self.swa_kv_pool.get_key_buffer(layer_id_pool) | |
| else: | |
| return self.full_kv_pool.get_key_buffer(layer_id_pool) | |
| def get_value_buffer(self, layer_id: int): | |
| layer_id_pool, is_swa = self.layers_mapping[layer_id] | |
| if is_swa: | |
| return self.swa_kv_pool.get_value_buffer(layer_id_pool) | |
| else: | |
| return self.full_kv_pool.get_value_buffer(layer_id_pool) | |
| def get_kv_buffer(self, layer_id: int): | |
| layer_id_pool, is_swa = self.layers_mapping[layer_id] | |
| if is_swa: | |
| return self.swa_kv_pool.get_kv_buffer(layer_id_pool) | |
| else: | |
| return self.full_kv_pool.get_kv_buffer(layer_id_pool) | |
| def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): | |
| assert self.full_to_swa_index_mapping is not None | |
| return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| k_scale: float = 1.0, | |
| v_scale: float = 1.0, | |
| ): | |
| layer_id = layer.layer_id | |
| layer_id_pool, is_swa = self.layers_mapping[layer_id] | |
| if is_swa: | |
| if self.full_to_swa_index_mapping is not None: | |
| loc = self.translate_loc_from_full_to_swa(loc) | |
| self.swa_kv_pool.set_kv_buffer( | |
| None, | |
| loc, | |
| cache_k, | |
| cache_v, | |
| k_scale, | |
| v_scale, | |
| layer_id_override=layer_id_pool, | |
| ) | |
| else: | |
| self.full_kv_pool.set_kv_buffer( | |
| None, | |
| loc, | |
| cache_k, | |
| cache_v, | |
| k_scale, | |
| v_scale, | |
| layer_id_override=layer_id_pool, | |
| ) | |
| class AscendTokenToKVPool(MHATokenToKVPool): | |
| def _create_buffers(self): | |
| with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| # [size, head_num, head_dim] for each layer | |
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | |
| # Continuous memory improves the efficiency of Ascend`s transmission backend, | |
| # while other backends remain unchanged. | |
| self.kv_buffer = torch.zeros( | |
| ( | |
| 2, | |
| self.layer_num, | |
| self.size // self.page_size + 1, | |
| self.page_size, | |
| self.head_num, | |
| self.head_dim, | |
| ), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| self.k_buffer = self.kv_buffer[0] | |
| self.v_buffer = self.kv_buffer[1] | |
| # for disagg | |
| def get_contiguous_buf_infos(self): | |
| # layer_num x [seq_len, head_num, head_dim] | |
| # layer_num x [page_num, page_size, head_num, head_dim] | |
| kv_data_ptrs = [ | |
| self.get_key_buffer(i).data_ptr() | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self.get_value_buffer(i).data_ptr() | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| kv_data_lens = [ | |
| self.get_key_buffer(i).nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self.get_value_buffer(i).nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| kv_item_lens = [ | |
| self.get_key_buffer(i)[0].nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] + [ | |
| self.get_value_buffer(i)[0].nbytes | |
| for i in range(self.start_layer, self.start_layer + self.layer_num) | |
| ] | |
| return kv_data_ptrs, kv_data_lens, kv_item_lens | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| k_scale: Optional[float] = None, | |
| v_scale: Optional[float] = None, | |
| layer_id_override: Optional[int] = None, | |
| ): | |
| if layer_id_override is not None: | |
| layer_id = layer_id_override | |
| else: | |
| layer_id = layer.layer_id | |
| if cache_k.dtype != self.dtype: | |
| if k_scale is not None: | |
| cache_k.div_(k_scale) | |
| if v_scale is not None: | |
| cache_v.div_(v_scale) | |
| cache_k = cache_k.to(self.dtype) | |
| cache_v = cache_v.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| cache_k = cache_k.view(self.store_dtype) | |
| cache_v = cache_v.view(self.store_dtype) | |
| torch_npu._npu_reshape_and_cache( | |
| key=cache_k, | |
| value=cache_v, | |
| key_cache=self.k_buffer[layer_id].view( | |
| -1, self.page_size, self.head_num, self.head_dim | |
| ), | |
| value_cache=self.v_buffer[layer_id].view( | |
| -1, self.page_size, self.head_num, self.head_dim | |
| ), | |
| slot_indices=loc, | |
| ) | |
| def set_mla_kv_buffer_kernel( | |
| kv_buffer_ptr, | |
| cache_k_nope_ptr, | |
| cache_k_rope_ptr, | |
| loc_ptr, | |
| buffer_stride: tl.constexpr, | |
| nope_stride: tl.constexpr, | |
| rope_stride: tl.constexpr, | |
| nope_dim: tl.constexpr, | |
| rope_dim: tl.constexpr, | |
| BLOCK: tl.constexpr, | |
| ): | |
| pid_loc = tl.program_id(0) | |
| pid_blk = tl.program_id(1) | |
| base = pid_blk * BLOCK | |
| offs = base + tl.arange(0, BLOCK) | |
| total_dim = nope_dim + rope_dim | |
| mask = offs < total_dim | |
| loc = tl.load(loc_ptr + pid_loc) | |
| dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs | |
| if base + BLOCK <= nope_dim: | |
| src = tl.load( | |
| cache_k_nope_ptr + pid_loc * nope_stride + offs, | |
| mask=mask, | |
| ) | |
| else: | |
| offs_rope = offs - nope_dim | |
| src = tl.load( | |
| cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, | |
| mask=mask, | |
| ) | |
| tl.store(dst_ptr, src, mask=mask) | |
| def set_mla_kv_buffer_triton( | |
| kv_buffer: torch.Tensor, | |
| loc: torch.Tensor, | |
| cache_k_nope: torch.Tensor, | |
| cache_k_rope: torch.Tensor, | |
| ): | |
| nope_dim = cache_k_nope.shape[-1] | |
| rope_dim = cache_k_rope.shape[-1] | |
| total_dim = nope_dim + rope_dim | |
| BLOCK = 128 | |
| n_loc = loc.numel() | |
| grid = (n_loc, triton.cdiv(total_dim, BLOCK)) | |
| set_mla_kv_buffer_kernel[grid]( | |
| kv_buffer, | |
| cache_k_nope, | |
| cache_k_rope, | |
| loc, | |
| kv_buffer.stride(0), | |
| cache_k_nope.stride(0), | |
| cache_k_rope.stride(0), | |
| nope_dim, | |
| rope_dim, | |
| BLOCK=BLOCK, | |
| ) | |
| class MLATokenToKVPool(KVCache): | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| dtype: torch.dtype, | |
| kv_lora_rank: int, | |
| qk_rope_head_dim: int, | |
| layer_num: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| use_nsa: bool = False, | |
| override_kv_cache_dim: Optional[int] = None, | |
| ): | |
| super().__init__( | |
| size, | |
| page_size, | |
| dtype, | |
| layer_num, | |
| device, | |
| enable_memory_saver, | |
| start_layer, | |
| end_layer, | |
| ) | |
| self.kv_lora_rank = kv_lora_rank | |
| self.qk_rope_head_dim = qk_rope_head_dim | |
| self.use_nsa = use_nsa | |
| self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn | |
| # TODO do not hardcode | |
| self.kv_cache_dim = ( | |
| 656 | |
| if self.use_nsa and self.nsa_kv_cache_store_fp8 | |
| else (kv_lora_rank + qk_rope_head_dim) | |
| ) | |
| with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.custom_mem_pool | |
| else nullcontext() | |
| ): | |
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | |
| self.kv_buffer = [ | |
| torch.zeros( | |
| (size + page_size, 1, self.kv_cache_dim), | |
| dtype=self.store_dtype, | |
| device=device, | |
| ) | |
| for _ in range(layer_num) | |
| ] | |
| self.data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.kv_buffer], | |
| dtype=torch.uint64, | |
| device=self.device, | |
| ) | |
| if not use_nsa: | |
| # NSA will allocate indexer KV cache later and then log the total size | |
| self._finalize_allocation_log(size) | |
| def get_kv_size_bytes(self): | |
| assert hasattr(self, "kv_buffer") | |
| kv_size_bytes = 0 | |
| for kv_cache in self.kv_buffer: | |
| kv_size_bytes += get_tensor_size_bytes(kv_cache) | |
| return kv_size_bytes | |
| # for disagg | |
| def get_contiguous_buf_infos(self): | |
| # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. | |
| kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] | |
| kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] | |
| kv_item_lens = [ | |
| self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num) | |
| ] | |
| return kv_data_ptrs, kv_data_lens, kv_item_lens | |
| def get_key_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| if self.store_dtype != self.dtype: | |
| return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.kv_buffer[layer_id - self.start_layer] | |
| def get_value_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| if self.store_dtype != self.dtype: | |
| return self.kv_buffer[layer_id - self.start_layer][ | |
| ..., : self.kv_lora_rank | |
| ].view(self.dtype) | |
| return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank] | |
| def get_kv_buffer(self, layer_id: int): | |
| return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| ): | |
| layer_id = layer.layer_id | |
| assert not (self.use_nsa and self.nsa_kv_cache_store_fp8) | |
| if cache_k.dtype != self.dtype: | |
| cache_k = cache_k.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view( | |
| self.store_dtype | |
| ) | |
| else: | |
| self.kv_buffer[layer_id - self.start_layer][loc] = cache_k | |
| def set_mla_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k_nope: torch.Tensor, | |
| cache_k_rope: torch.Tensor, | |
| ): | |
| layer_id = layer.layer_id | |
| if self.use_nsa and self.nsa_kv_cache_store_fp8: | |
| # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here | |
| # TODO no need to cat | |
| cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1) | |
| cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1) | |
| cache_k = cache_k.view(self.store_dtype) | |
| self.kv_buffer[layer_id - self.start_layer][loc] = cache_k | |
| else: | |
| if cache_k_nope.dtype != self.dtype: | |
| cache_k_nope = cache_k_nope.to(self.dtype) | |
| cache_k_rope = cache_k_rope.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| cache_k_nope = cache_k_nope.view(self.store_dtype) | |
| cache_k_rope = cache_k_rope.view(self.store_dtype) | |
| set_mla_kv_buffer_triton( | |
| self.kv_buffer[layer_id - self.start_layer], | |
| loc, | |
| cache_k_nope, | |
| cache_k_rope, | |
| ) | |
| def get_cpu_copy(self, indices): | |
| torch.cuda.synchronize() | |
| kv_cache_cpu = [] | |
| chunk_size = self.cpu_offloading_chunk_size | |
| for layer_id in range(self.layer_num): | |
| kv_cache_cpu.append([]) | |
| for i in range(0, len(indices), chunk_size): | |
| chunk_indices = indices[i : i + chunk_size] | |
| kv_cpu = self.kv_buffer[layer_id][chunk_indices].to( | |
| "cpu", non_blocking=True | |
| ) | |
| kv_cache_cpu[-1].append(kv_cpu) | |
| torch.cuda.synchronize() | |
| return kv_cache_cpu | |
| def load_cpu_copy(self, kv_cache_cpu, indices): | |
| torch.cuda.synchronize() | |
| chunk_size = self.cpu_offloading_chunk_size | |
| for layer_id in range(self.layer_num): | |
| for i in range(0, len(indices), chunk_size): | |
| chunk_indices = indices[i : i + chunk_size] | |
| kv_cpu = kv_cache_cpu[layer_id][i // chunk_size] | |
| assert kv_cpu.shape[0] == len(chunk_indices) | |
| kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True) | |
| self.kv_buffer[layer_id][chunk_indices] = kv_chunk | |
| torch.cuda.synchronize() | |
| class NSATokenToKVPool(MLATokenToKVPool): | |
| quant_block_size = 128 | |
| index_k_with_scale_buffer_dtype = torch.uint8 | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| kv_lora_rank: int, | |
| dtype: torch.dtype, | |
| qk_rope_head_dim: int, | |
| layer_num: int, | |
| device: str, | |
| index_head_dim: int, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| ): | |
| super().__init__( | |
| size, | |
| page_size, | |
| dtype, | |
| kv_lora_rank, | |
| qk_rope_head_dim, | |
| layer_num, | |
| device, | |
| enable_memory_saver, | |
| start_layer, | |
| end_layer, | |
| use_nsa=True, | |
| ) | |
| # self.index_k_dtype = torch.float8_e4m3fn | |
| # self.index_k_scale_dtype = torch.float32 | |
| self.index_head_dim = index_head_dim | |
| # num head == 1 and head dim == 128 for index_k in NSA | |
| assert index_head_dim == 128 | |
| assert self.page_size == 64 | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.custom_mem_pool | |
| else nullcontext() | |
| ): | |
| self.index_k_with_scale_buffer = [ | |
| torch.zeros( | |
| # Layout: | |
| # ref: test_attention.py :: kv_cache_cast_to_fp8 | |
| # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4) | |
| # data: for page i, | |
| # * buf[i, :page_size * head_dim] for fp8 data | |
| # * buf[i, page_size * head_dim:].view(float32) for scale | |
| ( | |
| (size + page_size + 1) // self.page_size, | |
| self.page_size | |
| * ( | |
| index_head_dim + index_head_dim // self.quant_block_size * 4 | |
| ), | |
| ), | |
| dtype=self.index_k_with_scale_buffer_dtype, | |
| device=device, | |
| ) | |
| for _ in range(layer_num) | |
| ] | |
| self._finalize_allocation_log(size) | |
| def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| return self.index_k_with_scale_buffer[layer_id - self.start_layer] | |
| def get_index_k_continuous( | |
| self, | |
| layer_id: int, | |
| seq_len: int, | |
| page_indices: torch.Tensor, | |
| ): | |
| buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] | |
| return index_buf_accessor.GetK.execute( | |
| self, buf, seq_len=seq_len, page_indices=page_indices | |
| ) | |
| def get_index_k_scale_continuous( | |
| self, | |
| layer_id: int, | |
| seq_len: int, | |
| page_indices: torch.Tensor, | |
| ): | |
| buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] | |
| return index_buf_accessor.GetS.execute( | |
| self, buf, seq_len=seq_len, page_indices=page_indices | |
| ) | |
| # TODO rename later (currently use diff name to avoid confusion) | |
| def set_index_k_and_scale_buffer( | |
| self, | |
| layer_id: int, | |
| loc: torch.Tensor, | |
| index_k: torch.Tensor, | |
| index_k_scale: torch.Tensor, | |
| ) -> None: | |
| buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] | |
| index_buf_accessor.SetKAndS.execute( | |
| pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale | |
| ) | |
| def get_state_buf_infos(self): | |
| data_ptrs = [ | |
| self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num) | |
| ] | |
| data_lens = [ | |
| self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num) | |
| ] | |
| item_lens = [ | |
| self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num) | |
| ] | |
| return data_ptrs, data_lens, item_lens | |
| def get_kv_size_bytes(self): | |
| kv_size_bytes = super().get_kv_size_bytes() | |
| for index_k_cache in self.index_k_with_scale_buffer: | |
| kv_size_bytes += get_tensor_size_bytes(index_k_cache) | |
| return kv_size_bytes | |
| class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| dtype: torch.dtype, | |
| kv_lora_rank: int, | |
| qk_rope_head_dim: int, | |
| index_head_dim: Optional[int], | |
| layer_num: int, | |
| device: str, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| ): | |
| super(MLATokenToKVPool, self).__init__( | |
| size, | |
| page_size, | |
| dtype, | |
| layer_num, | |
| device, | |
| enable_memory_saver, | |
| start_layer, | |
| end_layer, | |
| ) | |
| self.kv_lora_rank = kv_lora_rank | |
| self.qk_rope_head_dim = qk_rope_head_dim | |
| self.index_head_dim = index_head_dim | |
| self.custom_mem_pool = None | |
| with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| # The padded slot 0 is used for writing dummy outputs from padded tokens. | |
| self.k_buffer = torch.zeros( | |
| ( | |
| layer_num, | |
| self.size // self.page_size + 1, | |
| self.page_size, | |
| 1, | |
| self.kv_lora_rank, | |
| ), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| self.v_buffer = torch.zeros( | |
| ( | |
| layer_num, | |
| self.size // self.page_size + 1, | |
| self.page_size, | |
| 1, | |
| self.qk_rope_head_dim, | |
| ), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| if self.index_head_dim is not None: | |
| self.index_k_buffer = torch.zeros( | |
| ( | |
| layer_num, | |
| self.size // self.page_size + 1, | |
| self.page_size, | |
| 1, | |
| self.index_head_dim, | |
| ), | |
| dtype=self.store_dtype, | |
| device=self.device, | |
| ) | |
| self._finalize_allocation_log(size) | |
| def get_kv_size_bytes(self): | |
| assert hasattr(self, "k_buffer") | |
| assert hasattr(self, "v_buffer") | |
| kv_size_bytes = 0 | |
| for k_cache in self.k_buffer: | |
| kv_size_bytes += get_tensor_size_bytes(k_cache) | |
| for v_cache in self.v_buffer: | |
| kv_size_bytes += get_tensor_size_bytes(v_cache) | |
| if self.index_head_dim is not None: | |
| assert hasattr(self, "index_k_buffer") | |
| for index_k_cache in self.index_k_buffer: | |
| kv_size_bytes += get_tensor_size_bytes(index_k_cache) | |
| return kv_size_bytes | |
| def get_kv_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| return ( | |
| self.k_buffer[layer_id - self.start_layer], | |
| self.v_buffer[layer_id - self.start_layer], | |
| ) | |
| def get_key_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| if self.store_dtype != self.dtype: | |
| return self.k_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.k_buffer[layer_id - self.start_layer] | |
| def get_value_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| if self.store_dtype != self.dtype: | |
| return self.v_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.v_buffer[layer_id - self.start_layer] | |
| def get_index_k_buffer(self, layer_id: int): | |
| if self.layer_transfer_counter is not None: | |
| self.layer_transfer_counter.wait_until(layer_id - self.start_layer) | |
| if self.store_dtype != self.dtype: | |
| return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype) | |
| return self.index_k_buffer[layer_id - self.start_layer] | |
| # for disagg | |
| def get_contiguous_buf_infos(self): | |
| # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. | |
| kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [ | |
| self.v_buffer[i].data_ptr() for i in range(self.layer_num) | |
| ] | |
| kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [ | |
| self.v_buffer[i].nbytes for i in range(self.layer_num) | |
| ] | |
| kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [ | |
| self.v_buffer[i][0].nbytes for i in range(self.layer_num) | |
| ] | |
| if self.index_head_dim is not None: | |
| kv_data_ptrs += [ | |
| self.index_k_buffer[i].data_ptr() for i in range(self.layer_num) | |
| ] | |
| kv_data_lens += [ | |
| self.index_k_buffer[i].nbytes for i in range(self.layer_num) | |
| ] | |
| kv_item_lens += [ | |
| self.index_k_buffer[i][0].nbytes for i in range(self.layer_num) | |
| ] | |
| return kv_data_ptrs, kv_data_lens, kv_item_lens | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| ): | |
| layer_id = layer.layer_id | |
| if cache_k.dtype != self.dtype: | |
| cache_k = cache_k.to(self.dtype) | |
| cache_v = cache_v.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| cache_k = cache_k.view(self.store_dtype) | |
| cache_v = cache_v.view(self.store_dtype) | |
| if cache_v is None: | |
| cache_k, cache_v = cache_k.split( | |
| [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 | |
| ) | |
| torch_npu.npu_scatter_nd_update_( | |
| self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank), | |
| loc.view(-1, 1), | |
| cache_k.view(-1, 1, self.kv_lora_rank), | |
| ) | |
| torch_npu.npu_scatter_nd_update_( | |
| self.v_buffer[layer_id - self.start_layer].view( | |
| -1, 1, self.qk_rope_head_dim | |
| ), | |
| loc.view(-1, 1), | |
| cache_v.view(-1, 1, self.qk_rope_head_dim), | |
| ) | |
| def set_index_k_buffer( | |
| self, | |
| layer_id: int, | |
| loc: torch.Tensor, | |
| index_k: torch.Tensor, | |
| ): | |
| if index_k.dtype != self.dtype: | |
| index_k = index_k.to(self.dtype) | |
| if self.store_dtype != self.dtype: | |
| index_k = index_k.view(self.store_dtype) | |
| torch_npu.npu_scatter_nd_update_( | |
| self.index_k_buffer[layer_id - self.start_layer].view( | |
| -1, 1, self.index_head_dim | |
| ), | |
| loc.view(-1, 1), | |
| index_k.view(-1, 1, self.index_head_dim), | |
| ) | |
| class DoubleSparseTokenToKVPool(KVCache): | |
| def __init__( | |
| self, | |
| size: int, | |
| page_size: int, | |
| dtype: torch.dtype, | |
| head_num: int, | |
| head_dim: int, | |
| layer_num: int, | |
| device: str, | |
| heavy_channel_num: int, | |
| enable_memory_saver: bool, | |
| start_layer: Optional[int] = None, | |
| end_layer: Optional[int] = None, | |
| ): | |
| super().__init__( | |
| size, | |
| page_size, | |
| dtype, | |
| layer_num, | |
| device, | |
| enable_memory_saver, | |
| start_layer, | |
| end_layer, | |
| ) | |
| with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): | |
| with ( | |
| torch.cuda.use_mem_pool(self.custom_mem_pool) | |
| if self.enable_custom_mem_pool | |
| else nullcontext() | |
| ): | |
| # [size, head_num, head_dim] for each layer | |
| self.k_buffer = [ | |
| torch.zeros( | |
| (size + page_size, head_num, head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| for _ in range(layer_num) | |
| ] | |
| self.v_buffer = [ | |
| torch.zeros( | |
| (size + page_size, head_num, head_dim), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| for _ in range(layer_num) | |
| ] | |
| # [size, head_num, heavy_channel_num] for each layer | |
| self.label_buffer = [ | |
| torch.zeros( | |
| (size + 1, head_num, heavy_channel_num), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| for _ in range(layer_num) | |
| ] | |
| def get_key_buffer(self, layer_id: int): | |
| return self.k_buffer[layer_id - self.start_layer] | |
| def get_value_buffer(self, layer_id: int): | |
| return self.v_buffer[layer_id - self.start_layer] | |
| def get_label_buffer(self, layer_id: int): | |
| return self.label_buffer[layer_id - self.start_layer] | |
| def get_kv_buffer(self, layer_id: int): | |
| return ( | |
| self.k_buffer[layer_id - self.start_layer], | |
| self.v_buffer[layer_id - self.start_layer], | |
| ) | |
| def set_kv_buffer( | |
| self, | |
| layer: RadixAttention, | |
| loc: torch.Tensor, | |
| cache_k: torch.Tensor, | |
| cache_v: torch.Tensor, | |
| cache_label: torch.Tensor, | |
| ): | |
| # NOTE(Andy): ignore the dtype check | |
| layer_id = layer.layer_id | |
| self.k_buffer[layer_id - self.start_layer][loc] = cache_k | |
| self.v_buffer[layer_id - self.start_layer][loc] = cache_v | |
| self.label_buffer[layer_id - self.start_layer][loc] = cache_label | |
| def copy_all_layer_kv_cache_tiled( | |
| data_ptrs, | |
| strides, | |
| tgt_loc_ptr, | |
| src_loc_ptr, | |
| num_locs, | |
| num_locs_upper: tl.constexpr, | |
| BYTES_PER_TILE: tl.constexpr, | |
| ): | |
| """2D tiled kernel. Safe for in-place copy.""" | |
| bid = tl.program_id(0) | |
| tid = tl.program_id(1) | |
| stride = tl.load(strides + bid) | |
| base_ptr = tl.load(data_ptrs + bid) | |
| base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8)) | |
| byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE) | |
| mask_byte = byte_off < stride | |
| tl.multiple_of(byte_off, 16) | |
| loc_idx = tl.arange(0, num_locs_upper) | |
| mask_loc = loc_idx < num_locs | |
| src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0) | |
| tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0) | |
| src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :] | |
| tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :] | |
| mask = mask_loc[:, None] & mask_byte[None, :] | |
| vals = tl.load(src_ptr, mask=mask) | |
| tl.store(tgt_ptr, vals, mask=mask) | |
Xet Storage Details
- Size:
- 64.9 kB
- Xet hash:
- fb54988c9b8d94b9d261d387fb03a90fb0159645620ba30b00015951f13d5052
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.