| from __future__ import annotations | |
| """ | |
| Copyright 2023-2025 SGLang Team | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import logging | |
| import threading | |
| import time | |
| from queue import Empty, Full, Queue | |
| from typing import TYPE_CHECKING, List, NamedTuple, Optional | |
| import torch | |
| from sglang.srt.mem_cache.hicache_storage import ( | |
| HiCacheStorageConfig, | |
| HiCacheStorageExtraInfo, | |
| ) | |
| if TYPE_CHECKING: | |
| from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.memory_pool_host import HostKVCache | |
| from sglang.srt.distributed import ( | |
| get_tensor_model_parallel_rank, | |
| get_tensor_model_parallel_world_size, | |
| ) | |
| from sglang.srt.layers.dp_attention import ( | |
| get_attention_dp_rank, | |
| get_attention_tp_rank, | |
| get_attention_tp_size, | |
| is_dp_attention_enabled, | |
| ) | |
| from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool | |
| logger = logging.getLogger(__name__) | |
| class LayerLoadingEvent: | |
| def __init__(self, num_layers: int): | |
| self._num_layers = num_layers | |
| self.load_events = [torch.cuda.Event() for _ in range(num_layers)] | |
| self.start_event = torch.cuda.Event() # start event on controller stream | |
| def complete(self, layer_index: int): | |
| assert 0 <= layer_index < self._num_layers | |
| self.load_events[layer_index].record() | |
| def wait(self, layer_index: int): | |
| torch.cuda.current_stream().wait_event(self.load_events[layer_index]) | |
| def finish_event(self): | |
| return self.load_events[-1] | |
| class LayerDoneCounter: | |
| def __init__(self, num_layers: int): | |
| self.num_layers = num_layers | |
| # extra producer and consumer counters for overlap mode | |
| self.num_counters = 3 | |
| self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)] | |
| self.producer_index = -1 | |
| self.consumer_index = -1 | |
| def update_producer(self): | |
| self.producer_index = (self.producer_index + 1) % self.num_counters | |
| assert self.events[ | |
| self.producer_index | |
| ].finish_event.query(), ( | |
| "Producer finish event should be ready before being reused." | |
| ) | |
| return self.producer_index | |
| def set_consumer(self, index: int): | |
| self.consumer_index = index | |
| def wait_until(self, threshold: int): | |
| if self.consumer_index < 0: | |
| return | |
| self.events[self.consumer_index].wait(threshold) | |
| def reset(self): | |
| self.producer_index = -1 | |
| self.consumer_index = -1 | |
| class CacheOperation: | |
| counter = 0 | |
| def __init__( | |
| self, | |
| host_indices: torch.Tensor, | |
| device_indices: torch.Tensor, | |
| node_id: int, | |
| priority: Optional[int] = None, | |
| ): | |
| self.host_indices = host_indices | |
| self.device_indices = device_indices | |
| self.node_ids = [node_id] | |
| self.data = None | |
| self.id = CacheOperation.counter | |
| CacheOperation.counter += 1 | |
| # default priority is the order of creation | |
| self.priority = priority if priority is not None else self.id | |
| def merge_ops(ops: List[CacheOperation]) -> CacheOperation: | |
| assert len(ops) > 0 | |
| if len(ops) == 1: | |
| return ops[0] | |
| host_indices = torch.cat([op.host_indices for op in ops]) | |
| device_indices = torch.cat([op.device_indices for op in ops]) | |
| node_ids = [] | |
| priority = min(op.priority for op in ops) | |
| for op in ops: | |
| node_ids.extend(op.node_ids) | |
| merged_op = CacheOperation(host_indices, device_indices, -1, priority) | |
| merged_op.node_ids = node_ids | |
| return merged_op | |
| def __lt__(self, other: CacheOperation): | |
| return self.priority < other.priority | |
| class HiCacheAck(NamedTuple): | |
| start_event: torch.cuda.Event | |
| finish_event: torch.cuda.Event | |
| node_ids: List[int] | |
| class TransferBuffer: | |
| """ | |
| Overlapping buffer preparation and transfer operations to improve throughput. | |
| """ | |
| def __init__( | |
| self, stop_event, buffer_count: int = 3, max_buffer_size: int = 1024 | |
| ) -> None: | |
| self.stop_event = stop_event | |
| self.buffers = Queue(maxsize=buffer_count) | |
| # todo: adjust the buffer size based on throughput profile of the system | |
| self.max_buffer_size = max_buffer_size | |
| def full(self) -> bool: | |
| return self.buffers.full() | |
| def empty(self) -> bool: | |
| return self.buffers.empty() | |
| def put(self, item, block=True, timeout=1) -> None: | |
| while not self.stop_event.is_set(): | |
| try: | |
| self.buffers.put(item, block=block, timeout=timeout) | |
| break | |
| except Full: | |
| if not block: | |
| break | |
| continue | |
| except Exception as e: | |
| logger.error(e) | |
| def get(self, block=True, timeout=1) -> Optional[CacheOperation]: | |
| try: | |
| return self.buffers.get(block=block, timeout=timeout) | |
| except Empty: | |
| return None | |
| except Exception as e: | |
| logger.error(e) | |
| def clear(self): | |
| self.buffers.queue.clear() | |
| class StorageOperation: | |
| counter = 0 | |
| def __init__( | |
| self, | |
| host_indices: torch.Tensor, | |
| token_ids: List[int], | |
| last_hash: Optional[str] = None, | |
| hash_value: Optional[List[str]] = None, | |
| prefix_keys: Optional[List[str]] = None, | |
| ): | |
| self.host_indices = host_indices | |
| self.token_ids = token_ids | |
| self.last_hash = last_hash | |
| self.completed_tokens = 0 | |
| self.hash_value = hash_value if hash_value is not None else [] | |
| self.prefix_keys = prefix_keys | |
| self.id = StorageOperation.counter | |
| StorageOperation.counter += 1 | |
| def __lt__(self, other: "StorageOperation"): | |
| return self.id < other.id | |
| class PrefetchOperation(StorageOperation): | |
| def __init__( | |
| self, | |
| request_id: str, | |
| host_indices: torch.Tensor, | |
| token_ids: List[int], | |
| last_hash: Optional[str] = None, | |
| prefix_keys: Optional[List[str]] = None, | |
| ): | |
| self.request_id = request_id | |
| self._lock = threading.Lock() | |
| self._terminated_flag = False | |
| self.start_time = time.monotonic() | |
| super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys) | |
| def increment(self, num_tokens: int): | |
| with self._lock: | |
| if self._terminated_flag: | |
| return False | |
| self.completed_tokens += num_tokens | |
| return True | |
| def mark_terminate(self): | |
| with self._lock: | |
| self._terminated_flag = True | |
| def is_terminated(self) -> bool: | |
| return self._terminated_flag | |
| class HiCacheController: | |
| def __init__( | |
| self, | |
| token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, | |
| mem_pool_host: HostKVCache, | |
| page_size: int, | |
| tp_group: torch.distributed.ProcessGroup, | |
| load_cache_event: threading.Event, | |
| write_policy: str = "write_through_selective", | |
| io_backend: str = "", | |
| storage_backend: Optional[str] = None, | |
| prefetch_threshold: int = 256, | |
| model_name: Optional[str] = None, | |
| storage_backend_extra_config: Optional[dict] = None, | |
| ): | |
| self.mem_pool_device_allocator = token_to_kv_pool_allocator | |
| self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() | |
| self.mem_pool_host = mem_pool_host | |
| self.write_policy = write_policy | |
| self.page_size = page_size | |
| self.io_backend = io_backend | |
| self.enable_storage = False | |
| if storage_backend is not None: | |
| self.storage_backend_type = storage_backend | |
| from sglang.srt.mem_cache.hicache_storage import get_hash_str | |
| self.get_hash_str = get_hash_str | |
| self.storage_config = self._generate_storage_config( | |
| model_name, storage_backend_extra_config | |
| ) | |
| # for MLA models, only one rank needs to backup the KV cache | |
| self.backup_skip = ( | |
| self.storage_config.is_mla_model | |
| # todo: load balancing | |
| and self.storage_config.tp_rank != 0 | |
| ) | |
| # Use storage backend factory for dynamic backend creation | |
| from sglang.srt.mem_cache.storage import StorageBackendFactory | |
| try: | |
| self.storage_backend = StorageBackendFactory.create_backend( | |
| storage_backend, self.storage_config, self.mem_pool_host | |
| ) | |
| except ValueError as e: | |
| raise ValueError(f"Failed to create storage backend: {e}") from e | |
| self.storage_backend.register_mem_pool_host(self.mem_pool_host) | |
| self.enable_storage = True | |
| # todo: threshold policy for prefetching | |
| self.prefetch_threshold = max(prefetch_threshold, self.page_size) | |
| self.prefetch_capacity_limit = int( | |
| 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size) | |
| ) | |
| # granularity of batch storage IO operations, in number of pages | |
| self.storage_batch_size = 128 | |
| # tracking the number of tokens locked in prefetching, updated by the main scheduler thread | |
| self.prefetch_tokens_occupied = 0 | |
| # create a new communication group for synchronizing storage operations across TP workers | |
| self.tp_world_size = torch.distributed.get_world_size(group=tp_group) | |
| if self.tp_world_size > 1: | |
| group_ranks = torch.distributed.get_process_group_ranks(tp_group) | |
| self.prefetch_tp_group = torch.distributed.new_group( | |
| group_ranks, backend="gloo" | |
| ) | |
| # Select the get and set functions | |
| self.page_get_func = self._generic_page_get | |
| self.page_set_func = self._generic_page_set | |
| if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]: | |
| self.page_get_func = self._page_get_zero_copy | |
| self.page_set_func = self._page_set_zero_copy | |
| self.device = self.mem_pool_device.device | |
| self.layer_num = self.mem_pool_device.layer_num | |
| self.layer_done_counter = LayerDoneCounter(self.layer_num) | |
| self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) | |
| if write_policy not in [ | |
| "write_through", | |
| "write_through_selective", | |
| "write_back", | |
| ]: | |
| raise ValueError(f"Invalid write policy: {write_policy}") | |
| # self.write_queue = PriorityQueue[CacheOperation]() | |
| self.load_queue: List[CacheOperation] = [] | |
| self.write_queue: List[CacheOperation] = [] | |
| self.ack_load_queue: List[HiCacheAck] = [] | |
| self.ack_write_queue: List[HiCacheAck] = [] | |
| self.stop_event = threading.Event() | |
| self.write_buffer = TransferBuffer(self.stop_event) | |
| self.load_buffer = TransferBuffer( | |
| self.stop_event, buffer_count=10, max_buffer_size=100 | |
| ) | |
| self.write_stream = torch.cuda.Stream() | |
| self.load_stream = torch.cuda.Stream() | |
| if self.enable_storage: | |
| self.prefetch_thread = threading.Thread( | |
| target=self.prefetch_thread_func, daemon=True | |
| ) | |
| self.backup_thread = threading.Thread( | |
| target=self.backup_thread_func, daemon=True | |
| ) | |
| self.prefetch_queue = Queue() | |
| self.backup_queue = Queue() | |
| self.prefetch_revoke_queue = Queue() | |
| self.ack_backup_queue = Queue() | |
| self.host_mem_release_queue = Queue() | |
| self.prefetch_thread.start() | |
| self.backup_thread.start() | |
| def _generate_storage_config( | |
| self, | |
| model_name: Optional[str] = None, | |
| storage_backend_extra_config: Optional[dict] = None, | |
| ): | |
| if is_dp_attention_enabled(): | |
| self.tp_rank = get_attention_tp_rank() | |
| self.tp_size = get_attention_tp_size() | |
| self.dp_rank = get_attention_dp_rank() | |
| else: | |
| self.tp_rank = get_tensor_model_parallel_rank() | |
| self.tp_size = get_tensor_model_parallel_world_size() | |
| self.dp_rank = 0 | |
| # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. | |
| is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) | |
| return HiCacheStorageConfig( | |
| tp_rank=self.tp_rank, | |
| tp_size=self.tp_size, | |
| is_mla_model=is_mla_backend, | |
| is_page_first_layout=self.mem_pool_host.layout == "page_first", | |
| model_name=model_name, | |
| extra_config=storage_backend_extra_config, | |
| ) | |
| def reset(self): | |
| self.stop_event.set() | |
| self.write_queue.clear() | |
| self.load_queue.clear() | |
| self.write_buffer.clear() | |
| self.load_buffer.clear() | |
| self.ack_write_queue.clear() | |
| self.ack_load_queue.clear() | |
| if self.enable_storage: | |
| self.prefetch_thread.join() | |
| self.backup_thread.join() | |
| self.prefetch_queue.queue.clear() | |
| self.backup_queue.queue.clear() | |
| self.prefetch_revoke_queue.queue.clear() | |
| self.ack_backup_queue.queue.clear() | |
| self.stop_event.clear() | |
| if self.enable_storage: | |
| self.prefetch_thread = threading.Thread( | |
| target=self.prefetch_thread_func, daemon=True | |
| ) | |
| self.backup_thread = threading.Thread( | |
| target=self.backup_thread_func, daemon=True | |
| ) | |
| self.prefetch_thread.start() | |
| self.backup_thread.start() | |
| def write( | |
| self, | |
| device_indices: torch.Tensor, | |
| priority: Optional[int] = None, | |
| node_id: int = -1, | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| Back up KV caches from device memory to host memory. | |
| """ | |
| host_indices = self.mem_pool_host.alloc(len(device_indices)) | |
| if host_indices is None: | |
| return None | |
| self.write_queue.append( | |
| CacheOperation(host_indices, device_indices, node_id, priority) | |
| ) | |
| self.start_writing() | |
| return host_indices | |
| def start_writing(self) -> None: | |
| if len(self.write_queue) == 0: | |
| return | |
| op = CacheOperation.merge_ops(self.write_queue) | |
| host_indices, device_indices = self.move_indices(op) | |
| self.write_queue.clear() | |
| start_event = torch.cuda.Event() | |
| finish_event = torch.cuda.Event() | |
| start_event.record() | |
| with torch.cuda.stream(self.write_stream): | |
| start_event.wait(self.write_stream) | |
| self.mem_pool_host.backup_from_device_all_layer( | |
| self.mem_pool_device, host_indices, device_indices, self.io_backend | |
| ) | |
| finish_event.record() | |
| # NOTE: We must save the host indices and device indices here, | |
| # this is because we need to guarantee that these tensors are | |
| # still alive when the write stream is executing. | |
| if host_indices.is_cuda: | |
| host_indices.record_stream(self.write_stream) | |
| if device_indices.is_cuda: | |
| device_indices.record_stream(self.write_stream) | |
| self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids)) | |
| def load( | |
| self, | |
| host_indices: torch.Tensor, | |
| priority: Optional[int] = None, | |
| node_id: int = -1, | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| Load KV caches from host memory to device memory. | |
| """ | |
| device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) | |
| if device_indices is None: | |
| return None | |
| self.load_queue.append( | |
| CacheOperation(host_indices, device_indices, node_id, priority) | |
| ) | |
| return device_indices | |
| def move_indices(self, op: CacheOperation): | |
| host_indices, device_indices = op.host_indices, op.device_indices | |
| # move indices to GPU if using kernels, to host if using direct indexing | |
| if self.io_backend == "kernel": | |
| if not host_indices.is_cuda: | |
| host_indices = host_indices.to(self.device, non_blocking=True) | |
| return host_indices, device_indices | |
| elif self.io_backend == "direct": | |
| if self.mem_pool_host.layout == "layer_first": | |
| device_indices = device_indices.cpu() | |
| host_indices, idx = host_indices.sort() | |
| return host_indices, device_indices.index_select(0, idx) | |
| elif self.mem_pool_host.layout == "page_first_direct": | |
| return host_indices, device_indices.cpu() | |
| else: | |
| raise ValueError(f"Unsupported io backend") | |
| def start_loading(self) -> int: | |
| if len(self.load_queue) == 0: | |
| return -1 | |
| producer_id = self.layer_done_counter.update_producer() | |
| op = CacheOperation.merge_ops(self.load_queue) | |
| host_indices, device_indices = self.move_indices(op) | |
| self.load_queue.clear() | |
| producer_event = self.layer_done_counter.events[producer_id] | |
| producer_event.start_event.record() | |
| with torch.cuda.stream(self.load_stream): | |
| producer_event.start_event.wait(self.load_stream) | |
| for i in range(self.layer_num): | |
| self.mem_pool_host.load_to_device_per_layer( | |
| self.mem_pool_device, | |
| host_indices, | |
| device_indices, | |
| i, | |
| self.io_backend, | |
| ) | |
| producer_event.complete(i) | |
| # NOTE: We must save the host indices and device indices here, | |
| # this is because we need to guarantee that these tensors are | |
| # still alive when the load stream is executing. | |
| if host_indices.is_cuda: | |
| host_indices.record_stream(self.load_stream) | |
| if device_indices.is_cuda: | |
| device_indices.record_stream(self.load_stream) | |
| self.ack_load_queue.append( | |
| HiCacheAck( | |
| start_event=producer_event.start_event, | |
| finish_event=producer_event.finish_event, | |
| node_ids=op.node_ids, | |
| ) | |
| ) | |
| return producer_id | |
| def evict_device(self, device_indices: torch.Tensor) -> int: | |
| self.mem_pool_device_allocator.free(device_indices) | |
| return len(device_indices) | |
| def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: | |
| if not backup_only: | |
| raise ValueError("Other eviction policies are not supported yet.") | |
| self.mem_pool_host.free(host_indices) | |
| return len(host_indices) | |
| def prefetch( | |
| self, | |
| request_id: str, | |
| host_indices: torch.Tensor, | |
| new_input_tokens: List[int], | |
| last_hash: Optional[str] = None, | |
| prefix_keys: Optional[List[str]] = None, | |
| ) -> PrefetchOperation: | |
| """ | |
| Prefetch KV caches from storage backend to host memory. | |
| """ | |
| operation = PrefetchOperation( | |
| request_id, host_indices, new_input_tokens, last_hash, prefix_keys | |
| ) | |
| self.prefetch_queue.put(operation) | |
| return operation | |
| def terminate_prefetch(self, operation): | |
| operation.mark_terminate() | |
| return operation.completed_tokens, operation.hash_value | |
| def append_host_mem_release(self, host_indices: torch.Tensor): | |
| if host_indices.numel() == 0: | |
| return | |
| pages = host_indices.split(self.mem_pool_host.page_size) | |
| for page in pages: | |
| self.host_mem_release_queue.put(page) | |
| def _page_get_zero_copy( | |
| self, operation, hash_values, host_indices, extra_info=None | |
| ): | |
| results = self.storage_backend.batch_get_v1( | |
| hash_values, host_indices, extra_info | |
| ) | |
| inc = 0 | |
| for i in range(len(hash_values)): | |
| if not results[i]: | |
| logger.warning( | |
| f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}." | |
| ) | |
| break | |
| inc += self.page_size | |
| operation.increment(inc) | |
| # todo: deprecate | |
| def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None): | |
| dummy_page_dst = [ | |
| self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values | |
| ] | |
| page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst) | |
| if page_data is None: | |
| return | |
| for i in range(len(hash_values)): | |
| if page_data[i] is None: | |
| logger.warning( | |
| f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}." | |
| ) | |
| break | |
| # Must set the data before increasing the completed tokens. | |
| # Otherwise this page may be read before being set. | |
| self.mem_pool_host.set_from_flat_data_page( | |
| host_indices[i * self.page_size], | |
| page_data[i], | |
| ) | |
| if not operation.increment(self.page_size): | |
| break # Operation terminated by controller | |
| def _page_transfer(self, operation): | |
| # Transfer batch by batch | |
| prefix_keys = operation.prefix_keys | |
| for i in range(0, len(operation.hash_value), self.storage_batch_size): | |
| batch_hashes = operation.hash_value[i : i + self.storage_batch_size] | |
| batch_host_indices = operation.host_indices[ | |
| i * self.page_size : (i + len(batch_hashes)) * self.page_size | |
| ] | |
| prev_completed_tokens = operation.completed_tokens | |
| # Get one batch token, and update the completed_tokens if succeed | |
| extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) | |
| self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info) | |
| # Check termination | |
| if ( | |
| operation.completed_tokens | |
| != prev_completed_tokens + len(batch_hashes) * self.page_size | |
| ): | |
| operation.mark_terminate() | |
| break # Some operations fail or operation terminated by controller | |
| if prefix_keys and len(prefix_keys) > 0: | |
| prefix_keys += batch_hashes | |
| # release pre-allocated memory | |
| self.append_host_mem_release( | |
| operation.host_indices[operation.completed_tokens :] | |
| ) | |
| def prefetch_io_aux_func(self): | |
| """ | |
| Auxiliary function conducting IO operations for prefetching. | |
| """ | |
| while not self.stop_event.is_set(): | |
| try: | |
| operation = self.prefetch_buffer.get(block=True, timeout=1) | |
| self._page_transfer(operation) | |
| # operation terminated by controller, release pre-allocated memory | |
| self.append_host_mem_release( | |
| operation.host_indices[operation.completed_tokens :] | |
| ) | |
| except Empty: | |
| continue | |
| def prefetch_rate_limited(self) -> bool: | |
| """ | |
| Rate limit the prefetching operations to avoid overwhelming the storage backend. | |
| """ | |
| # cancel prefetch if too much memory is occupied | |
| if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit: | |
| return True | |
| # todo: more sophisticated rate limiting based on storage backend performance | |
| return False | |
| def _storage_hit_query(self, operation) -> tuple[list[str], int]: | |
| last_hash = operation.last_hash | |
| tokens_to_fetch = operation.token_ids | |
| prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None | |
| storage_query_count = 0 | |
| hash_value = [] | |
| for start in range( | |
| 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size | |
| ): | |
| end = min( | |
| start + self.page_size * self.storage_batch_size, len(tokens_to_fetch) | |
| ) | |
| batch_tokens = tokens_to_fetch[start:end] | |
| batch_hashes = [] | |
| for i in range(0, len(batch_tokens), self.page_size): | |
| last_hash = self.get_hash_str( | |
| batch_tokens[i : i + self.page_size], last_hash | |
| ) | |
| batch_hashes.append(last_hash) | |
| extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) | |
| hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info) | |
| hash_value.extend(batch_hashes[:hit_page_num]) | |
| storage_query_count += hit_page_num * self.page_size | |
| if hit_page_num < len(batch_hashes): | |
| break | |
| if prefix_keys and len(prefix_keys) > 0: | |
| prefix_keys += batch_hashes | |
| return hash_value, storage_query_count | |
| def prefetch_thread_func(self): | |
| """ | |
| Manage prefetching operations from storage backend to host memory. | |
| """ | |
| self.prefetch_buffer = Queue() | |
| aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True) | |
| aux_thread.start() | |
| while (not self.stop_event.is_set()) or not self.prefetch_queue.empty(): | |
| try: | |
| operation = self.prefetch_queue.get(block=True, timeout=1) | |
| if operation is None: | |
| continue | |
| hash_value, storage_hit_count = self._storage_hit_query(operation) | |
| if self.tp_world_size > 1: | |
| storage_hit_count_tensor = torch.tensor( | |
| storage_hit_count, dtype=torch.int | |
| ) | |
| torch.distributed.all_reduce( | |
| storage_hit_count_tensor, | |
| op=torch.distributed.ReduceOp.MIN, | |
| group=self.prefetch_tp_group, | |
| ) | |
| storage_hit_count = storage_hit_count_tensor.item() | |
| if storage_hit_count < self.prefetch_threshold: | |
| # not to prefetch if not enough benefits | |
| self.prefetch_revoke_queue.put(operation.request_id) | |
| self.append_host_mem_release(operation.host_indices) | |
| logger.debug( | |
| f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." | |
| ) | |
| else: | |
| operation.hash_value = hash_value[ | |
| : (storage_hit_count // self.page_size) | |
| ] | |
| # free the pre-allocated memory for pages that are not hit | |
| self.append_host_mem_release( | |
| operation.host_indices[storage_hit_count:] | |
| ) | |
| operation.host_indices = operation.host_indices[:storage_hit_count] | |
| logger.debug( | |
| f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}." | |
| ) | |
| self.prefetch_buffer.put(operation) | |
| except Empty: | |
| continue | |
| def write_storage( | |
| self, | |
| host_indices: torch.Tensor, | |
| token_ids: List[int], | |
| hash_value: Optional[List[str]] = None, | |
| prefix_keys: Optional[List[str]] = None, | |
| ) -> int: | |
| """ | |
| Write KV caches from host memory to storage backend. | |
| """ | |
| operation = StorageOperation( | |
| host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys | |
| ) | |
| self.backup_queue.put(operation) | |
| return operation.id | |
| # todo: deprecate | |
| def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool: | |
| data = [ | |
| self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) | |
| for i in range(len(hash_values)) | |
| ] | |
| return self.storage_backend.batch_set(hash_values, data) | |
| def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool: | |
| return all( | |
| self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info) | |
| ) | |
| # Backup batch by batch | |
| def _page_backup(self, operation): | |
| # Backup batch by batch | |
| prefix_keys = operation.prefix_keys | |
| for i in range(0, len(operation.hash_value), self.storage_batch_size): | |
| batch_hashes = operation.hash_value[i : i + self.storage_batch_size] | |
| batch_host_indices = operation.host_indices[ | |
| i * self.page_size : (i + len(batch_hashes)) * self.page_size | |
| ] | |
| # Set one batch token, and record if success. | |
| # todo: allow partial success | |
| extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys) | |
| success = self.page_set_func(batch_hashes, batch_host_indices, extra_info) | |
| if not success: | |
| logger.warning( | |
| f"Write page to storage: {len(batch_hashes)} pages failed." | |
| ) | |
| break | |
| if prefix_keys and len(prefix_keys) > 0: | |
| prefix_keys += batch_hashes | |
| operation.completed_tokens += self.page_size * len(batch_hashes) | |
| def backup_thread_func(self): | |
| """ | |
| Manage backup operations from host memory to storage backend. | |
| """ | |
| while not self.stop_event.is_set(): | |
| try: | |
| operation = self.backup_queue.get(block=True, timeout=1) | |
| if operation is None: | |
| continue | |
| if not self.backup_skip: | |
| self._page_backup(operation) | |
| self.ack_backup_queue.put(operation) | |
| except Empty: | |
| continue | |
Xet Storage Details
- Size:
- 30.5 kB
- Xet hash:
- 3e2e6b401564213164146a6f1a148fba33e12b640026608b6314180057eadd70
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.