joebruce1313's picture
Upload 38004 files
1f5470c verified
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Typing utilities for OpTree."""
from __future__ import annotations
import abc
import functools
import platform
import sys
import threading
import types
from builtins import dict as Dict # noqa: N812
from builtins import list as List # noqa: N812
from builtins import tuple as Tuple # noqa: N812
from collections import OrderedDict
from collections import defaultdict as DefaultDict # noqa: N812
from collections import deque as Deque # noqa: N812
from collections.abc import (
Collection,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
Sequence,
ValuesView,
)
from typing import (
Any,
Callable,
ClassVar,
Final,
ForwardRef,
Generic,
Optional,
Protocol,
TypeVar,
Union,
final,
get_origin,
runtime_checkable,
)
from typing_extensions import (
NamedTuple, # Generic NamedTuple: Python 3.11+
Never, # Python 3.11+
ParamSpec, # Python 3.10+
Self, # Python 3.11+
TypeAlias, # Python 3.10+
TypeAliasType, # Python 3.12+
)
from weakref import WeakKeyDictionary # pylint: disable=wrong-import-order
import optree._C as _C
from optree._C import PyTreeKind, PyTreeSpec
from optree.accessors import (
AutoEntry,
DataclassEntry,
FlattenedEntry,
GetAttrEntry,
GetItemEntry,
MappingEntry,
NamedTupleEntry,
PyTreeAccessor,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
__all__ = [
'PyTreeSpec',
'PyTreeDef',
'PyTreeKind',
'PyTree',
'PyTreeTypeVar',
'CustomTreeNode',
'Children',
'MetaData',
'FlattenFunc',
'UnflattenFunc',
'PyTreeEntry',
'GetItemEntry',
'GetAttrEntry',
'FlattenedEntry',
'AutoEntry',
'SequenceEntry',
'MappingEntry',
'NamedTupleEntry',
'StructSequenceEntry',
'DataclassEntry',
'PyTreeAccessor',
'is_namedtuple',
'is_namedtuple_class',
'is_namedtuple_instance',
'namedtuple_fields',
'is_structseq',
'is_structseq_class',
'is_structseq_instance',
'structseq_fields',
'T',
'S',
'U',
'KT',
'VT',
'P',
'F',
'Iterable',
'Sequence',
'Tuple',
'List',
'Dict',
'NamedTuple',
'OrderedDict',
'DefaultDict',
'Deque',
'StructSequence',
]
PyTreeDef: TypeAlias = PyTreeSpec # alias
T = TypeVar('T')
S = TypeVar('S')
U = TypeVar('U')
KT = TypeVar('KT')
VT = TypeVar('VT')
P = ParamSpec('P')
F = TypeVar('F', bound=Callable[..., Any])
Children: TypeAlias = Iterable[T]
MetaData: TypeAlias = Optional[Hashable]
@runtime_checkable
class CustomTreeNode(Protocol[T]):
"""The abstract base class for custom pytree nodes."""
def tree_flatten(
self,
/,
) -> (
# Use `range(num_children)` as path entries
tuple[Children[T], MetaData]
|
# With optionally implemented path entries
tuple[Children[T], MetaData, Iterable[Any] | None]
):
"""Flatten the custom pytree node into children and metadata."""
@classmethod
def tree_unflatten(cls, metadata: MetaData, children: Children[T], /) -> Self:
"""Unflatten the children and metadata into the custom pytree node."""
_UnionType = type(Union[int, str])
try: # pragma: no cover
from typing import _tp_cache # type: ignore[attr-defined] # pylint: disable=ungrouped-imports
except ImportError: # pragma: no cover
def _tp_cache(func: Callable[P, T], /) -> Callable[P, T]:
cached = functools.lru_cache(func)
@functools.wraps(func)
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return cached(*args, **kwargs) # type: ignore[arg-type]
except TypeError:
# All real errors (not unhashable args) are raised below.
return func(*args, **kwargs)
return inner
class PyTree(Generic[T]): # pragma: no cover
"""Generic PyTree type.
>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
list[ForwardRef('PyTree[torch.Tensor]')],
dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
collections.deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
"""
__slots__: ClassVar[tuple[()]] = ()
__instances__: ClassVar[
WeakKeyDictionary[
TypeAliasType,
tuple[type | TypeAliasType, str | None],
]
] = WeakKeyDictionary()
__instance_lock__: ClassVar[threading.Lock] = threading.Lock()
@_tp_cache
def __class_getitem__( # noqa: C901 # pylint: disable=too-many-branches
cls,
item: (
type[T]
| TypeAliasType
| tuple[type[T] | TypeAliasType]
| tuple[type[T] | TypeAliasType, str | None]
),
) -> TypeAliasType:
"""Instantiate a PyTree type with the given type."""
if not isinstance(item, tuple):
item = (item, None)
if len(item) == 1:
item = (item[0], None)
elif len(item) != 2:
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
param, name = item
if name is not None and not isinstance(name, str):
raise TypeError(
f'{cls.__name__}[...] only supports a tuple of 2 items, '
f'a parameter and a string of type name, got {item!r}.',
)
if isinstance(param, _UnionType) and get_origin(param) is Union: # type: ignore[unreachable]
with cls.__instance_lock__: # type: ignore[unreachable]
try:
if param in cls.__instances__:
return param # PyTree[PyTree[T]] -> PyTree[T]
except TypeError:
pass # non-hashable type
if name is not None:
recurse_ref = ForwardRef(name)
elif isinstance(param, TypeVar):
recurse_ref = ForwardRef(f'{cls.__name__}[{param.__name__}]') # type: ignore[unreachable]
elif isinstance(param, type):
if param.__module__ == 'builtins':
typename = param.__qualname__
else:
try:
typename = f'{param.__module__}.{param.__qualname__}'
except AttributeError:
typename = f'{param.__module__}.{param.__name__}'
recurse_ref = ForwardRef(f'{cls.__name__}[{typename}]')
else:
recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]')
pytree_alias = Union[
param, # type: ignore[valid-type]
Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence
List[recurse_ref], # type: ignore[valid-type]
Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict
Deque[recurse_ref], # type: ignore[valid-type]
CustomTreeNode[recurse_ref], # type: ignore[valid-type]
]
with cls.__instance_lock__:
cls.__instances__[pytree_alias] = (param, name) # type: ignore[index]
return pytree_alias # type: ignore[return-value]
def __new__(cls, /) -> Never: # pylint: disable=arguments-differ
"""Prohibit instantiation."""
raise TypeError('Cannot instantiate special typing classes.')
def __init_subclass__(cls, /, *args: Any, **kwargs: Any) -> Never:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
def __getitem__(self, key: Any, /) -> PyTree[T] | T:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __getattr__(self, name: str, /) -> PyTree[T] | T:
"""Emulate dataclass-like behavior."""
raise NotImplementedError
def __contains__(self, key: Any | T, /) -> bool:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __len__(self, /) -> int:
"""Emulate collection-like behavior."""
raise NotImplementedError
def __iter__(self, /) -> Iterator[PyTree[T] | T | Any]:
"""Emulate collection-like behavior."""
raise NotImplementedError
def index(self, key: Any | T, /) -> int:
"""Emulate sequence-like behavior."""
raise NotImplementedError
def count(self, key: Any | T, /) -> int:
"""Emulate sequence-like behavior."""
raise NotImplementedError
def get(self, key: Any, /, default: T | None = None) -> T | None:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def keys(self, /) -> KeysView[Any]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def values(self, /) -> ValuesView[PyTree[T] | T]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
def items(self, /) -> ItemsView[Any, PyTree[T] | T]:
"""Emulate mapping-like behavior."""
raise NotImplementedError
# pylint: disable-next=too-few-public-methods
class PyTreeTypeVar: # pragma: no cover
"""Type variable for PyTree.
>>> import torch
>>> TensorTree = PyTreeTypeVar('TensorTree', torch.Tensor)
>>> TensorTree # doctest: +IGNORE_WHITESPACE
typing.Union[torch.Tensor,
tuple[ForwardRef('TensorTree'), ...],
list[ForwardRef('TensorTree')],
dict[typing.Any, ForwardRef('TensorTree')],
collections.deque[ForwardRef('TensorTree')],
optree.typing.CustomTreeNode[ForwardRef('TensorTree')]]
"""
@_tp_cache
def __new__(cls, /, name: str, param: type | TypeAliasType) -> TypeAliasType: # type: ignore[misc]
"""Instantiate a PyTree type variable with the given name and parameter."""
if not isinstance(name, str):
raise TypeError(f'{cls.__name__} only supports a string of type name, got {name!r}.')
return PyTree[param, name] # type: ignore[misc,valid-type]
def __init_subclass__(cls, /, *args: Any, **kwargs: Any) -> Never:
"""Prohibit subclassing."""
raise TypeError('Cannot subclass special typing classes.')
class FlattenFunc(Protocol[T]): # pylint: disable=too-few-public-methods
"""The type stub class for flatten functions."""
@abc.abstractmethod
def __call__(
self,
container: Collection[T],
/,
) -> tuple[Children[T], MetaData] | tuple[Children[T], MetaData, Iterable[Any] | None]:
"""Flatten the container into children and metadata."""
class UnflattenFunc(Protocol[T]): # pylint: disable=too-few-public-methods
"""The type stub class for unflatten functions."""
@abc.abstractmethod
def __call__(self, metadata: MetaData, children: Children[T], /) -> Collection[T]:
"""Unflatten the children and metadata back into the container."""
def _override_with_(
cxx_implementation: Callable[P, T],
/,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to override the Python implementation with the C++ implementation.
>>> @_override_with_(any)
... def my_any(iterable):
... for elem in iterable:
... if elem:
... return True
... return False
...
>>> my_any([False, False, True, False, False, True]) # run at C speed
True
"""
def wrapper(python_implementation: Callable[P, T], /) -> Callable[P, T]:
@functools.wraps(python_implementation)
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
return cxx_implementation(*args, **kwargs)
wrapped.__cxx_implementation__ = cxx_implementation # type: ignore[attr-defined]
wrapped.__python_implementation__ = python_implementation # type: ignore[attr-defined]
return wrapped
return wrapper
@_override_with_(_C.is_namedtuple)
def is_namedtuple(obj: object | type, /) -> bool:
"""Return whether the object is an instance of namedtuple or a subclass of namedtuple."""
cls = obj if isinstance(obj, type) else type(obj)
return is_namedtuple_class(cls)
@_override_with_(_C.is_namedtuple_instance)
def is_namedtuple_instance(obj: object, /) -> bool:
"""Return whether the object is an instance of namedtuple."""
return is_namedtuple_class(type(obj))
@_override_with_(_C.is_namedtuple_class)
def is_namedtuple_class(cls: type, /) -> bool:
"""Return whether the class is a subclass of namedtuple."""
return (
isinstance(cls, type)
and issubclass(cls, tuple)
and isinstance(getattr(cls, '_fields', None), tuple)
# pylint: disable-next=unidiomatic-typecheck
and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined]
and callable(getattr(cls, '_make', None))
and callable(getattr(cls, '_asdict', None))
)
@_override_with_(_C.namedtuple_fields)
def namedtuple_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]:
"""Return the field names of a namedtuple."""
if isinstance(obj, type):
cls = obj
if not is_namedtuple_class(cls):
raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
else:
cls = type(obj)
if not is_namedtuple_class(cls):
raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
return cls._fields # type: ignore[attr-defined]
_T_co = TypeVar('_T_co', covariant=True)
class StructSequenceMeta(type):
"""The metaclass for PyStructSequence stub type."""
def __subclasscheck__(cls, subclass: type, /) -> bool:
"""Return whether the class is a PyStructSequence type.
>>> import time
>>> issubclass(time.struct_time, StructSequence)
True
>>> class MyTuple(tuple):
... n_fields = 2
... n_sequence_fields = 2
... n_unnamed_fields = 0
>>> issubclass(MyTuple, StructSequence)
False
"""
return is_structseq_class(subclass)
def __instancecheck__(cls, instance: Any, /) -> bool:
"""Return whether the object is a PyStructSequence instance.
>>> import sys
>>> isinstance(sys.float_info, StructSequence)
True
>>> isinstance((1, 2), StructSequence)
False
"""
return is_structseq_instance(instance)
# Reference: https://github.com/python/typeshed/blob/main/stdlib/_typeshed/__init__.pyi
# This is an internal CPython type that is like, but subtly different from a NamedTuple.
# `StructSequence` classes are unsubclassable, so are all decorated with `@final`.
# pylint: disable-next=invalid-name,missing-class-docstring
@final
class StructSequence(tuple[_T_co, ...], metaclass=StructSequenceMeta): # type: ignore[misc]
"""A generic type stub for CPython's ``PyStructSequence`` type."""
__slots__: ClassVar[tuple[()]] = ()
n_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_sequence_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
n_unnamed_fields: Final[ClassVar[int]] # type: ignore[misc] # pylint: disable=invalid-name
def __init_subclass__(cls, /) -> Never:
"""Prohibit subclassing."""
raise TypeError("type 'StructSequence' is not an acceptable base type")
# pylint: disable-next=unused-argument,redefined-builtin
def __new__(cls, /, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self:
"""Create a new :class:`StructSequence` instance."""
raise NotImplementedError
structseq: TypeAlias = StructSequence # noqa: PYI042
del StructSequenceMeta
@_override_with_(_C.is_structseq)
def is_structseq(obj: object | type, /) -> bool:
"""Return whether the object is an instance of PyStructSequence or a class of PyStructSequence."""
cls = obj if isinstance(obj, type) else type(obj)
return is_structseq_class(cls)
@_override_with_(_C.is_structseq_instance)
def is_structseq_instance(obj: object, /) -> bool:
"""Return whether the object is an instance of PyStructSequence."""
return is_structseq_class(type(obj))
# Set if the type allows subclassing (see CPython's Include/object.h)
Py_TPFLAGS_BASETYPE: int = _C.Py_TPFLAGS_BASETYPE # (1UL << 10)
@_override_with_(_C.is_structseq_class)
def is_structseq_class(cls: type, /) -> bool:
"""Return whether the class is a class of PyStructSequence."""
if (
isinstance(cls, type)
# Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)`
and cls.__bases__ == (tuple,)
# Check PyStructSequence members
and isinstance(getattr(cls, 'n_fields', None), int)
and isinstance(getattr(cls, 'n_sequence_fields', None), int)
and isinstance(getattr(cls, 'n_unnamed_fields', None), int)
):
# Check the type does not allow subclassing
if platform.python_implementation() == 'PyPy': # pragma: pypy cover
try:
types.new_class('subclass', bases=(cls,))
except (AssertionError, TypeError):
return True
return False
return not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # pragma: pypy no cover
return False
# pylint: disable-next=line-too-long
StructSequenceFieldType: type[types.MemberDescriptorType] = type(type(sys.version_info).major) # type: ignore[assignment]
@_override_with_(_C.structseq_fields)
def structseq_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]:
"""Return the field names of a PyStructSequence."""
if isinstance(obj, type):
cls = obj
if not is_structseq_class(cls):
raise TypeError(f'Expected a PyStructSequence type, got {cls!r}.')
else:
cls = type(obj)
if not is_structseq_class(cls):
raise TypeError(f'Expected an instance of PyStructSequence type, got {obj!r}.')
if platform.python_implementation() == 'PyPy': # pragma: pypy cover
indices_by_name = {
name: member.index # type: ignore[attr-defined]
for name, member in vars(cls).items()
if isinstance(member, StructSequenceFieldType)
}
fields = sorted(indices_by_name, key=indices_by_name.get) # type: ignore[arg-type]
else: # pragma: pypy no cover
fields = [
name
for name, member in vars(cls).items()
if isinstance(member, StructSequenceFieldType)
]
return tuple(fields[: cls.n_sequence_fields]) # type: ignore[attr-defined]
del _tp_cache