File size: 14,789 Bytes
1f5470c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utilities for working with ``PyTree``\s.

The :mod:`optree.pytree` namespace contains aliases of ``optree.tree_*`` utilities.

>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec  # doctest: +IGNORE_WHITESPACE
(
    [1, 2, 3, 4, 5],
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True

.. versionadded:: 0.14.1
"""

from __future__ import annotations

import functools as _functools
import inspect as _inspect
import sys as _sys
from builtins import all as _all
from types import ModuleType as _ModuleType
from typing import TYPE_CHECKING as _TYPE_CHECKING

import optree.dataclasses as dataclasses
import optree.functools as functools
from optree.accessors import PyTreeEntry
from optree.ops import tree_accessors as accessors
from optree.ops import tree_all as all  # pylint: disable=redefined-builtin
from optree.ops import tree_any as any  # pylint: disable=redefined-builtin
from optree.ops import tree_broadcast_common as broadcast_common
from optree.ops import tree_broadcast_map as broadcast_map
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path
from optree.ops import tree_broadcast_prefix as broadcast_prefix
from optree.ops import tree_flatten as flatten
from optree.ops import tree_flatten_one_level as flatten_one_level
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor
from optree.ops import tree_flatten_with_path as flatten_with_path
from optree.ops import tree_is_leaf as is_leaf
from optree.ops import tree_iter as iter  # pylint: disable=redefined-builtin
from optree.ops import tree_leaves as leaves
from optree.ops import tree_map as map  # pylint: disable=redefined-builtin
from optree.ops import tree_map_ as map_
from optree.ops import tree_map_with_accessor as map_with_accessor
from optree.ops import tree_map_with_accessor_ as map_with_accessor_
from optree.ops import tree_map_with_path as map_with_path
from optree.ops import tree_map_with_path_ as map_with_path_
from optree.ops import tree_max as max  # pylint: disable=redefined-builtin
from optree.ops import tree_min as min  # pylint: disable=redefined-builtin
from optree.ops import tree_partition as partition
from optree.ops import tree_paths as paths
from optree.ops import tree_reduce as reduce
from optree.ops import tree_replace_nones as replace_nones
from optree.ops import tree_structure as structure
from optree.ops import tree_sum as sum  # pylint: disable=redefined-builtin
from optree.ops import tree_transpose as transpose
from optree.ops import tree_transpose_map as transpose_map
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path
from optree.ops import tree_unflatten as unflatten
from optree.registry import dict_insertion_ordered
from optree.registry import register_pytree_node as register_node
from optree.registry import register_pytree_node_class as register_node_class
from optree.registry import unregister_pytree_node as unregister_node
from optree.typing import PyTreeKind, PyTreeSpec
from optree.version import __version__ as __version__  # pylint: disable=useless-import-alias


__all__ = [
    'reexport',
    'PyTreeSpec',
    'PyTreeKind',
    'PyTreeEntry',
    'flatten',
    'flatten_with_path',
    'flatten_with_accessor',
    'unflatten',
    'iter',
    'leaves',
    'structure',
    'paths',
    'accessors',
    'is_leaf',
    'map',
    'map_',
    'map_with_path',
    'map_with_path_',
    'map_with_accessor',
    'map_with_accessor_',
    'replace_nones',
    'partition',
    'transpose',
    'transpose_map',
    'transpose_map_with_path',
    'transpose_map_with_accessor',
    'broadcast_prefix',
    'broadcast_common',
    'broadcast_map',
    'broadcast_map_with_path',
    'broadcast_map_with_accessor',
    'reduce',
    'sum',
    'max',
    'min',
    'all',
    'any',
    'flatten_one_level',
    'register_node',
    'register_node_class',
    'unregister_node',
    'dict_insertion_ordered',
]


if _TYPE_CHECKING:
    from collections.abc import Callable, Iterable
    from typing import Any, TypeVar  # pylint: disable=ungrouped-imports
    from typing_extensions import ParamSpec  # Python 3.10+

    _P = ParamSpec('_P')
    _T = TypeVar('_T')


class ReexportedModule(_ModuleType):
    """A module that re-exports APIs from another module."""

    __doc__: str

    def __init__(
        self,
        name: str,
        *,
        namespace: str,
        original: _ModuleType,
        doc: str | None = None,
        __all__: Iterable[str] | None = None,
        __dir__: Iterable[str] | None = None,
        extra_members: dict[str, Any] | None = None,
    ) -> None:
        doc = doc or (
            f'Re-exports :mod:`{original.__name__}` as :mod:`{name}` '
            f'with namespace :const:`{namespace!r}`.'
        )
        super().__init__(name, doc)

        if __all__ is None:  # pragma: no branch
            __all__ = {n for n in original.__all__ if n != 'reexport'}
        __all__ = set(__all__)
        if __dir__ is None:  # pragma: no branch
            __dir__ = {n for n in original.__dir__() if not n.startswith('_') and n != 'reexport'}
        __dir__ = set(__dir__).intersection(__all__)

        if extra_members:
            for key, value in extra_members.items():
                setattr(self, key, value)
            __dir__.update(extra_members)

        self.__namespace = namespace
        self.__original = original
        self.__all_set = __all__
        self.__all = sorted(__all__)
        self.__dir = sorted(__dir__)

    @property
    def __all__(self) -> list[str]:
        """Return the list of attributes available in this module."""
        return self.__all

    def __dir__(self) -> list[str]:
        """Return the list of attributes available in this module."""
        return self.__dir.copy()

    def __getattr__(self, name: str, /) -> Any:
        """Get an attribute from the re-exported module."""
        if name in self.__all_set:
            attr = getattr(self.__original, name)
            if _inspect.isfunction(attr):
                attr = self.__reexport__(attr)
            setattr(self, name, attr)
            return attr
        raise AttributeError(f'module {self.__name__!r} has no attribute {name!r}')

    def __reexport__(self, func: Callable[_P, _T], /) -> Callable[_P, _T]:
        """Re-export a function with the default namespace."""
        sig = _inspect.signature(func)
        if 'namespace' not in sig.parameters:

            @_functools.wraps(func)
            def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
                return func(*args, **kwargs)
        else:

            @_functools.wraps(func)
            def wrapped(  # type: ignore[valid-type]
                *args: _P.args,
                namespace: str = self.__namespace,
                **kwargs: _P.kwargs,
            ) -> _T:
                return func(*args, namespace=namespace, **kwargs)  # type: ignore[arg-type]

            if func.__doc__:  # pragma: no branch
                wrapped.__doc__ = func.__doc__.replace(
                    "(default: :const:`''`, i.e., the global namespace)",
                    f'(default: :const:`{self.__namespace!r}`)',
                )
            wrapped.__signature__ = sig.replace(  # type: ignore[attr-defined]
                parameters=[
                    p if p.name != 'namespace' else p.replace(default=self.__namespace)
                    for p in sig.parameters.values()
                ],
            )

        if callable(getattr(func, 'get', None)):
            wrapped.get = self.__reexport__(func.get)  # type: ignore[attr-defined]

        return wrapped


if _TYPE_CHECKING:
    # pylint: disable-next=missing-class-docstring,too-few-public-methods
    class ReexportedPyTreeModule(ReexportedModule):
        __version__: str
        functools: _ModuleType
        dataclasses: _ModuleType

        PyTreeSpec: type[PyTreeSpec] = PyTreeSpec
        PyTreeKind: type[PyTreeKind] = PyTreeKind
        PyTreeEntry: type[PyTreeEntry] = PyTreeEntry
        flatten = staticmethod(flatten)
        flatten_with_path = staticmethod(flatten_with_path)
        flatten_with_accessor = staticmethod(flatten_with_accessor)
        unflatten = staticmethod(unflatten)
        iter = staticmethod(iter)
        leaves = staticmethod(leaves)
        structure = staticmethod(structure)
        paths = staticmethod(paths)
        accessors = staticmethod(accessors)
        is_leaf = staticmethod(is_leaf)
        map = staticmethod(map)
        map_ = staticmethod(map_)
        map_with_path = staticmethod(map_with_path)
        map_with_path_ = staticmethod(map_with_path_)
        map_with_accessor = staticmethod(map_with_accessor)
        map_with_accessor_ = staticmethod(map_with_accessor_)
        replace_nones = staticmethod(replace_nones)
        partition = staticmethod(partition)
        transpose = staticmethod(transpose)
        transpose_map = staticmethod(transpose_map)
        transpose_map_with_path = staticmethod(transpose_map_with_path)
        transpose_map_with_accessor = staticmethod(transpose_map_with_accessor)
        broadcast_prefix = staticmethod(broadcast_prefix)
        broadcast_common = staticmethod(broadcast_common)
        broadcast_map = staticmethod(broadcast_map)
        broadcast_map_with_path = staticmethod(broadcast_map_with_path)
        broadcast_map_with_accessor = staticmethod(broadcast_map_with_accessor)
        reduce = staticmethod(reduce)
        sum = staticmethod(sum)
        max = staticmethod(max)
        min = staticmethod(min)
        all = staticmethod(all)
        any = staticmethod(any)
        flatten_one_level = staticmethod(flatten_one_level)
        register_node = staticmethod(register_node)
        register_node_class = staticmethod(register_node_class)
        unregister_node = staticmethod(unregister_node)
        dict_insertion_ordered = staticmethod(dict_insertion_ordered)

    def reexport(*, namespace: str, module: str | None = None) -> ReexportedPyTreeModule:
        """Re-export a pytree utility module with the given namespace as default."""
        raise NotImplementedError('reexport() is not available in type checking mode')

else:

    def reexport(*, namespace: str, module: str | None = None) -> _ModuleType:  # type: ignore[misc]
        """Re-export a pytree utility module with the given namespace as default.

        >>> import optree
        >>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
        >>> pytree.flatten({'a': 1, 'b': 2})
        ([1, 2], PyTreeSpec({'a': *, 'b': *}))

        This function is useful for downstream libraries that want to re-export the pytree utilities
        with their own namespace::

            # foo/__init__.py
            import optree
            pytree = optree.pytree.reexport(namespace='foo')

            # foo/bar.py
            from foo import pytree

            @pytree.dataclasses.dataclass
            class Bar:
                a: int
                b: float

            print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
            # Output:
            #   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

        Args:
            namespace (str): The namespace to re-export from.
            module (str, optional): The name of the module to re-export.
                If not provided, defaults to ``<caller_module>.pytree``. The caller module is determined
                by inspecting the stack frame.

        Returns:
            The re-exported module.
        """
        # pylint: disable-next=import-outside-toplevel
        from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE

        if namespace is GLOBAL_NAMESPACE:
            namespace = ''
        elif not isinstance(namespace, str):
            raise TypeError(f'The namespace must be a string, got {namespace!r}.')

        if module is None:
            try:
                # pylint: disable-next=protected-access
                caller_module = _sys._getframemodulename(1) or '__main__'  # type: ignore[attr-defined]
            except AttributeError:  # pragma: no cover
                try:
                    # pylint: disable-next=protected-access
                    caller_module = _sys._getframe(1).f_globals.get('__name__', '__main__')
                except (AttributeError, ValueError):
                    caller_module = '__main__'
            module = f'{caller_module}.pytree'
        if not module or not _all(part.isidentifier() for part in module.split('.')):
            raise ValueError(f'invalid module name: {module!r}')

        for module_name in (module, f'{module}.dataclasses', f'{module}.functools'):
            if module_name in _sys.modules:
                raise ValueError(f'module {module_name!r} already exists')

        reexported_dataclasses = ReexportedModule(
            f'{module}.dataclasses',
            namespace=namespace,
            original=dataclasses,
        )
        reexported_functools = ReexportedModule(
            f'{module}.functools',
            namespace=namespace,
            original=functools,
        )
        mod: ReexportedPyTreeModule = ReexportedModule(  # type: ignore[assignment]
            module,
            namespace=namespace,
            original=_sys.modules[__name__],
            extra_members={
                '__version__': __version__,
                'dataclasses': reexported_dataclasses,
                'functools': reexported_functools,
            },
        )
        _sys.modules[module] = mod
        _sys.modules[f'{module}.dataclasses'] = reexported_dataclasses
        _sys.modules[f'{module}.functools'] = reexported_functools
        return mod