joebruce1313's picture
Upload 38004 files
1f5470c verified
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for custom pytree node types."""
# pylint: disable=too-many-lines
from __future__ import annotations
import contextlib
import dataclasses
import inspect
import sys
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import itemgetter, methodcaller
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, overload
import optree._C as _C
from optree.accessors import (
AutoEntry,
MappingEntry,
NamedTupleEntry,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
from optree.typing import PyTreeKind, StructSequence, T, is_namedtuple_class, is_structseq_class
from optree.utils import safe_zip, total_order_sorted, unzip2
if TYPE_CHECKING:
import builtins
from collections.abc import Collection, Generator, Iterable
from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc
# pylint: disable-next=invalid-name
CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode])
__all__ = [
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]
SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+
@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS)
class PyTreeNodeRegistryEntry(Generic[T]):
"""A dataclass that stores the information of a pytree node type."""
type: builtins.type[Collection[T]]
flatten_func: FlattenFunc[T]
unflatten_func: UnflattenFunc[T]
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
_: dataclasses.KW_ONLY # Python 3.10+
path_entry_type: builtins.type[PyTreeEntry] = AutoEntry
kind: PyTreeKind = PyTreeKind.CUSTOM
namespace: str = ''
del SLOTS
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class GlobalNamespace: # pragma: no cover
__slots__: ClassVar[tuple[()]] = ()
def __repr__(self, /) -> str:
return '<GLOBAL NAMESPACE>'
__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment]
__REGISTRY_LOCK: Lock = Lock()
del GlobalNamespace
if TYPE_CHECKING:
from typing_extensions import ParamSpec # Python 3.10+
_P = ParamSpec('_P')
_T = TypeVar('_T')
_GetP = ParamSpec('_GetP')
_GetT = TypeVar('_GetT')
class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]):
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
raise NotImplementedError
# pylint: disable-next=missing-function-docstring
def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT:
raise NotImplementedError
def _add_get(
get: Callable[_GetP, _GetT],
/,
) -> Callable[
[Callable[_P, _T]],
_CallableWithGet[_P, _T, _GetP, _GetT],
]:
def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]:
func.get = get # type: ignore[attr-defined]
return func # type: ignore[return-value]
return decorator
@overload
def pytree_node_registry_get(
cls: type,
/,
*,
namespace: str = '',
) -> PyTreeNodeRegistryEntry | None: ...
@overload
def pytree_node_registry_get(
cls: None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry]: ...
# pylint: disable-next=too-many-return-statements,too-many-branches
def pytree_node_registry_get( # noqa: C901
cls: type | None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None:
"""Lookup the pytree node registry.
>>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS
{
<class 'NoneType'>: PyTreeNodeRegistryEntry(
type=<class 'NoneType'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.PyTreeEntry'>,
kind=<PyTreeKind.NONE: 2>,
namespace=''
),
<class 'tuple'>: PyTreeNodeRegistryEntry(
type=<class 'tuple'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.SequenceEntry'>,
kind=<PyTreeKind.TUPLE: 3>,
namespace=''
),
<class 'list'>: PyTreeNodeRegistryEntry(
type=<class 'list'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.SequenceEntry'>,
kind=<PyTreeKind.LIST: 4>,
namespace=''
),
...
}
>>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
PyTreeNodeRegistryEntry(
type=<class 'collections.defaultdict'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.MappingEntry'>,
kind=<PyTreeKind.DEFAULTDICT: 8>,
namespace=''
)
>>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node
None
Args:
cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all
the registered pytree nodes in the namespace are returned.
namespace (str, optional): The namespace of the registry to retrieve. If not provided, the
global namespace is used.
Returns:
If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the
namespace is returned. If the ``cls`` is provided, the corresponding registry entry is
returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned,
i.e., the ``cls`` is represented as a leaf node.
"""
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
if (
cls is not None
and cls is not namedtuple # noqa: PYI024
and not inspect.isclass(cls)
):
raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover
if not isinstance(namespace, str):
raise TypeError( # pragma: !=3.9 cover
f'The namespace must be a string, got {namespace!r}.',
)
if cls is None:
namespaces = frozenset({namespace, ''})
with __REGISTRY_LOCK:
registry = {
handler.type: handler
for handler in _NODETYPE_REGISTRY.values()
if handler.namespace in namespaces
}
if _C.is_dict_insertion_ordered(namespace):
registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
return registry
if namespace != '':
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler
if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
if cls is defaultdict:
return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(StructSequence)
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
return None
@_add_get(pytree_node_registry_get)
def register_pytree_node(
cls: type[Collection[T]],
/,
flatten_func: FlattenFunc[T],
unflatten_func: UnflattenFunc[T],
*,
path_entry_type: type[PyTreeEntry] = AutoEntry,
namespace: str,
) -> type[Collection[T]]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
and returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable metadata to be stored in the treespec and
to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_func (callable): A function taking two arguments: the metadata that was returned
by ``flatten_func`` and stored in the treespec, and the unflattened children. The
function should return an instance of ``cls``.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
This is used to isolate the registry from other modules that might register a different
custom behavior for the same type.
Returns:
The same type as the input ``cls``.
Raises:
TypeError: If the input type is not a class.
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
Examples:
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='set',
... )
<class 'set'>
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
<class 'torch.Tensor'>
>>> # doctest: +SKIP
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy')
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
},
namespace='torch2numpy'
)
)
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
... def flatparam2tensor(metadata, children):
... return children[0].reshape(metadata)
...
... register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
<class 'torch.Tensor'>
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam')
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
""" # pylint: disable=line-too-long
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.register_node(
cls,
flatten_func,
unflatten_func,
path_entry_type,
namespace,
)
_NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry(
cls,
flatten_func,
unflatten_func,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
del pytree_node_registry_get, _add_get
@overload
def register_pytree_node_class(
cls: str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ...
@overload
def register_pytree_node_class(
cls: CustomTreeNodeType,
/,
*,
path_entry_type: type[PyTreeEntry] | None,
namespace: str,
) -> CustomTreeNodeType: ...
def register_pytree_node_class( # noqa: C901
cls: CustomTreeNodeType | str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type, optional): A Python type to treat as an internal pytree node.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type.
Returns:
The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
function that registers the class as a pytree node.
Raises:
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
This function is a thin wrapper around :func:`register_pytree_node`, and provides a
class-oriented interface::
@register_pytree_node_class(namespace='foo')
class Special:
TREE_PATH_ENTRY_TYPE = GetAttrEntry
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten(self):
return ((self.x, self.y), None, ('x', 'y'))
@classmethod
def tree_unflatten(cls, metadata, children):
return cls(*children)
@register_pytree_node_class('mylist')
class MyList(UserList):
TREE_PATH_ENTRY_TYPE = SequenceEntry
def tree_flatten(self):
return self.data, None, None
@classmethod
def tree_unflatten(cls, metadata, children):
return cls(*children)
"""
if cls is __GLOBAL_NAMESPACE or isinstance(cls, str):
if namespace is not None:
raise ValueError('Cannot specify `namespace` when the first argument is a string.')
if cls == '':
raise ValueError('The namespace cannot be an empty string.')
cls, namespace = None, cls
if namespace is None:
raise ValueError('Must specify `namespace` when the first argument is a class.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if cls is None:
def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType:
return register_pytree_node_class(
cls,
path_entry_type=path_entry_type,
namespace=namespace,
)
return decorator
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if path_entry_type is None:
path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry)
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
register_pytree_node(
cls,
methodcaller('tree_flatten'),
cls.tree_unflatten,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry:
"""Remove a type from the pytree node registry.
See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.
This function is the inverse operation of function :func:`register_pytree_node`.
Args:
cls (type): A Python type to remove from the pytree node registry.
namespace (str): The namespace of the pytree node registry to remove the type from.
Returns:
The removed registry entry.
Raises:
TypeError: If the input type is not a class.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is a built-in type that cannot be unregistered.
ValueError: If the type is not found in the registry.
Examples:
>>> # Register a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='temp',
... )
<class 'set'>
>>> # Unregister the Python type
>>> unregister_pytree_node(set, namespace='temp')
"""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.unregister_node(cls, namespace)
return _NODETYPE_REGISTRY.pop(registration_key)
@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]:
"""Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)
.. warning::
The dictionary sorting mode is a global setting and is **not thread-safe**. It is
recommended to use this context manager in a single-threaded environment.
Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
_C.set_dict_insertion_ordered(bool(mode), namespace)
try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)
def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]:
return total_order_sorted(items, key=itemgetter(0))
def _none_flatten(_: None, /) -> tuple[tuple[()], None]:
return (), None
def _none_unflatten(_: None, children: Iterable[Any], /) -> None:
sentinel = object()
if next(iter(children), sentinel) is not sentinel:
raise ValueError('Expected no children.')
def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]:
return tup, None
def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]:
return tuple(children)
def _list_flatten(lst: list[T], /) -> tuple[list[T], None]:
return lst, None
def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]:
return list(children)
def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(_sorted_items(dct.items()))
return values, list(keys), keys
def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]:
return OrderedDict(safe_zip(keys, values))
def _defaultdict_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, keys, entries = _dict_flatten(dct)
return values, (dct.default_factory, keys), entries
def _defaultdict_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_unflatten(keys, values))
def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, keys, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, keys), entries
def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))
def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]:
return deq, deq.maxlen
def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]:
return deque(children, maxlen=maxlen)
def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg]
return tup, type(tup)
# pylint: disable-next=line-too-long
def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg]
return cls(*children) # type: ignore[call-overload]
def _structseq_flatten(seq: StructSequence[T], /) -> tuple[tuple[T, ...], type[StructSequence[T]]]:
return seq, type(seq)
def _structseq_unflatten(
cls: type[StructSequence[T]],
children: Iterable[T],
/,
) -> StructSequence[T]:
return cls(children)
_NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = {
type(None): PyTreeNodeRegistryEntry(
type(None), # type: ignore[arg-type]
_none_flatten,
_none_unflatten,
path_entry_type=PyTreeEntry,
kind=PyTreeKind.NONE,
),
tuple: PyTreeNodeRegistryEntry(
tuple,
_tuple_flatten,
_tuple_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.TUPLE,
),
list: PyTreeNodeRegistryEntry(
list,
_list_flatten,
_list_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.LIST,
),
dict: PyTreeNodeRegistryEntry(
dict,
_dict_flatten,
_dict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
),
namedtuple: PyTreeNodeRegistryEntry( # type: ignore[dict-item] # noqa: PYI024
namedtuple, # type: ignore[arg-type] # noqa: PYI024
_namedtuple_flatten,
_namedtuple_unflatten,
path_entry_type=NamedTupleEntry,
kind=PyTreeKind.NAMEDTUPLE,
),
OrderedDict: PyTreeNodeRegistryEntry(
OrderedDict,
_ordereddict_flatten,
_ordereddict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.ORDEREDDICT,
),
defaultdict: PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_flatten,
_defaultdict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
),
deque: PyTreeNodeRegistryEntry(
deque,
_deque_flatten,
_deque_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.DEQUE,
),
StructSequence: PyTreeNodeRegistryEntry(
StructSequence,
_structseq_flatten,
_structseq_unflatten,
path_entry_type=StructSequenceEntry,
kind=PyTreeKind.STRUCTSEQUENCE,
),
}
_DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten,
_dict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
)
_DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten,
_defaultdict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
)