Spaces:
Running
Running
| from functools import partial, lru_cache | |
| from typing import Callable, Optional | |
| import numpy as np | |
| import ding | |
| from .default_helper import one_time_warning | |
| def njit(): | |
| """ | |
| Overview: | |
| Decorator to compile a function using numba. | |
| """ | |
| try: | |
| if ding.enable_numba: | |
| import numba | |
| from numba import njit as _njit | |
| version = numba.__version__ | |
| middle_version = version.split(".")[1] | |
| if int(middle_version) < 53: | |
| _njit = partial # noqa | |
| one_time_warning( | |
| "Due to your numba version <= 0.53.0, DI-engine disables it. And you can install \ | |
| numba==0.53.0 if you want to speed up something" | |
| ) | |
| else: | |
| _njit = partial | |
| except ImportError: | |
| one_time_warning("If you want to use numba to speed up segment tree, please install numba first") | |
| _njit = partial | |
| return _njit | |
| class SegmentTree: | |
| """ | |
| Overview: | |
| Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, | |
| non-leaf nodes are to do some operations on its left and right child. | |
| Interfaces: | |
| ``__init__``, ``reduce``, ``__setitem__``, ``__getitem__`` | |
| """ | |
| def __init__(self, capacity: int, operation: Callable, neutral_element: Optional[float] = None) -> None: | |
| """ | |
| Overview: | |
| Initialize the segment tree. Tree's root node is at index 1. | |
| Arguments: | |
| - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes), should be the power of 2. | |
| - operation (:obj:`function`): The operation function to construct the tree, e.g. sum, max, min, etc. | |
| - neutral_element (:obj:`float` or :obj:`None`): The value of the neutral element, which is used to init \ | |
| all nodes value in the tree. | |
| """ | |
| assert capacity > 0 and capacity & (capacity - 1) == 0 | |
| self.capacity = capacity | |
| self.operation = operation | |
| # Set neutral value(initial value) for all elements. | |
| if neutral_element is None: | |
| if operation == 'sum': | |
| neutral_element = 0. | |
| elif operation == 'min': | |
| neutral_element = np.inf | |
| elif operation == 'max': | |
| neutral_element = -np.inf | |
| else: | |
| raise ValueError("operation argument should be in min, max, sum (built in python functions).") | |
| self.neutral_element = neutral_element | |
| # Index 1 is the root; Index ranging in [capacity, 2 * capacity - 1] are the leaf nodes. | |
| # For each parent node with index i, left child is value[2*i] and right child is value[2*i+1]. | |
| self.value = np.full([capacity * 2], neutral_element) | |
| self._compile() | |
| def reduce(self, start: int = 0, end: Optional[int] = None) -> float: | |
| """ | |
| Overview: | |
| Reduce the tree in range ``[start, end)`` | |
| Arguments: | |
| - start (:obj:`int`): Start index(relative index, the first leaf node is 0), default set to 0 | |
| - end (:obj:`int` or :obj:`None`): End index(relative index), default set to ``self.capacity`` | |
| Returns: | |
| - reduce_result (:obj:`float`): The reduce result value, which is dependent on data type and operation | |
| """ | |
| # TODO(nyz) check if directly reduce from the array(value) can be faster | |
| if end is None: | |
| end = self.capacity | |
| assert (start < end) | |
| # Change to absolute leaf index by adding capacity. | |
| start += self.capacity | |
| end += self.capacity | |
| return _reduce(self.value, start, end, self.neutral_element, self.operation) | |
| def __setitem__(self, idx: int, val: float) -> None: | |
| """ | |
| Overview: | |
| Set ``leaf[idx] = val``; Then update the related nodes. | |
| Arguments: | |
| - idx (:obj:`int`): Leaf node index(relative index), should add ``capacity`` to change to absolute index. | |
| - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. | |
| """ | |
| assert (0 <= idx < self.capacity), idx | |
| # ``idx`` should add ``capacity`` to change to absolute index. | |
| _setitem(self.value, idx + self.capacity, val, self.operation) | |
| def __getitem__(self, idx: int) -> float: | |
| """ | |
| Overview: | |
| Get ``leaf[idx]`` | |
| Arguments: | |
| - idx (:obj:`int`): Leaf node ``index(relative index)``, add ``capacity`` to change to absolute index. | |
| Returns: | |
| - val (:obj:`float`): The value of ``leaf[idx]`` | |
| """ | |
| assert (0 <= idx < self.capacity) | |
| return self.value[idx + self.capacity] | |
| def _compile(self) -> None: | |
| """ | |
| Overview: | |
| Compile the functions using numba. | |
| """ | |
| f64 = np.array([0, 1], dtype=np.float64) | |
| f32 = np.array([0, 1], dtype=np.float32) | |
| i64 = np.array([0, 1], dtype=np.int64) | |
| for d in [f64, f32, i64]: | |
| _setitem(d, 0, 3.0, 'sum') | |
| _reduce(d, 0, 1, 0.0, 'min') | |
| _find_prefixsum_idx(d, 1, 0.5, 0.0) | |
| class SumSegmentTree(SegmentTree): | |
| """ | |
| Overview: | |
| Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``. | |
| Interfaces: | |
| ``__init__``, ``find_prefixsum_idx`` | |
| """ | |
| def __init__(self, capacity: int) -> None: | |
| """ | |
| Overview: | |
| Init sum segment tree by passing ``operation='sum'`` | |
| Arguments: | |
| - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). | |
| """ | |
| super(SumSegmentTree, self).__init__(capacity, operation='sum') | |
| def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int: | |
| """ | |
| Overview: | |
| Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) | |
| and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) | |
| Arguments: | |
| - prefixsum (:obj:`float`): The target prefixsum. | |
| - trust_caller (:obj:`bool`): Whether to trust caller, which means whether to check whether \ | |
| this tree's sum is greater than the input ``prefixsum`` by calling ``reduce`` function. | |
| Default set to True. | |
| Returns: | |
| - idx (:obj:`int`): Eligible index. | |
| """ | |
| if not trust_caller: | |
| assert 0 <= prefixsum <= self.reduce() + 1e-5, prefixsum | |
| return _find_prefixsum_idx(self.value, self.capacity, prefixsum, self.neutral_element) | |
| class MinSegmentTree(SegmentTree): | |
| """ | |
| Overview: | |
| Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``. | |
| Interfaces: | |
| ``__init__`` | |
| """ | |
| def __init__(self, capacity: int) -> None: | |
| """ | |
| Overview: | |
| Initialize sum segment tree by passing ``operation='min'`` | |
| Arguments: | |
| - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). | |
| """ | |
| super(MinSegmentTree, self).__init__(capacity, operation='min') | |
| def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None: | |
| """ | |
| Overview: | |
| Set ``tree[idx] = val``; Then update the related nodes. | |
| Arguments: | |
| - tree (:obj:`np.ndarray`): The tree array. | |
| - idx (:obj:`int`): The index of the leaf node. | |
| - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. | |
| - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. | |
| """ | |
| tree[idx] = val | |
| # Update from specified node to the root node | |
| while idx > 1: | |
| idx = idx >> 1 # To parent node idx | |
| left, right = tree[2 * idx], tree[2 * idx + 1] | |
| if operation == 'sum': | |
| tree[idx] = left + right | |
| elif operation == 'min': | |
| tree[idx] = min([left, right]) | |
| def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float: | |
| """ | |
| Overview: | |
| Reduce the tree in range ``[start, end)`` | |
| Arguments: | |
| - tree (:obj:`np.ndarray`): The tree array. | |
| - start (:obj:`int`): Start index(relative index, the first leaf node is 0). | |
| - end (:obj:`int`): End index(relative index). | |
| - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ | |
| all nodes value in the tree. | |
| - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. | |
| """ | |
| # Nodes in 【start, end) will be aggregated | |
| result = neutral_element | |
| while start < end: | |
| if start & 1: | |
| # If current start node (tree[start]) is a right child node, operate on start node and increase start by 1 | |
| if operation == 'sum': | |
| result = result + tree[start] | |
| elif operation == 'min': | |
| result = min([result, tree[start]]) | |
| start += 1 | |
| if end & 1: | |
| # If current end node (tree[end - 1]) is right child node, decrease end by 1 and operate on end node | |
| end -= 1 | |
| if operation == 'sum': | |
| result = result + tree[end] | |
| elif operation == 'min': | |
| result = min([result, tree[end]]) | |
| # Both start and end transform to respective parent node | |
| start = start >> 1 | |
| end = end >> 1 | |
| return result | |
| def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int: | |
| """ | |
| Overview: | |
| Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) | |
| and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) | |
| Arguments: | |
| - tree (:obj:`np.ndarray`): The tree array. | |
| - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). | |
| - prefixsum (:obj:`float`): The target prefixsum. | |
| - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ | |
| all nodes value in the tree. | |
| """ | |
| # The function is to find a non-leaf node's index which satisfies: | |
| # self.value[idx] > input prefixsum and self.value[idx + 1] <= input prefixsum | |
| # In other words, we can assume that there are intervals: [num_0, num_1), [num_1, num_2), ... [num_k, num_k+1), | |
| # the function is to find input prefixsum falls in which interval and return the interval's index. | |
| idx = 1 # start from root node | |
| while idx < capacity: | |
| child_base = 2 * idx | |
| if tree[child_base] > prefixsum: | |
| idx = child_base | |
| else: | |
| prefixsum -= tree[child_base] | |
| idx = child_base + 1 | |
| # Special case: The last element of ``self.value`` is neutral_element(0), | |
| # and caller wants to ``find_prefixsum_idx(root_value)``. | |
| # However, input prefixsum should be smaller than root_value. | |
| if idx == 2 * capacity - 1 and tree[idx] == neutral_element: | |
| tmp = idx | |
| while tmp >= capacity and tree[tmp] == neutral_element: | |
| tmp -= 1 | |
| if tmp != capacity: | |
| idx = tmp | |
| else: | |
| raise ValueError("All elements in tree are the neutral_element(0), can't find non-zero element") | |
| assert (tree[idx] != neutral_element) | |
| return idx - capacity | |