|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""OpTree: Optimized PyTree Utilities.""" |
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import difflib |
|
|
import functools |
|
|
import itertools |
|
|
import textwrap |
|
|
from collections import OrderedDict, defaultdict, deque |
|
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, overload |
|
|
|
|
|
import optree._C as _C |
|
|
from optree.accessors import PyTreeAccessor |
|
|
from optree.typing import NamedTuple, T, is_namedtuple_instance, is_structseq_instance |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
import builtins |
|
|
from collections.abc import Collection, Iterable, Mapping |
|
|
|
|
|
from optree.accessors import PyTreeEntry |
|
|
from optree.typing import ( |
|
|
MetaData, |
|
|
PyTree, |
|
|
PyTreeKind, |
|
|
PyTreeSpec, |
|
|
S, |
|
|
StructSequence, |
|
|
U, |
|
|
UnflattenFunc, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'MAX_RECURSION_DEPTH', |
|
|
'NONE_IS_NODE', |
|
|
'NONE_IS_LEAF', |
|
|
'tree_flatten', |
|
|
'tree_flatten_with_path', |
|
|
'tree_flatten_with_accessor', |
|
|
'tree_unflatten', |
|
|
'tree_iter', |
|
|
'tree_leaves', |
|
|
'tree_structure', |
|
|
'tree_paths', |
|
|
'tree_accessors', |
|
|
'tree_is_leaf', |
|
|
'all_leaves', |
|
|
'tree_map', |
|
|
'tree_map_', |
|
|
'tree_map_with_path', |
|
|
'tree_map_with_path_', |
|
|
'tree_map_with_accessor', |
|
|
'tree_map_with_accessor_', |
|
|
'tree_replace_nones', |
|
|
'tree_partition', |
|
|
'tree_transpose', |
|
|
'tree_transpose_map', |
|
|
'tree_transpose_map_with_path', |
|
|
'tree_transpose_map_with_accessor', |
|
|
'tree_broadcast_prefix', |
|
|
'broadcast_prefix', |
|
|
'tree_broadcast_common', |
|
|
'broadcast_common', |
|
|
'tree_broadcast_map', |
|
|
'tree_broadcast_map_with_path', |
|
|
'tree_broadcast_map_with_accessor', |
|
|
'tree_reduce', |
|
|
'tree_sum', |
|
|
'tree_max', |
|
|
'tree_min', |
|
|
'tree_all', |
|
|
'tree_any', |
|
|
'tree_flatten_one_level', |
|
|
'treespec_paths', |
|
|
'treespec_accessors', |
|
|
'treespec_entries', |
|
|
'treespec_entry', |
|
|
'treespec_children', |
|
|
'treespec_child', |
|
|
'treespec_one_level', |
|
|
'treespec_transform', |
|
|
'treespec_is_leaf', |
|
|
'treespec_is_strict_leaf', |
|
|
'treespec_is_one_level', |
|
|
'treespec_is_prefix', |
|
|
'treespec_is_suffix', |
|
|
'treespec_leaf', |
|
|
'treespec_none', |
|
|
'treespec_tuple', |
|
|
'treespec_list', |
|
|
'treespec_dict', |
|
|
'treespec_namedtuple', |
|
|
'treespec_ordereddict', |
|
|
'treespec_defaultdict', |
|
|
'treespec_deque', |
|
|
'treespec_structseq', |
|
|
'treespec_from_collection', |
|
|
'prefix_errors', |
|
|
] |
|
|
|
|
|
MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH |
|
|
"""Maximum recursion depth for pytree traversal. It is 1000. |
|
|
|
|
|
This limit prevents infinite recursion from causing an overflow of the C stack |
|
|
and crashing Python. |
|
|
""" |
|
|
NONE_IS_NODE: bool = False |
|
|
"""Literal constant that treats :data:`None` as a pytree non-leaf node.""" |
|
|
NONE_IS_LEAF: bool = True |
|
|
"""Literal constant that treats :data:`None` as a pytree leaf node.""" |
|
|
|
|
|
|
|
|
def tree_flatten( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[list[T], PyTreeSpec]: |
|
|
"""Flatten a pytree. |
|
|
|
|
|
See also :func:`tree_flatten_with_path` and :func:`tree_unflatten`. |
|
|
|
|
|
The flattening order (i.e., the order of elements in the output list) is deterministic, |
|
|
corresponding to a left-to-right depth-first tree traversal. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[1, 2, 3, 4, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
|
|
) |
|
|
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[1, 2, 3, 4, None, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
|
|
) |
|
|
>>> tree_flatten(1) |
|
|
([1], PyTreeSpec(*)) |
|
|
>>> tree_flatten(None) |
|
|
([], PyTreeSpec(None)) |
|
|
>>> tree_flatten(None, none_is_leaf=True) |
|
|
([None], PyTreeSpec(*, NoneIsLeaf)) |
|
|
|
|
|
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is |
|
|
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` |
|
|
if you want to keep the keys in the insertion order. |
|
|
|
|
|
>>> from collections import OrderedDict |
|
|
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) |
|
|
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[2, 3, 4, 1, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) |
|
|
) |
|
|
>>> tree_flatten(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[2, 3, 4, 1, None, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) |
|
|
) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the |
|
|
second element is a treespec representing the structure of the pytree. |
|
|
""" |
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
def tree_flatten_with_path( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[list[tuple[Any, ...]], list[T], PyTreeSpec]: |
|
|
"""Flatten a pytree and additionally record the paths. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_paths`, and :func:`treespec_paths`. |
|
|
|
|
|
The flattening order (i.e., the order of elements in the output list) is deterministic, |
|
|
corresponding to a left-to-right depth-first tree traversal. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_flatten_with_path(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)], |
|
|
[1, 2, 3, 4, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
|
|
) |
|
|
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)], |
|
|
[1, 2, 3, 4, None, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
|
|
) |
|
|
>>> tree_flatten_with_path(1) |
|
|
([()], [1], PyTreeSpec(*)) |
|
|
>>> tree_flatten_with_path(None) |
|
|
([], [], PyTreeSpec(None)) |
|
|
>>> tree_flatten_with_path(None, none_is_leaf=True) |
|
|
([()], [None], PyTreeSpec(*, NoneIsLeaf)) |
|
|
|
|
|
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is |
|
|
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` |
|
|
if you want to keep the keys in the insertion order. |
|
|
|
|
|
>>> from collections import OrderedDict |
|
|
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) |
|
|
>>> tree_flatten_with_path(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('d',)], |
|
|
[2, 3, 4, 1, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) |
|
|
) |
|
|
>>> tree_flatten_with_path(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[('b', 0), ('b', 1, 0), ('b', 1, 1), ('a',), ('c',), ('d',)], |
|
|
[2, 3, 4, 1, None, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) |
|
|
) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A triple ``(paths, leaves, treespec)``. The first element is a list of the paths to the leaf |
|
|
values, while each path is a tuple of the index or keys. The second element is a list of |
|
|
leaf values and the last element is a treespec representing the structure of the pytree. |
|
|
""" |
|
|
return _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
def tree_flatten_with_accessor( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[list[PyTreeAccessor], list[T], PyTreeSpec]: |
|
|
"""Flatten a pytree and additionally record the accessors. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_accessors`, and :func:`treespec_accessors`. |
|
|
|
|
|
The flattening order (i.e., the order of elements in the output list) is deterministic, |
|
|
corresponding to a left-to-right depth-first tree traversal. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_flatten_with_accessor(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[ |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) |
|
|
], |
|
|
[1, 2, 3, 4, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
|
|
) |
|
|
>>> tree_flatten_with_accessor(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[ |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) |
|
|
], |
|
|
[1, 2, 3, 4, None, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
|
|
) |
|
|
>>> tree_flatten_with_accessor(1) |
|
|
([PyTreeAccessor(*, ())], [1], PyTreeSpec(*)) |
|
|
>>> tree_flatten_with_accessor(None) |
|
|
([], [], PyTreeSpec(None)) |
|
|
>>> tree_flatten_with_accessor(None, none_is_leaf=True) |
|
|
([PyTreeAccessor(*, ())], [None], PyTreeSpec(*, NoneIsLeaf)) |
|
|
|
|
|
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is |
|
|
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` |
|
|
if you want to keep the keys in the insertion order. |
|
|
|
|
|
>>> from collections import OrderedDict |
|
|
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) |
|
|
>>> tree_flatten_with_accessor(tree) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[ |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>),)), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'collections.OrderedDict'>),)) |
|
|
], |
|
|
[2, 3, 4, 1, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *})) |
|
|
) |
|
|
>>> tree_flatten_with_accessor(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[ |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'collections.OrderedDict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'collections.OrderedDict'>),)), |
|
|
PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'collections.OrderedDict'>),)), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'collections.OrderedDict'>),)) |
|
|
], |
|
|
[2, 3, 4, 1, None, 5], |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf) |
|
|
) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A triple ``(accessors, leaves, treespec)``. The first element is a list of accessors to the |
|
|
leaf values. The second element is a list of leaf values and the last element is a treespec |
|
|
representing the structure of the pytree. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
return treespec.accessors(), leaves, treespec |
|
|
|
|
|
|
|
|
def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]: |
|
|
"""Reconstruct a pytree from the treespec and the leaves. |
|
|
|
|
|
The inverse of :func:`tree_flatten`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> leaves, treespec = tree_flatten(tree) |
|
|
>>> tree == tree_unflatten(treespec, leaves) |
|
|
True |
|
|
|
|
|
Args: |
|
|
treespec (PyTreeSpec): The treespec to reconstruct. |
|
|
leaves (iterable): The list of leaves to use for reconstruction. The list must match the |
|
|
number of leaves of the treespec. |
|
|
|
|
|
Returns: |
|
|
The reconstructed pytree, containing the ``leaves`` placed in the structure described by |
|
|
``treespec``. |
|
|
""" |
|
|
return treespec.unflatten(leaves) |
|
|
|
|
|
|
|
|
def tree_iter( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> Iterable[T]: |
|
|
"""Get an iterator over the leaves of a pytree. |
|
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_leaves`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> list(tree_iter(tree)) |
|
|
[1, 2, 3, 4, 5] |
|
|
>>> list(tree_iter(tree, none_is_leaf=True)) |
|
|
[1, 2, 3, 4, None, 5] |
|
|
>>> list(tree_iter(1)) |
|
|
[1] |
|
|
>>> list(tree_iter(None)) |
|
|
[] |
|
|
>>> list(tree_iter(None, none_is_leaf=True)) |
|
|
[None] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to iterate over. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
An iterator over the leaf values. |
|
|
""" |
|
|
return _C.PyTreeIter(tree, is_leaf, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
def tree_leaves( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> list[T]: |
|
|
"""Get the leaves of a pytree. |
|
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_iter`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_leaves(tree) |
|
|
[1, 2, 3, 4, 5] |
|
|
>>> tree_leaves(tree, none_is_leaf=True) |
|
|
[1, 2, 3, 4, None, 5] |
|
|
>>> tree_leaves(1) |
|
|
[1] |
|
|
>>> tree_leaves(None) |
|
|
[] |
|
|
>>> tree_leaves(None, none_is_leaf=True) |
|
|
[None] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A list of leaf values. |
|
|
""" |
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)[0] |
|
|
|
|
|
|
|
|
def tree_structure( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Get the treespec for a pytree. |
|
|
|
|
|
See also :func:`tree_flatten`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_structure(tree) |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
|
|
>>> tree_structure(tree, none_is_leaf=True) |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) |
|
|
>>> tree_structure(1) |
|
|
PyTreeSpec(*) |
|
|
>>> tree_structure(None) |
|
|
PyTreeSpec(None) |
|
|
>>> tree_structure(None, none_is_leaf=True) |
|
|
PyTreeSpec(*, NoneIsLeaf) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec object representing the structure of the pytree. |
|
|
""" |
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)[1] |
|
|
|
|
|
|
|
|
def tree_paths( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> list[tuple[Any, ...]]: |
|
|
"""Get the path entries to the leaves of a pytree. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`, and :func:`treespec_paths`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_paths(tree) |
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('d',)] |
|
|
>>> tree_paths(tree, none_is_leaf=True) |
|
|
[('a',), ('b', 0), ('b', 1, 0), ('b', 1, 1), ('c',), ('d',)] |
|
|
>>> tree_paths(1) |
|
|
[()] |
|
|
>>> tree_paths(None) |
|
|
[] |
|
|
>>> tree_paths(None, none_is_leaf=True) |
|
|
[()] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A list of the paths to the leaf values, while each path is a tuple of the index or keys. |
|
|
""" |
|
|
return _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace)[0] |
|
|
|
|
|
|
|
|
def tree_accessors( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> list[PyTreeAccessor]: |
|
|
"""Get the accessors to the leaves of a pytree. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_flatten_with_accessor`, and |
|
|
:func:`treespec_accessors`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> tree_accessors(tree) # doctest: +IGNORE_WHITESPACE |
|
|
[ |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) |
|
|
] |
|
|
>>> tree_accessors(tree, none_is_leaf=True) # doctest: +IGNORE_WHITESPACE |
|
|
[ |
|
|
PyTreeAccessor(*['a'], (MappingEntry(key='a', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['b'][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=0, type=<class 'tuple'>))), |
|
|
PyTreeAccessor(*['b'][1][0], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=0, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['b'][1][1], (MappingEntry(key='b', type=<class 'dict'>), SequenceEntry(index=1, type=<class 'tuple'>), SequenceEntry(index=1, type=<class 'list'>))), |
|
|
PyTreeAccessor(*['c'], (MappingEntry(key='c', type=<class 'dict'>),)), |
|
|
PyTreeAccessor(*['d'], (MappingEntry(key='d', type=<class 'dict'>),)) |
|
|
] |
|
|
>>> tree_accessors(1) |
|
|
[PyTreeAccessor(*, ())] |
|
|
>>> tree_accessors(None) |
|
|
[] |
|
|
>>> tree_accessors(None, none_is_leaf=True) |
|
|
[PyTreeAccessor(*, ())] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to flatten. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A list of accessors to the leaf values. |
|
|
""" |
|
|
return _C.flatten(tree, is_leaf, none_is_leaf, namespace)[1].accessors() |
|
|
|
|
|
|
|
|
def tree_is_leaf( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> bool: |
|
|
"""Test whether the given object is a leaf node. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_leaves`, and :func:`all_leaves`. |
|
|
|
|
|
>>> tree_is_leaf(1) |
|
|
True |
|
|
>>> tree_is_leaf(None) |
|
|
False |
|
|
>>> tree_is_leaf(None, none_is_leaf=True) |
|
|
True |
|
|
>>> tree_is_leaf({'a': 1, 'b': (2, 3)}) |
|
|
False |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to check if it is a leaf node. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than a leaf. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A boolean indicating if the given object is a leaf node. |
|
|
""" |
|
|
return _C.is_leaf(tree, is_leaf, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
def all_leaves( |
|
|
iterable: Iterable[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> bool: |
|
|
"""Test whether all elements in the given iterable are all leaves. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_leaves`, and :func:`tree_is_leaf`. |
|
|
|
|
|
>>> tree = {'a': [1, 2, 3]} |
|
|
>>> all_leaves(tree_leaves(tree)) |
|
|
True |
|
|
>>> all_leaves([tree]) |
|
|
False |
|
|
>>> all_leaves([1, 2, None, 3]) |
|
|
False |
|
|
>>> all_leaves([1, 2, None, 3], none_is_leaf=True) |
|
|
True |
|
|
|
|
|
Note that this function iterates and checks the elements in the input iterable object, which |
|
|
uses the :func:`iter` function. For dictionaries, ``iter(d)`` for a dictionary ``d`` iterates |
|
|
the keys of the dictionary, not the values. |
|
|
|
|
|
>>> list({'a': 1, 'b': (2, 3)}) |
|
|
['a', 'b'] |
|
|
>>> all_leaves({'a': 1, 'b': (2, 3)}) |
|
|
True |
|
|
|
|
|
This function is useful in advanced cases. For example, if a library allows arbitrary map |
|
|
operations on a flat list of leaves it may want to check if the result is still a flat list |
|
|
of leaves. |
|
|
|
|
|
Args: |
|
|
iterable (iterable): A iterable of leaves. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than a leaf. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A boolean indicating if all elements in the input iterable are leaves. |
|
|
""" |
|
|
return _C.all_leaves(iterable, is_leaf, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
def tree_map( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_map_`, :func:`tree_map_with_path`, :func:`tree_map_with_path_`, |
|
|
and :func:`tree_broadcast_map`. |
|
|
|
|
|
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) |
|
|
{'x': 8, 'y': (43, 65)} |
|
|
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (43, 65), 'z': None} |
|
|
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': False, 'y': (False, False), 'z': None} |
|
|
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True) |
|
|
{'x': False, 'y': (False, False), 'z': True} |
|
|
|
|
|
If multiple inputs are given, the structure of the tree is taken from the first input; |
|
|
subsequent inputs need only have ``tree`` as a prefix: |
|
|
|
|
|
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) |
|
|
[[5, 7, 9], [6, 1, 2]] |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by |
|
|
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` |
|
|
is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
return treespec.unflatten(map(func, *flat_args)) |
|
|
|
|
|
|
|
|
def tree_map_( |
|
|
func: Callable[..., Any], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[T]: |
|
|
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_with_path`, and :func:`tree_map_with_path_`. |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function |
|
|
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf |
|
|
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
deque(map(func, *flat_args), maxlen=0) |
|
|
return tree |
|
|
|
|
|
|
|
|
def tree_map_with_path( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path_`. |
|
|
|
|
|
>>> tree_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)}) |
|
|
{'x': (1, 7), 'y': ((2, 42), (2, 64))} |
|
|
>>> tree_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (44, 66), 'z': None} |
|
|
>>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}) |
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}} |
|
|
>>> tree_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True) |
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}} |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra paths. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by |
|
|
``func(p, x, *xs)`` where ``(p, x)`` are the path and value at the corresponding leaf in |
|
|
``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
paths, leaves, treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
return treespec.unflatten(map(func, paths, *flat_args)) |
|
|
|
|
|
|
|
|
def tree_map_with_path_( |
|
|
func: Callable[..., Any], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[T]: |
|
|
"""Like :func:`tree_map_with_path`, but do an inplace call on each leaf and return the original tree. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_path`. |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra paths. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function |
|
|
``func(p, x, *xs)`` (not the return value) where ``(p, x)`` are the path and value at the |
|
|
corresponding leaf in ``tree`` and ``xs`` is the tuple of values at values at corresponding |
|
|
nodes in ``rests``. |
|
|
""" |
|
|
paths, leaves, treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
deque(map(func, paths, *flat_args), maxlen=0) |
|
|
return tree |
|
|
|
|
|
|
|
|
def tree_map_with_accessor( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_accessor_`. |
|
|
|
|
|
>>> tree_map_with_accessor(lambda a, x: f'{a.codify("tree")} = {x!r}', {'x': 7, 'y': (42, 64)}) |
|
|
{'x': "tree['x'] = 7", 'y': ("tree['y'][0] = 42", "tree['y'][1] = 64")} |
|
|
>>> tree_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (44, 66), 'z': None} |
|
|
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
|
|
... lambda a, x: a, |
|
|
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, |
|
|
... ) |
|
|
{ |
|
|
'x': PyTreeAccessor(*['x'], ...), |
|
|
'y': ( |
|
|
PyTreeAccessor(*['y'][0], ...), |
|
|
PyTreeAccessor(*['y'][1], ...) |
|
|
), |
|
|
'z': {1.5: None} |
|
|
} |
|
|
>>> tree_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
|
|
... lambda a, x: a, |
|
|
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, |
|
|
... none_is_leaf=True, |
|
|
... ) |
|
|
{ |
|
|
'x': PyTreeAccessor(*['x'], ...), |
|
|
'y': ( |
|
|
PyTreeAccessor(*['y'][0], ...), |
|
|
PyTreeAccessor(*['y'][1], ...) |
|
|
), |
|
|
'z': { |
|
|
1.5: PyTreeAccessor(*['z'][1.5], ...) |
|
|
} |
|
|
} |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra accessors. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the same structure as ``tree`` but with the value at each leaf given by |
|
|
``func(a, x, *xs)`` where ``(a, x)`` are the accessor and value at the corresponding leaf in |
|
|
``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
return treespec.unflatten(map(func, treespec.accessors(), *flat_args)) |
|
|
|
|
|
|
|
|
def tree_map_with_accessor_( |
|
|
func: Callable[..., Any], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[T]: |
|
|
"""Like :func:`tree_map_with_accessor`, but do an inplace call on each leaf and return the original tree. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_`, and :func:`tree_map_with_accessor`. |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra accessors. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The original ``tree`` with the value at each leaf is given by the side-effect of function |
|
|
``func(a, x, *xs)`` (not the return value) where ``(a, x)`` are the accessor and value at |
|
|
the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at values at |
|
|
corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] |
|
|
deque(map(func, treespec.accessors(), *flat_args), maxlen=0) |
|
|
return tree |
|
|
|
|
|
|
|
|
def tree_replace_nones( |
|
|
sentinel: S, |
|
|
tree: PyTree[T | None], |
|
|
/, |
|
|
namespace: str = '', |
|
|
) -> PyTree[T | S]: |
|
|
"""Replace :data:`None` in ``tree`` with ``sentinel``. |
|
|
|
|
|
See also :func:`tree_flatten` and :func:`tree_map`. |
|
|
|
|
|
>>> tree_replace_nones(0, {'a': 1, 'b': None, 'c': (2, None)}) |
|
|
{'a': 1, 'b': 0, 'c': (2, 0)} |
|
|
>>> tree_replace_nones(0, None) |
|
|
0 |
|
|
|
|
|
Args: |
|
|
sentinel (object): The value to replace :data:`None` with. |
|
|
tree (pytree): A pytree to be transformed. |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the same structure as ``tree`` but with :data:`None` replaced. |
|
|
""" |
|
|
if tree is None: |
|
|
return sentinel |
|
|
return tree_map( |
|
|
lambda x: x if x is not None else sentinel, |
|
|
tree, |
|
|
none_is_leaf=True, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_partition( |
|
|
predicate: Callable[[T], bool], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[PyTree[T | None], PyTree[T | None]]: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_partition( |
|
|
predicate: Callable[[T], bool], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
fillvalue: S, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[PyTree[T | S], PyTree[T | S]]: ... |
|
|
|
|
|
|
|
|
def tree_partition( |
|
|
predicate: Callable[[T], bool], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
fillvalue: S = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[PyTree[T | S], PyTree[T | S]]: |
|
|
"""Partition a tree into the left and right part by the given predicate function. |
|
|
|
|
|
See also :func:`tree_transpose_map`. |
|
|
|
|
|
>>> left, right = tree_partition(lambda x: x > 10, {'x': 7, 'y': (42, 64)}) |
|
|
>>> left |
|
|
{'x': None, 'y': (42, 64)} |
|
|
>>> right |
|
|
{'x': 7, 'y': (None, None)} |
|
|
|
|
|
Instead of :data:`None`, one can also use a different sentinel value: |
|
|
|
|
|
>>> sentinel = object() |
|
|
>>> left, right = tree_partition(lambda x: x > 10, {'x': 7, 'y': (42, 64)}, fillvalue=sentinel) |
|
|
>>> left # doctest: +ELLIPSIS |
|
|
{'x': <object object at ...>, 'y': (42, 64)} |
|
|
>>> right # doctest: +ELLIPSIS |
|
|
{'x': 7, 'y': (<object object at ...>, <object object at ...>)} |
|
|
|
|
|
Args: |
|
|
predicate (callable): A function that takes a leaf value as argument, and splits/partitions |
|
|
it into the left or right tree based on the predicates return value. |
|
|
tree (pytree): A pytree to be split, with each leaf providing the first positional |
|
|
argument to function ``predicate``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
fillvalue (object, optional): A sentinel value to retain the tree structure. |
|
|
(default: :data:`None`) |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
Two pytrees with the same structure as ``tree`` but with orthogonal leaves based on the |
|
|
``predicate`` function. The first pytree contains all leaves where ``predicate`` evaluates |
|
|
to ``True``, the second for ``False``. The removed nodes in both trees are filled with |
|
|
``fillvalue`` to keep the original tree structure. |
|
|
""" |
|
|
return tree_transpose_map( |
|
|
lambda x: (x, fillvalue) if predicate(x) else (fillvalue, x), |
|
|
tree, |
|
|
inner_treespec=tree_structure((0, 0), none_is_leaf=none_is_leaf), |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def tree_transpose( |
|
|
outer_treespec: PyTreeSpec, |
|
|
inner_treespec: PyTreeSpec, |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
) -> PyTree[T]: |
|
|
"""Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_structure`, and :func:`tree_transpose_map`. |
|
|
|
|
|
>>> outer_treespec = tree_structure({'a': 1, 'b': 2, 'c': (3, 4)}) |
|
|
>>> outer_treespec |
|
|
PyTreeSpec({'a': *, 'b': *, 'c': (*, *)}) |
|
|
>>> inner_treespec = tree_structure((1, 2)) |
|
|
>>> inner_treespec |
|
|
PyTreeSpec((*, *)) |
|
|
>>> tree = {'a': (1, 2), 'b': (3, 4), 'c': ((5, 6), (7, 8))} |
|
|
>>> tree_transpose(outer_treespec, inner_treespec, tree) |
|
|
({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)}) |
|
|
|
|
|
For performance reasons, this function is only checks for the number of leaves in the input |
|
|
pytree, not the structure. The result is only enumerated up to the original order of leaves in |
|
|
``tree``, then transpose depends on the number of leaves in structure (inner, outer). The caller |
|
|
is responsible for ensuring that the input pytree has a prefix structure of ``outer_treespec`` |
|
|
followed by a prefix structure of ``inner_treespec``. Otherwise, the result may be incorrect. |
|
|
|
|
|
>>> tree_transpose(outer_treespec, inner_treespec, list(range(1, 9))) |
|
|
({'a': 1, 'b': 3, 'c': (5, 7)}, {'a': 2, 'b': 4, 'c': (6, 8)}) |
|
|
|
|
|
Args: |
|
|
outer_treespec (PyTreeSpec): A treespec object representing the outer structure of the pytree. |
|
|
inner_treespec (PyTreeSpec): A treespec object representing the inner structure of the pytree. |
|
|
tree (pytree): A pytree to be transposed. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
|
|
|
Returns: |
|
|
A new pytree with the same structure as ``inner_treespec`` but with the value at each leaf |
|
|
has the same structure as ``outer_treespec``. |
|
|
""" |
|
|
if outer_treespec.none_is_leaf != inner_treespec.none_is_leaf: |
|
|
raise ValueError('Tree structures must have the same none_is_leaf value.') |
|
|
outer_size = outer_treespec.num_leaves |
|
|
inner_size = inner_treespec.num_leaves |
|
|
if outer_size == 0 or inner_size == 0: |
|
|
raise ValueError('Tree structures must have at least one leaf.') |
|
|
if ( |
|
|
outer_treespec.namespace |
|
|
and inner_treespec.namespace |
|
|
and outer_treespec.namespace != inner_treespec.namespace |
|
|
): |
|
|
raise ValueError( |
|
|
f'Tree structures must have the same namespace, ' |
|
|
f'got {outer_treespec.namespace!r} vs. {inner_treespec.namespace!r}.', |
|
|
) |
|
|
|
|
|
leaves, treespec = tree_flatten( |
|
|
tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=outer_treespec.none_is_leaf, |
|
|
namespace=outer_treespec.namespace or inner_treespec.namespace, |
|
|
) |
|
|
if treespec.num_leaves != outer_size * inner_size: |
|
|
expected_treespec = outer_treespec.compose(inner_treespec) |
|
|
raise TypeError(f'Tree structure mismatch; expected: {expected_treespec}, got: {treespec}.') |
|
|
|
|
|
grouped = [ |
|
|
leaves[offset : offset + inner_size] |
|
|
for offset in range(0, outer_size * inner_size, inner_size) |
|
|
] |
|
|
transposed = zip(*grouped) |
|
|
subtrees = map(outer_treespec.unflatten, transposed) |
|
|
return inner_treespec.unflatten(subtrees) |
|
|
|
|
|
|
|
|
def tree_transpose_map( |
|
|
func: Callable[..., PyTree[U]], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
inner_treespec: PyTreeSpec | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args to produce a new pytree with transposed structure. |
|
|
|
|
|
See also :func:`tree_map`, :func:`tree_map_with_path`, and :func:`tree_transpose`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda x: {'identity': x, 'double': 2 * x}, |
|
|
... tree, |
|
|
... ) |
|
|
{ |
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, |
|
|
'double': {'b': (4, [6, 8]), 'a': 2, 'c': (10, 12)} |
|
|
} |
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda x: {'identity': x, 'double': (x, x)}, |
|
|
... tree, |
|
|
... ) |
|
|
{ |
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, |
|
|
'double': ( |
|
|
{'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, |
|
|
{'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
) |
|
|
} |
|
|
>>> tree_transpose_map( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda x: {'identity': x, 'double': (x, x)}, |
|
|
... tree, |
|
|
... inner_treespec=tree_structure({'identity': 0, 'double': 0}), |
|
|
... ) |
|
|
{ |
|
|
'identity': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)}, |
|
|
'double': {'b': ((2, 2), [(3, 3), (4, 4)]), 'a': (1, 1), 'c': ((5, 5), (6, 6))} |
|
|
} |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
inner_treespec (PyTreeSpec, optional): The treespec object representing the inner structure |
|
|
of the result pytree. If not specified, the inner structure is inferred from the result |
|
|
of the function ``func`` on the first leaf. (default: :data:`None`) |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new nested pytree with the same structure as ``inner_treespec`` but with the value at each |
|
|
leaf has the same structure as ``tree``. The subtree at each leaf is given by the result of |
|
|
function ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and |
|
|
``xs`` is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, outer_treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
if outer_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.') |
|
|
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests] |
|
|
outputs = list(map(func, *flat_args)) |
|
|
|
|
|
if inner_treespec is None: |
|
|
inner_treespec = tree_structure( |
|
|
outputs[0], |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
if inner_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') |
|
|
|
|
|
grouped = [inner_treespec.flatten_up_to(o) for o in outputs] |
|
|
transposed = zip(*grouped) |
|
|
subtrees = map(outer_treespec.unflatten, transposed) |
|
|
return inner_treespec.unflatten(subtrees) |
|
|
|
|
|
|
|
|
def tree_transpose_map_with_path( |
|
|
func: Callable[..., PyTree[U]], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
inner_treespec: PyTreeSpec | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree with transposed structure. |
|
|
|
|
|
See also :func:`tree_map_with_path`, :func:`tree_transpose_map`, and :func:`tree_transpose`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda p, x: {'depth': len(p), 'value': x}, |
|
|
... tree, |
|
|
... ) |
|
|
{ |
|
|
'depth': {'b': (2, [3, 3]), 'a': 1, 'c': (2, 2)}, |
|
|
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
} |
|
|
>>> tree_transpose_map_with_path( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda p, x: {'path': p, 'value': x}, |
|
|
... tree, |
|
|
... inner_treespec=tree_structure({'path': 0, 'value': 0}), |
|
|
... ) |
|
|
{ |
|
|
'path': { |
|
|
'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), |
|
|
'a': ('a',), |
|
|
'c': (('c', 0), ('c', 1)) |
|
|
}, |
|
|
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
} |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra paths. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
inner_treespec (PyTreeSpec, optional): The treespec object representing the inner structure |
|
|
of the result pytree. If not specified, the inner structure is inferred from the result |
|
|
of the function ``func`` on the first leaf. (default: :data:`None`) |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new nested pytree with the same structure as ``inner_treespec`` but with the value at each |
|
|
leaf has the same structure as ``tree``. The subtree at each leaf is given by the result of |
|
|
function ``func(p, x, *xs)`` where ``(p, x)`` are the path and value at the corresponding |
|
|
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
paths, leaves, outer_treespec = _C.flatten_with_path(tree, is_leaf, none_is_leaf, namespace) |
|
|
if outer_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.') |
|
|
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests] |
|
|
outputs = list(map(func, paths, *flat_args)) |
|
|
|
|
|
if inner_treespec is None: |
|
|
inner_treespec = tree_structure( |
|
|
outputs[0], |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
if inner_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') |
|
|
|
|
|
grouped = [inner_treespec.flatten_up_to(o) for o in outputs] |
|
|
transposed = zip(*grouped) |
|
|
subtrees = map(outer_treespec.unflatten, transposed) |
|
|
return inner_treespec.unflatten(subtrees) |
|
|
|
|
|
|
|
|
def tree_transpose_map_with_accessor( |
|
|
func: Callable[..., PyTree[U]], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[S], |
|
|
inner_treespec: PyTreeSpec | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree with transposed structure. |
|
|
|
|
|
See also :func:`tree_map_with_accessor`, :func:`tree_transpose_map`, and :func:`tree_transpose`. |
|
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda a, x: {'depth': len(a), 'code': a.codify('tree'), 'value': x}, |
|
|
... tree, |
|
|
... ) |
|
|
{ |
|
|
'depth': { |
|
|
'b': (2, [3, 3]), |
|
|
'a': 1, |
|
|
'c': (2, 2) |
|
|
}, |
|
|
'code': { |
|
|
'b': ("tree['b'][0]", ["tree['b'][1][0]", "tree['b'][1][1]"]), |
|
|
'a': "tree['a']", |
|
|
'c': ("tree['c'][0]", "tree['c'][1]") |
|
|
}, |
|
|
'value': { |
|
|
'b': (2, [3, 4]), |
|
|
'a': 1, |
|
|
'c': (5, 6) |
|
|
} |
|
|
} |
|
|
>>> tree_transpose_map_with_accessor( # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
|
|
... lambda a, x: {'path': a.path, 'accessor': a, 'value': x}, |
|
|
... tree, |
|
|
... inner_treespec=tree_structure({'path': 0, 'accessor': 0, 'value': 0}), |
|
|
... ) |
|
|
{ |
|
|
'path': { |
|
|
'b': (('b', 0), [('b', 1, 0), ('b', 1, 1)]), |
|
|
'a': ('a',), |
|
|
'c': (('c', 0), ('c', 1)) |
|
|
}, |
|
|
'accessor': { |
|
|
'b': ( |
|
|
PyTreeAccessor(*['b'][0], ...), |
|
|
[ |
|
|
PyTreeAccessor(*['b'][1][0], ...), |
|
|
PyTreeAccessor(*['b'][1][1], ...) |
|
|
] |
|
|
), |
|
|
'a': PyTreeAccessor(*['a'], ...), |
|
|
'c': ( |
|
|
PyTreeAccessor(*['c'][0], ...), |
|
|
PyTreeAccessor(*['c'][1], ...) |
|
|
) |
|
|
}, |
|
|
'value': {'b': (2, [3, 4]), 'a': 1, 'c': (5, 6)} |
|
|
} |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra accessors. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the second positional |
|
|
argument and the corresponding path providing the first positional argument to function |
|
|
``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as |
|
|
``tree`` or has ``tree`` as a prefix. |
|
|
inner_treespec (PyTreeSpec, optional): The treespec object representing the inner structure |
|
|
of the result pytree. If not specified, the inner structure is inferred from the result |
|
|
of the function ``func`` on the first leaf. (default: :data:`None`) |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new nested pytree with the same structure as ``inner_treespec`` but with the value at each |
|
|
leaf has the same structure as ``tree``. The subtree at each leaf is given by the result of |
|
|
function ``func(a, x, *xs)`` where ``(a, x)`` are the accessor and value at the corresponding |
|
|
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. |
|
|
""" |
|
|
leaves, outer_treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
if outer_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The outer structure must have at least one leaf. Got: {outer_treespec}.') |
|
|
flat_args = [leaves] + [outer_treespec.flatten_up_to(r) for r in rests] |
|
|
outputs = list(map(func, outer_treespec.accessors(), *flat_args)) |
|
|
|
|
|
if inner_treespec is None: |
|
|
inner_treespec = tree_structure( |
|
|
outputs[0], |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
if inner_treespec.num_leaves == 0: |
|
|
raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') |
|
|
|
|
|
grouped = [inner_treespec.flatten_up_to(o) for o in outputs] |
|
|
transposed = zip(*grouped) |
|
|
subtrees = map(outer_treespec.unflatten, transposed) |
|
|
return inner_treespec.unflatten(subtrees) |
|
|
|
|
|
|
|
|
def tree_broadcast_prefix( |
|
|
prefix_tree: PyTree[T], |
|
|
full_tree: PyTree[S], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[T]: |
|
|
"""Return a pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``. |
|
|
|
|
|
See also :func:`broadcast_prefix`, :func:`tree_broadcast_common`, and :func:`treespec_is_prefix`. |
|
|
|
|
|
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be |
|
|
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**. |
|
|
|
|
|
This function returns a pytree with the same size as ``full_tree``. The leaves are replicated |
|
|
from ``prefix_tree``. The number of replicas is determined by the corresponding subtree in |
|
|
``full_tree``. |
|
|
|
|
|
>>> tree_broadcast_prefix(1, [2, 3, 4]) |
|
|
[1, 1, 1] |
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6]) |
|
|
[1, 2, 3] |
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. |
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) |
|
|
[1, 2, (3, 3)] |
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) |
|
|
[1, 2, {'a': 3, 'b': 3, 'c': (None, 3)}] |
|
|
>>> tree_broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) |
|
|
[1, 2, {'a': 3, 'b': 3, 'c': (3, 3)}] |
|
|
|
|
|
Args: |
|
|
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``. |
|
|
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A pytree of same structure of ``full_tree`` with broadcasted subtrees in ``prefix_tree``. |
|
|
""" |
|
|
|
|
|
def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]: |
|
|
subtreespec = tree_structure( |
|
|
subtree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
return subtreespec.unflatten(itertools.repeat(x, subtreespec.num_leaves)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tree_map( |
|
|
broadcast_leaves, |
|
|
prefix_tree, |
|
|
full_tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def broadcast_prefix( |
|
|
prefix_tree: PyTree[T], |
|
|
full_tree: PyTree[S], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> list[T]: |
|
|
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. |
|
|
|
|
|
See also :func:`tree_broadcast_prefix`, :func:`broadcast_common`, and :func:`treespec_is_prefix`. |
|
|
|
|
|
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be |
|
|
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**. |
|
|
|
|
|
This function returns a list of leaves with the same size as ``full_tree``. The leaves are |
|
|
replicated from ``prefix_tree``. The number of replicas is determined by the corresponding |
|
|
subtree in ``full_tree``. |
|
|
|
|
|
>>> broadcast_prefix(1, [2, 3, 4]) |
|
|
[1, 1, 1] |
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, 6]) |
|
|
[1, 2, 3] |
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, 6, 7]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: list arity mismatch; expected: 3, got: 4; list: [4, 5, 6, 7]. |
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, (6, 7)]) |
|
|
[1, 2, 3, 3] |
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}]) |
|
|
[1, 2, 3, 3, 3] |
|
|
>>> broadcast_prefix([1, 2, 3], [4, 5, {'a': 6, 'b': 7, 'c': (None, 8)}], none_is_leaf=True) |
|
|
[1, 2, 3, 3, 3, 3] |
|
|
|
|
|
Args: |
|
|
prefix_tree (pytree): A pytree with the prefix structure of ``full_tree``. |
|
|
full_tree (pytree): A pytree with the suffix structure of ``prefix_tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``. |
|
|
""" |
|
|
result: list[T] = [] |
|
|
|
|
|
def add_leaves(x: T, subtree: PyTree[S]) -> None: |
|
|
subtreespec = tree_structure( |
|
|
subtree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
result.extend(itertools.repeat(x, subtreespec.num_leaves)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tree_map_( |
|
|
add_leaves, |
|
|
prefix_tree, |
|
|
full_tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
return result |
|
|
|
|
|
|
|
|
def tree_broadcast_common( |
|
|
tree: PyTree[T], |
|
|
other_tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[PyTree[T], PyTree[T]]: |
|
|
"""Return two pytrees of common suffix structure of ``tree`` and ``other_tree`` with broadcasted subtrees. |
|
|
|
|
|
See also :func:`broadcast_common`, :func:`tree_broadcast_prefix`, and :func:`treespec_is_prefix`. |
|
|
|
|
|
If a ``suffix_tree`` is a suffix of a ``tree``, this means the ``suffix_tree`` can be |
|
|
constructed by replacing the leaves of ``tree`` with appropriate **subtrees**. |
|
|
|
|
|
This function returns two pytrees with the same structure. The tree structure is the common |
|
|
suffix structure of ``tree`` and ``other_tree``. The leaves are replicated from ``tree`` and |
|
|
``other_tree``. The number of replicas is determined by the corresponding subtree in the suffix |
|
|
structure. |
|
|
|
|
|
>>> tree_broadcast_common(1, [2, 3, 4]) |
|
|
([1, 1, 1], [2, 3, 4]) |
|
|
>>> tree_broadcast_common([1, 2, 3], [4, 5, 6]) |
|
|
([1, 2, 3], [4, 5, 6]) |
|
|
>>> tree_broadcast_common([1, 2, 3], [4, 5, 6, 7]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: list arity mismatch; expected: 3, got: 4. |
|
|
>>> tree_broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) |
|
|
([1, (2, 3), (4, 4)], [5, (6, 6), (7, 8)]) |
|
|
>>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}]) |
|
|
([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (None, 4)}], |
|
|
[5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}]) |
|
|
>>> tree_broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True) |
|
|
([1, {'a': (2, 3)}, {'a': 4, 'b': 4, 'c': (4, 4)}], |
|
|
[5, {'a': (6, 6)}, {'a': 7, 'b': 8, 'c': (None, 9)}]) |
|
|
>>> tree_broadcast_common([1, None], [None, 2]) |
|
|
([None, None], [None, None]) |
|
|
>>> tree_broadcast_common([1, None], [None, 2], none_is_leaf=True) |
|
|
([1, None], [None, 2]) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree has a common suffix structure of ``other_tree``. |
|
|
other_tree (pytree): A pytree has a common suffix structure of ``tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
Two pytrees of common suffix structure of ``tree`` and ``other_tree`` with broadcasted subtrees. |
|
|
""" |
|
|
leaves, treespec = _C.flatten(tree, is_leaf, none_is_leaf, namespace) |
|
|
other_leaves, other_treespec = _C.flatten(other_tree, is_leaf, none_is_leaf, namespace) |
|
|
common_suffix_treespec = treespec.broadcast_to_common_suffix(other_treespec) |
|
|
|
|
|
sentinel: T = object() |
|
|
common_suffix_tree: PyTree[T] = common_suffix_treespec.unflatten( |
|
|
itertools.repeat(sentinel, common_suffix_treespec.num_leaves), |
|
|
) |
|
|
|
|
|
def broadcast_leaves(x: T, subtree: PyTree[T]) -> PyTree[T]: |
|
|
subtreespec = tree_structure( |
|
|
subtree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
return subtreespec.unflatten(itertools.repeat(x, subtreespec.num_leaves)) |
|
|
|
|
|
broadcasted_tree: PyTree[T] = treespec.unflatten( |
|
|
map( |
|
|
broadcast_leaves, |
|
|
leaves, |
|
|
treespec.flatten_up_to(common_suffix_tree), |
|
|
), |
|
|
) |
|
|
other_broadcasted_tree: PyTree[T] = other_treespec.unflatten( |
|
|
map( |
|
|
broadcast_leaves, |
|
|
other_leaves, |
|
|
other_treespec.flatten_up_to(common_suffix_tree), |
|
|
), |
|
|
) |
|
|
return broadcasted_tree, other_broadcasted_tree |
|
|
|
|
|
|
|
|
def broadcast_common( |
|
|
tree: PyTree[T], |
|
|
other_tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[list[T], list[T]]: |
|
|
"""Return two lists of leaves in ``tree`` and ``other_tree`` broadcasted to match the number of leaves in the common suffix structure. |
|
|
|
|
|
See also :func:`tree_broadcast_common`, :func:`broadcast_prefix`, and :func:`treespec_is_prefix`. |
|
|
|
|
|
If a ``suffix_tree`` is a suffix of a ``tree``, this means the ``suffix_tree`` can be |
|
|
constructed by replacing the leaves of ``tree`` with appropriate **subtrees**. |
|
|
|
|
|
This function returns two pytrees with the same structure. The tree structure is the common |
|
|
suffix structure of ``tree`` and ``other_tree``. The leaves are replicated from ``tree`` and |
|
|
``other_tree``. The number of replicas is determined by the corresponding subtree in the suffix |
|
|
structure. |
|
|
|
|
|
>>> broadcast_common(1, [2, 3, 4]) |
|
|
([1, 1, 1], [2, 3, 4]) |
|
|
>>> broadcast_common([1, 2, 3], [4, 5, 6]) |
|
|
([1, 2, 3], [4, 5, 6]) |
|
|
>>> broadcast_common([1, 2, 3], [4, 5, 6, 7]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: list arity mismatch; expected: 3, got: 4. |
|
|
>>> broadcast_common([1, (2, 3), 4], [5, 6, (7, 8)]) |
|
|
([1, 2, 3, 4, 4], [5, 6, 6, 7, 8]) |
|
|
>>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}]) |
|
|
([1, 2, 3, 4, 4, 4], [5, 6, 6, 7, 8, 9]) |
|
|
>>> broadcast_common([1, {'a': (2, 3)}, 4], [5, 6, {'a': 7, 'b': 8, 'c': (None, 9)}], none_is_leaf=True) |
|
|
([1, 2, 3, 4, 4, 4, 4], [5, 6, 6, 7, 8, None, 9]) |
|
|
>>> broadcast_common([1, None], [None, 2]) |
|
|
([], []) |
|
|
>>> broadcast_common([1, None], [None, 2], none_is_leaf=True) |
|
|
([1, None], [None, 2]) |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree has a common suffix structure of ``other_tree``. |
|
|
other_tree (pytree): A pytree has a common suffix structure of ``tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
Two lists of leaves in ``tree`` and ``other_tree`` broadcasted to match the number of leaves |
|
|
in the common suffix structure. |
|
|
""" |
|
|
broadcasted_tree, other_broadcasted_tree = tree_broadcast_common( |
|
|
tree, |
|
|
other_tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
broadcasted_leaves: list[T] = [] |
|
|
other_broadcasted_leaves: list[T] = [] |
|
|
|
|
|
def add_leaves(x: T, y: T) -> None: |
|
|
broadcasted_leaves.append(x) |
|
|
other_broadcasted_leaves.append(y) |
|
|
|
|
|
tree_map_( |
|
|
add_leaves, |
|
|
broadcasted_tree, |
|
|
other_broadcasted_tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
return broadcasted_leaves, other_broadcasted_leaves |
|
|
|
|
|
|
|
|
def _tree_broadcast_common( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[T], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> tuple[PyTree[T], ...]: |
|
|
if not rests: |
|
|
return (tree,) |
|
|
if len(rests) == 1: |
|
|
return tree_broadcast_common( |
|
|
tree, |
|
|
rests[0], |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
broadcasted_tree = tree |
|
|
broadcasted_rests = list(rests) |
|
|
for _ in range(2): |
|
|
for i, rest in enumerate(rests): |
|
|
broadcasted_tree, broadcasted_rests[i] = tree_broadcast_common( |
|
|
broadcasted_tree, |
|
|
rest, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
return (broadcasted_tree, *broadcasted_rests) |
|
|
|
|
|
|
|
|
|
|
|
def tree_broadcast_map( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[T], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_broadcast_map_with_path`, :func:`tree_map`, :func:`tree_map_`, |
|
|
and :func:`tree_map_with_path`. |
|
|
|
|
|
If only one input is provided, this function is the same as :func:`tree_map`: |
|
|
|
|
|
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) |
|
|
{'x': 8, 'y': (43, 65)} |
|
|
>>> tree_broadcast_map(lambda x: x + 1, {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (43, 65), 'z': None} |
|
|
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': False, 'y': (False, False), 'z': None} |
|
|
>>> tree_broadcast_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}, none_is_leaf=True) |
|
|
{'x': False, 'y': (False, False), 'z': True} |
|
|
|
|
|
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure |
|
|
of all inputs: |
|
|
|
|
|
>>> tree_broadcast_map(lambda x, y: x * y, [5, 6, (3, 4)], [{'a': 7, 'b': 9}, [1, 2], 8]) |
|
|
[{'a': 35, 'b': 45}, [6, 12], (24, 32)] |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with |
|
|
each other and with ``tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but |
|
|
with the value at each leaf given by ``func(x, *xs)`` where ``x`` is the value at the |
|
|
corresponding leaf (may be broadcasted) in ``tree`` and ``xs`` is the tuple of values at |
|
|
corresponding leaves (may be broadcasted) in ``rests``. |
|
|
""" |
|
|
return tree_map( |
|
|
func, |
|
|
*_tree_broadcast_common( |
|
|
tree, |
|
|
*rests, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
), |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def tree_broadcast_map_with_path( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[T], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree paths to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_broadcast_map`, :func:`tree_map`, :func:`tree_map_`, |
|
|
and :func:`tree_map_with_path`. |
|
|
|
|
|
If only one input is provided, this function is the same as :func:`tree_map`: |
|
|
|
|
|
>>> tree_broadcast_map_with_path(lambda p, x: (len(p), x), {'x': 7, 'y': (42, 64)}) |
|
|
{'x': (1, 7), 'y': ((2, 42), (2, 64))} |
|
|
>>> tree_broadcast_map_with_path(lambda p, x: x + len(p), {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (44, 66), 'z': None} |
|
|
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}) |
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: None}} |
|
|
>>> tree_broadcast_map_with_path(lambda p, x: p, {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, none_is_leaf=True) |
|
|
{'x': ('x',), 'y': (('y', 0), ('y', 1)), 'z': {1.5: ('z', 1.5)}} |
|
|
|
|
|
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure |
|
|
of all inputs: |
|
|
|
|
|
>>> tree_broadcast_map_with_path( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda p, x, y: (p, x * y), |
|
|
... [5, 6, (3, 4)], |
|
|
... [{'a': 7, 'b': 9}, [1, 2], 8], |
|
|
... ) |
|
|
[ |
|
|
{'a': ((0, 'a'), 35), 'b': ((0, 'b'), 45)}, |
|
|
[((1, 0), 6), ((1, 1), 12)], |
|
|
(((2, 0), 24), ((2, 1), 32)) |
|
|
] |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra paths. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with |
|
|
each other and with ``tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but |
|
|
with the value at each leaf given by ``func(p, x, *xs)`` where ``(p, x)`` are the path and |
|
|
value at the corresponding leaf (may be broadcasted) in and ``xs`` is the tuple of values at |
|
|
corresponding leaves (may be broadcasted) in ``rests``. |
|
|
""" |
|
|
return tree_map_with_path( |
|
|
func, |
|
|
*_tree_broadcast_common( |
|
|
tree, |
|
|
*rests, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
), |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def tree_broadcast_map_with_accessor( |
|
|
func: Callable[..., U], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*rests: PyTree[T], |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTree[U]: |
|
|
"""Map a multi-input function over pytree args as well as the tree accessors to produce a new pytree. |
|
|
|
|
|
See also :func:`tree_broadcast_map`, :func:`tree_map`, :func:`tree_map_`, |
|
|
and :func:`tree_map_with_accessor`. |
|
|
|
|
|
If only one input is provided, this function is the same as :func:`tree_map`: |
|
|
|
|
|
>>> tree_broadcast_map_with_accessor(lambda a, x: (len(a), x), {'x': 7, 'y': (42, 64)}) |
|
|
{'x': (1, 7), 'y': ((2, 42), (2, 64))} |
|
|
>>> tree_broadcast_map_with_accessor(lambda a, x: x + len(a), {'x': 7, 'y': (42, 64), 'z': None}) |
|
|
{'x': 8, 'y': (44, 66), 'z': None} |
|
|
>>> tree_broadcast_map_with_accessor( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda a, x: a.codify('tree'), |
|
|
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, |
|
|
... ) |
|
|
{ |
|
|
'x': "tree['x']", |
|
|
'y': ("tree['y'][0]", "tree['y'][1]"), |
|
|
'z': {1.5: None} |
|
|
} |
|
|
>>> tree_broadcast_map_with_accessor( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda a, x: a.codify('tree'), |
|
|
... {'x': 7, 'y': (42, 64), 'z': {1.5: None}}, |
|
|
... none_is_leaf=True, |
|
|
... ) |
|
|
{ |
|
|
'x': "tree['x']", |
|
|
'y': ("tree['y'][0]", "tree['y'][1]"), |
|
|
'z': {1.5: "tree['z'][1.5]"} |
|
|
} |
|
|
|
|
|
If multiple inputs are given, all input trees will be broadcasted to the common suffix structure |
|
|
of all inputs: |
|
|
|
|
|
>>> tree_broadcast_map_with_accessor( # doctest: +IGNORE_WHITESPACE |
|
|
... lambda a, x, y: f'{a.codify("tree")} = {x * y}', |
|
|
... [5, 6, (3, 4)], |
|
|
... [{'a': 7, 'b': 9}, [1, 2], 8], |
|
|
... ) |
|
|
[ |
|
|
{'a': "tree[0]['a'] = 35", 'b': "tree[0]['b'] = 45"}, |
|
|
['tree[1][0] = 6', 'tree[1][1] = 12'], |
|
|
('tree[2][0] = 24', 'tree[2][1] = 32') |
|
|
] |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes ``2 + len(rests)`` arguments, to be applied at the |
|
|
corresponding leaves of the pytrees with extra accessors. |
|
|
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional |
|
|
argument to function ``func``. |
|
|
rests (tuple of pytree): A tuple of pytrees, they should have a common suffix structure with |
|
|
each other and with ``tree``. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A new pytree with the structure as the common suffix structure of ``tree`` and ``rests`` but |
|
|
with the value at each leaf given by ``func(a, x, *xs)`` where ``(a, x)`` are the accessor |
|
|
and value at the corresponding leaf (may be broadcasted) in and ``xs`` is the tuple of |
|
|
values at corresponding leaves (may be broadcasted) in ``rests``. |
|
|
""" |
|
|
return tree_map_with_accessor( |
|
|
func, |
|
|
*_tree_broadcast_common( |
|
|
tree, |
|
|
*rests, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
), |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class MissingSentinel: |
|
|
__slots__: ClassVar[tuple[()]] = () |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return '<MISSING>' |
|
|
|
|
|
|
|
|
__MISSING: T = MissingSentinel() |
|
|
del MissingSentinel |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_reduce( |
|
|
func: Callable[[T, T], T], |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_reduce( |
|
|
func: Callable[[T, S], T], |
|
|
tree: PyTree[S], |
|
|
/, |
|
|
initial: T = __MISSING, |
|
|
*, |
|
|
is_leaf: Callable[[S], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
def tree_reduce( |
|
|
func: Callable[[T, S], T], |
|
|
tree: PyTree[S], |
|
|
/, |
|
|
initial: T = __MISSING, |
|
|
*, |
|
|
is_leaf: Callable[[S], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: |
|
|
"""Traversal through a pytree and reduce the leaves in left-to-right depth-first order. |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_sum`. |
|
|
|
|
|
>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, 3)}) |
|
|
6 |
|
|
>>> tree_reduce(lambda x, y: x + y, {'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default |
|
|
6 |
|
|
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}) |
|
|
3 |
|
|
>>> tree_reduce(lambda x, y: x and y, {'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) |
|
|
None |
|
|
|
|
|
Args: |
|
|
func (callable): A function that takes two arguments and returns a value of the same type. |
|
|
tree (pytree): A pytree to be traversed. |
|
|
initial (object, optional): An initial value to be used for the reduction. If not provided, |
|
|
the first leaf value is used as the initial value. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The result of reducing the leaves of the pytree using ``func``. |
|
|
""" |
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) |
|
|
if initial is __MISSING: |
|
|
return functools.reduce(func, leaves) |
|
|
return functools.reduce(func, leaves, initial) |
|
|
|
|
|
|
|
|
def tree_sum( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
start: T = 0, |
|
|
*, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: |
|
|
"""Sum ``start`` and leaf values in ``tree`` in left-to-right depth-first order and return the total. |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_reduce`. |
|
|
|
|
|
>>> tree_sum({'x': 1, 'y': (2, 3)}) |
|
|
6 |
|
|
>>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node with arity 0 by default |
|
|
6 |
|
|
>>> tree_sum({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType' |
|
|
>>> tree_sum({'x': 'a', 'y': ('b', None), 'z': 'c'}, start='') |
|
|
'abc' |
|
|
>>> tree_sum({'x': [1], 'y': ([2], [None]), 'z': [3]}, start=[], is_leaf=lambda x: isinstance(x, list)) |
|
|
[1, 2, None, 3] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
start (object, optional): An initial value to be used for the sum. (default: :data:`0`) |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The total sum of ``start`` and leaf values in ``tree``. |
|
|
""" |
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) |
|
|
|
|
|
if isinstance(start, str): |
|
|
return ''.join([start, *leaves]) |
|
|
if isinstance(start, (bytes, bytearray)): |
|
|
return b''.join([start, *leaves]) |
|
|
return sum(leaves, start) |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_max( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_max( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
default: T = __MISSING, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
def tree_max( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
default: T = __MISSING, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: |
|
|
"""Return the maximum leaf value in ``tree``. |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_min`. |
|
|
|
|
|
>>> tree_max({}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: max() iterable argument is empty |
|
|
>>> tree_max({}, default=0) |
|
|
0 |
|
|
>>> tree_max({'x': 0, 'y': (2, 1)}) |
|
|
2 |
|
|
>>> tree_max({'x': 0, 'y': (2, 1)}, key=lambda x: -x) |
|
|
0 |
|
|
>>> tree_max({'a': None}) # `None` is a non-leaf node with arity 0 by default |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: max() iterable argument is empty |
|
|
>>> tree_max({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default |
|
|
0 |
|
|
>>> tree_max({'a': None}, none_is_leaf=True) |
|
|
None |
|
|
>>> tree_max(None) # `None` is a non-leaf node with arity 0 by default |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: max() iterable argument is empty |
|
|
>>> tree_max(None, default=0) |
|
|
0 |
|
|
>>> tree_max(None, none_is_leaf=True) |
|
|
None |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
default (object, optional): The default value to return if ``tree`` is empty. If the ``tree`` |
|
|
is empty and ``default`` is not specified, raise a :exc:`ValueError`. |
|
|
key (callable or None, optional): An one argument ordering function like that used for |
|
|
:meth:`list.sort`. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The maximum leaf value in ``tree``. |
|
|
""" |
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) |
|
|
if default is __MISSING: |
|
|
return max(leaves, key=key) |
|
|
return max(leaves, default=default, key=key) |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_min( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def tree_min( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
default: T = __MISSING, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: ... |
|
|
|
|
|
|
|
|
def tree_min( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
default: T = __MISSING, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> T: |
|
|
"""Return the minimum leaf value in ``tree``. |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_max`. |
|
|
|
|
|
>>> tree_min({}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: min() iterable argument is empty |
|
|
>>> tree_min({}, default=0) |
|
|
0 |
|
|
>>> tree_min({'x': 0, 'y': (2, 1)}) |
|
|
0 |
|
|
>>> tree_min({'x': 0, 'y': (2, 1)}, key=lambda x: -x) |
|
|
2 |
|
|
>>> tree_min({'a': None}) # `None` is a non-leaf node with arity 0 by default |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: min() iterable argument is empty |
|
|
>>> tree_min({'a': None}, default=0) # `None` is a non-leaf node with arity 0 by default |
|
|
0 |
|
|
>>> tree_min({'a': None}, none_is_leaf=True) |
|
|
None |
|
|
>>> tree_min(None) # `None` is a non-leaf node with arity 0 by default |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: min() iterable argument is empty |
|
|
>>> tree_min(None, default=0) |
|
|
0 |
|
|
>>> tree_min(None, none_is_leaf=True) |
|
|
None |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
default (object, optional): The default value to return if ``tree`` is empty. If the ``tree`` |
|
|
is empty and ``default`` is not specified, raise a :exc:`ValueError`. |
|
|
key (callable or None, optional): An one argument ordering function like that used for |
|
|
:meth:`list.sort`. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
The minimum leaf value in ``tree``. |
|
|
""" |
|
|
leaves = tree_leaves(tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace) |
|
|
if default is __MISSING: |
|
|
return min(leaves, key=key) |
|
|
return min(leaves, default=default, key=key) |
|
|
|
|
|
|
|
|
def tree_all( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> bool: |
|
|
"""Test whether all leaves in ``tree`` are true (or if ``tree`` is empty). |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_any`. |
|
|
|
|
|
>>> tree_all({}) |
|
|
True |
|
|
>>> tree_all({'x': 1, 'y': (2, 3)}) |
|
|
True |
|
|
>>> tree_all({'x': 1, 'y': (2, None), 'z': 3}) # `None` is a non-leaf node by default |
|
|
True |
|
|
>>> tree_all({'x': 1, 'y': (2, None), 'z': 3}, none_is_leaf=True) |
|
|
False |
|
|
>>> tree_all(None) # `None` is a non-leaf node by default |
|
|
True |
|
|
>>> tree_all(None, none_is_leaf=True) |
|
|
False |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
:data:`True` if all leaves in ``tree`` are true, or if ``tree`` is empty. |
|
|
Otherwise, :data:`False`. |
|
|
""" |
|
|
return all( |
|
|
tree_iter( |
|
|
tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
def tree_any( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
*, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> bool: |
|
|
"""Test whether all leaves in ``tree`` are true (or :data:`False` if ``tree`` is empty). |
|
|
|
|
|
See also :func:`tree_leaves` and :func:`tree_all`. |
|
|
|
|
|
>>> tree_any({}) |
|
|
False |
|
|
>>> tree_any({'x': 0, 'y': (2, 0)}) |
|
|
True |
|
|
>>> tree_any({'a': None}) # `None` is a non-leaf node with arity 0 by default |
|
|
False |
|
|
>>> tree_any({'a': None}, none_is_leaf=True) # `None` is evaluated as false |
|
|
False |
|
|
>>> tree_any(None) # `None` is a non-leaf node with arity 0 by default |
|
|
False |
|
|
>>> tree_any(None, none_is_leaf=True) # `None` is evaluated as false |
|
|
False |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
:data:`True` if any leaves in ``tree`` are true, otherwise, :data:`False`. If ``tree`` is |
|
|
empty, return :data:`False`. |
|
|
""" |
|
|
return any( |
|
|
tree_iter( |
|
|
tree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
class FlattenOneLevelOutput(NamedTuple, Generic[T]): |
|
|
"""The output of :func:`tree_flatten_one_level`.""" |
|
|
|
|
|
children: list[PyTree[T]] |
|
|
"""A list of one-level children of the pytree node.""" |
|
|
|
|
|
metadata: MetaData |
|
|
"""The metadata used to reconstruct the pytree node.""" |
|
|
|
|
|
entries: tuple[Any, ...] |
|
|
"""A tuple of path entries to the children.""" |
|
|
|
|
|
unflatten_func: UnflattenFunc[PyTree[T]] |
|
|
"""A function that can be used to unflatten the metadata and children back to the pytree node.""" |
|
|
|
|
|
|
|
|
|
|
|
class FlattenOneLevelOutputEx(FlattenOneLevelOutput[T]): |
|
|
"""The output of :func:`tree_flatten_one_level`.""" |
|
|
|
|
|
type: builtins.type[Collection[T]] |
|
|
"""The type of the pytree node.""" |
|
|
|
|
|
path_entry_type: builtins.type[PyTreeEntry] |
|
|
"""The type of the path entry for the pytree node.""" |
|
|
|
|
|
kind: PyTreeKind |
|
|
"""The kind of the pytree node.""" |
|
|
|
|
|
|
|
|
def tree_flatten_one_level( |
|
|
tree: PyTree[T], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> FlattenOneLevelOutputEx[T]: |
|
|
"""Flatten the pytree one level, returning a 4-tuple of children, metadata, path entries, and an unflatten function. |
|
|
|
|
|
See also :func:`tree_flatten`, :func:`tree_flatten_with_path`. |
|
|
|
|
|
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level({'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}) |
|
|
>>> children, metadata, entries |
|
|
([1, (2, [3, 4]), None, 5], ['a', 'b', 'c', 'd'], ('a', 'b', 'c', 'd')) |
|
|
>>> unflatten_func(metadata, children) |
|
|
{'a': 1, 'b': (2, [3, 4]), 'c': None, 'd': 5} |
|
|
>>> children, metadata, entries, unflatten_func = tree_flatten_one_level([{'a': 1, 'b': (2, 3)}, (4, 5)]) |
|
|
>>> children, metadata, entries |
|
|
([{'a': 1, 'b': (2, 3)}, (4, 5)], None, (0, 1)) |
|
|
>>> unflatten_func(metadata, children) |
|
|
[{'a': 1, 'b': (2, 3)}, (4, 5)] |
|
|
|
|
|
Args: |
|
|
tree (pytree): A pytree to be traversed. |
|
|
is_leaf (callable, optional): An optionally specified function that will be called at each |
|
|
flattening step. It should return a boolean, with :data:`True` stopping the traversal |
|
|
and the whole subtree being treated as a leaf, and :data:`False` indicating the |
|
|
flattening should traverse the current object. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A 4-tuple ``(children, metadata, entries, unflatten_func)``. The first element is a list of |
|
|
one-level children of the pytree node. The second element is the metadata used to |
|
|
reconstruct the pytree node. The third element is a tuple of path entries to the children. |
|
|
The fourth element is a function that can be used to unflatten the metadata and |
|
|
children back to the pytree node. |
|
|
""" |
|
|
node_type = type(tree) |
|
|
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): |
|
|
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).') |
|
|
|
|
|
from optree.registry import register_pytree_node |
|
|
|
|
|
handler = register_pytree_node.get(node_type, namespace=namespace) |
|
|
if handler is None: |
|
|
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).') |
|
|
|
|
|
flattened = tuple(handler.flatten_func(tree)) |
|
|
if len(flattened) == 2: |
|
|
flattened = (*flattened, None) |
|
|
elif len(flattened) != 3: |
|
|
raise RuntimeError( |
|
|
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, ' |
|
|
f'got {len(flattened)}.', |
|
|
) |
|
|
flattened: tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None] |
|
|
children, metadata, entries = flattened |
|
|
children = list(children) |
|
|
entries = tuple(range(len(children)) if entries is None else entries) |
|
|
if len(children) != len(entries): |
|
|
raise RuntimeError( |
|
|
f'PyTree custom flatten function for type {node_type} returned inconsistent ' |
|
|
f'number of children ({len(children)}) and number of entries ({len(entries)}).', |
|
|
) |
|
|
|
|
|
output = FlattenOneLevelOutputEx( |
|
|
children=children, |
|
|
metadata=metadata, |
|
|
entries=entries, |
|
|
unflatten_func=handler.unflatten_func, |
|
|
) |
|
|
output.type = node_type |
|
|
output.path_entry_type = handler.path_entry_type |
|
|
output.kind = handler.kind |
|
|
return output |
|
|
|
|
|
|
|
|
def treespec_paths(treespec: PyTreeSpec, /) -> list[tuple[Any, ...]]: |
|
|
"""Return a list of paths to the leaves of a treespec. |
|
|
|
|
|
See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_paths(treespec) |
|
|
[('a', 0), ('a', 1, 0), ('a', 1, 1), ('b',), ('c', 0)] |
|
|
""" |
|
|
return treespec.paths() |
|
|
|
|
|
|
|
|
def treespec_accessors(treespec: PyTreeSpec, /) -> list[PyTreeAccessor]: |
|
|
"""Return a list of accessors to the leaves of a treespec. |
|
|
|
|
|
See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors`, |
|
|
and :meth:`PyTreeSpec.accessors`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_accessors(treespec) # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
|
|
[ |
|
|
PyTreeAccessor(*['a'][0], ...), |
|
|
PyTreeAccessor(*['a'][1][0], ...), |
|
|
PyTreeAccessor(*['a'][1][1], ...), |
|
|
PyTreeAccessor(*['b'], ...), |
|
|
PyTreeAccessor(*['c'][0], ...) |
|
|
] |
|
|
>>> treespec_accessors(treespec_leaf()) |
|
|
[PyTreeAccessor(*, ())] |
|
|
>>> treespec_accessors(treespec_none()) |
|
|
[] |
|
|
""" |
|
|
return treespec.accessors() |
|
|
|
|
|
|
|
|
def treespec_entries(treespec: PyTreeSpec, /) -> list[Any]: |
|
|
"""Return a list of one-level entries of a treespec to its children. |
|
|
|
|
|
See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`, |
|
|
and :meth:`PyTreeSpec.entries`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_entries(treespec) |
|
|
['a', 'b', 'c'] |
|
|
""" |
|
|
return treespec.entries() |
|
|
|
|
|
|
|
|
def treespec_entry(treespec: PyTreeSpec, index: int, /) -> Any: |
|
|
"""Return the entry of a treespec at the given index. |
|
|
|
|
|
See also :func:`treespec_entries`, :func:`treespec_children`, and :meth:`PyTreeSpec.entry`. |
|
|
""" |
|
|
return treespec.entry(index) |
|
|
|
|
|
|
|
|
def treespec_children(treespec: PyTreeSpec, /) -> list[PyTreeSpec]: |
|
|
"""Return a list of treespecs for the children of a treespec. |
|
|
|
|
|
See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`, |
|
|
:func:`treespec_one_level`, and :meth:`PyTreeSpec.children`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_children(treespec) |
|
|
[PyTreeSpec((*, [*, *])), PyTreeSpec(*), PyTreeSpec((*, None))] |
|
|
""" |
|
|
return treespec.children() |
|
|
|
|
|
|
|
|
def treespec_child(treespec: PyTreeSpec, index: int, /) -> PyTreeSpec: |
|
|
"""Return the treespec of the child of a treespec at the given index. |
|
|
|
|
|
See also :func:`treespec_children`, :func:`treespec_entries`, and :meth:`PyTreeSpec.child`. |
|
|
""" |
|
|
return treespec.child(index) |
|
|
|
|
|
|
|
|
def treespec_one_level(treespec: PyTreeSpec, /) -> PyTreeSpec | None: |
|
|
"""Return the one-level tree structure of the treespec or :data:`None` if the treespec is a leaf. |
|
|
|
|
|
See also :func:`treespec_children`, :func:`treespec_is_one_level`, and :meth:`PyTreeSpec.one_level`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_one_level(treespec) |
|
|
PyTreeSpec({'a': *, 'b': *, 'c': *}) |
|
|
""" |
|
|
return treespec.one_level() |
|
|
|
|
|
|
|
|
def treespec_transform( |
|
|
treespec: PyTreeSpec, |
|
|
/, |
|
|
f_node: Callable[[PyTreeSpec], PyTreeSpec] | None = None, |
|
|
f_leaf: Callable[[PyTreeSpec], PyTreeSpec] | None = None, |
|
|
) -> PyTreeSpec: |
|
|
"""Transform a treespec by applying functions to its nodes and leaves. |
|
|
|
|
|
See also :func:`treespec_children`, :func:`treespec_is_leaf`, and :meth:`PyTreeSpec.transform`. |
|
|
|
|
|
>>> treespec = tree_structure({'b': 3, 'a': (0, [1, 2]), 'c': (4, None)}) |
|
|
>>> treespec |
|
|
PyTreeSpec({'a': (*, [*, *]), 'b': *, 'c': (*, None)}) |
|
|
>>> treespec_transform(treespec, lambda spec: treespec_dict(zip(spec.entries(), spec.children()))) |
|
|
PyTreeSpec({'a': {0: *, 1: {0: *, 1: *}}, 'b': *, 'c': {0: *, 1: {}}}) |
|
|
>>> treespec_transform( |
|
|
... treespec, |
|
|
... lambda spec: ( |
|
|
... treespec_ordereddict(zip(spec.entries(), spec.children())) |
|
|
... if spec.type is dict |
|
|
... else spec |
|
|
... ), |
|
|
... ) |
|
|
PyTreeSpec(OrderedDict({'a': (*, [*, *]), 'b': *, 'c': (*, None)})) |
|
|
>>> treespec_transform( |
|
|
... treespec, |
|
|
... lambda spec: ( |
|
|
... treespec_ordereddict(tree_unflatten(spec, spec.children())) |
|
|
... if spec.type is dict |
|
|
... else spec |
|
|
... ), |
|
|
... ) |
|
|
PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': (*, None)})) |
|
|
>>> treespec_transform(treespec, lambda spec: treespec_tuple(spec.children())) |
|
|
PyTreeSpec(((*, (*, *)), *, (*, ()))) |
|
|
>>> treespec_transform( |
|
|
... treespec, |
|
|
... lambda spec: ( |
|
|
... treespec_list(spec.children()) |
|
|
... if spec.type is tuple |
|
|
... else spec |
|
|
... ), |
|
|
... ) |
|
|
PyTreeSpec({'a': [*, [*, *]], 'b': *, 'c': [*, None]}) |
|
|
>>> treespec_transform(treespec, None, lambda spec: tree_structure((1, [2]))) |
|
|
PyTreeSpec({'a': ((*, [*]), [(*, [*]), (*, [*])]), 'b': (*, [*]), 'c': ((*, [*]), None)}) |
|
|
""" |
|
|
return treespec.transform(f_node, f_leaf) |
|
|
|
|
|
|
|
|
def treespec_is_leaf(treespec: PyTreeSpec, /, *, strict: bool = True) -> bool: |
|
|
"""Return whether the treespec is a leaf that has no children. |
|
|
|
|
|
See also :func:`treespec_is_strict_leaf` and :meth:`PyTreeSpec.is_leaf`. |
|
|
|
|
|
This function is equivalent to ``treespec.is_leaf(strict=strict)``. If ``strict=False``, it will |
|
|
return :data:`True` if and only if the treespec represents a strict leaf. If ``strict=False``, |
|
|
it will return :data:`True` if the treespec represents a strict leaf or :data:`None` or an empty |
|
|
container (e.g., an empty tuple). |
|
|
|
|
|
>>> treespec_is_leaf(tree_structure(1)) |
|
|
True |
|
|
>>> treespec_is_leaf(tree_structure((1, 2))) |
|
|
False |
|
|
>>> treespec_is_leaf(tree_structure(None)) |
|
|
False |
|
|
>>> treespec_is_leaf(tree_structure(None), strict=False) |
|
|
True |
|
|
>>> treespec_is_leaf(tree_structure(None, none_is_leaf=False)) |
|
|
False |
|
|
>>> treespec_is_leaf(tree_structure(None, none_is_leaf=True)) |
|
|
True |
|
|
>>> treespec_is_leaf(tree_structure(())) |
|
|
False |
|
|
>>> treespec_is_leaf(tree_structure(()), strict=False) |
|
|
True |
|
|
>>> treespec_is_leaf(tree_structure([])) |
|
|
False |
|
|
>>> treespec_is_leaf(tree_structure([]), strict=False) |
|
|
True |
|
|
|
|
|
Args: |
|
|
treespec (PyTreeSpec): A treespec. |
|
|
strict (bool, optional): Whether not to treat :data:`None` or an empty |
|
|
container (e.g., an empty tuple) as a leaf. (default: :data:`True`) |
|
|
|
|
|
Returns: |
|
|
:data:`True` if the treespec represents a leaf that has no children, otherwise, :data:`False`. |
|
|
""" |
|
|
if strict: |
|
|
return treespec.num_nodes == 1 and treespec.num_leaves == 1 |
|
|
return treespec.num_nodes == 1 |
|
|
|
|
|
|
|
|
def treespec_is_strict_leaf(treespec: PyTreeSpec, /) -> bool: |
|
|
"""Return whether the treespec is a strict leaf. |
|
|
|
|
|
See also :func:`treespec_is_leaf` and :meth:`PyTreeSpec.is_leaf`. |
|
|
|
|
|
This function respects the ``none_is_leaf`` setting in the treespec. It is equivalent to |
|
|
``treespec.is_leaf(strict=True)``. It will return :data:`True` if and only if the treespec |
|
|
represents a strict leaf. |
|
|
|
|
|
>>> treespec_is_strict_leaf(tree_structure(1)) |
|
|
True |
|
|
>>> treespec_is_strict_leaf(tree_structure((1, 2))) |
|
|
False |
|
|
>>> treespec_is_strict_leaf(tree_structure(None)) |
|
|
False |
|
|
>>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=False)) |
|
|
False |
|
|
>>> treespec_is_strict_leaf(tree_structure(None, none_is_leaf=True)) |
|
|
True |
|
|
>>> treespec_is_strict_leaf(tree_structure(())) |
|
|
False |
|
|
>>> treespec_is_strict_leaf(tree_structure([])) |
|
|
False |
|
|
|
|
|
Args: |
|
|
treespec (PyTreeSpec): A treespec. |
|
|
|
|
|
Returns: |
|
|
:data:`True` if the treespec represents a strict leaf, otherwise, :data:`False`. |
|
|
""" |
|
|
return treespec.num_nodes == 1 and treespec.num_leaves == 1 |
|
|
|
|
|
|
|
|
def treespec_is_one_level(treespec: PyTreeSpec, /) -> bool: |
|
|
"""Return whether the treespec is a one-level tree structure. |
|
|
|
|
|
See also :func:`treespec_is_leaf`, :func:`treespec_one_level`, and :meth:`PyTreeSpec.is_one_level`. |
|
|
|
|
|
>>> treespec_is_one_level(tree_structure(1)) |
|
|
False |
|
|
>>> treespec_is_one_level(tree_structure((1, 2))) |
|
|
True |
|
|
>>> treespec_is_one_level(tree_structure({'a': 1, 'b': 2, 'c': 3})) |
|
|
True |
|
|
>>> treespec_is_one_level(tree_structure({'a': 1, 'b': (2, 3), 'c': 4})) |
|
|
False |
|
|
>>> treespec_is_one_level(tree_structure(None)) |
|
|
True |
|
|
""" |
|
|
return ( |
|
|
treespec.num_nodes == treespec.num_children + 1 |
|
|
and treespec.num_leaves == treespec.num_children |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_is_prefix( |
|
|
treespec: PyTreeSpec, |
|
|
other_treespec: PyTreeSpec, |
|
|
/, |
|
|
*, |
|
|
strict: bool = False, |
|
|
) -> bool: |
|
|
"""Return whether ``treespec`` is a prefix of ``other_treespec``. |
|
|
|
|
|
See also :func:`treespec_is_prefix` and :meth:`PyTreeSpec.is_prefix`. |
|
|
""" |
|
|
return treespec.is_prefix(other_treespec, strict=strict) |
|
|
|
|
|
|
|
|
def treespec_is_suffix( |
|
|
treespec: PyTreeSpec, |
|
|
other_treespec: PyTreeSpec, |
|
|
/, |
|
|
*, |
|
|
strict: bool = False, |
|
|
) -> bool: |
|
|
"""Return whether ``treespec`` is a suffix of ``other_treespec``. |
|
|
|
|
|
See also :func:`treespec_is_suffix` :meth:`PyTreeSpec.is_suffix`. |
|
|
""" |
|
|
return treespec.is_suffix(other_treespec, strict=strict) |
|
|
|
|
|
|
|
|
def treespec_leaf( |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a treespec representing a leaf node. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_none`, and :func:`treespec_tuple`. |
|
|
|
|
|
>>> treespec_leaf() |
|
|
PyTreeSpec(*) |
|
|
>>> treespec_leaf(none_is_leaf=True) |
|
|
PyTreeSpec(*, NoneIsLeaf) |
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_leaf(none_is_leaf=True) |
|
|
False |
|
|
>>> treespec_leaf() == tree_structure(1) |
|
|
True |
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(1, none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_leaf(none_is_leaf=True) == tree_structure(None, none_is_leaf=False) |
|
|
False |
|
|
>>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_leaf(none_is_leaf=True) == treespec_none(none_is_leaf=False) |
|
|
False |
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=True) |
|
|
False |
|
|
>>> treespec_leaf(none_is_leaf=False) == treespec_none(none_is_leaf=False) |
|
|
False |
|
|
|
|
|
Args: |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a leaf node. |
|
|
""" |
|
|
return _C.make_leaf( |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_none( |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a treespec representing a :data:`None` node. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_tuple`. |
|
|
|
|
|
>>> treespec_none() |
|
|
PyTreeSpec(None) |
|
|
>>> treespec_none(none_is_leaf=True) |
|
|
PyTreeSpec(*, NoneIsLeaf) |
|
|
>>> treespec_none(none_is_leaf=False) == treespec_none(none_is_leaf=True) |
|
|
False |
|
|
>>> treespec_none() == tree_structure(None) |
|
|
True |
|
|
>>> treespec_none() == tree_structure(1) |
|
|
False |
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(1, none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_none(none_is_leaf=True) == tree_structure(None, none_is_leaf=False) |
|
|
False |
|
|
>>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=True) |
|
|
True |
|
|
>>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=True) |
|
|
False |
|
|
>>> treespec_none(none_is_leaf=True) == treespec_leaf(none_is_leaf=False) |
|
|
False |
|
|
>>> treespec_none(none_is_leaf=False) == treespec_leaf(none_is_leaf=False) |
|
|
False |
|
|
|
|
|
Args: |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a :data:`None` node. |
|
|
""" |
|
|
return _C.make_none( |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_tuple( |
|
|
iterable: Iterable[PyTreeSpec] = (), |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a tuple treespec from an iterable of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_tuple([treespec_leaf(), treespec_leaf()]) |
|
|
PyTreeSpec((*, *)) |
|
|
>>> treespec_tuple([treespec_leaf(), treespec_leaf(), treespec_none()]) |
|
|
PyTreeSpec((*, *, None)) |
|
|
>>> treespec_tuple() |
|
|
PyTreeSpec(()) |
|
|
>>> treespec_tuple([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) |
|
|
PyTreeSpec((*, (*, *))) |
|
|
>>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2})]) |
|
|
PyTreeSpec((*, {'a': *, 'b': *})) |
|
|
>>> treespec_tuple([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have |
|
|
the same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a tuple node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
tuple(iterable), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_list( |
|
|
iterable: Iterable[PyTreeSpec] = (), |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a list treespec from an iterable of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_list([treespec_leaf(), treespec_leaf()]) |
|
|
PyTreeSpec([*, *]) |
|
|
>>> treespec_list([treespec_leaf(), treespec_leaf(), treespec_none()]) |
|
|
PyTreeSpec([*, *, None]) |
|
|
>>> treespec_list() |
|
|
PyTreeSpec([]) |
|
|
>>> treespec_list([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) |
|
|
PyTreeSpec([*, (*, *)]) |
|
|
>>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2})]) |
|
|
PyTreeSpec([*, {'a': *, 'b': *}]) |
|
|
>>> treespec_list([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have |
|
|
the same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a list node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
list(iterable), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_dict( |
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
**kwargs: PyTreeSpec, |
|
|
) -> PyTreeSpec: |
|
|
"""Make a dict treespec from a dict of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': treespec_leaf()}) |
|
|
PyTreeSpec({'a': *, 'b': *}) |
|
|
>>> treespec_dict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) |
|
|
PyTreeSpec({'a': None, 'b': *, 'c': *}) |
|
|
>>> treespec_dict() |
|
|
PyTreeSpec({}) |
|
|
>>> treespec_dict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) |
|
|
PyTreeSpec({'a': *, 'b': (*, *)}) |
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2])}) |
|
|
PyTreeSpec({'a': *, 'b': [*, *]}) |
|
|
>>> treespec_dict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the |
|
|
same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
**kwargs (PyTreeSpec, optional): Additional child treespecs to add to the mapping. |
|
|
|
|
|
Returns: |
|
|
A treespec representing a dict node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
dict(mapping, **kwargs), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_namedtuple( |
|
|
namedtuple: NamedTuple[PyTreeSpec], |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a namedtuple treespec from a namedtuple of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> from collections import namedtuple |
|
|
>>> Point = namedtuple('Point', ['x', 'y']) |
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_leaf())) |
|
|
PyTreeSpec(Point(x=*, y=*)) |
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=treespec_tuple([treespec_leaf(), treespec_leaf()]))) |
|
|
PyTreeSpec(Point(x=*, y=(*, *))) |
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2]))) |
|
|
PyTreeSpec(Point(x=*, y=[*, *])) |
|
|
>>> treespec_namedtuple(Point(x=treespec_leaf(), y=tree_structure([1, 2], none_is_leaf=True))) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
namedtuple (namedtuple of PyTreeSpec): A namedtuple of child treespecs. They must have the |
|
|
same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a dict node with the given children. |
|
|
""" |
|
|
if not is_namedtuple_instance(namedtuple): |
|
|
raise ValueError(f'Expected a namedtuple of PyTreeSpec(s), got {namedtuple!r}.') |
|
|
return _C.make_from_collection( |
|
|
namedtuple, |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_ordereddict( |
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
**kwargs: PyTreeSpec, |
|
|
) -> PyTreeSpec: |
|
|
"""Make an OrderedDict treespec from an OrderedDict of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': treespec_leaf()}) |
|
|
PyTreeSpec(OrderedDict({'a': *, 'b': *})) |
|
|
>>> treespec_ordereddict([('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) |
|
|
PyTreeSpec(OrderedDict({'b': *, 'c': *, 'a': None})) |
|
|
>>> treespec_ordereddict() |
|
|
PyTreeSpec(OrderedDict()) |
|
|
>>> treespec_ordereddict(a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) |
|
|
PyTreeSpec(OrderedDict({'a': *, 'b': (*, *)})) |
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2])}) |
|
|
PyTreeSpec(OrderedDict({'a': *, 'b': [*, *]})) |
|
|
>>> treespec_ordereddict({'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the |
|
|
same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
**kwargs (PyTreeSpec, optional): Additional child treespecs to add to the mapping. |
|
|
|
|
|
Returns: |
|
|
A treespec representing an OrderedDict node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
OrderedDict(mapping, **kwargs), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_defaultdict( |
|
|
default_factory: Callable[[], Any] | None = None, |
|
|
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
**kwargs: PyTreeSpec, |
|
|
) -> PyTreeSpec: |
|
|
"""Make a defaultdict treespec from a defaultdict of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': treespec_leaf()}) |
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': *})) |
|
|
>>> treespec_defaultdict(int, [('b', treespec_leaf()), ('c', treespec_leaf()), ('a', treespec_none())]) |
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': None, 'b': *, 'c': *})) |
|
|
>>> treespec_defaultdict() |
|
|
PyTreeSpec(defaultdict(None, {})) |
|
|
>>> treespec_defaultdict(int) |
|
|
PyTreeSpec(defaultdict(<class 'int'>, {})) |
|
|
>>> treespec_defaultdict(int, a=treespec_leaf(), b=treespec_tuple([treespec_leaf(), treespec_leaf()])) |
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': (*, *)})) |
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2])}) |
|
|
PyTreeSpec(defaultdict(<class 'int'>, {'a': *, 'b': [*, *]})) |
|
|
>>> treespec_defaultdict(int, {'a': treespec_leaf(), 'b': tree_structure([1, 2], none_is_leaf=True)}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
default_factory (callable or None, optional): A factory function that will be used to create |
|
|
a missing value. (default: :data:`None`) |
|
|
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have the |
|
|
same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
**kwargs (PyTreeSpec, optional): Additional child treespecs to add to the mapping. |
|
|
|
|
|
Returns: |
|
|
A treespec representing a defaultdict node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
defaultdict(default_factory, mapping, **kwargs), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_deque( |
|
|
iterable: Iterable[PyTreeSpec] = (), |
|
|
/, |
|
|
maxlen: int | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a deque treespec from a deque of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_deque([treespec_leaf(), treespec_leaf()]) |
|
|
PyTreeSpec(deque([*, *])) |
|
|
>>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5) |
|
|
PyTreeSpec(deque([*, *, None], maxlen=5)) |
|
|
>>> treespec_deque() |
|
|
PyTreeSpec(deque([])) |
|
|
>>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) |
|
|
PyTreeSpec(deque([*, (*, *)])) |
|
|
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5) |
|
|
PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5)) |
|
|
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)], maxlen=5) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
Args: |
|
|
iterable (iterable of PyTreeSpec, optional): A iterable of child treespecs. They must have |
|
|
the same ``none_is_leaf`` and ``namespace`` values. |
|
|
maxlen (int or None, optional): The maximum size of a deque or :data:`None` if unbounded. |
|
|
(default: :data:`None`) |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a deque node with the given children. |
|
|
""" |
|
|
return _C.make_from_collection( |
|
|
deque(iterable, maxlen=maxlen), |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_structseq( |
|
|
structseq: StructSequence[PyTreeSpec], |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a PyStructSequence treespec from a PyStructSequence of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
Args: |
|
|
structseq (PyStructSequence of PyTreeSpec): A PyStructSequence of child treespecs. They must |
|
|
have the same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing a PyStructSequence node with the given children. |
|
|
""" |
|
|
if not is_structseq_instance(structseq): |
|
|
raise ValueError(f'Expected a PyStructSequence of PyTreeSpec(s), got {structseq!r}.') |
|
|
return _C.make_from_collection( |
|
|
structseq, |
|
|
none_is_leaf, |
|
|
namespace, |
|
|
) |
|
|
|
|
|
|
|
|
def treespec_from_collection( |
|
|
collection: Collection[PyTreeSpec], |
|
|
/, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> PyTreeSpec: |
|
|
"""Make a treespec from a collection of child treespecs. |
|
|
|
|
|
See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. |
|
|
|
|
|
>>> treespec_from_collection(None) |
|
|
PyTreeSpec(None) |
|
|
>>> treespec_from_collection(None, none_is_leaf=True) |
|
|
PyTreeSpec(*, NoneIsLeaf) |
|
|
>>> treespec_from_collection(object()) |
|
|
PyTreeSpec(*) |
|
|
>>> treespec_from_collection([treespec_leaf(), treespec_none()]) |
|
|
PyTreeSpec([*, None]) |
|
|
>>> treespec_from_collection({'a': treespec_leaf(), 'b': treespec_none()}) |
|
|
PyTreeSpec({'a': *, 'b': None}) |
|
|
>>> treespec_from_collection(deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5)) |
|
|
PyTreeSpec(deque([*, {'a': *, 'b': *}], maxlen=5)) |
|
|
>>> treespec_from_collection({'a': treespec_leaf(), 'b': (treespec_leaf(), treespec_none())}) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected a(n) dict of PyTreeSpec(s), got {'a': PyTreeSpec(*), 'b': (PyTreeSpec(*), PyTreeSpec(None))}. |
|
|
>>> treespec_from_collection([treespec_leaf(), tree_structure({'a': 1, 'b': 2}, none_is_leaf=True)]) |
|
|
Traceback (most recent call last): |
|
|
... |
|
|
ValueError: Expected treespec(s) with `none_is_leaf=False`. |
|
|
|
|
|
|
|
|
Args: |
|
|
collection (collection of PyTreeSpec): A collection of child treespecs. They must have the |
|
|
same ``none_is_leaf`` and ``namespace`` values. |
|
|
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, |
|
|
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the |
|
|
treespec rather than in the leaves list and :data:`None` will be remain in the result |
|
|
pytree. (default: :data:`False`) |
|
|
namespace (str, optional): The registry namespace used for custom pytree node types. |
|
|
(default: :const:`''`, i.e., the global namespace) |
|
|
|
|
|
Returns: |
|
|
A treespec representing the same structure of the collection with the given children. |
|
|
""" |
|
|
return _C.make_from_collection(collection, none_is_leaf, namespace) |
|
|
|
|
|
|
|
|
STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) |
|
|
|
|
|
|
|
|
def prefix_errors( |
|
|
prefix_tree: PyTree[T], |
|
|
full_tree: PyTree[S], |
|
|
/, |
|
|
is_leaf: Callable[[T], bool] | None = None, |
|
|
*, |
|
|
none_is_leaf: bool = False, |
|
|
namespace: str = '', |
|
|
) -> list[Callable[[str], ValueError]]: |
|
|
"""Return a list of errors that would be raised by :func:`broadcast_prefix`.""" |
|
|
|
|
|
def helper( |
|
|
accessor: PyTreeAccessor, |
|
|
prefix_subtree: PyTree[T], |
|
|
full_subtree: PyTree[S], |
|
|
) -> Iterable[Callable[[str], ValueError]]: |
|
|
|
|
|
if tree_is_leaf( |
|
|
prefix_subtree, |
|
|
is_leaf=is_leaf, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
): |
|
|
return |
|
|
|
|
|
|
|
|
prefix_tree_type = type(prefix_subtree) |
|
|
full_tree_type = type(full_subtree) |
|
|
both_standard_dict = ( |
|
|
prefix_tree_type in STANDARD_DICT_TYPES and full_tree_type in STANDARD_DICT_TYPES |
|
|
) |
|
|
both_deque = prefix_tree_type is deque and full_tree_type is deque |
|
|
if ( |
|
|
prefix_tree_type is not full_tree_type |
|
|
and not both_standard_dict |
|
|
): |
|
|
yield lambda name: ValueError( |
|
|
f'pytree structure error: different types at key path\n' |
|
|
f' {accessor.codify(name) if accessor else name + " tree root"}\n' |
|
|
f'At that key path, the prefix pytree {name} has a subtree of type\n' |
|
|
f' {type(prefix_subtree)}\n' |
|
|
f'but at the same key path the full pytree has a subtree of different type\n' |
|
|
f' {type(full_subtree)}.', |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefix_tree_one_level_output = ( |
|
|
prefix_tree_children, |
|
|
prefix_tree_metadata, |
|
|
prefix_tree_entries, |
|
|
_, |
|
|
) = tree_flatten_one_level( |
|
|
prefix_subtree, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
full_tree_one_level_output = ( |
|
|
full_tree_children, |
|
|
full_tree_metadata, |
|
|
full_tree_entries, |
|
|
_, |
|
|
) = tree_flatten_one_level( |
|
|
full_subtree, |
|
|
none_is_leaf=none_is_leaf, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
if both_standard_dict: |
|
|
prefix_tree_keys: list[Any] = ( |
|
|
prefix_tree_metadata |
|
|
if prefix_tree_type is not defaultdict |
|
|
else prefix_tree_metadata[1] |
|
|
) |
|
|
full_tree_keys: list[Any] = ( |
|
|
full_tree_metadata |
|
|
if full_tree_type is not defaultdict |
|
|
else full_tree_metadata[1] |
|
|
) |
|
|
prefix_tree_keys_set = set(prefix_tree_keys) |
|
|
full_tree_keys_set = set(full_tree_keys) |
|
|
if prefix_tree_keys_set != full_tree_keys_set: |
|
|
missing_keys = sorted(prefix_tree_keys_set.difference(full_tree_keys_set)) |
|
|
extra_keys = sorted(full_tree_keys_set.difference(prefix_tree_keys_set)) |
|
|
key_difference = '' |
|
|
if missing_keys: |
|
|
key_difference += f'\nmissing key(s):\n {missing_keys}' |
|
|
if extra_keys: |
|
|
key_difference += f'\nextra key(s):\n {extra_keys}' |
|
|
yield lambda name: ValueError( |
|
|
f'pytree structure error: different pytree keys at key path\n' |
|
|
f' {accessor.codify(name) if accessor else name + " tree root"}\n' |
|
|
f'At that key path, the prefix pytree {name} has a subtree of type\n' |
|
|
f' {prefix_tree_type}\n' |
|
|
f'with {len(prefix_tree_keys)} key(s)\n' |
|
|
f' {prefix_tree_keys}\n' |
|
|
f'but at the same key path the full pytree has a subtree of type\n' |
|
|
f' {full_tree_type}\n' |
|
|
f'but with {len(full_tree_keys)} key(s)\n' |
|
|
f' {full_tree_keys}{key_difference}', |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
full_tree_children = [full_subtree[k] for k in prefix_tree_keys] |
|
|
|
|
|
if len(prefix_tree_children) != len(full_tree_children): |
|
|
yield lambda name: ValueError( |
|
|
f'pytree structure error: different numbers of pytree children at key path\n' |
|
|
f' {accessor.codify(name) if accessor else name + " tree root"}\n' |
|
|
f'At that key path, the prefix pytree {name} has a subtree of type\n' |
|
|
f' {prefix_tree_type}\n' |
|
|
f'with {len(prefix_tree_children)} children, ' |
|
|
f'but at the same key path the full pytree has a subtree of the same ' |
|
|
f'type but with {len(full_tree_children)} children.', |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
if ( |
|
|
prefix_tree_metadata != full_tree_metadata |
|
|
and (not both_deque) |
|
|
and ( |
|
|
|
|
|
not both_standard_dict |
|
|
) |
|
|
): |
|
|
prefix_tree_metadata_repr = repr(prefix_tree_metadata) |
|
|
full_tree_metadata_repr = repr(full_tree_metadata) |
|
|
metadata_diff = textwrap.indent( |
|
|
'\n'.join( |
|
|
difflib.ndiff( |
|
|
prefix_tree_metadata_repr.splitlines(), |
|
|
full_tree_metadata_repr.splitlines(), |
|
|
), |
|
|
), |
|
|
prefix=' ', |
|
|
) |
|
|
yield lambda name: ValueError( |
|
|
f'pytree structure error: different pytree metadata at key path\n' |
|
|
f' {accessor.codify(name) if accessor else name + " tree root"}\n' |
|
|
f'At that key path, the prefix pytree {name} has a subtree of type\n' |
|
|
f' {prefix_tree_type}\n' |
|
|
f'with metadata\n' |
|
|
f' {prefix_tree_metadata_repr}\n' |
|
|
f'but at the same key path the full pytree has a subtree of the same ' |
|
|
f'type but with metadata\n' |
|
|
f' {full_tree_metadata_repr}\n' |
|
|
f'so the diff in the metadata at these pytree nodes is\n' |
|
|
f'{metadata_diff}', |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
entries = [ |
|
|
prefix_tree_one_level_output.path_entry_type( |
|
|
e, |
|
|
prefix_tree_type, |
|
|
prefix_tree_one_level_output.kind, |
|
|
) |
|
|
for e in prefix_tree_entries |
|
|
] |
|
|
entries_ = [ |
|
|
full_tree_one_level_output.path_entry_type( |
|
|
e, |
|
|
full_tree_type, |
|
|
full_tree_one_level_output.kind, |
|
|
) |
|
|
for e in full_tree_entries |
|
|
] |
|
|
assert ( |
|
|
both_standard_dict |
|
|
or entries == entries_ |
|
|
), f'equal pytree nodes gave different keys: {entries} and {entries_}' |
|
|
|
|
|
for e, t1, t2 in zip(entries, prefix_tree_children, full_tree_children): |
|
|
yield from helper(accessor + e, t1, t2) |
|
|
|
|
|
return list(helper(PyTreeAccessor(), prefix_tree, full_tree)) |
|
|
|