joebruce1313's picture
Upload 38004 files
1f5470c verified
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utilities for working with ``PyTree``\s.
The :mod:`optree.pytree` namespace contains aliases of ``optree.tree_*`` utilities.
>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True
.. versionadded:: 0.14.1
"""
from __future__ import annotations
import functools as _functools
import inspect as _inspect
import sys as _sys
from builtins import all as _all
from types import ModuleType as _ModuleType
from typing import TYPE_CHECKING as _TYPE_CHECKING
import optree.dataclasses as dataclasses
import optree.functools as functools
from optree.accessors import PyTreeEntry
from optree.ops import tree_accessors as accessors
from optree.ops import tree_all as all # pylint: disable=redefined-builtin
from optree.ops import tree_any as any # pylint: disable=redefined-builtin
from optree.ops import tree_broadcast_common as broadcast_common
from optree.ops import tree_broadcast_map as broadcast_map
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path
from optree.ops import tree_broadcast_prefix as broadcast_prefix
from optree.ops import tree_flatten as flatten
from optree.ops import tree_flatten_one_level as flatten_one_level
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor
from optree.ops import tree_flatten_with_path as flatten_with_path
from optree.ops import tree_is_leaf as is_leaf
from optree.ops import tree_iter as iter # pylint: disable=redefined-builtin
from optree.ops import tree_leaves as leaves
from optree.ops import tree_map as map # pylint: disable=redefined-builtin
from optree.ops import tree_map_ as map_
from optree.ops import tree_map_with_accessor as map_with_accessor
from optree.ops import tree_map_with_accessor_ as map_with_accessor_
from optree.ops import tree_map_with_path as map_with_path
from optree.ops import tree_map_with_path_ as map_with_path_
from optree.ops import tree_max as max # pylint: disable=redefined-builtin
from optree.ops import tree_min as min # pylint: disable=redefined-builtin
from optree.ops import tree_partition as partition
from optree.ops import tree_paths as paths
from optree.ops import tree_reduce as reduce
from optree.ops import tree_replace_nones as replace_nones
from optree.ops import tree_structure as structure
from optree.ops import tree_sum as sum # pylint: disable=redefined-builtin
from optree.ops import tree_transpose as transpose
from optree.ops import tree_transpose_map as transpose_map
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path
from optree.ops import tree_unflatten as unflatten
from optree.registry import dict_insertion_ordered
from optree.registry import register_pytree_node as register_node
from optree.registry import register_pytree_node_class as register_node_class
from optree.registry import unregister_pytree_node as unregister_node
from optree.typing import PyTreeKind, PyTreeSpec
from optree.version import __version__ as __version__ # pylint: disable=useless-import-alias
__all__ = [
'reexport',
'PyTreeSpec',
'PyTreeKind',
'PyTreeEntry',
'flatten',
'flatten_with_path',
'flatten_with_accessor',
'unflatten',
'iter',
'leaves',
'structure',
'paths',
'accessors',
'is_leaf',
'map',
'map_',
'map_with_path',
'map_with_path_',
'map_with_accessor',
'map_with_accessor_',
'replace_nones',
'partition',
'transpose',
'transpose_map',
'transpose_map_with_path',
'transpose_map_with_accessor',
'broadcast_prefix',
'broadcast_common',
'broadcast_map',
'broadcast_map_with_path',
'broadcast_map_with_accessor',
'reduce',
'sum',
'max',
'min',
'all',
'any',
'flatten_one_level',
'register_node',
'register_node_class',
'unregister_node',
'dict_insertion_ordered',
]
if _TYPE_CHECKING:
from collections.abc import Callable, Iterable
from typing import Any, TypeVar # pylint: disable=ungrouped-imports
from typing_extensions import ParamSpec # Python 3.10+
_P = ParamSpec('_P')
_T = TypeVar('_T')
class ReexportedModule(_ModuleType):
"""A module that re-exports APIs from another module."""
__doc__: str
def __init__(
self,
name: str,
*,
namespace: str,
original: _ModuleType,
doc: str | None = None,
__all__: Iterable[str] | None = None,
__dir__: Iterable[str] | None = None,
extra_members: dict[str, Any] | None = None,
) -> None:
doc = doc or (
f'Re-exports :mod:`{original.__name__}` as :mod:`{name}` '
f'with namespace :const:`{namespace!r}`.'
)
super().__init__(name, doc)
if __all__ is None: # pragma: no branch
__all__ = {n for n in original.__all__ if n != 'reexport'}
__all__ = set(__all__)
if __dir__ is None: # pragma: no branch
__dir__ = {n for n in original.__dir__() if not n.startswith('_') and n != 'reexport'}
__dir__ = set(__dir__).intersection(__all__)
if extra_members:
for key, value in extra_members.items():
setattr(self, key, value)
__dir__.update(extra_members)
self.__namespace = namespace
self.__original = original
self.__all_set = __all__
self.__all = sorted(__all__)
self.__dir = sorted(__dir__)
@property
def __all__(self) -> list[str]:
"""Return the list of attributes available in this module."""
return self.__all
def __dir__(self) -> list[str]:
"""Return the list of attributes available in this module."""
return self.__dir.copy()
def __getattr__(self, name: str, /) -> Any:
"""Get an attribute from the re-exported module."""
if name in self.__all_set:
attr = getattr(self.__original, name)
if _inspect.isfunction(attr):
attr = self.__reexport__(attr)
setattr(self, name, attr)
return attr
raise AttributeError(f'module {self.__name__!r} has no attribute {name!r}')
def __reexport__(self, func: Callable[_P, _T], /) -> Callable[_P, _T]:
"""Re-export a function with the default namespace."""
sig = _inspect.signature(func)
if 'namespace' not in sig.parameters:
@_functools.wraps(func)
def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
return func(*args, **kwargs)
else:
@_functools.wraps(func)
def wrapped( # type: ignore[valid-type]
*args: _P.args,
namespace: str = self.__namespace,
**kwargs: _P.kwargs,
) -> _T:
return func(*args, namespace=namespace, **kwargs) # type: ignore[arg-type]
if func.__doc__: # pragma: no branch
wrapped.__doc__ = func.__doc__.replace(
"(default: :const:`''`, i.e., the global namespace)",
f'(default: :const:`{self.__namespace!r}`)',
)
wrapped.__signature__ = sig.replace( # type: ignore[attr-defined]
parameters=[
p if p.name != 'namespace' else p.replace(default=self.__namespace)
for p in sig.parameters.values()
],
)
if callable(getattr(func, 'get', None)):
wrapped.get = self.__reexport__(func.get) # type: ignore[attr-defined]
return wrapped
if _TYPE_CHECKING:
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class ReexportedPyTreeModule(ReexportedModule):
__version__: str
functools: _ModuleType
dataclasses: _ModuleType
PyTreeSpec: type[PyTreeSpec] = PyTreeSpec
PyTreeKind: type[PyTreeKind] = PyTreeKind
PyTreeEntry: type[PyTreeEntry] = PyTreeEntry
flatten = staticmethod(flatten)
flatten_with_path = staticmethod(flatten_with_path)
flatten_with_accessor = staticmethod(flatten_with_accessor)
unflatten = staticmethod(unflatten)
iter = staticmethod(iter)
leaves = staticmethod(leaves)
structure = staticmethod(structure)
paths = staticmethod(paths)
accessors = staticmethod(accessors)
is_leaf = staticmethod(is_leaf)
map = staticmethod(map)
map_ = staticmethod(map_)
map_with_path = staticmethod(map_with_path)
map_with_path_ = staticmethod(map_with_path_)
map_with_accessor = staticmethod(map_with_accessor)
map_with_accessor_ = staticmethod(map_with_accessor_)
replace_nones = staticmethod(replace_nones)
partition = staticmethod(partition)
transpose = staticmethod(transpose)
transpose_map = staticmethod(transpose_map)
transpose_map_with_path = staticmethod(transpose_map_with_path)
transpose_map_with_accessor = staticmethod(transpose_map_with_accessor)
broadcast_prefix = staticmethod(broadcast_prefix)
broadcast_common = staticmethod(broadcast_common)
broadcast_map = staticmethod(broadcast_map)
broadcast_map_with_path = staticmethod(broadcast_map_with_path)
broadcast_map_with_accessor = staticmethod(broadcast_map_with_accessor)
reduce = staticmethod(reduce)
sum = staticmethod(sum)
max = staticmethod(max)
min = staticmethod(min)
all = staticmethod(all)
any = staticmethod(any)
flatten_one_level = staticmethod(flatten_one_level)
register_node = staticmethod(register_node)
register_node_class = staticmethod(register_node_class)
unregister_node = staticmethod(unregister_node)
dict_insertion_ordered = staticmethod(dict_insertion_ordered)
def reexport(*, namespace: str, module: str | None = None) -> ReexportedPyTreeModule:
"""Re-export a pytree utility module with the given namespace as default."""
raise NotImplementedError('reexport() is not available in type checking mode')
else:
def reexport(*, namespace: str, module: str | None = None) -> _ModuleType: # type: ignore[misc]
"""Re-export a pytree utility module with the given namespace as default.
>>> import optree
>>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
>>> pytree.flatten({'a': 1, 'b': 2})
([1, 2], PyTreeSpec({'a': *, 'b': *}))
This function is useful for downstream libraries that want to re-export the pytree utilities
with their own namespace::
# foo/__init__.py
import optree
pytree = optree.pytree.reexport(namespace='foo')
# foo/bar.py
from foo import pytree
@pytree.dataclasses.dataclass
class Bar:
a: int
b: float
print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
# Output:
# ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))
Args:
namespace (str): The namespace to re-export from.
module (str, optional): The name of the module to re-export.
If not provided, defaults to ``<caller_module>.pytree``. The caller module is determined
by inspecting the stack frame.
Returns:
The re-exported module.
"""
# pylint: disable-next=import-outside-toplevel
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
if namespace is GLOBAL_NAMESPACE:
namespace = ''
elif not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if module is None:
try:
# pylint: disable-next=protected-access
caller_module = _sys._getframemodulename(1) or '__main__' # type: ignore[attr-defined]
except AttributeError: # pragma: no cover
try:
# pylint: disable-next=protected-access
caller_module = _sys._getframe(1).f_globals.get('__name__', '__main__')
except (AttributeError, ValueError):
caller_module = '__main__'
module = f'{caller_module}.pytree'
if not module or not _all(part.isidentifier() for part in module.split('.')):
raise ValueError(f'invalid module name: {module!r}')
for module_name in (module, f'{module}.dataclasses', f'{module}.functools'):
if module_name in _sys.modules:
raise ValueError(f'module {module_name!r} already exists')
reexported_dataclasses = ReexportedModule(
f'{module}.dataclasses',
namespace=namespace,
original=dataclasses,
)
reexported_functools = ReexportedModule(
f'{module}.functools',
namespace=namespace,
original=functools,
)
mod: ReexportedPyTreeModule = ReexportedModule( # type: ignore[assignment]
module,
namespace=namespace,
original=_sys.modules[__name__],
extra_members={
'__version__': __version__,
'dataclasses': reexported_dataclasses,
'functools': reexported_functools,
},
)
_sys.modules[module] = mod
_sys.modules[f'{module}.dataclasses'] = reexported_dataclasses
_sys.modules[f'{module}.functools'] = reexported_functools
return mod