| import abc | |
| import logging | |
| import threading | |
| from functools import wraps | |
| from typing import Optional | |
| import psutil | |
| import torch | |
| from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool | |
| from sglang.srt.utils import is_npu, is_xpu | |
| _is_npu = is_npu() | |
| _is_xpu = is_xpu() | |
| if not (_is_npu or _is_xpu): | |
| from sgl_kernel.kvcacheio import ( | |
| transfer_kv_all_layer, | |
| transfer_kv_all_layer_direct_lf_pf, | |
| transfer_kv_all_layer_lf_pf, | |
| transfer_kv_all_layer_mla, | |
| transfer_kv_all_layer_mla_lf_pf, | |
| transfer_kv_direct, | |
| transfer_kv_per_layer, | |
| transfer_kv_per_layer_direct_pf_lf, | |
| transfer_kv_per_layer_mla, | |
| transfer_kv_per_layer_mla_pf_lf, | |
| transfer_kv_per_layer_pf_lf, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def synchronized(func): | |
| def wrapper(self, *args, **kwargs): | |
| with self.lock: | |
| return func(self, *args, **kwargs) | |
| return wrapper | |
| class HostKVCache(abc.ABC): | |
| def __init__( | |
| self, | |
| device_pool: KVCache, | |
| host_to_device_ratio: float, | |
| host_size: int, | |
| page_size: int, | |
| layout: str, | |
| pin_memory: bool, | |
| device: str, | |
| ): | |
| self.device_pool = device_pool | |
| self.page_size = page_size | |
| self.layout = layout | |
| self.pin_memory = pin_memory | |
| self.device = device | |
| self.dtype = device_pool.store_dtype | |
| self.size_per_token = self.get_size_per_token() | |
| if host_size > 0: | |
| self.size = int(host_size * 1e9 // self.size_per_token) | |
| else: | |
| self.size = int(device_pool.size * host_to_device_ratio) | |
| # Align the host memory pool size to the page size | |
| self.size = self.size - (self.size % self.page_size) | |
| self.page_num = self.size // self.page_size | |
| self.start_layer = device_pool.start_layer | |
| self.end_layer = device_pool.end_layer | |
| assert ( | |
| self.size > device_pool.size | |
| ), "The host memory should be larger than the device memory with the current protocol" | |
| # Verify there is enough available host memory. | |
| host_mem = psutil.virtual_memory() | |
| requested_bytes = self.size * self.size_per_token | |
| # preserve at least 10GB for other usage | |
| ten_gb = 10 * (1024**3) | |
| available_bytes = host_mem.available - ten_gb | |
| if requested_bytes > available_bytes: | |
| raise ValueError( | |
| f"Not enough host memory available. Requesting " | |
| f"{requested_bytes / 1e9:.2f} GB but only have " | |
| f"{available_bytes / 1e9:.2f} GB free. Please reduce the " | |
| f"size of the hierarchical cache." | |
| ) | |
| else: | |
| logger.info( | |
| f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." | |
| ) | |
| self.kv_buffer = self.init_kv_buffer() | |
| # A lock for synchronized operations on memory allocation and state transitions. | |
| self.lock = threading.RLock() | |
| self.clear() | |
| def get_size_per_token(self): | |
| raise NotImplementedError() | |
| def init_kv_buffer(self): | |
| raise NotImplementedError() | |
| def load_to_device_per_layer( | |
| self, device_pool, host_indices, device_indices, layer_id, io_backend | |
| ) -> None: | |
| """ | |
| Load KV data from the host memory pool to the device memory pool for a specific layer. | |
| """ | |
| raise NotImplementedError() | |
| def backup_from_device_all_layer( | |
| self, device_pool, host_indices, device_indices, io_backend | |
| ) -> None: | |
| """ | |
| Backup KV data from the device memory pool to the host memory pool for all layers. | |
| """ | |
| raise NotImplementedError() | |
| def get_data_page(self, index, flat: bool = True) -> torch.Tensor: | |
| """ | |
| Get a flat data page from the host memory pool. | |
| """ | |
| raise NotImplementedError() | |
| def get_dummy_flat_data_page(self) -> torch.Tensor: | |
| """ | |
| Get a dummy flat data page from the host memory pool. | |
| This is used for prefetching or initializing empty pages. | |
| """ | |
| raise NotImplementedError() | |
| def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: | |
| """ | |
| Set a flat data page to the host memory pool. | |
| """ | |
| raise NotImplementedError() | |
| def clear(self): | |
| # Initialize memory states and tracking structures. | |
| self.mem_state = torch.zeros( | |
| (self.size,), dtype=torch.uint8, device=self.device | |
| ) | |
| self.free_slots = torch.arange(self.size, dtype=torch.int64) | |
| def available_size(self): | |
| return len(self.free_slots) | |
| def alloc(self, need_size: int) -> Optional[torch.Tensor]: | |
| assert ( | |
| need_size % self.page_size == 0 | |
| ), "The requested size should be a multiple of the page size." | |
| if need_size > self.available_size(): | |
| return None | |
| select_index = self.free_slots[:need_size] | |
| self.free_slots = self.free_slots[need_size:] | |
| return select_index | |
| def free(self, indices: torch.Tensor) -> int: | |
| self.free_slots = torch.cat([self.free_slots, indices]) | |
| return len(indices) | |
| class MHATokenToKVPoolHost(HostKVCache): | |
| device_pool: MHATokenToKVPool | |
| def __init__( | |
| self, | |
| device_pool: MHATokenToKVPool, | |
| host_to_device_ratio: float, | |
| host_size: int, | |
| page_size: int, | |
| layout: str, | |
| pin_memory: bool = True, | |
| device: str = "cpu", | |
| ): | |
| super().__init__( | |
| device_pool, | |
| host_to_device_ratio, | |
| host_size, | |
| page_size, | |
| layout, | |
| pin_memory, | |
| device, | |
| ) | |
| self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)] | |
| self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)] | |
| self.k_data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.k_data_refs], | |
| dtype=torch.uint64, | |
| device=self.device_pool.device, | |
| ) | |
| self.v_data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.v_data_refs], | |
| dtype=torch.uint64, | |
| device=self.device_pool.device, | |
| ) | |
| def get_size_per_token(self): | |
| self.head_num = self.device_pool.head_num | |
| self.head_dim = self.device_pool.head_dim | |
| self.layer_num = self.device_pool.layer_num | |
| return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 | |
| def get_ksize_per_token(self): | |
| return self.get_size_per_token() // 2 | |
| def init_kv_buffer(self): | |
| if self.layout == "layer_first": | |
| dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) | |
| elif self.layout == "page_first": | |
| dims = (2, self.size, self.layer_num, self.head_num, self.head_dim) | |
| elif self.layout == "page_first_direct": | |
| dims = ( | |
| 2, | |
| self.page_num, | |
| self.layer_num, | |
| self.page_size, | |
| self.head_num, | |
| self.head_dim, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize | |
| self.layout_dim = self.token_stride_size * self.layer_num | |
| return torch.empty( | |
| dims, | |
| dtype=self.dtype, | |
| device=self.device, | |
| pin_memory=self.pin_memory, | |
| ) | |
| def k_buffer(self): | |
| return self.kv_buffer[0] | |
| def v_buffer(self): | |
| return self.kv_buffer[1] | |
| def load_to_device_per_layer( | |
| self, | |
| device_pool, | |
| host_indices, | |
| device_indices, | |
| layer_id, | |
| io_backend, | |
| ): | |
| if io_backend == "kernel": | |
| if self.layout == "layer_first": | |
| transfer_kv_per_layer( | |
| src_k=self.k_buffer[layer_id], | |
| dst_k=device_pool.k_buffer[layer_id], | |
| src_v=self.v_buffer[layer_id], | |
| dst_v=device_pool.v_buffer[layer_id], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| item_size=self.token_stride_size, | |
| ) | |
| elif self.layout == "page_first": | |
| transfer_kv_per_layer_pf_lf( | |
| src_k=self.k_buffer, | |
| dst_k=device_pool.k_buffer[layer_id], | |
| src_v=self.v_buffer, | |
| dst_v=device_pool.v_buffer[layer_id], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| layer_id=layer_id, | |
| item_size=self.token_stride_size, | |
| src_layout_dim=self.layout_dim, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| elif io_backend == "direct": | |
| if self.layout == "layer_first": | |
| transfer_kv_direct( | |
| src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]], | |
| dst_layers=[ | |
| device_pool.k_buffer[layer_id], | |
| device_pool.v_buffer[layer_id], | |
| ], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| page_size=self.page_size, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| transfer_kv_per_layer_direct_pf_lf( | |
| src_ptrs=[self.k_buffer, self.v_buffer], | |
| dst_ptrs=[ | |
| device_pool.k_buffer[layer_id], | |
| device_pool.v_buffer[layer_id], | |
| ], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| layer_id=layer_id, | |
| page_size=self.page_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| else: | |
| raise ValueError(f"Unsupported IO backend: {io_backend}") | |
| def backup_from_device_all_layer( | |
| self, device_pool, host_indices, device_indices, io_backend | |
| ): | |
| if io_backend == "kernel": | |
| if self.layout == "layer_first": | |
| transfer_kv_all_layer( | |
| src_k_layers=device_pool.k_data_ptrs, | |
| dst_k_layers=self.k_data_ptrs, | |
| src_v_layers=device_pool.v_data_ptrs, | |
| dst_v_layers=self.v_data_ptrs, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| item_size=self.token_stride_size, | |
| num_layers=self.layer_num, | |
| ) | |
| elif self.layout == "page_first": | |
| transfer_kv_all_layer_lf_pf( | |
| src_k_layers=device_pool.k_data_ptrs, | |
| dst_k=self.k_buffer, | |
| src_v_layers=device_pool.v_data_ptrs, | |
| dst_v=self.v_buffer, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| item_size=self.token_stride_size, | |
| dst_layout_dim=self.layout_dim, | |
| num_layers=self.layer_num, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| elif io_backend == "direct": | |
| if self.layout == "layer_first": | |
| transfer_kv_direct( | |
| src_layers=device_pool.k_buffer + device_pool.v_buffer, | |
| dst_layers=self.k_data_refs + self.v_data_refs, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| page_size=self.page_size, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| transfer_kv_all_layer_direct_lf_pf( | |
| src_ptrs=device_pool.k_buffer + device_pool.v_buffer, | |
| dst_ptrs=[self.k_buffer, self.v_buffer], | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| page_size=self.page_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| else: | |
| raise ValueError(f"Unsupported IO backend: {io_backend}") | |
| def get_data_page(self, index, flat: bool = True) -> torch.Tensor: | |
| if self.layout == "layer_first": | |
| data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :] | |
| elif self.layout == "page_first": | |
| data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :] | |
| elif self.layout == "page_first_direct": | |
| real_index = index // self.page_size | |
| data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| if flat: | |
| data_page = data_page.flatten() | |
| return data_page | |
| def get_dummy_flat_data_page(self) -> torch.Tensor: | |
| return torch.zeros( | |
| (2, self.layer_num, self.page_size, self.head_num, self.head_dim), | |
| dtype=self.dtype, | |
| device=self.device, | |
| pin_memory=self.pin_memory, | |
| ).flatten() | |
| def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: | |
| if self.layout == "layer_first": | |
| self.kv_buffer[:, :, index : index + self.page_size, :, :] = ( | |
| data_page.reshape( | |
| 2, | |
| self.layer_num, | |
| self.page_size, | |
| self.head_num, | |
| self.head_dim, | |
| ) | |
| ) | |
| elif self.layout == "page_first": | |
| self.kv_buffer[:, index : index + self.page_size, :, :, :] = ( | |
| data_page.reshape( | |
| 2, self.page_size, self.layer_num, self.head_num, self.head_dim | |
| ) | |
| ) | |
| elif self.layout == "page_first_direct": | |
| real_index = index // self.page_size | |
| self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = ( | |
| data_page.reshape( | |
| 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| def get_page_buffer_meta(self, indices): | |
| """ " | |
| meta data for zero copy | |
| """ | |
| assert len(indices) % self.page_size == 0 | |
| ptr_list = [] | |
| kv_buffer_data_ptr = self.kv_buffer.data_ptr() | |
| indices = indices.tolist() | |
| v_offset = ( | |
| self.layer_num | |
| * self.size | |
| * self.head_num | |
| * self.head_dim | |
| * self.dtype.itemsize | |
| ) | |
| if self.layout == "layer_first": | |
| for index in range(0, len(indices), self.page_size): | |
| for layer_id in range(self.layer_num): | |
| k_ptr = ( | |
| kv_buffer_data_ptr | |
| + indices[index] | |
| * self.head_num | |
| * self.head_dim | |
| * self.dtype.itemsize | |
| + layer_id | |
| * self.size | |
| * self.head_num | |
| * self.head_dim | |
| * self.dtype.itemsize | |
| ) | |
| v_ptr = k_ptr + v_offset | |
| ptr_list.append(k_ptr) | |
| ptr_list.append(v_ptr) | |
| element_size = ( | |
| self.dtype.itemsize * self.page_size * self.head_num * self.head_dim | |
| ) | |
| element_size_list = [element_size] * len(ptr_list) | |
| elif self.layout in ["page_first", "page_first_direct"]: | |
| for index in range(0, len(indices), self.page_size): | |
| k_ptr = ( | |
| kv_buffer_data_ptr | |
| + indices[index] | |
| * self.layer_num | |
| * self.head_num | |
| * self.head_dim | |
| * self.dtype.itemsize | |
| ) | |
| v_ptr = k_ptr + v_offset | |
| ptr_list.append(k_ptr) | |
| ptr_list.append(v_ptr) | |
| element_size = ( | |
| self.layer_num | |
| * self.dtype.itemsize | |
| * self.page_size | |
| * self.head_num | |
| * self.head_dim | |
| ) | |
| element_size_list = [element_size] * len(ptr_list) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| return ptr_list, element_size_list | |
| class MLATokenToKVPoolHost(HostKVCache): | |
| device_pool: MLATokenToKVPool | |
| def __init__( | |
| self, | |
| device_pool: MLATokenToKVPool, | |
| host_to_device_ratio: float, | |
| host_size: int, | |
| page_size: int, | |
| layout: str, | |
| pin_memory: bool = True, | |
| device: str = "cpu", | |
| ): | |
| super().__init__( | |
| device_pool, | |
| host_to_device_ratio, | |
| host_size, | |
| page_size, | |
| layout, | |
| pin_memory, | |
| device, | |
| ) | |
| self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] | |
| self.data_ptrs = torch.tensor( | |
| [x.data_ptr() for x in self.data_refs], | |
| dtype=torch.uint64, | |
| device=self.device_pool.device, | |
| ) | |
| def get_size_per_token(self): | |
| self.kv_lora_rank = self.device_pool.kv_lora_rank | |
| self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim | |
| self.layer_num = self.device_pool.layer_num | |
| return ( | |
| (self.kv_lora_rank + self.qk_rope_head_dim) | |
| * 1 | |
| * self.dtype.itemsize | |
| * self.layer_num | |
| ) | |
| def get_ksize_per_token(self): | |
| return self.get_size_per_token() | |
| def init_kv_buffer(self): | |
| if self.layout == "layer_first": | |
| dims = ( | |
| self.layer_num, | |
| self.size, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| elif self.layout == "page_first": | |
| dims = ( | |
| self.size, | |
| self.layer_num, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| dims = ( | |
| self.page_num, | |
| self.layer_num, | |
| self.page_size, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| self.token_stride_size = ( | |
| self.kv_lora_rank + self.qk_rope_head_dim | |
| ) * self.dtype.itemsize | |
| self.layout_dim = self.token_stride_size * self.layer_num | |
| return torch.empty( | |
| dims, | |
| dtype=self.dtype, | |
| device=self.device, | |
| pin_memory=self.pin_memory, | |
| ) | |
| def load_to_device_per_layer( | |
| self, device_pool, host_indices, device_indices, layer_id, io_backend | |
| ): | |
| if io_backend == "kernel": | |
| if self.layout == "layer_first": | |
| transfer_kv_per_layer_mla( | |
| src=self.kv_buffer[layer_id], | |
| dst=device_pool.kv_buffer[layer_id], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| item_size=self.token_stride_size, | |
| ) | |
| elif self.layout == "page_first": | |
| transfer_kv_per_layer_mla_pf_lf( | |
| src=self.kv_buffer, | |
| dst=device_pool.kv_buffer[layer_id], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| layer_id=layer_id, | |
| item_size=self.token_stride_size, | |
| src_layout_dim=self.layout_dim, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| elif io_backend == "direct": | |
| if self.layout == "layer_first": | |
| transfer_kv_direct( | |
| src_layers=[self.kv_buffer[layer_id]], | |
| dst_layers=[device_pool.kv_buffer[layer_id]], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| page_size=self.page_size, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| transfer_kv_per_layer_direct_pf_lf( | |
| src_ptrs=[self.kv_buffer], | |
| dst_ptrs=[device_pool.kv_buffer[layer_id]], | |
| src_indices=host_indices, | |
| dst_indices=device_indices, | |
| layer_id=layer_id, | |
| page_size=self.page_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| def backup_from_device_all_layer( | |
| self, device_pool, host_indices, device_indices, io_backend | |
| ): | |
| if io_backend == "kernel": | |
| if self.layout == "layer_first": | |
| transfer_kv_all_layer_mla( | |
| src_layers=device_pool.data_ptrs, | |
| dst_layers=self.data_ptrs, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| item_size=self.token_stride_size, | |
| num_layers=self.layer_num, | |
| ) | |
| elif self.layout == "page_first": | |
| transfer_kv_all_layer_mla_lf_pf( | |
| src_layers=device_pool.data_ptrs, | |
| dst=self.kv_buffer, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| item_size=self.token_stride_size, | |
| dst_layout_dim=self.layout_dim, | |
| num_layers=self.layer_num, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| elif io_backend == "direct": | |
| if self.layout == "layer_first": | |
| transfer_kv_direct( | |
| src_layers=device_pool.kv_buffer, | |
| dst_layers=self.data_refs, | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| page_size=self.page_size, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| transfer_kv_all_layer_direct_lf_pf( | |
| src_ptrs=device_pool.kv_buffer, | |
| dst_ptrs=[self.kv_buffer], | |
| src_indices=device_indices, | |
| dst_indices=host_indices, | |
| page_size=self.page_size, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| else: | |
| raise ValueError(f"Unsupported IO backend: {io_backend}") | |
| def get_data_page(self, index, flat: bool = True) -> torch.Tensor: | |
| if self.layout == "layer_first": | |
| data_page = self.kv_buffer[:, index : index + self.page_size, :, :] | |
| elif self.layout == "page_first": | |
| data_page = self.kv_buffer[index : index + self.page_size, :, :, :] | |
| elif self.layout == "page_first_direct": | |
| real_index = index // self.page_size | |
| data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :] | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| if flat: | |
| data_page = data_page.flatten() | |
| return data_page | |
| def get_dummy_flat_data_page(self) -> torch.Tensor: | |
| return torch.zeros( | |
| ( | |
| self.layer_num, | |
| self.page_size, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ), | |
| dtype=self.dtype, | |
| device=self.device, | |
| pin_memory=self.pin_memory, | |
| ).flatten() | |
| def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: | |
| if self.layout == "layer_first": | |
| self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( | |
| self.layer_num, | |
| self.page_size, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| elif self.layout == "page_first": | |
| self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape( | |
| self.page_size, | |
| self.layer_num, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| elif self.layout == "page_first_direct": | |
| real_index = index // self.page_size | |
| self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape( | |
| 1, | |
| self.layer_num, | |
| self.page_size, | |
| 1, | |
| self.kv_lora_rank + self.qk_rope_head_dim, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| def get_page_buffer_meta(self, indices): | |
| """ " | |
| meta data for zero copy | |
| """ | |
| assert len(indices) % self.page_size == 0 | |
| ptr_list = [] | |
| kv_buffer_data_ptr = self.kv_buffer.data_ptr() | |
| indices = indices.tolist() | |
| if self.layout == "layer_first": | |
| for index in range(0, len(indices), self.page_size): | |
| for layer_id in range(self.layer_num): | |
| k_ptr = ( | |
| kv_buffer_data_ptr | |
| + indices[index] | |
| * (self.kv_lora_rank + self.qk_rope_head_dim) | |
| * self.dtype.itemsize | |
| + layer_id | |
| * self.size | |
| * (self.kv_lora_rank + self.qk_rope_head_dim) | |
| * self.dtype.itemsize | |
| ) | |
| ptr_list.append(k_ptr) | |
| element_size = ( | |
| self.dtype.itemsize | |
| * self.page_size | |
| * (self.kv_lora_rank + self.qk_rope_head_dim) | |
| ) | |
| element_size_list = [element_size] * len(ptr_list) | |
| elif self.layout in ["page_first", "page_first_direct"]: | |
| for index in range(0, len(indices), self.page_size): | |
| k_ptr = ( | |
| kv_buffer_data_ptr | |
| + indices[index] | |
| * self.layer_num | |
| * (self.kv_lora_rank + self.qk_rope_head_dim) | |
| * self.dtype.itemsize | |
| ) | |
| ptr_list.append(k_ptr) | |
| element_size = ( | |
| self.layer_num | |
| * self.dtype.itemsize | |
| * self.page_size | |
| * (self.kv_lora_rank + self.qk_rope_head_dim) | |
| ) | |
| element_size_list = [element_size] * len(ptr_list) | |
| else: | |
| raise ValueError(f"Unsupported layout: {self.layout}") | |
| return ptr_list, element_size_list | |
Xet Storage Details
- Size:
- 27.6 kB
- Xet hash:
- 48a4e80f9a30db0ce800983b560d6653af71bb44ecc76009bee68d23ad475f6e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.