|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Utilities for working with ``PyTree``\s. |
|
|
|
|
|
The :mod:`optree.pytree` namespace contains aliases of ``optree.tree_*`` utilities. |
|
|
|
|
|
>>> import optree.pytree as pytree |
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
|
|
>>> leaves, treespec = pytree.flatten(tree) |
|
|
>>> leaves, treespec # doctest: +IGNORE_WHITESPACE |
|
|
( |
|
|
[1, 2, 3, 4, 5], |
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
|
|
) |
|
|
>>> tree == pytree.unflatten(treespec, leaves) |
|
|
True |
|
|
|
|
|
.. versionadded:: 0.14.1 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import functools as _functools |
|
|
import inspect as _inspect |
|
|
import sys as _sys |
|
|
from builtins import all as _all |
|
|
from types import ModuleType as _ModuleType |
|
|
from typing import TYPE_CHECKING as _TYPE_CHECKING |
|
|
|
|
|
import optree.dataclasses as dataclasses |
|
|
import optree.functools as functools |
|
|
from optree.accessors import PyTreeEntry |
|
|
from optree.ops import tree_accessors as accessors |
|
|
from optree.ops import tree_all as all |
|
|
from optree.ops import tree_any as any |
|
|
from optree.ops import tree_broadcast_common as broadcast_common |
|
|
from optree.ops import tree_broadcast_map as broadcast_map |
|
|
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor |
|
|
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path |
|
|
from optree.ops import tree_broadcast_prefix as broadcast_prefix |
|
|
from optree.ops import tree_flatten as flatten |
|
|
from optree.ops import tree_flatten_one_level as flatten_one_level |
|
|
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor |
|
|
from optree.ops import tree_flatten_with_path as flatten_with_path |
|
|
from optree.ops import tree_is_leaf as is_leaf |
|
|
from optree.ops import tree_iter as iter |
|
|
from optree.ops import tree_leaves as leaves |
|
|
from optree.ops import tree_map as map |
|
|
from optree.ops import tree_map_ as map_ |
|
|
from optree.ops import tree_map_with_accessor as map_with_accessor |
|
|
from optree.ops import tree_map_with_accessor_ as map_with_accessor_ |
|
|
from optree.ops import tree_map_with_path as map_with_path |
|
|
from optree.ops import tree_map_with_path_ as map_with_path_ |
|
|
from optree.ops import tree_max as max |
|
|
from optree.ops import tree_min as min |
|
|
from optree.ops import tree_partition as partition |
|
|
from optree.ops import tree_paths as paths |
|
|
from optree.ops import tree_reduce as reduce |
|
|
from optree.ops import tree_replace_nones as replace_nones |
|
|
from optree.ops import tree_structure as structure |
|
|
from optree.ops import tree_sum as sum |
|
|
from optree.ops import tree_transpose as transpose |
|
|
from optree.ops import tree_transpose_map as transpose_map |
|
|
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor |
|
|
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path |
|
|
from optree.ops import tree_unflatten as unflatten |
|
|
from optree.registry import dict_insertion_ordered |
|
|
from optree.registry import register_pytree_node as register_node |
|
|
from optree.registry import register_pytree_node_class as register_node_class |
|
|
from optree.registry import unregister_pytree_node as unregister_node |
|
|
from optree.typing import PyTreeKind, PyTreeSpec |
|
|
from optree.version import __version__ as __version__ |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
'reexport', |
|
|
'PyTreeSpec', |
|
|
'PyTreeKind', |
|
|
'PyTreeEntry', |
|
|
'flatten', |
|
|
'flatten_with_path', |
|
|
'flatten_with_accessor', |
|
|
'unflatten', |
|
|
'iter', |
|
|
'leaves', |
|
|
'structure', |
|
|
'paths', |
|
|
'accessors', |
|
|
'is_leaf', |
|
|
'map', |
|
|
'map_', |
|
|
'map_with_path', |
|
|
'map_with_path_', |
|
|
'map_with_accessor', |
|
|
'map_with_accessor_', |
|
|
'replace_nones', |
|
|
'partition', |
|
|
'transpose', |
|
|
'transpose_map', |
|
|
'transpose_map_with_path', |
|
|
'transpose_map_with_accessor', |
|
|
'broadcast_prefix', |
|
|
'broadcast_common', |
|
|
'broadcast_map', |
|
|
'broadcast_map_with_path', |
|
|
'broadcast_map_with_accessor', |
|
|
'reduce', |
|
|
'sum', |
|
|
'max', |
|
|
'min', |
|
|
'all', |
|
|
'any', |
|
|
'flatten_one_level', |
|
|
'register_node', |
|
|
'register_node_class', |
|
|
'unregister_node', |
|
|
'dict_insertion_ordered', |
|
|
] |
|
|
|
|
|
|
|
|
if _TYPE_CHECKING: |
|
|
from collections.abc import Callable, Iterable |
|
|
from typing import Any, TypeVar |
|
|
from typing_extensions import ParamSpec |
|
|
|
|
|
_P = ParamSpec('_P') |
|
|
_T = TypeVar('_T') |
|
|
|
|
|
|
|
|
class ReexportedModule(_ModuleType): |
|
|
"""A module that re-exports APIs from another module.""" |
|
|
|
|
|
__doc__: str |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
*, |
|
|
namespace: str, |
|
|
original: _ModuleType, |
|
|
doc: str | None = None, |
|
|
__all__: Iterable[str] | None = None, |
|
|
__dir__: Iterable[str] | None = None, |
|
|
extra_members: dict[str, Any] | None = None, |
|
|
) -> None: |
|
|
doc = doc or ( |
|
|
f'Re-exports :mod:`{original.__name__}` as :mod:`{name}` ' |
|
|
f'with namespace :const:`{namespace!r}`.' |
|
|
) |
|
|
super().__init__(name, doc) |
|
|
|
|
|
if __all__ is None: |
|
|
__all__ = {n for n in original.__all__ if n != 'reexport'} |
|
|
__all__ = set(__all__) |
|
|
if __dir__ is None: |
|
|
__dir__ = {n for n in original.__dir__() if not n.startswith('_') and n != 'reexport'} |
|
|
__dir__ = set(__dir__).intersection(__all__) |
|
|
|
|
|
if extra_members: |
|
|
for key, value in extra_members.items(): |
|
|
setattr(self, key, value) |
|
|
__dir__.update(extra_members) |
|
|
|
|
|
self.__namespace = namespace |
|
|
self.__original = original |
|
|
self.__all_set = __all__ |
|
|
self.__all = sorted(__all__) |
|
|
self.__dir = sorted(__dir__) |
|
|
|
|
|
@property |
|
|
def __all__(self) -> list[str]: |
|
|
"""Return the list of attributes available in this module.""" |
|
|
return self.__all |
|
|
|
|
|
def __dir__(self) -> list[str]: |
|
|
"""Return the list of attributes available in this module.""" |
|
|
return self.__dir.copy() |
|
|
|
|
|
def __getattr__(self, name: str, /) -> Any: |
|
|
"""Get an attribute from the re-exported module.""" |
|
|
if name in self.__all_set: |
|
|
attr = getattr(self.__original, name) |
|
|
if _inspect.isfunction(attr): |
|
|
attr = self.__reexport__(attr) |
|
|
setattr(self, name, attr) |
|
|
return attr |
|
|
raise AttributeError(f'module {self.__name__!r} has no attribute {name!r}') |
|
|
|
|
|
def __reexport__(self, func: Callable[_P, _T], /) -> Callable[_P, _T]: |
|
|
"""Re-export a function with the default namespace.""" |
|
|
sig = _inspect.signature(func) |
|
|
if 'namespace' not in sig.parameters: |
|
|
|
|
|
@_functools.wraps(func) |
|
|
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T: |
|
|
return func(*args, **kwargs) |
|
|
else: |
|
|
|
|
|
@_functools.wraps(func) |
|
|
def wrapped( |
|
|
*args: _P.args, |
|
|
namespace: str = self.__namespace, |
|
|
**kwargs: _P.kwargs, |
|
|
) -> _T: |
|
|
return func(*args, namespace=namespace, **kwargs) |
|
|
|
|
|
if func.__doc__: |
|
|
wrapped.__doc__ = func.__doc__.replace( |
|
|
"(default: :const:`''`, i.e., the global namespace)", |
|
|
f'(default: :const:`{self.__namespace!r}`)', |
|
|
) |
|
|
wrapped.__signature__ = sig.replace( |
|
|
parameters=[ |
|
|
p if p.name != 'namespace' else p.replace(default=self.__namespace) |
|
|
for p in sig.parameters.values() |
|
|
], |
|
|
) |
|
|
|
|
|
if callable(getattr(func, 'get', None)): |
|
|
wrapped.get = self.__reexport__(func.get) |
|
|
|
|
|
return wrapped |
|
|
|
|
|
|
|
|
if _TYPE_CHECKING: |
|
|
|
|
|
class ReexportedPyTreeModule(ReexportedModule): |
|
|
__version__: str |
|
|
functools: _ModuleType |
|
|
dataclasses: _ModuleType |
|
|
|
|
|
PyTreeSpec: type[PyTreeSpec] = PyTreeSpec |
|
|
PyTreeKind: type[PyTreeKind] = PyTreeKind |
|
|
PyTreeEntry: type[PyTreeEntry] = PyTreeEntry |
|
|
flatten = staticmethod(flatten) |
|
|
flatten_with_path = staticmethod(flatten_with_path) |
|
|
flatten_with_accessor = staticmethod(flatten_with_accessor) |
|
|
unflatten = staticmethod(unflatten) |
|
|
iter = staticmethod(iter) |
|
|
leaves = staticmethod(leaves) |
|
|
structure = staticmethod(structure) |
|
|
paths = staticmethod(paths) |
|
|
accessors = staticmethod(accessors) |
|
|
is_leaf = staticmethod(is_leaf) |
|
|
map = staticmethod(map) |
|
|
map_ = staticmethod(map_) |
|
|
map_with_path = staticmethod(map_with_path) |
|
|
map_with_path_ = staticmethod(map_with_path_) |
|
|
map_with_accessor = staticmethod(map_with_accessor) |
|
|
map_with_accessor_ = staticmethod(map_with_accessor_) |
|
|
replace_nones = staticmethod(replace_nones) |
|
|
partition = staticmethod(partition) |
|
|
transpose = staticmethod(transpose) |
|
|
transpose_map = staticmethod(transpose_map) |
|
|
transpose_map_with_path = staticmethod(transpose_map_with_path) |
|
|
transpose_map_with_accessor = staticmethod(transpose_map_with_accessor) |
|
|
broadcast_prefix = staticmethod(broadcast_prefix) |
|
|
broadcast_common = staticmethod(broadcast_common) |
|
|
broadcast_map = staticmethod(broadcast_map) |
|
|
broadcast_map_with_path = staticmethod(broadcast_map_with_path) |
|
|
broadcast_map_with_accessor = staticmethod(broadcast_map_with_accessor) |
|
|
reduce = staticmethod(reduce) |
|
|
sum = staticmethod(sum) |
|
|
max = staticmethod(max) |
|
|
min = staticmethod(min) |
|
|
all = staticmethod(all) |
|
|
any = staticmethod(any) |
|
|
flatten_one_level = staticmethod(flatten_one_level) |
|
|
register_node = staticmethod(register_node) |
|
|
register_node_class = staticmethod(register_node_class) |
|
|
unregister_node = staticmethod(unregister_node) |
|
|
dict_insertion_ordered = staticmethod(dict_insertion_ordered) |
|
|
|
|
|
def reexport(*, namespace: str, module: str | None = None) -> ReexportedPyTreeModule: |
|
|
"""Re-export a pytree utility module with the given namespace as default.""" |
|
|
raise NotImplementedError('reexport() is not available in type checking mode') |
|
|
|
|
|
else: |
|
|
|
|
|
def reexport(*, namespace: str, module: str | None = None) -> _ModuleType: |
|
|
"""Re-export a pytree utility module with the given namespace as default. |
|
|
|
|
|
>>> import optree |
|
|
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree') |
|
|
>>> pytree.flatten({'a': 1, 'b': 2}) |
|
|
([1, 2], PyTreeSpec({'a': *, 'b': *})) |
|
|
|
|
|
This function is useful for downstream libraries that want to re-export the pytree utilities |
|
|
with their own namespace:: |
|
|
|
|
|
# foo/__init__.py |
|
|
import optree |
|
|
pytree = optree.pytree.reexport(namespace='foo') |
|
|
|
|
|
# foo/bar.py |
|
|
from foo import pytree |
|
|
|
|
|
@pytree.dataclasses.dataclass |
|
|
class Bar: |
|
|
a: int |
|
|
b: float |
|
|
|
|
|
print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)})) |
|
|
# Output: |
|
|
# ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo')) |
|
|
|
|
|
Args: |
|
|
namespace (str): The namespace to re-export from. |
|
|
module (str, optional): The name of the module to re-export. |
|
|
If not provided, defaults to ``<caller_module>.pytree``. The caller module is determined |
|
|
by inspecting the stack frame. |
|
|
|
|
|
Returns: |
|
|
The re-exported module. |
|
|
""" |
|
|
|
|
|
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE |
|
|
|
|
|
if namespace is GLOBAL_NAMESPACE: |
|
|
namespace = '' |
|
|
elif not isinstance(namespace, str): |
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
|
|
|
|
|
if module is None: |
|
|
try: |
|
|
|
|
|
caller_module = _sys._getframemodulename(1) or '__main__' |
|
|
except AttributeError: |
|
|
try: |
|
|
|
|
|
caller_module = _sys._getframe(1).f_globals.get('__name__', '__main__') |
|
|
except (AttributeError, ValueError): |
|
|
caller_module = '__main__' |
|
|
module = f'{caller_module}.pytree' |
|
|
if not module or not _all(part.isidentifier() for part in module.split('.')): |
|
|
raise ValueError(f'invalid module name: {module!r}') |
|
|
|
|
|
for module_name in (module, f'{module}.dataclasses', f'{module}.functools'): |
|
|
if module_name in _sys.modules: |
|
|
raise ValueError(f'module {module_name!r} already exists') |
|
|
|
|
|
reexported_dataclasses = ReexportedModule( |
|
|
f'{module}.dataclasses', |
|
|
namespace=namespace, |
|
|
original=dataclasses, |
|
|
) |
|
|
reexported_functools = ReexportedModule( |
|
|
f'{module}.functools', |
|
|
namespace=namespace, |
|
|
original=functools, |
|
|
) |
|
|
mod: ReexportedPyTreeModule = ReexportedModule( |
|
|
module, |
|
|
namespace=namespace, |
|
|
original=_sys.modules[__name__], |
|
|
extra_members={ |
|
|
'__version__': __version__, |
|
|
'dataclasses': reexported_dataclasses, |
|
|
'functools': reexported_functools, |
|
|
}, |
|
|
) |
|
|
_sys.modules[module] = mod |
|
|
_sys.modules[f'{module}.dataclasses'] = reexported_dataclasses |
|
|
_sys.modules[f'{module}.functools'] = reexported_functools |
|
|
return mod |
|
|
|