| namespace radix_tree_v2 { | |
| using node_iterator_t = typename TreeNode::iterator_t; | |
| struct RadixTree::Impl { | |
| public: | |
| Impl(bool disabled, bool use_hicache, std::size_t page_size, std::size_t host_size, std::size_t threshold) | |
| : m_root(/*node_id_=*/0), | |
| m_evictable_size(0), | |
| m_protected_size(0), | |
| m_cached_vec(), | |
| m_node_map(), | |
| m_node_counter(1), // start from 1 to avoid confusion with root node | |
| disabled(disabled), | |
| use_hicache(use_hicache), | |
| page_size(page_size), | |
| threshold(threshold) { | |
| _assert(page_size > 0, "Page size must be greater than zero"); | |
| _assert(use_hicache == (host_size > 0), "Hierarchical cache is enabled iff host size > 0"); | |
| m_root.ref_count = 1; // root node is always protected | |
| m_cached_vec.reserve(page_size); // to avoid repeated allocations | |
| m_node_map[m_root.node_id] = &m_root; // add root to the map | |
| } | |
| TreeNode* split_node(node_iterator_t iterator, std::size_t prefix_length) { | |
| // from `parent -> old_node` to `parent-> new_node -> old_node` | |
| // the prefix part of the old node is moved to the new node | |
| auto old_node_ptr = std::move(iterator->second); | |
| auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++); | |
| auto* old_node = old_node_ptr.get(); | |
| auto* new_node = new_node_ptr.get(); | |
| auto* parent = old_node->parent(); | |
| // set up data structures | |
| split_prefix(new_node, old_node, prefix_length); | |
| // set up parent-child relationship | |
| add_child(new_node, std::move(old_node_ptr)); | |
| add_child(parent, std::move(new_node_ptr), iterator); | |
| m_node_map[new_node->node_id] = new_node; // add to the map | |
| return new_node; | |
| } | |
| // node: x -> [GPU] | |
| TreeNode* create_device_node(TreeNode* parent, token_vec_t vec, at::Tensor indices) { | |
| auto new_node_ptr = std::make_unique<TreeNode>(m_node_counter++); | |
| auto new_node = new_node_ptr.get(); | |
| new_node_ptr->_unsafe_tokens() = std::move(vec); | |
| new_node_ptr->_unsafe_device_indices() = std::move(indices); | |
| m_evictable_size += new_node_ptr->length(); | |
| add_child(parent, std::move(new_node_ptr)); | |
| m_node_map[new_node->node_id] = new_node; // add to the map | |
| return new_node; | |
| } | |
| // node: [GPU] -> x | |
| void remove_device_node(TreeNode* node) { | |
| _assert(node->on_gpu_only() && node->ref_count == 0); | |
| m_evictable_size -= node->length(); | |
| node->parent()->erase_child(get_key(node)); | |
| m_node_map.erase(node->node_id); // remove from the map | |
| } | |
| /** | |
| * @brief Walk the tree to find the node that matches the key. | |
| * If the key partially matches a node, it will split that node. | |
| * @return A pair containing the last node that matches the key and | |
| * the total prefix length matched (on gpu and cpu) so far. | |
| */ | |
| std::pair<TreeNode*, std::size_t> tree_walk(token_slice key) { | |
| _assert(key.size() % page_size == 0, "Key should be page-aligned"); | |
| std::size_t total_prefix_length = 0; | |
| TreeNode* node = &m_root; | |
| const auto now = std::chrono::steady_clock::now(); | |
| while (key.size() > 0) { | |
| const auto iterator = node->find_child(get_key(key)); | |
| if (iterator == node->end()) break; | |
| // walk to the child node | |
| node = iterator->second.get(); | |
| // at least `page_size` tokens are matched, and there may be more tokens to match | |
| // the return value prefix_length is no less than `page_size` | |
| const auto prefix_length = align(node->diff_key(key, page_size) + page_size); | |
| total_prefix_length += prefix_length; | |
| // split the node if the prefix is not the whole token vector | |
| if (prefix_length < node->length()) { | |
| return {split_node(iterator, prefix_length), total_prefix_length}; | |
| } | |
| // we have matched the whole key, continue to the next node | |
| node->access(now); | |
| key = key.subspan(prefix_length); | |
| } | |
| return {node, total_prefix_length}; | |
| } | |
| std::vector<TreeNode*> collect_leaves() const { | |
| std::vector<TreeNode*> leaves; | |
| std::vector<TreeNode*> stack = {}; | |
| for (const auto& [_, child] : m_root) { | |
| stack.push_back(child.get()); | |
| } | |
| while (!stack.empty()) { | |
| const auto node = stack.back(); | |
| stack.pop_back(); | |
| if (node->is_leaf()) { | |
| if (node->ref_count == 0) { | |
| leaves.push_back(node); | |
| } | |
| } else { | |
| for (const auto& [_, child] : *node) { | |
| stack.push_back(child.get()); | |
| } | |
| } | |
| } | |
| return leaves; | |
| } | |
| std::vector<TreeNode*> collect_leaves_device() const { | |
| // for non-hicache, every leaf device node is a leaf node (since no backup on host) | |
| if (!use_hicache) return collect_leaves(); | |
| std::vector<TreeNode*> leaves; | |
| std::vector<TreeNode*> stack = {}; | |
| for (const auto& [_, child] : m_root) { | |
| stack.push_back(child.get()); | |
| } | |
| while (!stack.empty()) { | |
| const auto node = stack.back(); | |
| stack.pop_back(); | |
| if (!node->on_gpu()) continue; // skip nodes that are not on GPU | |
| if (node->is_leaf_device()) { | |
| if (node->ref_count == 0) { | |
| leaves.push_back(node); | |
| } | |
| } else { | |
| for (const auto& [_, child] : *node) { | |
| stack.push_back(child.get()); | |
| } | |
| } | |
| } | |
| return leaves; | |
| } | |
| void lock_ref(TreeNode* node, bool increment) { | |
| if (node->is_root()) return; // skip root node | |
| _assert(node->on_gpu(), "Cannot lock reference on an evicted node"); | |
| if (increment) | |
| walk_to_root(node, [this](TreeNode* n) { | |
| if (n->ref_count == 0) { | |
| m_evictable_size -= n->length(); | |
| m_protected_size += n->length(); | |
| } | |
| n->ref_count++; | |
| }); | |
| else | |
| walk_to_root(node, [this](TreeNode* n) { | |
| _assert(n->ref_count != 0, "Cannot decrement reference count = zero"); | |
| n->ref_count--; | |
| if (n->ref_count == 0) { | |
| m_protected_size -= n->length(); | |
| m_evictable_size += n->length(); | |
| } | |
| }); | |
| } | |
| void lock_ref(NodeHandle node_ptr, bool increment) { | |
| return lock_ref(id2node(node_ptr), increment); | |
| } | |
| void lock(TreeNode* node) { | |
| return lock_ref(node, /*increment=*/true); | |
| } | |
| void unlock(TreeNode* node) { | |
| return lock_ref(node, /*increment=*/false); | |
| } | |
| std::size_t total_size() const { | |
| std::size_t size = 0; | |
| std::vector<const TreeNode*> stack = {&m_root}; | |
| while (!stack.empty()) { | |
| auto* node = stack.back(); | |
| stack.pop_back(); | |
| size += node->length(); | |
| for (const auto& [_, child] : *node) | |
| stack.push_back(child.get()); | |
| } | |
| return size; | |
| } | |
| std::size_t evictable_size() const { | |
| return m_evictable_size; | |
| } | |
| std::size_t protected_size() const { | |
| return m_protected_size; | |
| } | |
| std::size_t align(std::size_t size) const { | |
| return (size / page_size) * page_size; // align to page size | |
| } | |
| TreeNode* id2node(NodeHandle node_id) const { | |
| const auto iterator = m_node_map.find(node_id); | |
| _assert(iterator != m_node_map.end(), "Node not found in the map"); | |
| return iterator->second; | |
| } | |
| void reset() { | |
| _assert(m_root.ref_count == 1, "Root node must be protected during reset"); | |
| m_node_counter = 1; // reset node counter | |
| m_root.root_reset(); | |
| m_evictable_size = 0; | |
| m_protected_size = 0; | |
| m_node_map.clear(); | |
| m_node_map[m_root.node_id] = &m_root; // re-add root to the map | |
| } | |
| void debug_print(std::ostream& os) const; | |
| private: | |
| // some auxiliary functions | |
| token_vec_t& get_key(token_slice tokens) { | |
| _assert(tokens.size() >= page_size, "Key should be at least page-sized"); | |
| tokens = tokens.subspan(0, page_size); | |
| m_cached_vec.assign(tokens.begin(), tokens.end()); | |
| return m_cached_vec; | |
| } | |
| // justify for _unsafe call: we need to read the key part of the tokens | |
| token_vec_t& get_key(TreeNode* node) { | |
| return get_key(node->_unsafe_tokens()); | |
| } | |
| void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child) { | |
| parent->add_child(get_key(child.get()), std::move(child)); | |
| } | |
| void add_child(TreeNode* parent, std::unique_ptr<TreeNode>&& child, node_iterator_t it) { | |
| parent->add_child(it, std::move(child)); | |
| } | |
| TreeNode m_root; // root node of the tree | |
| std::size_t m_evictable_size; // number of evictable tokens on GPU (lock ref = 0) | |
| std::size_t m_protected_size; // number of protected tokens on GPU (lock ref > 0) | |
| token_vec_t m_cached_vec; // cached vector of tokens for the current operation | |
| std::unordered_map<std::size_t, TreeNode*> m_node_map; // map of node keys to nodes | |
| std::size_t m_node_counter; // counter for node IDs | |
| public: | |
| // some public constant configurations (without m_ prefix) | |
| const bool disabled; // whether the cache is enabled, or just a temporary cache | |
| const bool use_hicache; // whether to use the HiCache for this tree | |
| const std::size_t page_size; // size of each page in the cache | |
| const std::size_t threshold; // threshold for write_through | |
| }; | |
| } // namespace radix_tree_v2 | |
Xet Storage Details
- Size:
- 9.4 kB
- Xet hash:
- f726e6df4c82484464e1bc948506840e48e6eba8aa01be08bb8df900eb30aa9f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.