| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Registry for custom pytree node types.""" |
| |
|
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import contextlib |
| | import dataclasses |
| | import inspect |
| | import sys |
| | from collections import OrderedDict, defaultdict, deque, namedtuple |
| | from operator import itemgetter, methodcaller |
| | from threading import Lock |
| | from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, overload |
| |
|
| | import optree._C as _C |
| | from optree.accessors import ( |
| | AutoEntry, |
| | MappingEntry, |
| | NamedTupleEntry, |
| | PyTreeEntry, |
| | SequenceEntry, |
| | StructSequenceEntry, |
| | ) |
| | from optree.typing import PyTreeKind, StructSequence, T, is_namedtuple_class, is_structseq_class |
| | from optree.utils import safe_zip, total_order_sorted, unzip2 |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | import builtins |
| | from collections.abc import Collection, Generator, Iterable |
| |
|
| | from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc |
| |
|
| | |
| | CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode]) |
| |
|
| |
|
| | __all__ = [ |
| | 'register_pytree_node', |
| | 'register_pytree_node_class', |
| | 'unregister_pytree_node', |
| | 'dict_insertion_ordered', |
| | ] |
| |
|
| |
|
| | SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} |
| |
|
| |
|
| | @dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS) |
| | class PyTreeNodeRegistryEntry(Generic[T]): |
| | """A dataclass that stores the information of a pytree node type.""" |
| |
|
| | type: builtins.type[Collection[T]] |
| | flatten_func: FlattenFunc[T] |
| | unflatten_func: UnflattenFunc[T] |
| |
|
| | if sys.version_info >= (3, 10): |
| | _: dataclasses.KW_ONLY |
| |
|
| | path_entry_type: builtins.type[PyTreeEntry] = AutoEntry |
| | kind: PyTreeKind = PyTreeKind.CUSTOM |
| | namespace: str = '' |
| |
|
| |
|
| | del SLOTS |
| |
|
| |
|
| | |
| | class GlobalNamespace: |
| | __slots__: ClassVar[tuple[()]] = () |
| |
|
| | def __repr__(self, /) -> str: |
| | return '<GLOBAL NAMESPACE>' |
| |
|
| |
|
| | __GLOBAL_NAMESPACE: str = GlobalNamespace() |
| | __REGISTRY_LOCK: Lock = Lock() |
| | del GlobalNamespace |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from typing_extensions import ParamSpec |
| |
|
| | _P = ParamSpec('_P') |
| | _T = TypeVar('_T') |
| | _GetP = ParamSpec('_GetP') |
| | _GetT = TypeVar('_GetT') |
| |
|
| | class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]): |
| | def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T: |
| | raise NotImplementedError |
| |
|
| | |
| | def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT: |
| | raise NotImplementedError |
| |
|
| |
|
| | def _add_get( |
| | get: Callable[_GetP, _GetT], |
| | /, |
| | ) -> Callable[ |
| | [Callable[_P, _T]], |
| | _CallableWithGet[_P, _T, _GetP, _GetT], |
| | ]: |
| | def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]: |
| | func.get = get |
| | return func |
| |
|
| | return decorator |
| |
|
| |
|
| | @overload |
| | def pytree_node_registry_get( |
| | cls: type, |
| | /, |
| | *, |
| | namespace: str = '', |
| | ) -> PyTreeNodeRegistryEntry | None: ... |
| |
|
| |
|
| | @overload |
| | def pytree_node_registry_get( |
| | cls: None = None, |
| | /, |
| | *, |
| | namespace: str = '', |
| | ) -> dict[type, PyTreeNodeRegistryEntry]: ... |
| |
|
| |
|
| | |
| | def pytree_node_registry_get( |
| | cls: type | None = None, |
| | /, |
| | *, |
| | namespace: str = '', |
| | ) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None: |
| | """Lookup the pytree node registry. |
| | |
| | >>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
| | { |
| | <class 'NoneType'>: PyTreeNodeRegistryEntry( |
| | type=<class 'NoneType'>, |
| | flatten_func=<function ...>, |
| | unflatten_func=<function ...>, |
| | path_entry_type=<class 'optree.PyTreeEntry'>, |
| | kind=<PyTreeKind.NONE: 2>, |
| | namespace='' |
| | ), |
| | <class 'tuple'>: PyTreeNodeRegistryEntry( |
| | type=<class 'tuple'>, |
| | flatten_func=<function ...>, |
| | unflatten_func=<function ...>, |
| | path_entry_type=<class 'optree.SequenceEntry'>, |
| | kind=<PyTreeKind.TUPLE: 3>, |
| | namespace='' |
| | ), |
| | <class 'list'>: PyTreeNodeRegistryEntry( |
| | type=<class 'list'>, |
| | flatten_func=<function ...>, |
| | unflatten_func=<function ...>, |
| | path_entry_type=<class 'optree.SequenceEntry'>, |
| | kind=<PyTreeKind.LIST: 4>, |
| | namespace='' |
| | ), |
| | ... |
| | } |
| | >>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
| | PyTreeNodeRegistryEntry( |
| | type=<class 'collections.defaultdict'>, |
| | flatten_func=<function ...>, |
| | unflatten_func=<function ...>, |
| | path_entry_type=<class 'optree.MappingEntry'>, |
| | kind=<PyTreeKind.DEFAULTDICT: 8>, |
| | namespace='' |
| | ) |
| | >>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node |
| | None |
| | |
| | Args: |
| | cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all |
| | the registered pytree nodes in the namespace are returned. |
| | namespace (str, optional): The namespace of the registry to retrieve. If not provided, the |
| | global namespace is used. |
| | |
| | Returns: |
| | If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the |
| | namespace is returned. If the ``cls`` is provided, the corresponding registry entry is |
| | returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned, |
| | i.e., the ``cls`` is represented as a leaf node. |
| | """ |
| | if namespace is __GLOBAL_NAMESPACE: |
| | namespace = '' |
| | if ( |
| | cls is not None |
| | and cls is not namedtuple |
| | and not inspect.isclass(cls) |
| | ): |
| | raise TypeError(f'Expected a class or None, got {cls!r}.') |
| | if not isinstance(namespace, str): |
| | raise TypeError( |
| | f'The namespace must be a string, got {namespace!r}.', |
| | ) |
| |
|
| | if cls is None: |
| | namespaces = frozenset({namespace, ''}) |
| | with __REGISTRY_LOCK: |
| | registry = { |
| | handler.type: handler |
| | for handler in _NODETYPE_REGISTRY.values() |
| | if handler.namespace in namespaces |
| | } |
| | if _C.is_dict_insertion_ordered(namespace): |
| | registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY |
| | registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY |
| | return registry |
| |
|
| | if namespace != '': |
| | handler = _NODETYPE_REGISTRY.get((namespace, cls)) |
| | if handler is not None: |
| | return handler |
| |
|
| | if _C.is_dict_insertion_ordered(namespace): |
| | if cls is dict: |
| | return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY |
| | if cls is defaultdict: |
| | return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY |
| |
|
| | handler = _NODETYPE_REGISTRY.get(cls) |
| | if handler is not None: |
| | return handler |
| | if is_structseq_class(cls): |
| | return _NODETYPE_REGISTRY.get(StructSequence) |
| | if is_namedtuple_class(cls): |
| | return _NODETYPE_REGISTRY.get(namedtuple) |
| | return None |
| |
|
| |
|
| | @_add_get(pytree_node_registry_get) |
| | def register_pytree_node( |
| | cls: type[Collection[T]], |
| | /, |
| | flatten_func: FlattenFunc[T], |
| | unflatten_func: UnflattenFunc[T], |
| | *, |
| | path_entry_type: type[PyTreeEntry] = AutoEntry, |
| | namespace: str, |
| | ) -> type[Collection[T]]: |
| | """Extend the set of types that are considered internal nodes in pytrees. |
| | |
| | See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`. |
| | |
| | The ``namespace`` argument is used to avoid collisions that occur when different libraries |
| | register the same Python type with different behaviors. It is recommended to add a unique prefix |
| | to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify |
| | the same class in different namespaces for different use cases. |
| | |
| | .. warning:: |
| | For safety reasons, a ``namespace`` must be specified while registering a custom type. It is |
| | used to isolate the behavior of flattening and unflattening a pytree node type. This is to |
| | prevent accidental collisions between different libraries that may register the same type. |
| | |
| | Args: |
| | cls (type): A Python type to treat as an internal pytree node. |
| | flatten_func (callable): A function to be used during flattening, taking an instance of ``cls`` |
| | and returning a triple or optionally a pair, with (1) an iterable for the children to be |
| | flattened recursively, and (2) some hashable metadata to be stored in the treespec and |
| | to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path |
| | entries to the corresponding children. If the entries are not provided or given by |
| | :data:`None`, then `range(len(children))` will be used. |
| | unflatten_func (callable): A function taking two arguments: the metadata that was returned |
| | by ``flatten_func`` and stored in the treespec, and the unflattened children. The |
| | function should return an instance of ``cls``. |
| | path_entry_type (type, optional): The type of the path entry to be used in the treespec. |
| | (default: :class:`AutoEntry`) |
| | namespace (str): A non-empty string that uniquely identifies the namespace of the type registry. |
| | This is used to isolate the registry from other modules that might register a different |
| | custom behavior for the same type. |
| | |
| | Returns: |
| | The same type as the input ``cls``. |
| | |
| | Raises: |
| | TypeError: If the input type is not a class. |
| | TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`. |
| | TypeError: If the namespace is not a string. |
| | ValueError: If the namespace is an empty string. |
| | ValueError: If the type is already registered in the registry. |
| | |
| | Examples: |
| | >>> # Registry a Python type with lambda functions |
| | >>> register_pytree_node( |
| | ... set, |
| | ... lambda s: (sorted(s), None, None), |
| | ... lambda _, children: set(children), |
| | ... namespace='set', |
| | ... ) |
| | <class 'set'> |
| | |
| | >>> # Register a Python type into a namespace |
| | >>> import torch |
| | >>> register_pytree_node( |
| | ... torch.Tensor, |
| | ... flatten_func=lambda tensor: ( |
| | ... (tensor.cpu().detach().numpy(),), |
| | ... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad}, |
| | ... ), |
| | ... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata), |
| | ... namespace='torch2numpy', |
| | ... ) |
| | <class 'torch.Tensor'> |
| | |
| | >>> # doctest: +SKIP |
| | >>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))} |
| | >>> tree |
| | {'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])} |
| | |
| | >>> # Flatten without specifying the namespace |
| | >>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes |
| | ([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *})) |
| | |
| | >>> # Flatten with the namespace |
| | >>> tree_flatten(tree, namespace='torch2numpy') |
| | ( |
| | [array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)], |
| | PyTreeSpec( |
| | { |
| | 'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]), |
| | 'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*]) |
| | }, |
| | namespace='torch2numpy' |
| | ) |
| | ) |
| | |
| | >>> # Register the same type with a different namespace for different behaviors |
| | >>> def tensor2flatparam(tensor): |
| | ... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None |
| | ... |
| | ... def flatparam2tensor(metadata, children): |
| | ... return children[0].reshape(metadata) |
| | ... |
| | ... register_pytree_node( |
| | ... torch.Tensor, |
| | ... flatten_func=tensor2flatparam, |
| | ... unflatten_func=flatparam2tensor, |
| | ... namespace='tensor2flatparam', |
| | ... ) |
| | <class 'torch.Tensor'> |
| | |
| | >>> # Flatten with the new namespace |
| | >>> tree_flatten(tree, namespace='tensor2flatparam') |
| | ( |
| | [ |
| | Parameter containing: tensor([0., 0.], requires_grad=True), |
| | Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True) |
| | ], |
| | PyTreeSpec( |
| | { |
| | 'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]), |
| | 'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*]) |
| | }, |
| | namespace='tensor2flatparam' |
| | ) |
| | ) |
| | """ |
| | if not inspect.isclass(cls): |
| | raise TypeError(f'Expected a class, got {cls!r}.') |
| | if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): |
| | raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.') |
| | if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): |
| | raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
| | if namespace == '': |
| | raise ValueError('The namespace cannot be an empty string.') |
| |
|
| | registration_key: type | tuple[str, type] |
| | if namespace is __GLOBAL_NAMESPACE: |
| | registration_key = cls |
| | namespace = '' |
| | else: |
| | registration_key = (namespace, cls) |
| |
|
| | with __REGISTRY_LOCK: |
| | _C.register_node( |
| | cls, |
| | flatten_func, |
| | unflatten_func, |
| | path_entry_type, |
| | namespace, |
| | ) |
| | _NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry( |
| | cls, |
| | flatten_func, |
| | unflatten_func, |
| | path_entry_type=path_entry_type, |
| | namespace=namespace, |
| | ) |
| | return cls |
| |
|
| |
|
| | del pytree_node_registry_get, _add_get |
| |
|
| |
|
| | @overload |
| | def register_pytree_node_class( |
| | cls: str | None = None, |
| | /, |
| | *, |
| | path_entry_type: type[PyTreeEntry] | None = None, |
| | namespace: str | None = None, |
| | ) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ... |
| |
|
| |
|
| | @overload |
| | def register_pytree_node_class( |
| | cls: CustomTreeNodeType, |
| | /, |
| | *, |
| | path_entry_type: type[PyTreeEntry] | None, |
| | namespace: str, |
| | ) -> CustomTreeNodeType: ... |
| |
|
| |
|
| | def register_pytree_node_class( |
| | cls: CustomTreeNodeType | str | None = None, |
| | /, |
| | *, |
| | path_entry_type: type[PyTreeEntry] | None = None, |
| | namespace: str | None = None, |
| | ) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]: |
| | """Extend the set of types that are considered internal nodes in pytrees. |
| | |
| | See also :func:`register_pytree_node` and :func:`unregister_pytree_node`. |
| | |
| | The ``namespace`` argument is used to avoid collisions that occur when different libraries |
| | register the same Python type with different behaviors. It is recommended to add a unique prefix |
| | to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify |
| | the same class in different namespaces for different use cases. |
| | |
| | .. warning:: |
| | For safety reasons, a ``namespace`` must be specified while registering a custom type. It is |
| | used to isolate the behavior of flattening and unflattening a pytree node type. This is to |
| | prevent accidental collisions between different libraries that may register the same type. |
| | |
| | Args: |
| | cls (type, optional): A Python type to treat as an internal pytree node. |
| | path_entry_type (type, optional): The type of the path entry to be used in the treespec. |
| | (default: :class:`AutoEntry`) |
| | namespace (str, optional): A non-empty string that uniquely identifies the namespace of the |
| | type registry. This is used to isolate the registry from other modules that might |
| | register a different custom behavior for the same type. |
| | |
| | Returns: |
| | The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator |
| | function that registers the class as a pytree node. |
| | |
| | Raises: |
| | TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`. |
| | TypeError: If the namespace is not a string. |
| | ValueError: If the namespace is an empty string. |
| | ValueError: If the type is already registered in the registry. |
| | |
| | This function is a thin wrapper around :func:`register_pytree_node`, and provides a |
| | class-oriented interface:: |
| | |
| | @register_pytree_node_class(namespace='foo') |
| | class Special: |
| | TREE_PATH_ENTRY_TYPE = GetAttrEntry |
| | |
| | def __init__(self, x, y): |
| | self.x = x |
| | self.y = y |
| | |
| | def tree_flatten(self): |
| | return ((self.x, self.y), None, ('x', 'y')) |
| | |
| | @classmethod |
| | def tree_unflatten(cls, metadata, children): |
| | return cls(*children) |
| | |
| | @register_pytree_node_class('mylist') |
| | class MyList(UserList): |
| | TREE_PATH_ENTRY_TYPE = SequenceEntry |
| | |
| | def tree_flatten(self): |
| | return self.data, None, None |
| | |
| | @classmethod |
| | def tree_unflatten(cls, metadata, children): |
| | return cls(*children) |
| | """ |
| | if cls is __GLOBAL_NAMESPACE or isinstance(cls, str): |
| | if namespace is not None: |
| | raise ValueError('Cannot specify `namespace` when the first argument is a string.') |
| | if cls == '': |
| | raise ValueError('The namespace cannot be an empty string.') |
| | cls, namespace = None, cls |
| |
|
| | if namespace is None: |
| | raise ValueError('Must specify `namespace` when the first argument is a class.') |
| | if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): |
| | raise TypeError(f'The namespace must be a string, got {namespace!r}') |
| | if namespace == '': |
| | raise ValueError('The namespace cannot be an empty string.') |
| |
|
| | if cls is None: |
| |
|
| | def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType: |
| | return register_pytree_node_class( |
| | cls, |
| | path_entry_type=path_entry_type, |
| | namespace=namespace, |
| | ) |
| |
|
| | return decorator |
| |
|
| | if not inspect.isclass(cls): |
| | raise TypeError(f'Expected a class, got {cls!r}.') |
| | if path_entry_type is None: |
| | path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry) |
| | if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): |
| | raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.') |
| |
|
| | register_pytree_node( |
| | cls, |
| | methodcaller('tree_flatten'), |
| | cls.tree_unflatten, |
| | path_entry_type=path_entry_type, |
| | namespace=namespace, |
| | ) |
| | return cls |
| |
|
| |
|
| | def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry: |
| | """Remove a type from the pytree node registry. |
| | |
| | See also :func:`register_pytree_node` and :func:`register_pytree_node_class`. |
| | |
| | This function is the inverse operation of function :func:`register_pytree_node`. |
| | |
| | Args: |
| | cls (type): A Python type to remove from the pytree node registry. |
| | namespace (str): The namespace of the pytree node registry to remove the type from. |
| | |
| | Returns: |
| | The removed registry entry. |
| | |
| | Raises: |
| | TypeError: If the input type is not a class. |
| | TypeError: If the namespace is not a string. |
| | ValueError: If the namespace is an empty string. |
| | ValueError: If the type is a built-in type that cannot be unregistered. |
| | ValueError: If the type is not found in the registry. |
| | |
| | Examples: |
| | >>> # Register a Python type with lambda functions |
| | >>> register_pytree_node( |
| | ... set, |
| | ... lambda s: (sorted(s), None, None), |
| | ... lambda _, children: set(children), |
| | ... namespace='temp', |
| | ... ) |
| | <class 'set'> |
| | |
| | >>> # Unregister the Python type |
| | >>> unregister_pytree_node(set, namespace='temp') |
| | """ |
| | if not inspect.isclass(cls): |
| | raise TypeError(f'Expected a class, got {cls!r}.') |
| | if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): |
| | raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
| | if namespace == '': |
| | raise ValueError('The namespace cannot be an empty string.') |
| |
|
| | registration_key: type | tuple[str, type] |
| | if namespace is __GLOBAL_NAMESPACE: |
| | registration_key = cls |
| | namespace = '' |
| | else: |
| | registration_key = (namespace, cls) |
| |
|
| | with __REGISTRY_LOCK: |
| | _C.unregister_node(cls, namespace) |
| | return _NODETYPE_REGISTRY.pop(registration_key) |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]: |
| | """Context manager to temporarily set the dictionary sorting mode. |
| | |
| | This context manager is used to temporarily set the dictionary sorting mode for a specific |
| | namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary |
| | should be sorted or keeping the insertion order when flattening a pytree. |
| | |
| | >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} |
| | >>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE |
| | ( |
| | [1, 2, 3, 4, 5], |
| | PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *}) |
| | ) |
| | >>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE |
| | ... tree_flatten(tree, namespace='some-namespace') |
| | ( |
| | [2, 3, 4, 1, 5], |
| | PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace') |
| | ) |
| | |
| | .. warning:: |
| | The dictionary sorting mode is a global setting and is **not thread-safe**. It is |
| | recommended to use this context manager in a single-threaded environment. |
| | |
| | Args: |
| | mode (bool): The dictionary sorting mode to set. |
| | namespace (str): The namespace to set the dictionary sorting mode for. |
| | """ |
| | if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): |
| | raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
| | if namespace == '': |
| | raise ValueError('The namespace cannot be an empty string.') |
| | if namespace is __GLOBAL_NAMESPACE: |
| | namespace = '' |
| |
|
| | with __REGISTRY_LOCK: |
| | prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False) |
| | _C.set_dict_insertion_ordered(bool(mode), namespace) |
| |
|
| | try: |
| | yield |
| | finally: |
| | with __REGISTRY_LOCK: |
| | _C.set_dict_insertion_ordered(prev, namespace) |
| |
|
| |
|
| | def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]: |
| | return total_order_sorted(items, key=itemgetter(0)) |
| |
|
| |
|
| | def _none_flatten(_: None, /) -> tuple[tuple[()], None]: |
| | return (), None |
| |
|
| |
|
| | def _none_unflatten(_: None, children: Iterable[Any], /) -> None: |
| | sentinel = object() |
| | if next(iter(children), sentinel) is not sentinel: |
| | raise ValueError('Expected no children.') |
| |
|
| |
|
| | def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]: |
| | return tup, None |
| |
|
| |
|
| | def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]: |
| | return tuple(children) |
| |
|
| |
|
| | def _list_flatten(lst: list[T], /) -> tuple[list[T], None]: |
| | return lst, None |
| |
|
| |
|
| | def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]: |
| | return list(children) |
| |
|
| |
|
| | def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: |
| | keys, values = unzip2(_sorted_items(dct.items())) |
| | return values, list(keys), keys |
| |
|
| |
|
| | def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: |
| | return dict(safe_zip(keys, values)) |
| |
|
| |
|
| | def _dict_insertion_ordered_flatten( |
| | dct: dict[KT, VT], |
| | /, |
| | ) -> tuple[ |
| | tuple[VT, ...], |
| | list[KT], |
| | tuple[KT, ...], |
| | ]: |
| | keys, values = unzip2(dct.items()) |
| | return values, list(keys), keys |
| |
|
| |
|
| | def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: |
| | return dict(safe_zip(keys, values)) |
| |
|
| |
|
| | def _ordereddict_flatten( |
| | dct: OrderedDict[KT, VT], |
| | /, |
| | ) -> tuple[ |
| | tuple[VT, ...], |
| | list[KT], |
| | tuple[KT, ...], |
| | ]: |
| | keys, values = unzip2(dct.items()) |
| | return values, list(keys), keys |
| |
|
| |
|
| | def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]: |
| | return OrderedDict(safe_zip(keys, values)) |
| |
|
| |
|
| | def _defaultdict_flatten( |
| | dct: defaultdict[KT, VT], |
| | /, |
| | ) -> tuple[ |
| | tuple[VT, ...], |
| | tuple[Callable[[], VT] | None, list[KT]], |
| | tuple[KT, ...], |
| | ]: |
| | values, keys, entries = _dict_flatten(dct) |
| | return values, (dct.default_factory, keys), entries |
| |
|
| |
|
| | def _defaultdict_unflatten( |
| | metadata: tuple[Callable[[], VT], list[KT]], |
| | values: Iterable[VT], |
| | /, |
| | ) -> defaultdict[KT, VT]: |
| | default_factory, keys = metadata |
| | return defaultdict(default_factory, _dict_unflatten(keys, values)) |
| |
|
| |
|
| | def _defaultdict_insertion_ordered_flatten( |
| | dct: defaultdict[KT, VT], |
| | /, |
| | ) -> tuple[ |
| | tuple[VT, ...], |
| | tuple[Callable[[], VT] | None, list[KT]], |
| | tuple[KT, ...], |
| | ]: |
| | values, keys, entries = _dict_insertion_ordered_flatten(dct) |
| | return values, (dct.default_factory, keys), entries |
| |
|
| |
|
| | def _defaultdict_insertion_ordered_unflatten( |
| | metadata: tuple[Callable[[], VT], list[KT]], |
| | values: Iterable[VT], |
| | /, |
| | ) -> defaultdict[KT, VT]: |
| | default_factory, keys = metadata |
| | return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values)) |
| |
|
| |
|
| | def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]: |
| | return deq, deq.maxlen |
| |
|
| |
|
| | def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]: |
| | return deque(children, maxlen=maxlen) |
| |
|
| |
|
| | def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: |
| | return tup, type(tup) |
| |
|
| |
|
| | |
| | def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: |
| | return cls(*children) |
| |
|
| |
|
| | def _structseq_flatten(seq: StructSequence[T], /) -> tuple[tuple[T, ...], type[StructSequence[T]]]: |
| | return seq, type(seq) |
| |
|
| |
|
| | def _structseq_unflatten( |
| | cls: type[StructSequence[T]], |
| | children: Iterable[T], |
| | /, |
| | ) -> StructSequence[T]: |
| | return cls(children) |
| |
|
| |
|
| | _NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = { |
| | type(None): PyTreeNodeRegistryEntry( |
| | type(None), |
| | _none_flatten, |
| | _none_unflatten, |
| | path_entry_type=PyTreeEntry, |
| | kind=PyTreeKind.NONE, |
| | ), |
| | tuple: PyTreeNodeRegistryEntry( |
| | tuple, |
| | _tuple_flatten, |
| | _tuple_unflatten, |
| | path_entry_type=SequenceEntry, |
| | kind=PyTreeKind.TUPLE, |
| | ), |
| | list: PyTreeNodeRegistryEntry( |
| | list, |
| | _list_flatten, |
| | _list_unflatten, |
| | path_entry_type=SequenceEntry, |
| | kind=PyTreeKind.LIST, |
| | ), |
| | dict: PyTreeNodeRegistryEntry( |
| | dict, |
| | _dict_flatten, |
| | _dict_unflatten, |
| | path_entry_type=MappingEntry, |
| | kind=PyTreeKind.DICT, |
| | ), |
| | namedtuple: PyTreeNodeRegistryEntry( |
| | namedtuple, |
| | _namedtuple_flatten, |
| | _namedtuple_unflatten, |
| | path_entry_type=NamedTupleEntry, |
| | kind=PyTreeKind.NAMEDTUPLE, |
| | ), |
| | OrderedDict: PyTreeNodeRegistryEntry( |
| | OrderedDict, |
| | _ordereddict_flatten, |
| | _ordereddict_unflatten, |
| | path_entry_type=MappingEntry, |
| | kind=PyTreeKind.ORDEREDDICT, |
| | ), |
| | defaultdict: PyTreeNodeRegistryEntry( |
| | defaultdict, |
| | _defaultdict_flatten, |
| | _defaultdict_unflatten, |
| | path_entry_type=MappingEntry, |
| | kind=PyTreeKind.DEFAULTDICT, |
| | ), |
| | deque: PyTreeNodeRegistryEntry( |
| | deque, |
| | _deque_flatten, |
| | _deque_unflatten, |
| | path_entry_type=SequenceEntry, |
| | kind=PyTreeKind.DEQUE, |
| | ), |
| | StructSequence: PyTreeNodeRegistryEntry( |
| | StructSequence, |
| | _structseq_flatten, |
| | _structseq_unflatten, |
| | path_entry_type=StructSequenceEntry, |
| | kind=PyTreeKind.STRUCTSEQUENCE, |
| | ), |
| | } |
| |
|
| |
|
| | _DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry( |
| | dict, |
| | _dict_insertion_ordered_flatten, |
| | _dict_insertion_ordered_unflatten, |
| | path_entry_type=MappingEntry, |
| | kind=PyTreeKind.DICT, |
| | ) |
| | _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry( |
| | defaultdict, |
| | _defaultdict_insertion_ordered_flatten, |
| | _defaultdict_insertion_ordered_unflatten, |
| | path_entry_type=MappingEntry, |
| | kind=PyTreeKind.DEFAULTDICT, |
| | ) |
| |
|