|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTree integration with :mod:`dataclasses`. |
|
|
|
|
|
This module implements PyTree integration with :mod:`dataclasses` by redefining the :func:`field`, |
|
|
:func:`dataclass`, and :func:`make_dataclass` functions. Other APIs are re-exported from the |
|
|
original :mod:`dataclasses` module. |
|
|
|
|
|
The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields |
|
|
are stored in a special attribute named ``__optree_dataclass_fields__`` in the dataclass. |
|
|
|
|
|
>>> import math |
|
|
... import optree |
|
|
... |
|
|
>>> @optree.dataclasses.dataclass(namespace='my_module') |
|
|
... class Point: |
|
|
... x: float |
|
|
... y: float |
|
|
... z: float = 0.0 |
|
|
... norm: float = optree.dataclasses.field(init=False, pytree_node=False) |
|
|
... |
|
|
... def __post_init__(self) -> None: |
|
|
... self.norm = math.hypot(self.x, self.y, self.z) |
|
|
... |
|
|
>>> point = Point(2.0, 6.0, 3.0) |
|
|
>>> point |
|
|
Point(x=2.0, y=6.0, z=3.0, norm=7.0) |
|
|
>>> # Flatten without specifying the namespace |
|
|
>>> optree.tree_flatten(point) # `Point`s are leaf nodes |
|
|
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*)) |
|
|
>>> # Flatten with the namespace |
|
|
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module') |
|
|
>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS |
|
|
( |
|
|
[ |
|
|
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)), |
|
|
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)), |
|
|
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),)) |
|
|
], |
|
|
[2.0, 6.0, 3.0], |
|
|
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module') |
|
|
) |
|
|
>>> point == optree.tree_unflatten(treespec, leaves) |
|
|
True |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import contextlib |
|
|
import dataclasses |
|
|
import functools |
|
|
import inspect |
|
|
import sys |
|
|
import types |
|
|
from dataclasses import * |
|
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, TypeVar, overload |
|
|
from typing_extensions import dataclass_transform |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Iterable |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [*dataclasses.__all__] |
|
|
|
|
|
|
|
|
_FIELDS = '__optree_dataclass_fields__' |
|
|
_PYTREE_NODE_DEFAULT: bool = True |
|
|
|
|
|
|
|
|
_T = TypeVar('_T') |
|
|
_U = TypeVar('_U') |
|
|
_TypeT = TypeVar('_TypeT', bound=type) |
|
|
|
|
|
|
|
|
@overload |
|
|
def field( |
|
|
*, |
|
|
default: _T, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
hash: bool | None = None, |
|
|
compare: bool = True, |
|
|
metadata: dict[Any, Any] | None = None, |
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, |
|
|
doc: str | None = None, |
|
|
pytree_node: bool | None = None, |
|
|
) -> _T: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def field( |
|
|
*, |
|
|
default_factory: Callable[[], _T], |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
hash: bool | None = None, |
|
|
compare: bool = True, |
|
|
metadata: dict[Any, Any] | None = None, |
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, |
|
|
doc: str | None = None, |
|
|
pytree_node: bool | None = None, |
|
|
) -> _T: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def field( |
|
|
*, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
hash: bool | None = None, |
|
|
compare: bool = True, |
|
|
metadata: dict[Any, Any] | None = None, |
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, |
|
|
doc: str | None = None, |
|
|
pytree_node: bool | None = None, |
|
|
) -> Any: ... |
|
|
|
|
|
|
|
|
def field( |
|
|
*, |
|
|
default: Any = dataclasses.MISSING, |
|
|
default_factory: Any = dataclasses.MISSING, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
hash: bool | None = None, |
|
|
compare: bool = True, |
|
|
metadata: dict[Any, Any] | None = None, |
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, |
|
|
doc: str | None = None, |
|
|
pytree_node: bool | None = None, |
|
|
) -> Any: |
|
|
"""Field factory for :func:`dataclass`. |
|
|
|
|
|
This factory function is used to define the fields in a dataclass. It is similar to the field |
|
|
factory :func:`dataclasses.field`, but with an additional ``pytree_node`` parameter. If |
|
|
``pytree_node`` is :data:`True` (default), the field will be considered a child node in the |
|
|
PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will |
|
|
be considered as PyTree metadata. |
|
|
|
|
|
Setting ``pytree_node`` in the field factory is equivalent to setting a key ``'pytree_node'`` in |
|
|
``metadata`` in the original field factory. The ``pytree_node`` value can be accessed using |
|
|
``field.metadata['pytree_node']``. If ``pytree_node`` is :data:`None`, the value |
|
|
``metadata.get('pytree_node', True)`` will be used. |
|
|
|
|
|
.. note:: |
|
|
If a field is considered a child node, it must be included in the argument list of the |
|
|
:meth:`__init__` method, i.e., passes ``init=True`` in the field factory. |
|
|
|
|
|
Args: |
|
|
pytree_node (bool or None, optional): Whether the field is a PyTree node. |
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.field`. |
|
|
|
|
|
Returns: |
|
|
dataclasses.Field: The field defined using the provided arguments with |
|
|
``field.metadata['pytree_node']`` set. |
|
|
""" |
|
|
metadata = (metadata or {}).copy() |
|
|
if pytree_node is None: |
|
|
pytree_node = metadata.get('pytree_node', _PYTREE_NODE_DEFAULT) |
|
|
metadata['pytree_node'] = pytree_node |
|
|
|
|
|
kwargs = { |
|
|
'default': default, |
|
|
'default_factory': default_factory, |
|
|
'init': init, |
|
|
'repr': repr, |
|
|
'hash': hash, |
|
|
'compare': compare, |
|
|
'metadata': metadata, |
|
|
} |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
kwargs['kw_only'] = kw_only |
|
|
elif kw_only is not dataclasses.MISSING: |
|
|
raise TypeError("field() got an unexpected keyword argument 'kw_only'") |
|
|
|
|
|
if sys.version_info >= (3, 14): |
|
|
kwargs['doc'] = doc |
|
|
elif doc is not None: |
|
|
raise TypeError("field() got an unexpected keyword argument 'doc'") |
|
|
|
|
|
if not init and pytree_node: |
|
|
raise TypeError( |
|
|
'`pytree_node=True` is not allowed for non-init fields. ' |
|
|
f'Please explicitly set `{__name__}.field(init=False, pytree_node=False)`.', |
|
|
) |
|
|
|
|
|
return dataclasses.field(**kwargs) |
|
|
|
|
|
|
|
|
@overload |
|
|
def dataclass( |
|
|
*, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
eq: bool = True, |
|
|
order: bool = False, |
|
|
unsafe_hash: bool = False, |
|
|
frozen: bool = False, |
|
|
match_args: bool = True, |
|
|
kw_only: bool = False, |
|
|
slots: bool = False, |
|
|
weakref_slot: bool = False, |
|
|
namespace: str, |
|
|
) -> Callable[[_TypeT], _TypeT]: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def dataclass( |
|
|
cls: _TypeT, |
|
|
/, |
|
|
*, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
eq: bool = True, |
|
|
order: bool = False, |
|
|
unsafe_hash: bool = False, |
|
|
frozen: bool = False, |
|
|
match_args: bool = True, |
|
|
kw_only: bool = False, |
|
|
slots: bool = False, |
|
|
weakref_slot: bool = False, |
|
|
namespace: str, |
|
|
) -> _TypeT: ... |
|
|
|
|
|
|
|
|
@dataclass_transform(field_specifiers=(field,)) |
|
|
def dataclass( |
|
|
cls: _TypeT | None = None, |
|
|
/, |
|
|
*, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
eq: bool = True, |
|
|
order: bool = False, |
|
|
unsafe_hash: bool = False, |
|
|
frozen: bool = False, |
|
|
match_args: bool = True, |
|
|
kw_only: bool = False, |
|
|
slots: bool = False, |
|
|
weakref_slot: bool = False, |
|
|
namespace: str, |
|
|
) -> _TypeT | Callable[[_TypeT], _TypeT]: |
|
|
"""Dataclass decorator with PyTree integration. |
|
|
|
|
|
Args: |
|
|
cls (type or None, optional): The class to decorate. If :data:`None`, return a decorator. |
|
|
namespace (str): The registry namespace used for the PyTree registration. |
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.dataclass`. |
|
|
|
|
|
Returns: |
|
|
type or callable: The decorated class with PyTree integration or decorator function. |
|
|
""" |
|
|
|
|
|
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE |
|
|
|
|
|
kwargs = { |
|
|
'init': init, |
|
|
'repr': repr, |
|
|
'eq': eq, |
|
|
'order': order, |
|
|
'unsafe_hash': unsafe_hash, |
|
|
'frozen': frozen, |
|
|
} |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
kwargs['match_args'] = match_args |
|
|
kwargs['kw_only'] = kw_only |
|
|
kwargs['slots'] = slots |
|
|
elif match_args is not True: |
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'match_args'") |
|
|
elif kw_only is not False: |
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'kw_only'") |
|
|
elif slots is not False: |
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'slots'") |
|
|
|
|
|
if sys.version_info >= (3, 11): |
|
|
kwargs['weakref_slot'] = weakref_slot |
|
|
elif weakref_slot is not False: |
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'weakref_slot'") |
|
|
|
|
|
if cls is None: |
|
|
|
|
|
def decorator(cls: _TypeT) -> _TypeT: |
|
|
return dataclass(cls, namespace=namespace, **kwargs) |
|
|
|
|
|
return decorator |
|
|
|
|
|
if not inspect.isclass(cls): |
|
|
raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.') |
|
|
if _FIELDS in cls.__dict__: |
|
|
raise TypeError( |
|
|
f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.', |
|
|
) |
|
|
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str): |
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
|
|
if namespace == '': |
|
|
namespace = GLOBAL_NAMESPACE |
|
|
|
|
|
cls = dataclasses.dataclass(cls, **kwargs) |
|
|
|
|
|
children_fields = {} |
|
|
metadata_fields = {} |
|
|
for f in dataclasses.fields(cls): |
|
|
if f.metadata.get('pytree_node', _PYTREE_NODE_DEFAULT): |
|
|
if not f.init: |
|
|
raise TypeError( |
|
|
f'PyTree node field {f.name!r} must be included in `__init__()`. ' |
|
|
f'Or you can explicitly set `{__name__}.field(init=False, pytree_node=False)`.', |
|
|
) |
|
|
children_fields[f.name] = f |
|
|
elif f.init: |
|
|
metadata_fields[f.name] = f |
|
|
|
|
|
children_field_names = tuple(children_fields) |
|
|
children_fields = types.MappingProxyType(children_fields) |
|
|
metadata_fields = types.MappingProxyType(metadata_fields) |
|
|
setattr(cls, _FIELDS, (children_fields, metadata_fields)) |
|
|
|
|
|
def flatten_func( |
|
|
obj: _T, |
|
|
/, |
|
|
) -> tuple[ |
|
|
tuple[_U, ...], |
|
|
tuple[tuple[str, Any], ...], |
|
|
tuple[str, ...], |
|
|
]: |
|
|
children = tuple(getattr(obj, name) for name in children_field_names) |
|
|
metadata = tuple((name, getattr(obj, name)) for name in metadata_fields) |
|
|
return children, metadata, children_field_names |
|
|
|
|
|
|
|
|
def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[_U, ...], /) -> _T: |
|
|
kwargs = dict(zip(children_field_names, children)) |
|
|
kwargs.update(metadata) |
|
|
return cls(**kwargs) |
|
|
|
|
|
from optree.accessors import DataclassEntry |
|
|
from optree.registry import register_pytree_node |
|
|
|
|
|
return register_pytree_node( |
|
|
cls, |
|
|
flatten_func, |
|
|
unflatten_func, |
|
|
path_entry_type=DataclassEntry, |
|
|
namespace=namespace, |
|
|
) |
|
|
|
|
|
|
|
|
class _DataclassDecorator(Protocol[_TypeT]): |
|
|
def __call__( |
|
|
self, |
|
|
cls: _TypeT, |
|
|
/, |
|
|
*, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
eq: bool = True, |
|
|
order: bool = False, |
|
|
unsafe_hash: bool = False, |
|
|
frozen: bool = False, |
|
|
match_args: bool = True, |
|
|
kw_only: bool = False, |
|
|
slots: bool = False, |
|
|
weakref_slot: bool = False, |
|
|
) -> _TypeT: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
def make_dataclass( |
|
|
cls_name: str, |
|
|
|
|
|
fields: Iterable[str | tuple[str, Any] | tuple[str, Any, Any]], |
|
|
*, |
|
|
bases: tuple[type, ...] = (), |
|
|
ns: dict[str, Any] | None = None, |
|
|
init: bool = True, |
|
|
repr: bool = True, |
|
|
eq: bool = True, |
|
|
order: bool = False, |
|
|
unsafe_hash: bool = False, |
|
|
frozen: bool = False, |
|
|
match_args: bool = True, |
|
|
kw_only: bool = False, |
|
|
slots: bool = False, |
|
|
weakref_slot: bool = False, |
|
|
module: str | None = None, |
|
|
decorator: _DataclassDecorator[_TypeT] = dataclasses.dataclass, |
|
|
namespace: str, |
|
|
) -> _TypeT: |
|
|
"""Make a new dynamically created dataclass with PyTree integration. |
|
|
|
|
|
The dataclass name will be ``cls_name``. ``fields`` is an iterable of either (name), (name, type), |
|
|
or (name, type, Field) objects. If type is omitted, use the string :data:`typing.Any`. Field |
|
|
objects are created by the equivalent of calling :func:`field` (name, type [, Field-info]). |
|
|
|
|
|
The ``namespace`` parameter is the PyTree registration namespace which should be a string. The |
|
|
``namespace`` in the original :func:`dataclasses.make_dataclass` function is renamed to ``ns`` |
|
|
to avoid conflicts. |
|
|
|
|
|
The remaining parameters are passed to :func:`dataclasses.make_dataclass`. |
|
|
See :func:`dataclasses.make_dataclass` for more information. |
|
|
|
|
|
Args: |
|
|
cls_name: The name of the dataclass. |
|
|
fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]): An iterable of either |
|
|
(name), (name, type), or (name, type, Field) objects. |
|
|
namespace (str): The registry namespace used for the PyTree registration. |
|
|
ns (dict or None, optional): The namespace used in dynamic type creation. |
|
|
See :func:`dataclasses.make_dataclass` and the builtin :func:`type` function for more |
|
|
information. |
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.make_dataclass`. |
|
|
|
|
|
Returns: |
|
|
type: The dynamically created dataclass with PyTree integration. |
|
|
""" |
|
|
|
|
|
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE |
|
|
|
|
|
if isinstance(namespace, dict) or namespace is None: |
|
|
if ns is GLOBAL_NAMESPACE or isinstance(ns, str): |
|
|
ns, namespace = namespace, ns |
|
|
elif ns is None: |
|
|
raise TypeError("make_dataclass() missing 1 required keyword-only argument: 'ns'") |
|
|
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str): |
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.') |
|
|
if namespace == '': |
|
|
namespace = GLOBAL_NAMESPACE |
|
|
|
|
|
dataclass_kwargs = { |
|
|
'init': init, |
|
|
'repr': repr, |
|
|
'eq': eq, |
|
|
'order': order, |
|
|
'unsafe_hash': unsafe_hash, |
|
|
'frozen': frozen, |
|
|
} |
|
|
make_dataclass_kwargs = { |
|
|
'bases': bases, |
|
|
'namespace': ns, |
|
|
} |
|
|
|
|
|
if sys.version_info >= (3, 10): |
|
|
dataclass_kwargs['match_args'] = match_args |
|
|
dataclass_kwargs['kw_only'] = kw_only |
|
|
dataclass_kwargs['slots'] = slots |
|
|
elif match_args is not True: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'match_args'") |
|
|
elif kw_only is not False: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'kw_only'") |
|
|
elif slots is not False: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'slots'") |
|
|
|
|
|
if sys.version_info >= (3, 11): |
|
|
dataclass_kwargs['weakref_slot'] = weakref_slot |
|
|
elif weakref_slot is not False: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'weakref_slot'") |
|
|
|
|
|
if sys.version_info >= (3, 12): |
|
|
if module is None: |
|
|
try: |
|
|
|
|
|
module = sys._getframemodulename(1) or '__main__' |
|
|
except AttributeError: |
|
|
with contextlib.suppress(AttributeError, ValueError): |
|
|
|
|
|
module = sys._getframe(1).f_globals.get('__name__', '__main__') |
|
|
make_dataclass_kwargs['module'] = module |
|
|
elif module is not None: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'module'") |
|
|
|
|
|
registered_by_decorator = False |
|
|
if sys.version_info >= (3, 14): |
|
|
if decorator in (dataclasses.dataclass, dataclass): |
|
|
decorator = functools.partial(dataclass, namespace=namespace) |
|
|
registered_by_decorator = True |
|
|
make_dataclass_kwargs['decorator'] = decorator |
|
|
elif decorator is not dataclasses.dataclass: |
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'decorator'") |
|
|
|
|
|
cls: _TypeT = dataclasses.make_dataclass( |
|
|
cls_name, |
|
|
fields=fields, |
|
|
**dataclass_kwargs, |
|
|
**make_dataclass_kwargs, |
|
|
) |
|
|
if not registered_by_decorator: |
|
|
dataclass_kwargs.pop('slots', None) |
|
|
dataclass_kwargs.pop('weakref_slot', None) |
|
|
cls = dataclass(cls, **dataclass_kwargs, namespace=namespace) |
|
|
return cls |
|
|
|