| import heapq | |
| import json | |
| import logging | |
| import threading | |
| import time | |
| from typing import List, Optional | |
| import torch | |
| from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation | |
| from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.base_prefix_cache import MatchResult | |
| from sglang.srt.mem_cache.memory_pool import ( | |
| MHATokenToKVPool, | |
| MLATokenToKVPool, | |
| ReqToTokenPool, | |
| ) | |
| from sglang.srt.mem_cache.memory_pool_host import ( | |
| MHATokenToKVPoolHost, | |
| MLATokenToKVPoolHost, | |
| ) | |
| from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode | |
| from sglang.srt.metrics.collector import StorageMetricsCollector | |
| logger = logging.getLogger(__name__) | |
| class HiRadixCache(RadixCache): | |
| def __init__( | |
| self, | |
| req_to_token_pool: ReqToTokenPool, | |
| token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, | |
| tp_cache_group: torch.distributed.ProcessGroup, | |
| page_size: int, | |
| hicache_ratio: float, | |
| hicache_size: int, | |
| hicache_write_policy: str, | |
| hicache_io_backend: str, | |
| hicache_mem_layout: str, | |
| enable_metrics: bool, | |
| eviction_policy: str = "lru", | |
| hicache_storage_backend: Optional[str] = None, | |
| hicache_storage_prefetch_policy: Optional[str] = "best_effort", | |
| model_name: Optional[str] = None, | |
| storage_backend_extra_config: Optional[str] = None, | |
| is_eagle: bool = False, | |
| ): | |
| if hicache_io_backend == "direct": | |
| if hicache_mem_layout == "page_first": | |
| hicache_mem_layout = "page_first_direct" | |
| logger.warning( | |
| "Page first layout is not supported with direct IO backend, switching to page first direct layout" | |
| ) | |
| self.kv_cache = token_to_kv_pool_allocator.get_kvcache() | |
| if isinstance(self.kv_cache, MHATokenToKVPool): | |
| self.token_to_kv_pool_host = MHATokenToKVPoolHost( | |
| self.kv_cache, | |
| hicache_ratio, | |
| hicache_size, | |
| page_size, | |
| hicache_mem_layout, | |
| ) | |
| elif isinstance(self.kv_cache, MLATokenToKVPool): | |
| self.token_to_kv_pool_host = MLATokenToKVPoolHost( | |
| self.kv_cache, | |
| hicache_ratio, | |
| hicache_size, | |
| page_size, | |
| hicache_mem_layout, | |
| ) | |
| else: | |
| raise ValueError(f"HiRadixCache only supports MHA and MLA yet") | |
| self.tp_group = tp_cache_group | |
| self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) | |
| self.enable_storage = hicache_storage_backend is not None | |
| self.enable_storage_metrics = self.enable_storage and enable_metrics | |
| ( | |
| extra_config, | |
| prefetch_threshold, | |
| prefetch_timeout_base, | |
| prefetch_timeout_per_ki_token, | |
| hicache_storage_pass_prefix_keys, | |
| ) = self._parse_storage_backend_extra_config(storage_backend_extra_config) | |
| self.prefetch_threshold = prefetch_threshold | |
| self.prefetch_timeout_base = prefetch_timeout_base | |
| self.prefetch_timeout_per_page = ( | |
| page_size / 1024 * prefetch_timeout_per_ki_token | |
| ) | |
| self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys | |
| # TODO: support more timeout check functions | |
| self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func | |
| self.prefetch_stop_policy = hicache_storage_prefetch_policy | |
| self.load_cache_event = threading.Event() | |
| self.cache_controller = HiCacheController( | |
| token_to_kv_pool_allocator, | |
| self.token_to_kv_pool_host, | |
| page_size, | |
| self.tp_group, | |
| load_cache_event=self.load_cache_event, | |
| write_policy=hicache_write_policy, | |
| io_backend=hicache_io_backend, | |
| storage_backend=hicache_storage_backend, | |
| prefetch_threshold=self.prefetch_threshold, | |
| model_name=model_name, | |
| storage_backend_extra_config=extra_config, | |
| ) | |
| if self.enable_storage_metrics: | |
| # TODO: support pp | |
| labels = { | |
| "storage_backend": hicache_storage_backend, | |
| "tp_rank": self.cache_controller.tp_rank, | |
| "dp_rank": self.cache_controller.dp_rank, | |
| } | |
| self.metrics_collector = StorageMetricsCollector(labels=labels) | |
| # record the nodes with ongoing write through | |
| self.ongoing_write_through = {} | |
| # record the node segments with ongoing load back | |
| self.ongoing_load_back = {} | |
| # record the ongoing prefetch requests | |
| self.ongoing_prefetch = {} | |
| self.ongoing_backup = {} | |
| # todo: dynamically adjust the threshold | |
| self.write_through_threshold = ( | |
| 1 if hicache_write_policy == "write_through" else 2 | |
| ) | |
| self.load_back_threshold = 10 | |
| super().__init__( | |
| req_to_token_pool, | |
| token_to_kv_pool_allocator, | |
| page_size, | |
| disable=False, | |
| eviction_policy=eviction_policy, | |
| is_eagle=is_eagle, | |
| ) | |
| def _parse_storage_backend_extra_config( | |
| self, storage_backend_extra_config: Optional[str] | |
| ): | |
| """ | |
| Parse storage backend extra config JSON and extract specific parameters. | |
| Args: | |
| storage_backend_extra_config: JSON string containing extra configuration | |
| Returns: | |
| tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys) | |
| """ | |
| # Parse extra config JSON if provided | |
| extra_config = {} | |
| if storage_backend_extra_config: | |
| try: | |
| extra_config = json.loads(storage_backend_extra_config) | |
| except Exception as e: | |
| logger.error(f"Invalid backend extra config JSON: {e}") | |
| raise e | |
| prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens | |
| prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds | |
| prefetch_timeout_per_ki_token = extra_config.pop( | |
| "prefetch_timeout_per_ki_token", 0.25 | |
| ) # seconds per 1024 tokens | |
| hicache_storage_pass_prefix_keys = extra_config.pop( | |
| "hicache_storage_pass_prefix_keys", False | |
| ) | |
| if not isinstance(prefetch_threshold, int): | |
| raise ValueError( | |
| f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}" | |
| ) | |
| if not isinstance(prefetch_timeout_base, (int, float)): | |
| raise ValueError( | |
| f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}" | |
| ) | |
| if not isinstance(prefetch_timeout_per_ki_token, (int, float)): | |
| raise ValueError( | |
| f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}" | |
| ) | |
| return ( | |
| extra_config, | |
| prefetch_threshold, | |
| float(prefetch_timeout_base), | |
| float(prefetch_timeout_per_ki_token), | |
| hicache_storage_pass_prefix_keys, | |
| ) | |
| def reset(self): | |
| TreeNode.counter = 0 | |
| self.cache_controller.reset() | |
| self.token_to_kv_pool_host.clear() | |
| super().reset() | |
| def get_height(self, node: TreeNode): | |
| height = 0 | |
| while node != self.root_node: | |
| node = node.parent | |
| height += 1 | |
| return height | |
| def clear_storage_backend(self) -> bool: | |
| if self.enable_storage: | |
| try: | |
| # Check if the storage backend has a clear method (for nixl backends) | |
| if hasattr(self.cache_controller.storage_backend, "clear"): | |
| self.cache_controller.storage_backend.clear() | |
| logger.info( | |
| "Hierarchical cache storage backend cleared successfully!" | |
| ) | |
| return True | |
| else: | |
| logger.warning( | |
| f"Storage backend {type(self.cache_controller.storage_backend).__name__} does not support clear operation." | |
| ) | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to clear hierarchical cache storage backend: {e}") | |
| return False | |
| else: | |
| logger.warning("Hierarchical cache storage backend is not enabled.") | |
| return False | |
| def write_backup(self, node: TreeNode, write_back=False): | |
| host_indices = self.cache_controller.write( | |
| device_indices=node.value, | |
| node_id=node.id, | |
| ) | |
| if host_indices is None: | |
| self.evict_host(len(node.value)) | |
| host_indices = self.cache_controller.write( | |
| device_indices=node.value, | |
| node_id=node.id, | |
| ) | |
| if host_indices is not None: | |
| node.host_value = host_indices | |
| assert len(node.host_value) > 0 | |
| self.ongoing_write_through[node.id] = node | |
| if not write_back: | |
| # no need to lock nodes if write back | |
| self.inc_lock_ref(node) | |
| else: | |
| return 0 | |
| return len(host_indices) | |
| def write_backup_storage(self, node: TreeNode): | |
| prefix_keys = ( | |
| node.get_prefix_hash_values(node.parent) | |
| if self.hicache_storage_pass_prefix_keys | |
| else None | |
| ) | |
| operation_id = self.cache_controller.write_storage( | |
| node.host_value, node.key, node.hash_value, prefix_keys | |
| ) | |
| self.ongoing_backup[operation_id] = node | |
| node.protect_host() | |
| def _inc_hit_count(self, node: TreeNode, chunked=False): | |
| # skip the hit count update for chunked requests | |
| if self.cache_controller.write_policy == "write_back" or chunked: | |
| return | |
| node.hit_count += 1 | |
| if not node.backuped: | |
| if node.hit_count >= self.write_through_threshold: | |
| # write to host if the node is not backuped | |
| self.write_backup(node) | |
| def writing_check(self, write_back=False): | |
| if write_back: | |
| # blocking till all write back complete | |
| while len(self.ongoing_write_through) > 0: | |
| for _, finish_event, ack_list in self.cache_controller.ack_write_queue: | |
| finish_event.synchronize() | |
| for ack_id in ack_list: | |
| del self.ongoing_write_through[ack_id] | |
| self.cache_controller.ack_write_queue.clear() | |
| assert len(self.ongoing_write_through) == 0 | |
| return | |
| # NOTE: all ranks has the same ongoing_write_through, can skip sync if empty | |
| if len(self.ongoing_write_through) == 0: | |
| return | |
| finish_count = 0 | |
| for _, finish_event, ack_list in self.cache_controller.ack_write_queue: | |
| if not finish_event.query(): | |
| break | |
| finish_count += 1 | |
| queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") | |
| if self.tp_world_size > 1: | |
| # synchronize TP workers to make the same update to radix cache | |
| torch.distributed.all_reduce( | |
| queue_size, | |
| op=torch.distributed.ReduceOp.MIN, | |
| group=self.tp_group, | |
| ) | |
| finish_count = int(queue_size.item()) | |
| while finish_count > 0: | |
| _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0) | |
| finish_event.synchronize() | |
| for ack_id in ack_list: | |
| backuped_node = self.ongoing_write_through.pop(ack_id) | |
| self.dec_lock_ref(backuped_node) | |
| if self.enable_storage: | |
| self.write_backup_storage(backuped_node) | |
| finish_count -= 1 | |
| def loading_check(self): | |
| finish_count = 0 | |
| for _, finish_event, ack_list in self.cache_controller.ack_load_queue: | |
| if not finish_event.query(): | |
| # the KV cache loading is still ongoing | |
| break | |
| finish_count += 1 | |
| # no need to sync across TP workers as batch forwarding is synced | |
| for ack_id in ack_list: | |
| end_node = self.ongoing_load_back.pop(ack_id) | |
| self.dec_lock_ref(end_node) | |
| # ACK until all events are processed | |
| del self.cache_controller.ack_load_queue[:finish_count] | |
| def evictable_size(self): | |
| return self.evictable_size_ | |
| def evict(self, num_tokens: int): | |
| leaves = self._collect_leaves_device() | |
| eviction_heap = [ | |
| (self.eviction_strategy.get_priority(node), node) for node in leaves | |
| ] | |
| heapq.heapify(eviction_heap) | |
| num_evicted = 0 | |
| write_back_nodes = [] | |
| while num_evicted < num_tokens and len(eviction_heap): | |
| _priority, x = heapq.heappop(eviction_heap) | |
| if x.lock_ref > 0: | |
| continue | |
| if not x.backuped: | |
| if self.cache_controller.write_policy == "write_back": | |
| # write to host if the node is not backuped | |
| num_evicted += self.write_backup(x, write_back=True) | |
| write_back_nodes.append(x) | |
| else: | |
| num_evicted += self._evict_regular(x) | |
| else: | |
| num_evicted += self._evict_backuped(x) | |
| for child in x.parent.children.values(): | |
| if child in write_back_nodes: | |
| continue | |
| if not child.evicted: | |
| break | |
| else: | |
| # all children are evicted or no children | |
| new_priority = self.eviction_strategy.get_priority(x.parent) | |
| heapq.heappush(eviction_heap, (new_priority, x.parent)) | |
| if self.cache_controller.write_policy == "write_back": | |
| self.writing_check(write_back=True) | |
| for node in write_back_nodes: | |
| assert node.backuped | |
| self._evict_backuped(node) | |
| def _evict_backuped(self, node: TreeNode): | |
| # evict a node already written to host | |
| num_evicted = self.cache_controller.evict_device(node.value) | |
| assert num_evicted > 0 | |
| self.evictable_size_ -= num_evicted | |
| node.value = None | |
| return num_evicted | |
| def _evict_regular(self, node: TreeNode): | |
| # evict a node not initiated write to host | |
| self.cache_controller.mem_pool_device_allocator.free(node.value) | |
| num_evicted = len(node.value) | |
| self._delete_leaf(node) | |
| return num_evicted | |
| def evict_host(self, num_tokens: int): | |
| leaves = self._collect_leaves() | |
| eviction_heap = [ | |
| (self.eviction_strategy.get_priority(node), node) for node in leaves | |
| ] | |
| heapq.heapify(eviction_heap) | |
| num_evicted = 0 | |
| while num_evicted < num_tokens and len(eviction_heap): | |
| _priority, x = heapq.heappop(eviction_heap) | |
| if x == self.root_node: | |
| break | |
| # only evict the host value of evicted nodes | |
| if not x.evicted: | |
| continue | |
| # node is protected from eviction as it has ongoing prefetch or backup to storage | |
| if x.host_ref_counter > 0: | |
| continue | |
| num_evicted += self.cache_controller.evict_host(x.host_value) | |
| for k, v in x.parent.children.items(): | |
| if v == x: | |
| break | |
| del x.parent.children[k] | |
| if len(x.parent.children) == 0 and x.parent.evicted: | |
| new_priority = self.eviction_strategy.get_priority(x.parent) | |
| heapq.heappush(eviction_heap, (new_priority, x.parent)) | |
| def load_back( | |
| self, node: TreeNode, mem_quota: Optional[int] = None | |
| ) -> Optional[torch.Tensor]: | |
| # todo: more loading policies | |
| last_hit_node = node | |
| nodes_to_load = [] | |
| while node.evicted: | |
| assert ( | |
| node.backuped | |
| ), "No backup available on evicted nodes, should not happen" | |
| nodes_to_load.insert(0, node) | |
| node = node.parent | |
| else: | |
| ancester_node = node | |
| # protect the ancestor nodes from eviction | |
| delta = self.inc_lock_ref(ancester_node) | |
| # load it all or not at all | |
| host_indices = torch.cat([n.host_value for n in nodes_to_load]) | |
| if len(host_indices) < self.load_back_threshold or ( | |
| len(host_indices) > mem_quota + delta if mem_quota is not None else False | |
| ): | |
| # skip loading back if the total size is too small or exceeding the memory quota | |
| self.dec_lock_ref(ancester_node) | |
| return None | |
| device_indices = self.cache_controller.load( | |
| host_indices=host_indices, node_id=last_hit_node.id | |
| ) | |
| if device_indices is None: | |
| self.evict(len(host_indices)) | |
| device_indices = self.cache_controller.load( | |
| host_indices=host_indices, node_id=last_hit_node.id | |
| ) | |
| self.dec_lock_ref(ancester_node) | |
| if device_indices is None: | |
| # no sufficient GPU memory to load back KV caches | |
| return None | |
| self.ongoing_load_back[last_hit_node.id] = last_hit_node | |
| offset = 0 | |
| for node in nodes_to_load: | |
| node.value = device_indices[offset : offset + len(node.host_value)] | |
| offset += len(node.host_value) | |
| self.evictable_size_ += len(device_indices) | |
| self.inc_lock_ref(last_hit_node) | |
| return device_indices | |
| def init_load_back( | |
| self, | |
| last_node: TreeNode, | |
| host_hit_length: int, | |
| mem_quota: Optional[int] = None, | |
| ): | |
| _ = host_hit_length # unused, but kept for compatibility | |
| if last_node.evicted: | |
| loading_values = self.load_back(last_node, mem_quota) | |
| if loading_values is not None: | |
| logger.debug( | |
| f"loading back {len(loading_values)} tokens for node {last_node.id}" | |
| ) | |
| return loading_values, last_node | |
| while last_node.evicted: | |
| last_node = last_node.parent | |
| return ( | |
| torch.empty((0,), dtype=torch.int64, device=self.device), | |
| last_node, | |
| ) | |
| def ready_to_load_host_cache(self) -> int: | |
| """ | |
| Notify the cache controller to start the KV cache loading. | |
| Return the consumer index for the schedule batch manager to track. | |
| """ | |
| return self.cache_controller.start_loading() | |
| def check_hicache_events(self): | |
| self.writing_check() | |
| self.loading_check() | |
| if self.enable_storage: | |
| self.drain_storage_control_queues() | |
| if self.enable_storage_metrics: | |
| self.metrics_collector.log_storage_metrics( | |
| self.cache_controller.storage_backend.get_stats() | |
| ) | |
| def drain_storage_control_queues(self): | |
| """ | |
| Combine prefetch revoke, backup ack, and host mem release checks | |
| to minimize TP synchronization and Python overhead. | |
| """ | |
| cc = self.cache_controller | |
| qsizes = torch.tensor( | |
| [ | |
| cc.prefetch_revoke_queue.qsize(), | |
| cc.ack_backup_queue.qsize(), | |
| cc.host_mem_release_queue.qsize(), | |
| ], | |
| dtype=torch.int, | |
| ) | |
| if self.tp_world_size > 1: | |
| torch.distributed.all_reduce( | |
| qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group | |
| ) | |
| n_revoke, n_backup, n_release = map(int, qsizes.tolist()) | |
| # process prefetch revokes | |
| for _ in range(n_revoke): | |
| req_id = cc.prefetch_revoke_queue.get() | |
| info = self.ongoing_prefetch.pop(req_id, None) | |
| if info is not None: | |
| last_host_node, token_ids, _, _ = info | |
| last_host_node.release_host() | |
| cc.prefetch_tokens_occupied -= len(token_ids) | |
| # else: the revoked operation already got terminated, nothing to do | |
| # process backup acks | |
| for _ in range(n_backup): | |
| operation = cc.ack_backup_queue.get() | |
| ack_id = operation.id | |
| entry = self.ongoing_backup.pop(ack_id, None) | |
| if entry is not None: | |
| entry.release_host() | |
| if self.enable_storage_metrics: | |
| self.metrics_collector.log_backuped_tokens(operation.completed_tokens) | |
| # release host memory | |
| host_indices_list = [] | |
| for _ in range(n_release): | |
| host_indices_list.append(cc.host_mem_release_queue.get()) | |
| if host_indices_list: | |
| host_indices = torch.cat(host_indices_list, dim=0) | |
| cc.mem_pool_host.free(host_indices) | |
| # Timeout is linearly increasing with the number of pages | |
| def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation): | |
| # If hash_value has not been computed in timeout_base seconds, terminate it. | |
| return ( | |
| time.monotonic() - operation.start_time | |
| > self.prefetch_timeout_base | |
| + len(operation.hash_value) * self.prefetch_timeout_per_page | |
| ) | |
| def can_terminate_prefetch(self, operation: PrefetchOperation): | |
| can_terminate = True | |
| if self.prefetch_stop_policy == "best_effort": | |
| return can_terminate | |
| if len(operation.hash_value) == 0: | |
| completed = False | |
| else: | |
| completed = ( | |
| operation.completed_tokens == len(operation.hash_value) * self.page_size | |
| ) | |
| if self.prefetch_stop_policy == "wait_complete": | |
| can_terminate = completed | |
| elif self.prefetch_stop_policy == "timeout": | |
| can_terminate = completed or self.is_prefetch_timeout(operation) | |
| else: | |
| # unknown prefetch stop policy, just return True | |
| return True | |
| operation_terminated = operation.is_terminated() | |
| if self.tp_world_size > 1: | |
| states = torch.tensor( | |
| [1 - int(can_terminate), int(operation_terminated)], | |
| dtype=torch.int, | |
| ) | |
| torch.distributed.all_reduce( | |
| states, | |
| op=torch.distributed.ReduceOp.MAX, | |
| group=self.tp_group, | |
| ) | |
| can_terminate = states[0].item() == 0 | |
| operation_terminated = states[1].item() == 1 | |
| # the operation should be terminated if it is already terminated on any TP worker | |
| # or it meets the termination condition on all TP workers | |
| can_terminate = can_terminate or operation_terminated | |
| return can_terminate | |
| def check_prefetch_progress(self, req_id: str) -> bool: | |
| if req_id not in self.ongoing_prefetch: | |
| # there is no ongoing prefetch for this request or it has been revoked | |
| return True | |
| # todo: more policies for prefetch progress such as timeout | |
| # the current policy is to prefetch with best effort and terminate when queuing is over | |
| last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ | |
| req_id | |
| ] | |
| if operation.host_indices is None: | |
| # prefetch has not been issued due to insufficient host memory | |
| return True | |
| if not self.can_terminate_prefetch(operation): | |
| return False | |
| completed_tokens, hash_value = self.cache_controller.terminate_prefetch( | |
| operation | |
| ) | |
| logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") | |
| min_completed_tokens = completed_tokens | |
| if self.tp_world_size > 1: | |
| # synchrnoize TP workers to make the same update to hiradix cache | |
| completed_tokens_tensor = torch.tensor( | |
| min_completed_tokens, dtype=torch.int | |
| ) | |
| torch.distributed.all_reduce( | |
| completed_tokens_tensor, | |
| op=torch.distributed.ReduceOp.MIN, | |
| group=self.tp_group, | |
| ) | |
| min_completed_tokens = completed_tokens_tensor.item() | |
| fetched_token_ids = token_ids[:min_completed_tokens] | |
| written_indices = host_indices[:min_completed_tokens] | |
| matched_length = self._insert_helper_host( | |
| last_host_node, | |
| RadixKey( | |
| token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key | |
| ), | |
| written_indices, | |
| hash_value[: min_completed_tokens // self.page_size], | |
| ) | |
| self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) | |
| self.cache_controller.append_host_mem_release( | |
| host_indices[min_completed_tokens:completed_tokens] | |
| ) | |
| last_host_node.release_host() | |
| del self.ongoing_prefetch[req_id] | |
| self.cache_controller.prefetch_tokens_occupied -= len(token_ids) | |
| if self.enable_storage_metrics: | |
| self.metrics_collector.log_prefetched_tokens( | |
| min_completed_tokens - matched_length | |
| ) | |
| return True | |
| def match_prefix(self, key: RadixKey, **kwargs): | |
| empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) | |
| key.token_ids = self.key_convert_fn(key.token_ids) | |
| if self.disable or len(key) == 0: | |
| return MatchResult( | |
| device_indices=empty_value, | |
| last_device_node=self.root_node, | |
| last_host_node=self.root_node, | |
| host_hit_length=0, | |
| ) | |
| if self.page_size != 1: | |
| page_aligned_len = len(key) // self.page_size * self.page_size | |
| key = key[:page_aligned_len] | |
| value, last_node = self._match_prefix_helper(self.root_node, key) | |
| if value: | |
| value = torch.cat(value) | |
| else: | |
| value = empty_value | |
| host_hit_length = 0 | |
| last_host_node = last_node | |
| while last_node.evicted: | |
| host_hit_length += len(last_node.host_value) | |
| last_node = last_node.parent | |
| while not last_host_node.backuped: | |
| last_host_node = last_host_node.parent | |
| return MatchResult( | |
| device_indices=value, | |
| last_device_node=last_node, | |
| last_host_node=last_host_node, | |
| host_hit_length=host_hit_length, | |
| ) | |
| def prefetch_from_storage( | |
| self, | |
| req_id: str, | |
| last_host_node: TreeNode, | |
| new_input_tokens: List[int], | |
| last_hash: Optional[str] = None, | |
| prefix_keys: Optional[List[str]] = None, | |
| ): | |
| # align the number of fetching tokens to the page size | |
| prefetch_length = len(new_input_tokens) - ( | |
| len(new_input_tokens) % self.page_size | |
| ) | |
| new_input_tokens = new_input_tokens[:prefetch_length] | |
| if ( | |
| not self.enable_storage | |
| or prefetch_length < self.prefetch_threshold | |
| or self.cache_controller.prefetch_rate_limited() | |
| ): | |
| return | |
| last_host_node.protect_host() | |
| host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) | |
| if host_indices is None: | |
| self.evict_host(prefetch_length) | |
| host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) | |
| if host_indices is None: | |
| last_host_node.release_host() | |
| # no sufficient host memory for prefetch | |
| return | |
| operation = self.cache_controller.prefetch( | |
| req_id, host_indices, new_input_tokens, last_hash, prefix_keys | |
| ) | |
| self.ongoing_prefetch[req_id] = ( | |
| last_host_node, | |
| new_input_tokens, | |
| host_indices, | |
| operation, | |
| ) | |
| self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) | |
| def _insert_helper_host( | |
| self, node: TreeNode, key: RadixKey, host_value, hash_value | |
| ): | |
| node.last_access_time = time.monotonic() | |
| if len(key) == 0: | |
| return 0 | |
| child_key = self.get_child_key_fn(key) | |
| matched_length = 0 | |
| while len(key) > 0 and child_key in node.children.keys(): | |
| node = node.children[child_key] | |
| node.last_access_time = time.monotonic() | |
| prefix_len = self.key_match_fn(node.key, key) | |
| key = key[prefix_len:] | |
| host_value = host_value[prefix_len:] | |
| hash_value = hash_value[prefix_len // self.page_size :] | |
| matched_length += prefix_len | |
| if prefix_len < len(node.key): | |
| new_node = self._split_node(node.key, node, prefix_len) | |
| node = new_node | |
| if len(key): | |
| child_key = self.get_child_key_fn(key) | |
| if len(key): | |
| new_node = TreeNode() | |
| new_node.parent = node | |
| new_node.key = key | |
| new_node.value = None | |
| new_node.host_value = host_value | |
| new_node.hash_value = hash_value | |
| node.children[child_key] = new_node | |
| return matched_length | |
| def _match_prefix_helper(self, node: TreeNode, key: RadixKey): | |
| node.last_access_time = time.monotonic() | |
| child_key = self.get_child_key_fn(key) | |
| value = [] | |
| while len(key) > 0 and child_key in node.children.keys(): | |
| child = node.children[child_key] | |
| child.last_access_time = time.monotonic() | |
| prefix_len = self.key_match_fn(child.key, key) | |
| if prefix_len < len(child.key): | |
| new_node = self._split_node(child.key, child, prefix_len) | |
| if not new_node.evicted: | |
| value.append(new_node.value) | |
| node = new_node | |
| break | |
| else: | |
| if not child.evicted: | |
| value.append(child.value) | |
| node = child | |
| key = key[prefix_len:] | |
| if len(key): | |
| child_key = self.get_child_key_fn(key) | |
| return value, node | |
| def _split_node(self, key: RadixKey, child: TreeNode, split_len: int): | |
| # child node split into new_node -> child | |
| new_node = TreeNode() | |
| new_node.children = {self.get_child_key_fn(key[split_len:]): child} | |
| new_node.parent = child.parent | |
| new_node.lock_ref = child.lock_ref | |
| new_node.key = child.key[:split_len] | |
| new_node.hit_count = child.hit_count | |
| # split value and host value if exists | |
| if child.evicted: | |
| new_node.value = None | |
| else: | |
| new_node.value = child.value[:split_len] | |
| child.value = child.value[split_len:] | |
| if child.backuped: | |
| new_node.host_value = child.host_value[:split_len] | |
| child.host_value = child.host_value[split_len:] | |
| if child.hash_value: | |
| new_node.hash_value = child.hash_value[: split_len // self.page_size] | |
| child.hash_value = child.hash_value[split_len // self.page_size :] | |
| child.parent = new_node | |
| child.key = child.key[split_len:] | |
| new_node.parent.children[self.get_child_key_fn(key)] = new_node | |
| return new_node | |
| def insert(self, key: RadixKey, value=None, chunked=False): | |
| key.token_ids = self.key_convert_fn(key.token_ids) | |
| if len(key) == 0: | |
| return 0 | |
| if self.is_eagle and value is not None: | |
| # Make sure the value len equal to the EAGLE bigram key len | |
| value = value[: len(key)] | |
| node = self.root_node | |
| child_key = self.get_child_key_fn(key) | |
| total_prefix_length = 0 | |
| while len(key) > 0 and child_key in node.children.keys(): | |
| node = node.children[child_key] | |
| node.last_access_time = time.monotonic() | |
| prefix_len = self.key_match_fn(node.key, key) | |
| if prefix_len == len(node.key): | |
| if node.evicted: | |
| # change the reference if the node is evicted | |
| # this often happens in the case of KV cache recomputation | |
| node.value = value[:prefix_len] | |
| self.evictable_size_ += len(node.value) | |
| else: | |
| self._inc_hit_count(node, chunked) | |
| total_prefix_length += prefix_len | |
| else: | |
| # partial match, split the node | |
| new_node = self._split_node(node.key, node, prefix_len) | |
| if new_node.evicted: | |
| new_node.value = value[:prefix_len] | |
| self.evictable_size_ += len(new_node.value) | |
| else: | |
| self._inc_hit_count(new_node, chunked) | |
| total_prefix_length += prefix_len | |
| node = new_node | |
| key = key[prefix_len:] | |
| value = value[prefix_len:] | |
| if len(key): | |
| child_key = self.get_child_key_fn(key) | |
| if len(key): | |
| new_node = TreeNode() | |
| new_node.parent = node | |
| new_node.key = key | |
| new_node.value = value | |
| node.children[child_key] = new_node | |
| self.evictable_size_ += len(value) | |
| if self.enable_storage: | |
| last_hash = node.get_last_hash_value() | |
| assert (node == self.root_node) or ( | |
| last_hash is not None | |
| ), "Parent node must have a hash value with storage enabled" | |
| new_node.hash_value = [] | |
| for idx in range(0, len(key), self.page_size): | |
| new_node.hash_value.append( | |
| self.cache_controller.get_hash_str( | |
| key.token_ids[idx : idx + self.page_size], | |
| prior_hash=last_hash, | |
| ) | |
| ) | |
| last_hash = new_node.hash_value[-1] | |
| if self.cache_controller.write_policy != "write_back": | |
| self._inc_hit_count(new_node, chunked) | |
| return total_prefix_length | |
| def _collect_leaves_device(self): | |
| def is_leaf(node): | |
| if node.evicted: | |
| return False | |
| if node == self.root_node: | |
| return False | |
| if len(node.children) == 0: | |
| return True | |
| for child in node.children.values(): | |
| if not child.evicted: | |
| return False | |
| return True | |
| ret_list = [] | |
| stack = [self.root_node] | |
| while stack: | |
| cur_node = stack.pop() | |
| if is_leaf(cur_node): | |
| ret_list.append(cur_node) | |
| else: | |
| for cur_child in cur_node.children.values(): | |
| if not cur_child.evicted: | |
| stack.append(cur_child) | |
| return ret_list | |
| def release_aborted_request(self, rid: str): | |
| if rid not in self.ongoing_prefetch: | |
| return | |
| last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid] | |
| if operation.host_indices is None: | |
| return | |
| completed_tokens, _ = self.cache_controller.terminate_prefetch(operation) | |
| if self.tp_world_size > 1: | |
| torch.distributed.barrier(group=self.tp_group) | |
| last_host_node.release_host() | |
| del self.ongoing_prefetch[rid] | |
| self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) | |
| self.cache_controller.prefetch_tokens_occupied -= len(token_ids) | |
Xet Storage Details
- Size:
- 36 kB
- Xet hash:
- 134d945a1c1fe1bc352e086ac335a07a386bb2835c648c84dfa30f543db22cfa
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.