| from __future__ import annotations | |
| """ | |
| Copyright 2023-2024 SGLang Team | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| """ | |
| The radix tree data structure for managing the hybrid (full and Mamba) KV cache. | |
| """ | |
| import heapq | |
| import time | |
| from collections import defaultdict | |
| from typing import TYPE_CHECKING, List, Optional, Tuple | |
| import torch | |
| from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator | |
| from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult | |
| from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool | |
| from sglang.srt.mem_cache.radix_cache import ( | |
| RadixKey, | |
| _key_match_page_size1, | |
| get_child_key, | |
| ) | |
| if TYPE_CHECKING: | |
| from sglang.srt.managers.schedule_batch import Req | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TreeNode: | |
| counter = 0 | |
| def __init__(self, id: Optional[int] = None): | |
| self.children = defaultdict(TreeNode) | |
| self.parent: TreeNode = None | |
| self.key: RadixKey = None | |
| self.value: Optional[torch.Tensor] = None | |
| self.mamba_value: Optional[torch.Tensor] = None | |
| # invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked; | |
| # if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So, | |
| # full_lock_ref is always >= mamba_lock_ref. | |
| # for full_lock, once it is locked, its parent must be locked as well | |
| # for mamba_lock, it only need lock node itself | |
| self.full_lock_ref = 0 | |
| self.mamba_lock_ref = 0 | |
| # last access time is only used for sanity check. LRU is maintained by the lru list. | |
| self.last_access_time = time.monotonic() | |
| self.hit_count = 0 | |
| # store the host indices of KV cache | |
| self.host_value = None | |
| # for lru list, invariant: | |
| # 1. prev has greater last_access_time | |
| # 2. next has smaller last_access_time | |
| self.prev = None | |
| self.next = None | |
| self.mamba_prev = None | |
| self.mamba_next = None | |
| self.id = TreeNode.counter if id is None else id | |
| TreeNode.counter += 1 | |
| def evicted(self): | |
| return self.value is None | |
| def backuped(self): | |
| return self.host_value is not None | |
| def __lt__(self, other: "TreeNode"): | |
| return self.last_access_time < other.last_access_time | |
| class LRUList: | |
| def __init__(self, mamba: bool = False): | |
| self.mamba = mamba | |
| if self.mamba: | |
| self.prv = "mamba_prev" | |
| self.nxt = "mamba_next" | |
| self.lock_ref = "mamba_lock_ref" | |
| else: | |
| self.prv = "prev" | |
| self.nxt = "next" | |
| self.lock_ref = "full_lock_ref" | |
| # Initialize dummy head and tail nodes | |
| self.head = TreeNode() # Most recently used side | |
| self.tail = TreeNode() # Least recently used side | |
| setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail | |
| setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head | |
| self.cache = {} | |
| def _add_node(self, node): | |
| """Helper to add node right after head (most recently used)""" | |
| self._add_node_after(self.head, node) | |
| def _add_node_after(self, old_node, new_node): | |
| """Helper to add node right after old_node""" | |
| setattr(new_node, self.prv, old_node) # new_node.prev = old_node | |
| setattr( | |
| new_node, self.nxt, getattr(old_node, self.nxt) | |
| ) # new_node.next = old_node.next | |
| setattr( | |
| getattr(old_node, self.nxt), self.prv, new_node | |
| ) # old_node.next.prev = new_node | |
| setattr(old_node, self.nxt, new_node) # old_node.next = new_node | |
| def _remove_node(self, node): | |
| """Helper to remove node from linked list""" | |
| setattr( | |
| getattr(node, self.prv), self.nxt, getattr(node, self.nxt) | |
| ) # node.prev.next = node.next | |
| setattr( | |
| getattr(node, self.nxt), self.prv, getattr(node, self.prv) | |
| ) # node.next.prev = node.prev | |
| def _get_lru(self) -> Optional[TreeNode]: | |
| """ | |
| Get the least recently used node | |
| """ | |
| if len(self.cache) == 0: | |
| return None | |
| return getattr(self.tail, self.prv) | |
| def reset_node_mru(self, node): | |
| """ | |
| Move a (existing) node to most recently used position | |
| """ | |
| assert node.id in self.cache, f"Resetting node {node.id=} not in lru list" | |
| assert ( | |
| not self.mamba or node.mamba_value is not None | |
| ), f"Resetting mamba tombstone node in mamba lru list: {node.id=}" | |
| self._remove_node(node) | |
| self._add_node(node) | |
| def reset_node_and_parents_mru(self, node, root_node): | |
| """ | |
| Move an (existing) node and its parents to most recently used position. Child node is | |
| more recently used than parent node. | |
| """ | |
| prev_node = self.head | |
| while node != root_node: | |
| if not self.mamba or node.mamba_value is not None: | |
| assert ( | |
| node.id in self.cache | |
| ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" | |
| self._remove_node(node) | |
| self._add_node_after(prev_node, node) | |
| prev_node = node | |
| node = node.parent | |
| def insert_mru(self, node): | |
| """ | |
| Insert a (new) node as most recently used | |
| """ | |
| assert ( | |
| not self.mamba or node.mamba_value is not None | |
| ), f"Inserting mamba tombstone node in mamba lru list: {node.id=}" | |
| assert ( | |
| node.id not in self.cache | |
| ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}" | |
| self.cache[node.id] = node | |
| self._add_node(node) | |
| def remove_node(self, node: TreeNode): | |
| """ | |
| Remove node from lru list | |
| """ | |
| assert node.id in self.cache, f"Removing node {node.id=} not in lru list" | |
| assert ( | |
| not self.mamba or node.mamba_value is not None | |
| ), f"Removing mamba tombstone node from mamba lru list: {node.id=}" | |
| del self.cache[node.id] | |
| self._remove_node(node) | |
| def get_lru_no_lock(self) -> Optional[TreeNode]: | |
| """ | |
| Get the least recently used node that is not locked | |
| """ | |
| return self.get_prev_no_lock(self.tail, check_id=False) | |
| def get_leaf_lru_no_lock(self) -> Optional[TreeNode]: | |
| """ | |
| Get the least recently used leaf node that is not locked | |
| """ | |
| return self.get_prev_leaf_no_lock(self.tail, check_id=False) | |
| def get_prev_no_lock( | |
| self, node: TreeNode, check_id: bool = True | |
| ) -> Optional[TreeNode]: | |
| """ | |
| Get the previous (i.e. more recently used) node that is not locked | |
| """ | |
| if check_id: | |
| assert ( | |
| node.id in self.cache | |
| ), f"Getting prev of node {node.id=} not in lru list" | |
| x = getattr(node, self.prv) # x = node.prev | |
| while getattr(x, self.lock_ref) > 0: | |
| x = getattr(x, self.prv) # x = x.prev | |
| # if x is the head, it means there is no node in the lru list without lock | |
| if x == self.head: | |
| return None | |
| return x | |
| def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True): | |
| """ | |
| Get the previous (i.e. more recently used) leaf node that is not locked | |
| """ | |
| if check_id: | |
| assert ( | |
| node.id in self.cache | |
| ), f"Getting prev of node {node.id=} not in lru list" | |
| x = getattr(node, self.prv) # x = node.prev | |
| while getattr(x, self.lock_ref) > 0 or len(x.children) > 0: | |
| x = getattr(x, self.prv) # x = x.prev | |
| # if x is the head, it means there is no leaf node in the lru list without lock | |
| if x == self.head: | |
| return None | |
| return x | |
| def in_list(self, node: Optional[TreeNode]): | |
| """ | |
| Check if the node is in the lru list | |
| """ | |
| if not node: | |
| return False | |
| return node.id in self.cache | |
| # Note: this is expensive, only use for debug | |
| def sanity_check_evictable_size(self): | |
| """ | |
| Check the evictable size (i.e. the size of the nodes that are not locked) | |
| """ | |
| node = self.get_lru_no_lock() | |
| evictable_size = 0 | |
| while self.in_list(node): | |
| evictable_size += ( | |
| len(node.value) if not self.mamba else len(node.mamba_value) | |
| ) | |
| node = self.get_prev_no_lock(node) | |
| return evictable_size | |
| # Note: this is expensive, only use for debug or idle check | |
| def sanity_check(self, tree_cache: "MambaRadixCache"): | |
| """ | |
| Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and | |
| checking if the lru list is valid. | |
| """ | |
| try: | |
| if self.mamba: | |
| nodes = tree_cache._collect_nontombstone_nodes() | |
| else: | |
| nodes = tree_cache._collect_all_nodes() | |
| total_nodes = len(nodes) | |
| total_lru = len(self.cache) | |
| # heapify based on last_access_time | |
| heapq.heapify(nodes) | |
| # the root node is not in the lru list | |
| assert len(nodes) == ( | |
| total_lru + (0 if self.mamba else 1) | |
| ), f"len(nodes): {len(nodes)}, total_lru: {total_lru}" | |
| x_lru = self._get_lru() | |
| while len(nodes): | |
| x = heapq.heappop(nodes) | |
| if x == tree_cache.root_node: | |
| # root node is not in the lru list | |
| continue | |
| assert ( | |
| x == x_lru | |
| ), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}" | |
| assert ( | |
| x_lru.full_lock_ref == 0 | |
| ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}" | |
| assert ( | |
| x_lru.mamba_lock_ref == 0 | |
| ), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}" | |
| x_lru = getattr(x, self.prv) | |
| if self.mamba: | |
| evictable_size = tree_cache.mamba_evictable_size() | |
| lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size() | |
| else: | |
| evictable_size = tree_cache.full_evictable_size() | |
| lru_list_evictable_size = tree_cache.full_lru_list_evictable_size() | |
| assert ( | |
| evictable_size == lru_list_evictable_size | |
| ), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" | |
| except Exception as e: | |
| msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}" | |
| logger.error(msg) | |
| raise Exception(msg) | |
| class MambaRadixCache(BasePrefixCache): | |
| def __init__( | |
| self, | |
| req_to_token_pool: HybridReqToTokenPool, | |
| token_to_kv_pool_allocator: TokenToKVPoolAllocator, | |
| page_size: int, | |
| disable: bool = False, | |
| ): | |
| assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator) | |
| self.req_to_token_pool = req_to_token_pool | |
| self.token_to_kv_pool_allocator = token_to_kv_pool_allocator | |
| assert page_size == 1, "Only support page_size=1 in mamba radix cache now." | |
| self.page_size = page_size | |
| self.disable = disable | |
| if self.token_to_kv_pool_allocator: | |
| self.device = self.token_to_kv_pool_allocator.device | |
| else: | |
| self.device = torch.device("cpu") | |
| self.key_match_fn = _key_match_page_size1 | |
| self.get_child_key_fn = get_child_key | |
| self.reset() | |
| ##### Public API ##### | |
| def reset(self) -> None: | |
| self.root_node = TreeNode() | |
| self.root_node.key = [] | |
| self.root_node.value = [] | |
| self.root_node.full_lock_ref = 1 | |
| self.root_node.mamba_lock_ref = 1 | |
| self.full_evictable_size_ = 0 | |
| self.mamba_evictable_size_ = 0 | |
| self.full_protected_size_ = 0 | |
| self.mamba_protected_size_ = 0 | |
| # LRU lists are used to maintain the order of eviction of the nodes in the tree | |
| self.full_lru_list = LRUList(mamba=False) | |
| self.mamba_lru_list = LRUList(mamba=True) | |
| def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: | |
| """Find the matching prefix from the radix tree. | |
| Args: | |
| key: A RadixKey contains token IDs to find a matching prefix. | |
| Returns: | |
| A tuple of a tensor of matching prefix token IDs and | |
| the last node that contains the prefix values. Note that | |
| this API can modify the internal state of the Radix tree. | |
| The last node create a new child if the prefix is shorter | |
| than the last node's value. | |
| """ | |
| cow_mamba: bool = kwargs.get("cow_mamba", False) | |
| req: Req = kwargs.get("req", None) | |
| if self.disable or len(key) == 0: | |
| return MatchResult( | |
| device_indices=torch.empty( | |
| (0,), | |
| dtype=torch.int64, | |
| device=self.device, | |
| ), | |
| last_device_node=self.root_node, | |
| last_host_node=self.root_node, | |
| ) | |
| value, last_node = self._match_prefix_helper(key) | |
| # copy mamba state to req local space if cow is true | |
| if cow_mamba and last_node.mamba_value is not None: | |
| assert req.req_pool_idx is None # req_pool_idx is uninitialed | |
| # for reqs without mamba cache | |
| if req.mamba_pool_idx is None: | |
| dst_index = self.req_to_token_pool.mamba_pool.alloc(1) | |
| # try to alloc again, protect last_node from eviction | |
| if dst_index is None: | |
| self.inc_lock_ref(last_node) | |
| self.evict_mamba(1) | |
| dst_index = self.req_to_token_pool.mamba_pool.alloc(1) | |
| self.dec_lock_ref(last_node) | |
| assert dst_index is not None, "Can not alloc mamba cache" | |
| src_index = last_node.mamba_value | |
| self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) | |
| req.mamba_pool_idx = dst_index[0] | |
| else: | |
| src_index = last_node.mamba_value | |
| dst_index = req.mamba_pool_idx.unsqueeze(0) | |
| self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) | |
| if value: | |
| value = torch.cat(value) | |
| else: | |
| value = torch.empty((0,), dtype=torch.int64, device=self.device) | |
| return MatchResult( | |
| device_indices=value, | |
| last_device_node=last_node, | |
| last_host_node=last_node, | |
| ) | |
| def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]: | |
| if self.disable: | |
| return 0 | |
| if value is None: | |
| value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) | |
| return self._insert_helper(self.root_node, key, value, mamba_value) | |
| def cache_finished_req(self, req: Req) -> None: | |
| """Cache request when it finishes.""" | |
| if self.disable: | |
| kv_indices = self.req_to_token_pool.req_to_token[ | |
| req.req_pool_idx, | |
| : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), | |
| ] | |
| self.token_to_kv_pool_allocator.free(kv_indices) | |
| self.req_to_token_pool.free(req.req_pool_idx) | |
| return | |
| token_ids = (req.origin_input_ids + req.output_ids)[:-1] | |
| kv_indices = self.req_to_token_pool.req_to_token[ | |
| req.req_pool_idx, : len(token_ids) | |
| ] | |
| page_aligned_len = len(kv_indices) | |
| page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) | |
| # Radix Cache takes one ref in memory pool | |
| # insert the token_ids and kv_indices into the radix tree | |
| # Note: the insert function already frees the overlapped kv_indices | |
| mamba_value = ( | |
| self.req_to_token_pool.get_mamba_indices(req.req_pool_idx) | |
| .unsqueeze(-1) | |
| .clone() | |
| ) | |
| new_prefix_len, mamba_exist = self.insert( | |
| RadixKey(token_ids[:page_aligned_len], req.extra_key), | |
| page_aligned_kv_indices, | |
| mamba_value, | |
| ) | |
| self.token_to_kv_pool_allocator.free( | |
| kv_indices[len(req.prefix_indices) : new_prefix_len] | |
| ) | |
| self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist) | |
| self.dec_lock_ref(req.last_node) | |
| def cache_unfinished_req(self, req: Req, chunked=False) -> None: | |
| """Cache request when it is unfinished.""" | |
| if self.disable: | |
| kv_indices = self.req_to_token_pool.req_to_token[ | |
| req.req_pool_idx, : len(req.fill_ids) | |
| ] | |
| # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later | |
| req.prefix_indices = kv_indices | |
| return | |
| token_ids = req.fill_ids | |
| kv_indices = self.req_to_token_pool.req_to_token[ | |
| req.req_pool_idx, : len(token_ids) | |
| ] | |
| page_aligned_len = len(kv_indices) | |
| page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) | |
| page_aligned_token_ids = token_ids[:page_aligned_len] | |
| mamba_value = self.req_to_token_pool.get_mamba_indices( | |
| req.req_pool_idx | |
| ).unsqueeze(-1) | |
| # radix tree mamba value is forked from req space | |
| mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value) | |
| # if alloc mamba cache failed, do evict and alloc again | |
| if mamba_value_forked is None: | |
| self.evict_mamba(1) | |
| mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from( | |
| mamba_value | |
| ) | |
| assert mamba_value_forked is not None, "Can not alloc mamba cache" | |
| new_prefix_len, mamba_exist = self.insert( | |
| RadixKey(page_aligned_token_ids, req.extra_key), | |
| page_aligned_kv_indices, | |
| mamba_value_forked, | |
| ) | |
| self.token_to_kv_pool_allocator.free( | |
| kv_indices[len(req.prefix_indices) : new_prefix_len] | |
| ) | |
| # there is a mamba cache in radix cache, release it | |
| if mamba_exist: | |
| self.req_to_token_pool.mamba_pool.free(mamba_value_forked) | |
| # The prefix indices could be updated, reuse it | |
| new_indices, new_last_node, _, _ = self.match_prefix( | |
| RadixKey(page_aligned_token_ids, req.extra_key) | |
| ) | |
| if not mamba_exist: | |
| assert torch.equal(new_last_node.mamba_value, mamba_value_forked) | |
| assert len(req.prefix_indices) <= len( | |
| new_indices | |
| ), f"{req.prefix_indices=}, {new_indices=}" | |
| assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" | |
| self.req_to_token_pool.write( | |
| (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), | |
| new_indices[len(req.prefix_indices) :], | |
| ) | |
| self.dec_lock_ref(req.last_node) | |
| self.inc_lock_ref(new_last_node) | |
| # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later | |
| req.prefix_indices = new_indices | |
| req.last_node = new_last_node | |
| def pretty_print(self) -> None: | |
| self._print_helper(self.root_node, 0) | |
| total_size, total_mamba_size = self._total_size_helper() | |
| print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}") | |
| def total_size(self) -> Tuple[int, int]: | |
| return self._total_size_helper() | |
| def _evict_leaf_node( | |
| self, x: TreeNode, is_evict_mamba: bool | |
| ) -> Tuple[int, int, TreeNode, TreeNode]: | |
| assert ( | |
| x.full_lock_ref == 0 and x.mamba_lock_ref == 0 | |
| ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}" | |
| assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}" | |
| # 1. a leaf node, free full tokens and mamba | |
| self.token_to_kv_pool_allocator.free(x.value) | |
| full_num_evicted = len(x.value) | |
| self.req_to_token_pool.mamba_pool.free(x.mamba_value) | |
| mamba_num_evicted = len(x.mamba_value) | |
| # 2. get the next node, update the lru lists | |
| if is_evict_mamba: | |
| x_next = self.mamba_lru_list.get_prev_no_lock(x) | |
| else: | |
| x_next = self.full_lru_list.get_prev_leaf_no_lock(x) | |
| self.full_lru_list.remove_node(x) | |
| self.mamba_lru_list.remove_node(x) | |
| # 3. delete the leaf node | |
| self._delete_leaf(x) | |
| # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone | |
| x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) | |
| full_num_evicted += leaf_full_num_evicted | |
| return full_num_evicted, mamba_num_evicted, x, x_next | |
| def evict_mamba(self, mamba_num: int) -> None: | |
| if self.disable or mamba_num <= 0: | |
| return | |
| # get the least recently used node that is not locked, doesn't have to be a leaf | |
| x = self.mamba_lru_list.get_lru_no_lock() | |
| mamba_num_evicted = 0 | |
| # evict lru leaf nodes until mamba_num_tokens is reached | |
| while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)): | |
| assert x.mamba_value is not None, f"node has no mamba value, {x.id=}" | |
| assert ( | |
| len(x.mamba_value) == 1 | |
| ), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}" | |
| assert x != self.root_node, f"root node is not evictable, {x.id=}" | |
| assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}" | |
| if len(x.children) > 0: | |
| # 1. an internal node, free mamba tokens. | |
| self.req_to_token_pool.mamba_pool.free(x.mamba_value) | |
| mamba_num_evicted += len(x.mamba_value) | |
| # 2. get the next node, update the lru lists | |
| x_next = self.mamba_lru_list.get_prev_no_lock(x) | |
| self.mamba_lru_list.remove_node(x) | |
| # 3. tombstone the node | |
| self._tombstone_internal_node(x) | |
| else: | |
| _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True) | |
| mamba_num_evicted += mamba_evicted_delta | |
| x = x_next | |
| def evict(self, full_num_tokens: int) -> None: | |
| if self.disable or full_num_tokens <= 0: | |
| return | |
| full_num_evicted = 0 | |
| # get the least recently used leaf node that is not locked | |
| x = self.full_lru_list.get_leaf_lru_no_lock() | |
| while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x): | |
| assert ( | |
| x != self.root_node | |
| ), f"root node should not exist in full lru list, {x.id=}" | |
| full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False) | |
| full_num_evicted += full_num_evicted_delta | |
| # if parent has no more children, it is a leaf. It is possible that this node is lru, so | |
| # we need to get the first leaf node in the lru list | |
| if len(x.parent.children) == 0: | |
| x_next = self.full_lru_list.get_leaf_lru_no_lock() | |
| x = x_next | |
| def inc_lock_ref(self, node: TreeNode) -> Optional[int]: | |
| """ | |
| Increment the lock reference count for the node. | |
| It locks the full_lock_ref for nodes between the [last node, root), exclusive. | |
| It locks the mamba_lock_ref for current node if its mamba_value exists. | |
| """ | |
| if self.disable: | |
| return None | |
| # protect mamba value in current node if it exists | |
| if node.mamba_value is not None: | |
| if node.mamba_lock_ref == 0: | |
| self.mamba_evictable_size_ -= len(node.mamba_value) | |
| self.mamba_protected_size_ += len(node.mamba_value) | |
| node.mamba_lock_ref += 1 | |
| while node != self.root_node: | |
| # lock full from node to root | |
| assert ( | |
| node.full_lock_ref >= 0 | |
| ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}" | |
| if node.full_lock_ref == 0: | |
| self.full_evictable_size_ -= len(node.value) | |
| self.full_protected_size_ += len(node.value) | |
| node.full_lock_ref += 1 | |
| node = node.parent | |
| return None | |
| def dec_lock_ref(self, node: TreeNode): | |
| """ | |
| Decrement the lock reference count for the node. | |
| It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. | |
| It unlocks the mamba_lock_ref for current node if its mamba_value exists. | |
| """ | |
| if self.disable: | |
| return | |
| if node.mamba_value is not None: | |
| assert ( | |
| node.mamba_lock_ref > 0 | |
| ), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}" | |
| if node.mamba_lock_ref == 1: | |
| self.mamba_evictable_size_ += len(node.mamba_value) | |
| self.mamba_protected_size_ -= len(node.mamba_value) | |
| node.mamba_lock_ref -= 1 | |
| while node != self.root_node: | |
| assert ( | |
| node.full_lock_ref > 0 | |
| ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}" | |
| if node.full_lock_ref == 1: | |
| self.full_evictable_size_ += len(node.value) | |
| self.full_protected_size_ -= len(node.value) | |
| node.full_lock_ref -= 1 | |
| node = node.parent | |
| def sanity_check(self): | |
| self.full_lru_list.sanity_check(self) | |
| self.mamba_lru_list.sanity_check(self) | |
| def evictable_size(self) -> Tuple[int, int]: | |
| # Note: use full_evictable_size() and mamba_evictable_size() instead. | |
| raise NotImplementedError | |
| def full_evictable_size(self) -> int: | |
| return self.full_evictable_size_ | |
| def mamba_evictable_size(self) -> int: | |
| return self.mamba_evictable_size_ | |
| # Note: this is expensive, only use for debug | |
| def full_lru_list_evictable_size(self) -> int: | |
| return self.full_lru_list.sanity_check_evictable_size() | |
| # Note: this is expensive, only use for debug | |
| def mamba_lru_list_evictable_size(self) -> int: | |
| return self.mamba_lru_list.sanity_check_evictable_size() | |
| def protected_size(self) -> Tuple[int, int]: | |
| # Note: use full_protected_size() and mamba_protected_size() instead. | |
| raise NotImplementedError | |
| def full_protected_size(self) -> int: | |
| # protected size refers to the size of the full cache that is locked | |
| return self.full_protected_size_ | |
| def mamba_protected_size(self) -> int: | |
| # protected size refers to the size of the mamba cache that is locked | |
| return self.mamba_protected_size_ | |
| def all_values_flatten(self) -> torch.Tensor: | |
| values = [] | |
| def _dfs_helper(node: TreeNode): | |
| for _, child in node.children.items(): | |
| values.append(child.value) | |
| _dfs_helper(child) | |
| _dfs_helper(self.root_node) | |
| return torch.cat(values) | |
| ##### Internal Helper Functions ##### | |
| def _match_prefix_helper( | |
| self, key: RadixKey | |
| ) -> Tuple[List[torch.Tensor], TreeNode]: | |
| """ | |
| Mamba prefix matching helper. It factors in the sliding window size such that | |
| the matched node is guaranteed to either 1. connected to root without mamba tombstone, | |
| or 2. the number of matching tokens from the matched node to the last mamba tombstone | |
| node is greater than or equal to the sliding window size. | |
| """ | |
| node = self.root_node | |
| child_key = self.get_child_key_fn(key) | |
| value = [] | |
| best_value_len = 0 | |
| best_last_node = node | |
| while len(key) > 0 and child_key in node.children.keys(): | |
| child = node.children[child_key] | |
| # update best_value_len and best_last_node if needed | |
| if node.mamba_value is not None: | |
| best_value_len = len(value) | |
| best_last_node = node | |
| prefix_len = self.key_match_fn(child.key, key) | |
| if prefix_len < len(child.key): | |
| new_node = self._split_node(child.key, child, prefix_len) | |
| value.append(new_node.value) | |
| node = new_node | |
| break | |
| else: | |
| value.append(child.value) | |
| node = child | |
| key = key[prefix_len:] | |
| if len(key): | |
| child_key = self.get_child_key_fn(key) | |
| # handle best_value_len and best_last_node, for the case that last node is fully matched | |
| if node.mamba_value is not None: | |
| best_value_len = len(value) | |
| best_last_node = node | |
| # update time for matched nodes, and make nodes closer to root to be least recently used | |
| # this allows mamba to evict nodes closer to root first | |
| self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) | |
| self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) | |
| # This last_access_time is for sanity check, can be deleted after validation in production | |
| cur_time = time.monotonic() | |
| while node: | |
| node.last_access_time = cur_time | |
| cur_time -= 0.0001 | |
| node = node.parent | |
| return value[:best_value_len], best_last_node | |
| def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: | |
| # new_node -> child | |
| new_node = TreeNode() | |
| new_node.children = {self.get_child_key_fn(key[split_len:]): child} | |
| new_node.parent = child.parent | |
| new_node.mamba_value = None # mamba cache can not be split | |
| new_node.full_lock_ref = child.full_lock_ref | |
| new_node.mamba_lock_ref = 0 | |
| new_node.key = child.key[:split_len] | |
| new_node.value = child.value[:split_len] | |
| # child time should be later than parent's time for mamba tombstone | |
| child.last_access_time = time.monotonic() | |
| self.full_lru_list.remove_node(child) | |
| if child.mamba_value is not None: | |
| self.mamba_lru_list.remove_node(child) | |
| child.parent = new_node | |
| child.key = child.key[split_len:] | |
| child.value = child.value[split_len:] | |
| new_node.parent.children[self.get_child_key_fn(key)] = new_node | |
| # insert the new node and child into the lru lists, insert | |
| # parent first so that parent is after child in the lru list | |
| self.full_lru_list.insert_mru(new_node) | |
| self.full_lru_list.insert_mru(child) | |
| if child.mamba_value is not None: | |
| self.mamba_lru_list.insert_mru(child) | |
| return new_node | |
| def _insert_helper( | |
| self, | |
| node: TreeNode, | |
| key: RadixKey, | |
| value, | |
| mamba_value, | |
| ) -> Tuple[int, bool]: | |
| # Update the last access time from root to leaf, so that | |
| # mamba will tombstone the node closer to root first | |
| assert mamba_value is not None, "Mamba value should not be None here." | |
| node.last_access_time = time.monotonic() | |
| if node != self.root_node: | |
| self.full_lru_list.reset_node_mru(node) | |
| if node.mamba_value is not None: | |
| self.mamba_lru_list.reset_node_mru(node) | |
| if len(key) == 0: | |
| return 0, True | |
| child_key = self.get_child_key_fn(key) | |
| total_prefix_length = 0 | |
| while len(key) > 0 and child_key in node.children.keys(): | |
| node = node.children[child_key] | |
| node.last_access_time = time.monotonic() | |
| self.full_lru_list.reset_node_mru(node) | |
| if node.mamba_value is not None: | |
| self.mamba_lru_list.reset_node_mru(node) | |
| prefix_len = self.key_match_fn(node.key, key) | |
| total_prefix_length += prefix_len | |
| key = key[prefix_len:] | |
| value = value[prefix_len:] | |
| if prefix_len < len(node.key): | |
| new_node = self._split_node(node.key, node, prefix_len) | |
| node = new_node | |
| if len(key): | |
| child_key = self.get_child_key_fn(key) | |
| mamba_value_exist = False | |
| if len(key): | |
| new_node = TreeNode() | |
| new_node.parent = node | |
| new_node.key = key | |
| new_node.value = value | |
| new_node.mamba_value = mamba_value | |
| self.full_lru_list.insert_mru(new_node) | |
| self.full_evictable_size_ += len(value) | |
| self.mamba_evictable_size_ += len(mamba_value) | |
| self.mamba_lru_list.insert_mru(new_node) | |
| node.children[child_key] = new_node | |
| elif node.mamba_value is None: # add for mamba tombstone | |
| node.mamba_value = mamba_value | |
| self.mamba_evictable_size_ += len(mamba_value) | |
| self.mamba_lru_list.insert_mru(node) | |
| else: | |
| mamba_value_exist = True | |
| self.mamba_lru_list.reset_node_mru(node) | |
| return total_prefix_length, mamba_value_exist | |
| def _iteratively_delete_tombstone_leaf( | |
| self, node: TreeNode | |
| ) -> Tuple[TreeNode, int]: | |
| full_num_evicted = 0 | |
| while node.parent.mamba_value is None and len(node.parent.children) == 0: | |
| # root node is not evictable | |
| if node.parent == self.root_node: | |
| break | |
| # if locked, means node is in use, skip | |
| if node.parent.full_lock_ref > 0: | |
| break | |
| assert ( | |
| node.parent.mamba_lock_ref == 0 | |
| ), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}" | |
| # delete tombstone node evicts full tokens | |
| self.token_to_kv_pool_allocator.free(node.parent.value) | |
| full_num_evicted += len(node.parent.value) | |
| self.full_lru_list.remove_node(node.parent) | |
| self._delete_tombstone_leaf(node.parent) | |
| node = node.parent | |
| return node, full_num_evicted | |
| def _delete_leaf(self, node: TreeNode) -> None: | |
| assert ( | |
| node.mamba_value is not None | |
| ), f"Invariant violated: leaf node is a tombstone, {node.id=}" | |
| assert len(node.children) == 0, f"leaf node has children, {node.id=}" | |
| for k, v in node.parent.children.items(): | |
| if v == node: | |
| break | |
| del node.parent.children[k] | |
| self.full_evictable_size_ -= len(node.key) | |
| self.mamba_evictable_size_ -= len(node.mamba_value) | |
| def _tombstone_internal_node(self, node: TreeNode) -> None: | |
| assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" | |
| self.mamba_evictable_size_ -= len(node.mamba_value) | |
| node.mamba_value = None | |
| def _delete_tombstone_leaf(self, node: TreeNode) -> None: | |
| assert ( | |
| node.mamba_value is None | |
| ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}" | |
| assert len(node.children) == 0, f"leaf node has children, {node.id=}" | |
| for k, v in node.parent.children.items(): | |
| if v == node: | |
| break | |
| del node.parent.children[k] | |
| self.full_evictable_size_ -= len(node.key) | |
| def _collect_leaves(self) -> List[TreeNode]: | |
| ret_list = [] | |
| stack = [self.root_node] | |
| while stack: | |
| cur_node = stack.pop() | |
| if len(cur_node.children) == 0: | |
| ret_list.append(cur_node) | |
| else: | |
| stack.extend(cur_node.children.values()) | |
| return ret_list | |
| def _collect_nontombstone_nodes(self) -> List[TreeNode]: | |
| ret_list = [] | |
| stack = [self.root_node] | |
| while stack: | |
| cur_node = stack.pop() | |
| if cur_node.mamba_value is not None: | |
| ret_list.append(cur_node) | |
| stack.extend(cur_node.children.values()) | |
| return ret_list | |
| def _collect_all_nodes(self) -> List[TreeNode]: | |
| ret_list = [] | |
| stack = [self.root_node] | |
| while stack: | |
| cur_node = stack.pop() | |
| ret_list.append(cur_node) | |
| stack.extend(cur_node.children.values()) | |
| return ret_list | |
| def _print_helper(self, node: TreeNode, indent: int) -> None: | |
| """Prints the radix tree in a human-readable format.""" | |
| stack = [(node, indent)] | |
| while stack: | |
| current_node, current_indent = stack.pop() | |
| print( | |
| " " * current_indent, | |
| f"[{current_node.id}]", | |
| len(current_node.key), | |
| f"fr={current_node.full_lock_ref}", | |
| f"mr={current_node.mamba_lock_ref}", | |
| f"fll={self.full_lru_list.in_list(current_node)}", | |
| f"mll={self.mamba_lru_list.in_list(current_node)}", | |
| f"mv={current_node.mamba_value}", | |
| ) | |
| for key, child in current_node.children.items(): | |
| stack.append((child, current_indent + 2)) | |
| assert key == self.get_child_key_fn( | |
| child.key | |
| ), f"{key=}, {self.get_child_key_fn(child.key)=}" | |
| def _total_size_helper(self) -> Tuple[int, int]: | |
| total_size = 0 | |
| total_mamba_size = 0 | |
| stack = [self.root_node] | |
| while stack: | |
| current_node = stack.pop() | |
| total_size += len(current_node.value) | |
| if current_node.mamba_value is not None: | |
| total_mamba_size += len(current_node.mamba_value) | |
| for child in current_node.children.values(): | |
| if child.evicted: | |
| continue | |
| stack.append(child) | |
| return total_size, total_mamba_size | |
Xet Storage Details
- Size:
- 38.7 kB
- Xet hash:
- 44cf46ca6271bef30f9585c2116c04ac63978981845721ca2b7e5ae603425bfc
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.